✨ feat: add Midjourney (#138)
* 🚧 stash * ✨ feat: add Midjourney * 📝 doc: update readme
This commit is contained in:
parent
87bfecf3e9
commit
c1fc32add7
10
README.md
10
README.md
@ -58,6 +58,16 @@ _本项目是基于[one-api](https://github.com/songquanpeng/one-api)二次开
|
|||||||
|
|
||||||
请查看[文档](https://github.com/MartialBE/one-api/wiki)
|
请查看[文档](https://github.com/MartialBE/one-api/wiki)
|
||||||
|
|
||||||
|
## 感谢
|
||||||
|
|
||||||
|
- 本程序使用了以下开源项目
|
||||||
|
- [one-api](https://github.com/songquanpeng/one-api)为本项目的基础
|
||||||
|
- [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template)为本项目的前端界面
|
||||||
|
- [minimal-ui-kit](https://github.com/minimal-ui-kit/material-kit-react),使用了其中的部分样式
|
||||||
|
- [new api](https://github.com/Calcium-Ion/new-api),Midjourney 模块的代码来源于此
|
||||||
|
|
||||||
|
感谢以上项目的作者和贡献者
|
||||||
|
|
||||||
## 其他
|
## 其他
|
||||||
|
|
||||||
<a href="https://next.ossinsight.io/widgets/official/analyze-repo-stars-history?repo_id=689214770" target="_blank" style="display: block" align="center">
|
<a href="https://next.ossinsight.io/widgets/official/analyze-repo-stars-history?repo_id=689214770" target="_blank" style="display: block" align="center">
|
||||||
|
@ -37,6 +37,9 @@ var WeChatAuthEnabled = false
|
|||||||
var TurnstileCheckEnabled = false
|
var TurnstileCheckEnabled = false
|
||||||
var RegisterEnabled = true
|
var RegisterEnabled = true
|
||||||
|
|
||||||
|
// mj
|
||||||
|
var MjNotifyEnabled = false
|
||||||
|
|
||||||
var EmailDomainRestrictionEnabled = false
|
var EmailDomainRestrictionEnabled = false
|
||||||
var EmailDomainWhitelist = []string{
|
var EmailDomainWhitelist = []string{
|
||||||
"gmail.com",
|
"gmail.com",
|
||||||
@ -161,6 +164,7 @@ const (
|
|||||||
ChannelTypeGroq = 31
|
ChannelTypeGroq = 31
|
||||||
ChannelTypeBedrock = 32
|
ChannelTypeBedrock = 32
|
||||||
ChannelTypeLingyi = 33
|
ChannelTypeLingyi = 33
|
||||||
|
ChannelTypeMidjourney = 34
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
@ -198,6 +202,7 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://api.groq.com/openai", //31
|
"https://api.groq.com/openai", //31
|
||||||
"", //32
|
"", //32
|
||||||
"https://api.lingyiwanwu.com", //33
|
"https://api.lingyiwanwu.com", //33
|
||||||
|
"", //34
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
32
common/go-channel.go
Normal file
32
common/go-channel.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime/debug"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SafeGoroutine(f func()) {
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
f()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func SafeSend(ch chan bool, value bool) (closed bool) {
|
||||||
|
defer func() {
|
||||||
|
// Recover from panic if one occured. A panic would mean the channel was closed.
|
||||||
|
if recover() != nil {
|
||||||
|
closed = true
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// This will panic if the channel is closed.
|
||||||
|
ch <- value
|
||||||
|
|
||||||
|
// If the code reaches here, then the channel was not closed.
|
||||||
|
return false
|
||||||
|
}
|
@ -106,7 +106,10 @@ func logHelper(ctx context.Context, level string, msg string) {
|
|||||||
if level == loggerINFO {
|
if level == loggerINFO {
|
||||||
writer = gin.DefaultWriter
|
writer = gin.DefaultWriter
|
||||||
}
|
}
|
||||||
id := ctx.Value(RequestIdKey)
|
id, ok := ctx.Value(RequestIdKey).(string)
|
||||||
|
if !ok {
|
||||||
|
id = "unknown"
|
||||||
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||||
logCount++ // we don't need accurate count, so no lock here
|
logCount++ // we don't need accurate count, so no lock here
|
||||||
|
@ -23,6 +23,7 @@ type HTTPRequester struct {
|
|||||||
CreateFormBuilder func(io.Writer) FormBuilder
|
CreateFormBuilder func(io.Writer) FormBuilder
|
||||||
ErrorHandler HttpErrorHandler
|
ErrorHandler HttpErrorHandler
|
||||||
proxyAddr string
|
proxyAddr string
|
||||||
|
Context context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHTTPRequester 创建一个新的 HTTPRequester 实例。
|
// NewHTTPRequester 创建一个新的 HTTPRequester 实例。
|
||||||
@ -37,6 +38,7 @@ func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequ
|
|||||||
},
|
},
|
||||||
ErrorHandler: errorHandler,
|
ErrorHandler: errorHandler,
|
||||||
proxyAddr: proxyAddr,
|
proxyAddr: proxyAddr,
|
||||||
|
Context: context.Background(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -47,18 +49,18 @@ type requestOptions struct {
|
|||||||
|
|
||||||
type requestOption func(*requestOptions)
|
type requestOption func(*requestOptions)
|
||||||
|
|
||||||
func (r *HTTPRequester) getContext() context.Context {
|
func (r *HTTPRequester) setProxy() context.Context {
|
||||||
if r.proxyAddr == "" {
|
if r.proxyAddr == "" {
|
||||||
return context.Background()
|
return r.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果是以 socks5:// 开头的地址,那么使用 socks5 代理
|
// 如果是以 socks5:// 开头的地址,那么使用 socks5 代理
|
||||||
if strings.HasPrefix(r.proxyAddr, "socks5://") {
|
if strings.HasPrefix(r.proxyAddr, "socks5://") {
|
||||||
return context.WithValue(context.Background(), ProxySock5AddrKey, r.proxyAddr)
|
return context.WithValue(r.Context, ProxySock5AddrKey, r.proxyAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 否则使用 http 代理
|
// 否则使用 http 代理
|
||||||
return context.WithValue(context.Background(), ProxyHTTPAddrKey, r.proxyAddr)
|
return context.WithValue(r.Context, ProxyHTTPAddrKey, r.proxyAddr)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,7 +73,7 @@ func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption)
|
|||||||
for _, setter := range setters {
|
for _, setter := range setters {
|
||||||
setter(args)
|
setter(args)
|
||||||
}
|
}
|
||||||
req, err := r.requestBuilder.Build(r.getContext(), method, url, args.body, args.header)
|
req, err := r.requestBuilder.Build(r.setProxy(), method, url, args.body, args.header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
285
controller/midjourney.go
Normal file
285
controller/midjourney.go
Normal file
@ -0,0 +1,285 @@
|
|||||||
|
// Author: Calcium-Ion
|
||||||
|
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||||
|
// Path: controller/midjourney.go
|
||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/common/requester"
|
||||||
|
"one-api/model"
|
||||||
|
provider "one-api/providers/midjourney"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
var activeMidjourneyTask = make(chan bool, 1)
|
||||||
|
|
||||||
|
func InitMidjourneyTask() {
|
||||||
|
common.SafeGoroutine(func() {
|
||||||
|
midjourneyTask()
|
||||||
|
})
|
||||||
|
|
||||||
|
ActivateUpdateMidjourneyTaskBulk()
|
||||||
|
}
|
||||||
|
|
||||||
|
func midjourneyTask() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-activeMidjourneyTask:
|
||||||
|
UpdateMidjourneyTaskBulk()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ActivateUpdateMidjourneyTaskBulk() {
|
||||||
|
if len(activeMidjourneyTask) == 0 {
|
||||||
|
activeMidjourneyTask <- true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateMidjourneyTaskBulk() {
|
||||||
|
ctx := context.WithValue(context.Background(), common.RequestIdKey, "MidjourneyTask")
|
||||||
|
for {
|
||||||
|
common.LogInfo(ctx, "running")
|
||||||
|
|
||||||
|
tasks := model.GetAllUnFinishTasks()
|
||||||
|
|
||||||
|
// 如果没有未完成的任务,则等待
|
||||||
|
if len(tasks) == 0 {
|
||||||
|
for len(activeMidjourneyTask) > 0 {
|
||||||
|
<-activeMidjourneyTask
|
||||||
|
}
|
||||||
|
common.LogInfo(ctx, "no tasks, waiting...")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||||
|
taskChannelM := make(map[int][]string)
|
||||||
|
taskM := make(map[string]*model.Midjourney)
|
||||||
|
nullTaskIds := make([]int, 0)
|
||||||
|
for _, task := range tasks {
|
||||||
|
if task.MjId == "" {
|
||||||
|
// 统计失败的未完成任务
|
||||||
|
nullTaskIds = append(nullTaskIds, task.Id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
taskM[task.MjId] = task
|
||||||
|
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
|
||||||
|
}
|
||||||
|
if len(nullTaskIds) > 0 {
|
||||||
|
err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
|
||||||
|
"status": "FAILURE",
|
||||||
|
"progress": "100%",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
|
||||||
|
} else {
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(taskChannelM) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for channelId, taskIds := range taskChannelM {
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||||
|
if len(taskIds) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
midjourneyChannel := model.ChannelGroup.GetChannel(channelId)
|
||||||
|
if midjourneyChannel == nil {
|
||||||
|
err := model.MjBulkUpdate(taskIds, map[string]any{
|
||||||
|
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||||
|
"status": "FAILURE",
|
||||||
|
"progress": "100%",
|
||||||
|
})
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"ids": taskIds,
|
||||||
|
})
|
||||||
|
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 设置超时时间
|
||||||
|
timeout := time.Second * 5
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
// 使用带有超时的 context 创建新的请求
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||||
|
resp, err := requester.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var responseItems []provider.MidjourneyDto
|
||||||
|
err = json.Unmarshal(responseBody, &responseItems)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
req.Body.Close()
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
for _, responseItem := range responseItems {
|
||||||
|
task := taskM[responseItem.MjId]
|
||||||
|
|
||||||
|
useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
|
||||||
|
// 如果时间超过一小时,且进度不是100%,则认为任务失败
|
||||||
|
if useTime > 3600000 && task.Progress != "100%" {
|
||||||
|
responseItem.FailReason = "上游任务超时(超过1小时)"
|
||||||
|
responseItem.Status = "FAILURE"
|
||||||
|
}
|
||||||
|
if !checkMjTaskNeedUpdate(task, responseItem) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
task.Code = 1
|
||||||
|
task.Progress = responseItem.Progress
|
||||||
|
task.PromptEn = responseItem.PromptEn
|
||||||
|
task.State = responseItem.State
|
||||||
|
task.SubmitTime = responseItem.SubmitTime
|
||||||
|
task.StartTime = responseItem.StartTime
|
||||||
|
task.FinishTime = responseItem.FinishTime
|
||||||
|
task.ImageUrl = responseItem.ImageUrl
|
||||||
|
task.Status = responseItem.Status
|
||||||
|
task.FailReason = responseItem.FailReason
|
||||||
|
if responseItem.Properties != nil {
|
||||||
|
propertiesStr, _ := json.Marshal(responseItem.Properties)
|
||||||
|
task.Properties = string(propertiesStr)
|
||||||
|
}
|
||||||
|
if responseItem.Buttons != nil {
|
||||||
|
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
||||||
|
task.Buttons = string(buttonStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
||||||
|
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||||
|
task.Progress = "100%"
|
||||||
|
err = model.CacheUpdateUserQuota(task.UserId)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||||
|
} else {
|
||||||
|
quota := task.Quota
|
||||||
|
if quota != 0 {
|
||||||
|
err = model.IncreaseUserQuota(task.UserId, quota)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||||
|
}
|
||||||
|
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
||||||
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = task.Update()
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(time.Duration(15) * time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask provider.MidjourneyDto) bool {
|
||||||
|
if oldTask.Code != 1 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.Progress != newTask.Progress {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.PromptEn != newTask.PromptEn {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.State != newTask.State {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.SubmitTime != newTask.SubmitTime {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.StartTime != newTask.StartTime {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.FinishTime != newTask.FinishTime {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.ImageUrl != newTask.ImageUrl {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.Status != newTask.Status {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.FailReason != newTask.FailReason {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.FinishTime != newTask.FinishTime {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.Progress != "100%" && newTask.FailReason != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAllMidjourney(c *gin.Context) {
|
||||||
|
var params model.TaskQueryParams
|
||||||
|
if err := c.ShouldBindQuery(¶ms); err != nil {
|
||||||
|
common.APIRespondWithError(c, http.StatusOK, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
midjourneys, err := model.GetAllTasks(¶ms)
|
||||||
|
if err != nil {
|
||||||
|
common.APIRespondWithError(c, http.StatusOK, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": midjourneys,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUserMidjourney(c *gin.Context) {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
var params model.TaskQueryParams
|
||||||
|
if err := c.ShouldBindQuery(¶ms); err != nil {
|
||||||
|
common.APIRespondWithError(c, http.StatusOK, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
midjourneys, err := model.GetAllUserTask(userId, ¶ms)
|
||||||
|
if err != nil {
|
||||||
|
common.APIRespondWithError(c, http.StatusOK, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": midjourneys,
|
||||||
|
})
|
||||||
|
}
|
@ -40,6 +40,7 @@ func GetStatus(c *gin.Context) {
|
|||||||
"quota_per_unit": common.QuotaPerUnit,
|
"quota_per_unit": common.QuotaPerUnit,
|
||||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||||
"telegram_bot": telegram_bot,
|
"telegram_bot": telegram_bot,
|
||||||
|
"mj_notify_enabled": common.MjNotifyEnabled,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
2
main.go
2
main.go
@ -45,6 +45,8 @@ func main() {
|
|||||||
// Initialize Telegram bot
|
// Initialize Telegram bot
|
||||||
telegram.InitTelegramBot()
|
telegram.InitTelegramBot()
|
||||||
|
|
||||||
|
controller.InitMidjourneyTask()
|
||||||
|
|
||||||
initHttpServer()
|
initHttpServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,43 +83,54 @@ func RootAuth() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TokenAuth() func(c *gin.Context) {
|
func tokenAuth(c *gin.Context, key string) {
|
||||||
return func(c *gin.Context) {
|
key = strings.TrimPrefix(key, "Bearer ")
|
||||||
key := c.Request.Header.Get("Authorization")
|
key = strings.TrimPrefix(key, "sk-")
|
||||||
key = strings.TrimPrefix(key, "Bearer ")
|
parts := strings.Split(key, "-")
|
||||||
key = strings.TrimPrefix(key, "sk-")
|
key = parts[0]
|
||||||
parts := strings.Split(key, "-")
|
token, err := model.ValidateUserToken(key)
|
||||||
key = parts[0]
|
if err != nil {
|
||||||
token, err := model.ValidateUserToken(key)
|
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
||||||
if err != nil {
|
return
|
||||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
}
|
||||||
return
|
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||||
}
|
if err != nil {
|
||||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||||
if err != nil {
|
return
|
||||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
}
|
||||||
return
|
if !userEnabled {
|
||||||
}
|
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||||
if !userEnabled {
|
return
|
||||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
}
|
||||||
return
|
c.Set("id", token.UserId)
|
||||||
}
|
c.Set("token_id", token.Id)
|
||||||
c.Set("id", token.UserId)
|
c.Set("token_name", token.Name)
|
||||||
c.Set("token_id", token.Id)
|
if len(parts) > 1 {
|
||||||
c.Set("token_name", token.Name)
|
if model.IsAdmin(token.UserId) {
|
||||||
if len(parts) > 1 {
|
channelId := common.String2Int(parts[1])
|
||||||
if model.IsAdmin(token.UserId) {
|
if channelId == 0 {
|
||||||
channelId := common.String2Int(parts[1])
|
abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id")
|
||||||
if channelId == 0 {
|
|
||||||
abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Set("specific_channel_id", channelId)
|
|
||||||
} else {
|
|
||||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Set("specific_channel_id", channelId)
|
||||||
|
} else {
|
||||||
|
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
c.Next()
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenaiAuth() func(c *gin.Context) {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
key := c.Request.Header.Get("Authorization")
|
||||||
|
tokenAuth(c, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MjAuth() func(c *gin.Context) {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
key := c.Request.Header.Get("mj-api-secret")
|
||||||
|
tokenAuth(c, key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -114,6 +114,17 @@ func (cc *ChannelsChooser) GetGroupModels(group string) ([]string, error) {
|
|||||||
return models, nil
|
return models, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cc *ChannelsChooser) GetChannel(channelId int) *Channel {
|
||||||
|
cc.RLock()
|
||||||
|
defer cc.RUnlock()
|
||||||
|
|
||||||
|
if choice, ok := cc.Channels[channelId]; ok {
|
||||||
|
return choice.Channel
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var ChannelGroup = ChannelsChooser{}
|
var ChannelGroup = ChannelsChooser{}
|
||||||
|
|
||||||
func (cc *ChannelsChooser) Load() {
|
func (cc *ChannelsChooser) Load() {
|
||||||
|
@ -139,6 +139,10 @@ func InitDB() (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
err = db.AutoMigrate(&Midjourney{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
common.SysLog("database migrated")
|
common.SysLog("database migrated")
|
||||||
err = createRootAccountIfNeed()
|
err = createRootAccountIfNeed()
|
||||||
return err
|
return err
|
||||||
|
182
model/midjourney.go
Normal file
182
model/midjourney.go
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
// Copyright (c) 2024 Calcium-Ion
|
||||||
|
//
|
||||||
|
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
// of this software and associated documentation files (the "Software"), to deal
|
||||||
|
// in the Software without restriction, including without limitation the rights
|
||||||
|
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
// copies of the Software, and to permit persons to whom the Software is
|
||||||
|
// furnished to do so, subject to the following conditions:
|
||||||
|
//
|
||||||
|
// The above copyright notice and this permission notice shall be included in all
|
||||||
|
// copies or substantial portions of the Software.
|
||||||
|
//
|
||||||
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
// SOFTWARE.
|
||||||
|
//
|
||||||
|
// Author: Calcium-Ion
|
||||||
|
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||||
|
|
||||||
|
package model
|
||||||
|
|
||||||
|
type Midjourney struct {
|
||||||
|
Id int `json:"id"`
|
||||||
|
Code int `json:"code"`
|
||||||
|
UserId int `json:"user_id" gorm:"index"`
|
||||||
|
Action string `json:"action" gorm:"type:varchar(40);index"`
|
||||||
|
MjId string `json:"mj_id" gorm:"index"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
PromptEn string `json:"prompt_en"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
State string `json:"state"`
|
||||||
|
SubmitTime int64 `json:"submit_time" gorm:"index"`
|
||||||
|
StartTime int64 `json:"start_time" gorm:"index"`
|
||||||
|
FinishTime int64 `json:"finish_time" gorm:"index"`
|
||||||
|
ImageUrl string `json:"image_url"`
|
||||||
|
Status string `json:"status" gorm:"type:varchar(20);index"`
|
||||||
|
Progress string `json:"progress" gorm:"type:varchar(30);index"`
|
||||||
|
FailReason string `json:"fail_reason"`
|
||||||
|
ChannelId int `json:"channel_id"`
|
||||||
|
Quota int `json:"quota"`
|
||||||
|
Buttons string `json:"buttons"`
|
||||||
|
Properties string `json:"properties"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
|
||||||
|
type TaskQueryParams struct {
|
||||||
|
ChannelID int `form:"channel_id"`
|
||||||
|
MjID string `form:"mj_id"`
|
||||||
|
StartTimestamp int `form:"start_timestamp"`
|
||||||
|
EndTimestamp int `form:"end_timestamp"`
|
||||||
|
PaginationParams
|
||||||
|
}
|
||||||
|
|
||||||
|
var allowedMidjourneyOrderFields = map[string]bool{
|
||||||
|
"id": true,
|
||||||
|
"user_id": true,
|
||||||
|
"code": true,
|
||||||
|
"action": true,
|
||||||
|
"mj_id": true,
|
||||||
|
"submit_time": true,
|
||||||
|
"start_time": true,
|
||||||
|
"finish_time": true,
|
||||||
|
"status": true,
|
||||||
|
"channel_id": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAllUserTask(userId int, params *TaskQueryParams) (*DataResult[Midjourney], error) {
|
||||||
|
var tasks []*Midjourney
|
||||||
|
|
||||||
|
// 初始化查询构建器
|
||||||
|
query := DB.Where("user_id = ?", userId)
|
||||||
|
|
||||||
|
if params.MjID != "" {
|
||||||
|
query = query.Where("mj_id = ?", params.MjID)
|
||||||
|
}
|
||||||
|
if params.StartTimestamp != 0 {
|
||||||
|
// 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
|
||||||
|
query = query.Where("submit_time >= ?", params.StartTimestamp)
|
||||||
|
}
|
||||||
|
if params.EndTimestamp != 0 {
|
||||||
|
query = query.Where("submit_time <= ?", params.EndTimestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
return PaginateAndOrder(query, ¶ms.PaginationParams, &tasks, allowedMidjourneyOrderFields)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAllTasks(params *TaskQueryParams) (*DataResult[Midjourney], error) {
|
||||||
|
var tasks []*Midjourney
|
||||||
|
|
||||||
|
// 初始化查询构建器
|
||||||
|
query := DB
|
||||||
|
|
||||||
|
// 添加过滤条件
|
||||||
|
if params.ChannelID != 0 {
|
||||||
|
query = query.Where("channel_id = ?", params.ChannelID)
|
||||||
|
}
|
||||||
|
if params.MjID != "" {
|
||||||
|
query = query.Where("mj_id = ?", params.MjID)
|
||||||
|
}
|
||||||
|
if params.StartTimestamp != 0 {
|
||||||
|
query = query.Where("submit_time >= ?", params.StartTimestamp)
|
||||||
|
}
|
||||||
|
if params.EndTimestamp != 0 {
|
||||||
|
query = query.Where("submit_time <= ?", params.EndTimestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
return PaginateAndOrder(query, ¶ms.PaginationParams, &tasks, allowedMidjourneyOrderFields)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAllUnFinishTasks() []*Midjourney {
|
||||||
|
var tasks []*Midjourney
|
||||||
|
// get all tasks progress is not 100%
|
||||||
|
err := DB.Where("progress != ?", "100%").Find(&tasks).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return tasks
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetByOnlyMJId(mjId string) *Midjourney {
|
||||||
|
var mj *Midjourney
|
||||||
|
err := DB.Where("mj_id = ?", mjId).First(&mj).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return mj
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetByMJId(userId int, mjId string) *Midjourney {
|
||||||
|
var mj *Midjourney
|
||||||
|
err := DB.Where("user_id = ? and mj_id = ?", userId, mjId).First(&mj).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return mj
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetByMJIds(userId int, mjIds []string) []*Midjourney {
|
||||||
|
var mj []*Midjourney
|
||||||
|
err := DB.Where("user_id = ? and mj_id in (?)", userId, mjIds).Find(&mj).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return mj
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMjByuId(id int) *Midjourney {
|
||||||
|
var mj *Midjourney
|
||||||
|
err := DB.Where("id = ?", id).First(&mj).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return mj
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateProgress(id int, progress string) error {
|
||||||
|
return DB.Model(&Midjourney{}).Where("id = ?", id).Update("progress", progress).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (midjourney *Midjourney) Insert() error {
|
||||||
|
return DB.Create(midjourney).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (midjourney *Midjourney) Update() error {
|
||||||
|
return DB.Save(midjourney).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func MjBulkUpdate(mjIds []string, params map[string]any) error {
|
||||||
|
return DB.Model(&Midjourney{}).
|
||||||
|
Where("mj_id in (?)", mjIds).
|
||||||
|
Updates(params).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
|
||||||
|
return DB.Model(&Midjourney{}).
|
||||||
|
Where("id in (?)", taskIDs).
|
||||||
|
Updates(params).Error
|
||||||
|
}
|
@ -74,6 +74,8 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
||||||
common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds)
|
common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds)
|
||||||
|
|
||||||
|
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled)
|
||||||
|
|
||||||
common.OptionMapRWMutex.Unlock()
|
common.OptionMapRWMutex.Unlock()
|
||||||
loadOptionsFromDatabase()
|
loadOptionsFromDatabase()
|
||||||
}
|
}
|
||||||
@ -138,6 +140,7 @@ var optionBoolMap = map[string]*bool{
|
|||||||
"LogConsumeEnabled": &common.LogConsumeEnabled,
|
"LogConsumeEnabled": &common.LogConsumeEnabled,
|
||||||
"DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled,
|
"DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled,
|
||||||
"DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled,
|
"DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled,
|
||||||
|
"MjNotifyEnabled": &common.MjNotifyEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
var optionStringMap = map[string]*string{
|
var optionStringMap = map[string]*string{
|
||||||
|
@ -301,5 +301,33 @@ func GetDefaultPrice() []*Price {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var DefaultMJPrice = map[string]float64{
|
||||||
|
"mj_imagine": 50,
|
||||||
|
"mj_variation": 50,
|
||||||
|
"mj_reroll": 50,
|
||||||
|
"mj_blend": 50,
|
||||||
|
"mj_modal": 50,
|
||||||
|
"mj_zoom": 50,
|
||||||
|
"mj_shorten": 50,
|
||||||
|
"mj_high_variation": 50,
|
||||||
|
"mj_low_variation": 50,
|
||||||
|
"mj_pan": 50,
|
||||||
|
"mj_inpaint": 0,
|
||||||
|
"mj_custom_zoom": 0,
|
||||||
|
"mj_describe": 25,
|
||||||
|
"mj_upscale": 25,
|
||||||
|
"swap_face": 25,
|
||||||
|
}
|
||||||
|
|
||||||
|
for model, mjPrice := range DefaultMJPrice {
|
||||||
|
prices = append(prices, &Price{
|
||||||
|
Model: model,
|
||||||
|
Type: TimesPriceType,
|
||||||
|
ChannelType: common.ChannelTypeMidjourney,
|
||||||
|
Input: mjPrice,
|
||||||
|
Output: mjPrice,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return prices
|
return prices
|
||||||
}
|
}
|
||||||
|
121
providers/midjourney/base.go
Normal file
121
providers/midjourney/base.go
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
package midjourney
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/common/requester"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/providers/base"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 定义供应商工厂
|
||||||
|
type MidjourneyProviderFactory struct{}
|
||||||
|
|
||||||
|
// 创建 MidjourneyProvider
|
||||||
|
func (f MidjourneyProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||||
|
return &MidjourneyProvider{
|
||||||
|
BaseProvider: base.BaseProvider{
|
||||||
|
Config: getConfig(),
|
||||||
|
Channel: channel,
|
||||||
|
Requester: requester.NewHTTPRequester(*channel.Proxy, nil),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getConfig() base.ProviderConfig {
|
||||||
|
return base.ProviderConfig{
|
||||||
|
BaseURL: "",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type MidjourneyProvider struct {
|
||||||
|
base.BaseProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MidjourneyProvider) Send(timeout int, requestURL string) (*MidjourneyResponseWithStatusCode, []byte, error) {
|
||||||
|
var nullBytes []byte
|
||||||
|
var mapResult map[string]interface{}
|
||||||
|
if p.Context.Request.Method != "GET" {
|
||||||
|
err := json.NewDecoder(p.Context.Request.Body).Decode(&mapResult)
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||||
|
}
|
||||||
|
delete(mapResult, "accountFilter")
|
||||||
|
if !common.MjNotifyEnabled {
|
||||||
|
delete(mapResult, "notifyHook")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody, err := json.Marshal(mapResult)
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fullRequestURL := p.GetFullRequestURL(requestURL, "")
|
||||||
|
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
p.Requester.Context, cancel = context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := p.Requester.NewRequest(p.Context.Request.Method, fullRequestURL, p.Requester.WithBody(bytes.NewBuffer(reqBody)), p.Requester.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, errWith := p.Requester.SendRequestRaw(req)
|
||||||
|
if errWith != nil {
|
||||||
|
common.SysError("do request failed: " + errWith.Error())
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||||
|
}
|
||||||
|
statusCode := resp.StatusCode
|
||||||
|
err = req.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
||||||
|
}
|
||||||
|
err = p.Context.Request.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
||||||
|
}
|
||||||
|
var midjResponse MidjourneyResponse
|
||||||
|
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
|
||||||
|
}
|
||||||
|
respStr := string(responseBody)
|
||||||
|
log.Printf("responseBody: %s", respStr)
|
||||||
|
if respStr == "" {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
|
||||||
|
} else {
|
||||||
|
err = json.Unmarshal(responseBody, &midjResponse)
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MidjourneyResponseWithStatusCode{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Response: midjResponse,
|
||||||
|
}, responseBody, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MidjourneyProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
|
headers = make(map[string]string)
|
||||||
|
headers["mj-api-secret"] = p.Channel.Key
|
||||||
|
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||||
|
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
69
providers/midjourney/constant.go
Normal file
69
providers/midjourney/constant.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
// Author: Calcium-Ion
|
||||||
|
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||||
|
// Path: relay/constant/relay_mode.go
|
||||||
|
package midjourney
|
||||||
|
|
||||||
|
const (
|
||||||
|
RelayModeUnknown = iota
|
||||||
|
RelayModeMidjourneyImagine
|
||||||
|
RelayModeMidjourneyDescribe
|
||||||
|
RelayModeMidjourneyBlend
|
||||||
|
RelayModeMidjourneyChange
|
||||||
|
RelayModeMidjourneySimpleChange
|
||||||
|
RelayModeMidjourneyNotify
|
||||||
|
RelayModeMidjourneyTaskFetch
|
||||||
|
RelayModeMidjourneyTaskImageSeed
|
||||||
|
RelayModeMidjourneyTaskFetchByCondition
|
||||||
|
RelayModeAudioSpeech
|
||||||
|
RelayModeAudioTranscription
|
||||||
|
RelayModeAudioTranslation
|
||||||
|
RelayModeMidjourneyAction
|
||||||
|
RelayModeMidjourneyModal
|
||||||
|
RelayModeMidjourneyShorten
|
||||||
|
RelayModeMidjourneySwapFace
|
||||||
|
)
|
||||||
|
|
||||||
|
// Author: Calcium-Ion
|
||||||
|
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||||
|
// Path: constant/midjourney.go
|
||||||
|
|
||||||
|
const (
|
||||||
|
MjErrorUnknown = 5
|
||||||
|
MjRequestError = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MjActionImagine = "IMAGINE"
|
||||||
|
MjActionDescribe = "DESCRIBE"
|
||||||
|
MjActionBlend = "BLEND"
|
||||||
|
MjActionUpscale = "UPSCALE"
|
||||||
|
MjActionVariation = "VARIATION"
|
||||||
|
MjActionReRoll = "REROLL"
|
||||||
|
MjActionInPaint = "INPAINT"
|
||||||
|
MjActionModal = "MODAL"
|
||||||
|
MjActionZoom = "ZOOM"
|
||||||
|
MjActionCustomZoom = "CUSTOM_ZOOM"
|
||||||
|
MjActionShorten = "SHORTEN"
|
||||||
|
MjActionHighVariation = "HIGH_VARIATION"
|
||||||
|
MjActionLowVariation = "LOW_VARIATION"
|
||||||
|
MjActionPan = "PAN"
|
||||||
|
MjActionSwapFace = "SWAP_FACE"
|
||||||
|
)
|
||||||
|
|
||||||
|
var MidjourneyModel2Action = map[string]string{
|
||||||
|
"mj_imagine": MjActionImagine,
|
||||||
|
"mj_describe": MjActionDescribe,
|
||||||
|
"mj_blend": MjActionBlend,
|
||||||
|
"mj_upscale": MjActionUpscale,
|
||||||
|
"mj_variation": MjActionVariation,
|
||||||
|
"mj_reroll": MjActionReRoll,
|
||||||
|
"mj_modal": MjActionModal,
|
||||||
|
"mj_inpaint": MjActionInPaint,
|
||||||
|
"mj_zoom": MjActionZoom,
|
||||||
|
"mj_custom_zoom": MjActionCustomZoom,
|
||||||
|
"mj_shorten": MjActionShorten,
|
||||||
|
"mj_high_variation": MjActionHighVariation,
|
||||||
|
"mj_low_variation": MjActionLowVariation,
|
||||||
|
"mj_pan": MjActionPan,
|
||||||
|
"swap_face": MjActionSwapFace,
|
||||||
|
}
|
18
providers/midjourney/error.go
Normal file
18
providers/midjourney/error.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
// Author: Calcium-Ion
|
||||||
|
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||||
|
// Path: service/error.go
|
||||||
|
package midjourney
|
||||||
|
|
||||||
|
func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *MidjourneyResponseWithStatusCode {
|
||||||
|
return &MidjourneyResponseWithStatusCode{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Response: *MidjourneyErrorWrapper(code, desc),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MidjourneyErrorWrapper(code int, desc string) *MidjourneyResponse {
|
||||||
|
return &MidjourneyResponse{
|
||||||
|
Code: code,
|
||||||
|
Description: desc,
|
||||||
|
}
|
||||||
|
}
|
92
providers/midjourney/type.go
Normal file
92
providers/midjourney/type.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
// Author: Calcium-Ion
|
||||||
|
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||||
|
// Path: dto/midjourney.go
|
||||||
|
package midjourney
|
||||||
|
|
||||||
|
type SwapFaceRequest struct {
|
||||||
|
SourceBase64 string `json:"sourceBase64"`
|
||||||
|
TargetBase64 string `json:"targetBase64"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MidjourneyRequest struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
CustomId string `json:"customId"`
|
||||||
|
BotType string `json:"botType"`
|
||||||
|
NotifyHook string `json:"notifyHook"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
State string `json:"state"`
|
||||||
|
TaskId string `json:"taskId"`
|
||||||
|
Base64Array []string `json:"base64Array"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
MaskBase64 string `json:"maskBase64"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MidjourneyResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Properties interface{} `json:"properties"`
|
||||||
|
Result string `json:"result"`
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MidjourneyResponseWithStatusCode struct {
|
||||||
|
StatusCode int `json:"statusCode"`
|
||||||
|
Response MidjourneyResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
type MidjourneyDto struct {
|
||||||
|
MjId string `json:"id"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
CustomId string `json:"customId"`
|
||||||
|
BotType string `json:"botType"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
PromptEn string `json:"promptEn"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
State string `json:"state"`
|
||||||
|
SubmitTime int64 `json:"submitTime"`
|
||||||
|
StartTime int64 `json:"startTime"`
|
||||||
|
FinishTime int64 `json:"finishTime"`
|
||||||
|
ImageUrl string `json:"imageUrl"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Progress string `json:"progress"`
|
||||||
|
FailReason string `json:"failReason"`
|
||||||
|
Buttons any `json:"buttons"`
|
||||||
|
MaskBase64 string `json:"maskBase64"`
|
||||||
|
Properties *Properties `json:"properties"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MidjourneyStatus struct {
|
||||||
|
Status int `json:"status"`
|
||||||
|
}
|
||||||
|
type MidjourneyWithoutStatus struct {
|
||||||
|
Id int `json:"id"`
|
||||||
|
Code int `json:"code"`
|
||||||
|
UserId int `json:"user_id" gorm:"index"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
MjId string `json:"mj_id" gorm:"index"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
PromptEn string `json:"prompt_en"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
State string `json:"state"`
|
||||||
|
SubmitTime int64 `json:"submit_time"`
|
||||||
|
StartTime int64 `json:"start_time"`
|
||||||
|
FinishTime int64 `json:"finish_time"`
|
||||||
|
ImageUrl string `json:"image_url"`
|
||||||
|
Progress string `json:"progress"`
|
||||||
|
FailReason string `json:"fail_reason"`
|
||||||
|
ChannelId int `json:"channel_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ActionButton struct {
|
||||||
|
CustomId any `json:"customId"`
|
||||||
|
Emoji any `json:"emoji"`
|
||||||
|
Label any `json:"label"`
|
||||||
|
Type any `json:"type"`
|
||||||
|
Style any `json:"style"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Properties struct {
|
||||||
|
FinalPrompt string `json:"finalPrompt"`
|
||||||
|
FinalZhPrompt string `json:"finalZhPrompt"`
|
||||||
|
}
|
@ -14,6 +14,7 @@ import (
|
|||||||
"one-api/providers/deepseek"
|
"one-api/providers/deepseek"
|
||||||
"one-api/providers/gemini"
|
"one-api/providers/gemini"
|
||||||
"one-api/providers/groq"
|
"one-api/providers/groq"
|
||||||
|
"one-api/providers/midjourney"
|
||||||
"one-api/providers/minimax"
|
"one-api/providers/minimax"
|
||||||
"one-api/providers/mistral"
|
"one-api/providers/mistral"
|
||||||
"one-api/providers/openai"
|
"one-api/providers/openai"
|
||||||
@ -52,6 +53,7 @@ func init() {
|
|||||||
providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{}
|
providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{}
|
||||||
providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{}
|
providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{}
|
||||||
providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
|
providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
|
||||||
|
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ type RelayBaseInterface interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *relayBase) setProvider(modelName string) error {
|
func (r *relayBase) setProvider(modelName string) error {
|
||||||
provider, modelName, fail := getProvider(r.c, modelName)
|
provider, modelName, fail := GetProvider(r.c, modelName)
|
||||||
if fail != nil {
|
if fail != nil {
|
||||||
return fail
|
return fail
|
||||||
}
|
}
|
||||||
|
@ -45,7 +45,7 @@ func Path2Relay(c *gin.Context, path string) RelayBaseInterface {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
|
func GetProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
|
||||||
channel, fail := fetchChannel(c, modeName)
|
channel, fail := fetchChannel(c, modeName)
|
||||||
if fail != nil {
|
if fail != nil {
|
||||||
return
|
return
|
||||||
|
19
relay/midjourney/LICENSE
Normal file
19
relay/midjourney/LICENSE
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
Copyright (c) 2024 Calcium-Ion
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
578
relay/midjourney/relay-mj.go
Normal file
578
relay/midjourney/relay-mj.go
Normal file
@ -0,0 +1,578 @@
|
|||||||
|
// Author: Calcium-Ion
|
||||||
|
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||||
|
// Path: relay/relay-mj.go
|
||||||
|
package midjourney
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/controller"
|
||||||
|
"one-api/model"
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
|
provider "one-api/providers/midjourney"
|
||||||
|
"one-api/relay"
|
||||||
|
"one-api/relay/util"
|
||||||
|
"one-api/types"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RelayMidjourneyImage(c *gin.Context) {
|
||||||
|
taskId := c.Param("id")
|
||||||
|
midjourneyTask := model.GetByOnlyMJId(taskId)
|
||||||
|
if midjourneyTask == nil {
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "midjourney_task_not_found",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := http.Get(midjourneyTask.ImageUrl)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
|
"error": "http_get_image_failed",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
responseBody, _ := io.ReadAll(resp.Body)
|
||||||
|
c.JSON(resp.StatusCode, gin.H{
|
||||||
|
"error": string(responseBody),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 从Content-Type头获取MIME类型
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
if contentType == "" {
|
||||||
|
// 如果无法确定内容类型,则默认为jpeg
|
||||||
|
contentType = "image/jpeg"
|
||||||
|
}
|
||||||
|
// 设置响应的内容类型
|
||||||
|
c.Writer.Header().Set("Content-Type", contentType)
|
||||||
|
// 将图片流式传输到响应体
|
||||||
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to stream image:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RelayMidjourneyNotify(c *gin.Context) *provider.MidjourneyResponse {
|
||||||
|
var midjRequest provider.MidjourneyDto
|
||||||
|
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||||
|
if err != nil {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: "bind_request_body_failed",
|
||||||
|
Properties: nil,
|
||||||
|
Result: "",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
|
||||||
|
if midjourneyTask == nil {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: "midjourney_task_not_found",
|
||||||
|
Properties: nil,
|
||||||
|
Result: "",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
midjourneyTask.Progress = midjRequest.Progress
|
||||||
|
midjourneyTask.PromptEn = midjRequest.PromptEn
|
||||||
|
midjourneyTask.State = midjRequest.State
|
||||||
|
midjourneyTask.SubmitTime = midjRequest.SubmitTime
|
||||||
|
midjourneyTask.StartTime = midjRequest.StartTime
|
||||||
|
midjourneyTask.FinishTime = midjRequest.FinishTime
|
||||||
|
midjourneyTask.ImageUrl = midjRequest.ImageUrl
|
||||||
|
midjourneyTask.Status = midjRequest.Status
|
||||||
|
midjourneyTask.FailReason = midjRequest.FailReason
|
||||||
|
err = midjourneyTask.Update()
|
||||||
|
if err != nil {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: "update_midjourney_task_failed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func coverMidjourneyTaskDto(originTask *model.Midjourney) (midjourneyTask provider.MidjourneyDto) {
|
||||||
|
midjourneyTask.MjId = originTask.MjId
|
||||||
|
midjourneyTask.Progress = originTask.Progress
|
||||||
|
midjourneyTask.PromptEn = originTask.PromptEn
|
||||||
|
midjourneyTask.State = originTask.State
|
||||||
|
midjourneyTask.SubmitTime = originTask.SubmitTime
|
||||||
|
midjourneyTask.StartTime = originTask.StartTime
|
||||||
|
midjourneyTask.FinishTime = originTask.FinishTime
|
||||||
|
midjourneyTask.ImageUrl = ""
|
||||||
|
if originTask.ImageUrl != "" {
|
||||||
|
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
|
||||||
|
if originTask.Status != "SUCCESS" {
|
||||||
|
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
midjourneyTask.Status = originTask.Status
|
||||||
|
midjourneyTask.FailReason = originTask.FailReason
|
||||||
|
midjourneyTask.Action = originTask.Action
|
||||||
|
midjourneyTask.Description = originTask.Description
|
||||||
|
midjourneyTask.Prompt = originTask.Prompt
|
||||||
|
if originTask.Buttons != "" {
|
||||||
|
var buttons []provider.ActionButton
|
||||||
|
err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
|
||||||
|
if err == nil {
|
||||||
|
midjourneyTask.Buttons = buttons
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if originTask.Properties != "" {
|
||||||
|
var properties provider.Properties
|
||||||
|
err := json.Unmarshal([]byte(originTask.Properties), &properties)
|
||||||
|
if err == nil {
|
||||||
|
midjourneyTask.Properties = &properties
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse {
|
||||||
|
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneySwapFace, 0, nil)
|
||||||
|
if errWithMJ != nil {
|
||||||
|
return errWithMJ
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := time.Now().UnixNano() / int64(time.Millisecond)
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
var swapFaceRequest provider.SwapFaceRequest
|
||||||
|
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
|
||||||
|
if err != nil {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
|
||||||
|
}
|
||||||
|
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||||
|
}
|
||||||
|
|
||||||
|
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
|
||||||
|
if errWithOA != nil {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: errWithOA.Message,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
requestURL := getMjRequestPath(c.Request.URL.String())
|
||||||
|
mjResp, _, err := mjProvider.Send(60, requestURL)
|
||||||
|
if err != nil {
|
||||||
|
quotaInstance.Undo(c)
|
||||||
|
return &mjResp.Response
|
||||||
|
}
|
||||||
|
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||||
|
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1000, TotalTokens: 1000})
|
||||||
|
} else {
|
||||||
|
quotaInstance.Undo(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
quota := int(quotaInstance.GetInputRatio() * 1000)
|
||||||
|
|
||||||
|
midjResponse := &mjResp.Response
|
||||||
|
midjourneyTask := &model.Midjourney{
|
||||||
|
UserId: userId,
|
||||||
|
Code: midjResponse.Code,
|
||||||
|
Action: provider.MjActionSwapFace,
|
||||||
|
MjId: midjResponse.Result,
|
||||||
|
Prompt: "InsightFace",
|
||||||
|
PromptEn: "",
|
||||||
|
Description: midjResponse.Description,
|
||||||
|
State: "",
|
||||||
|
SubmitTime: startTime,
|
||||||
|
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||||||
|
FinishTime: 0,
|
||||||
|
ImageUrl: "",
|
||||||
|
Status: "",
|
||||||
|
Progress: "0%",
|
||||||
|
FailReason: "",
|
||||||
|
ChannelId: c.GetInt("channel_id"),
|
||||||
|
Quota: quota,
|
||||||
|
}
|
||||||
|
err = midjourneyTask.Insert()
|
||||||
|
if err != nil {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "insert_midjourney_task_failed")
|
||||||
|
}
|
||||||
|
// 开始激活任务
|
||||||
|
controller.ActivateUpdateMidjourneyTaskBulk()
|
||||||
|
|
||||||
|
c.Writer.WriteHeader(mjResp.StatusCode)
|
||||||
|
respBody, err := json.Marshal(midjResponse)
|
||||||
|
if err != nil {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "unmarshal_response_body_failed")
|
||||||
|
}
|
||||||
|
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||||
|
if err != nil {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "copy_response_body_failed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RelayMidjourneyTaskImageSeed(c *gin.Context) *provider.MidjourneyResponse {
|
||||||
|
taskId := c.Param("id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
originTask := model.GetByMJId(userId, taskId)
|
||||||
|
if originTask == nil {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_no_found")
|
||||||
|
}
|
||||||
|
|
||||||
|
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneyTaskImageSeed, originTask.ChannelId, nil)
|
||||||
|
if errWithMJ != nil {
|
||||||
|
return errWithMJ
|
||||||
|
}
|
||||||
|
|
||||||
|
requestURL := getMjRequestPath(c.Request.URL.String())
|
||||||
|
midjResponseWithStatus, _, err := mjProvider.Send(30, requestURL)
|
||||||
|
if err != nil {
|
||||||
|
return &midjResponseWithStatus.Response
|
||||||
|
}
|
||||||
|
midjResponse := &midjResponseWithStatus.Response
|
||||||
|
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
||||||
|
respBody, err := json.Marshal(midjResponse)
|
||||||
|
if err != nil {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "unmarshal_response_body_failed")
|
||||||
|
}
|
||||||
|
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||||
|
if err != nil {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "copy_response_body_failed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RelayMidjourneyTask(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
var err error
|
||||||
|
var respBody []byte
|
||||||
|
switch relayMode {
|
||||||
|
case provider.RelayModeMidjourneyTaskFetch:
|
||||||
|
taskId := c.Param("id")
|
||||||
|
originTask := model.GetByMJId(userId, taskId)
|
||||||
|
if originTask == nil {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: "task_no_found",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
midjourneyTask := coverMidjourneyTaskDto(originTask)
|
||||||
|
respBody, err = json.Marshal(midjourneyTask)
|
||||||
|
if err != nil {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: "unmarshal_response_body_failed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case provider.RelayModeMidjourneyTaskFetchByCondition:
|
||||||
|
var condition = struct {
|
||||||
|
IDs []string `json:"ids"`
|
||||||
|
}{}
|
||||||
|
err = c.BindJSON(&condition)
|
||||||
|
if err != nil {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: "do_request_failed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var tasks []provider.MidjourneyDto
|
||||||
|
if len(condition.IDs) != 0 {
|
||||||
|
originTasks := model.GetByMJIds(userId, condition.IDs)
|
||||||
|
for _, originTask := range originTasks {
|
||||||
|
midjourneyTask := coverMidjourneyTaskDto(originTask)
|
||||||
|
tasks = append(tasks, midjourneyTask)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tasks == nil {
|
||||||
|
tasks = make([]provider.MidjourneyDto, 0)
|
||||||
|
}
|
||||||
|
respBody, err = json.Marshal(tasks)
|
||||||
|
if err != nil {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: "unmarshal_response_body_failed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||||
|
if err != nil {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: "copy_response_body_failed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
|
||||||
|
channelId := 0
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
consumeQuota := true
|
||||||
|
var midjRequest provider.MidjourneyRequest
|
||||||
|
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||||
|
if err != nil {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if relayMode == provider.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
||||||
|
mjErr := CoverPlusActionToNormalAction(&midjRequest)
|
||||||
|
if mjErr != nil {
|
||||||
|
return mjErr
|
||||||
|
}
|
||||||
|
relayMode = provider.RelayModeMidjourneyChange
|
||||||
|
}
|
||||||
|
|
||||||
|
if relayMode == provider.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
||||||
|
if midjRequest.Prompt == "" {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "prompt_is_required")
|
||||||
|
}
|
||||||
|
midjRequest.Action = provider.MjActionImagine
|
||||||
|
} else if relayMode == provider.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
||||||
|
midjRequest.Action = provider.MjActionDescribe
|
||||||
|
} else if relayMode == provider.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
|
||||||
|
midjRequest.Action = provider.MjActionShorten
|
||||||
|
} else if relayMode == provider.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||||
|
midjRequest.Action = provider.MjActionBlend
|
||||||
|
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
||||||
|
mjId := ""
|
||||||
|
if relayMode == provider.RelayModeMidjourneyChange {
|
||||||
|
if midjRequest.TaskId == "" {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_id_is_required")
|
||||||
|
} else if midjRequest.Action == "" {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "action_is_required")
|
||||||
|
} else if midjRequest.Index == 0 {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "index_is_required")
|
||||||
|
}
|
||||||
|
//action = midjRequest.Action
|
||||||
|
mjId = midjRequest.TaskId
|
||||||
|
} else if relayMode == provider.RelayModeMidjourneySimpleChange {
|
||||||
|
if midjRequest.Content == "" {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "content_is_required")
|
||||||
|
}
|
||||||
|
params := ConvertSimpleChangeParams(midjRequest.Content)
|
||||||
|
if params == nil {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "content_parse_failed")
|
||||||
|
}
|
||||||
|
mjId = params.TaskId
|
||||||
|
midjRequest.Action = params.Action
|
||||||
|
} else if relayMode == provider.RelayModeMidjourneyModal {
|
||||||
|
//if midjRequest.MaskBase64 == "" {
|
||||||
|
// return provider.MidjourneyErrorWrapper(provider.MjRequestError, "mask_base64_is_required")
|
||||||
|
//}
|
||||||
|
mjId = midjRequest.TaskId
|
||||||
|
midjRequest.Action = provider.MjActionModal
|
||||||
|
}
|
||||||
|
|
||||||
|
originTask := model.GetByMJId(userId, mjId)
|
||||||
|
if originTask == nil {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_not_found")
|
||||||
|
} else if originTask.Status != "SUCCESS" && relayMode != provider.RelayModeMidjourneyModal {
|
||||||
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_status_not_success")
|
||||||
|
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||||
|
channelId = originTask.ChannelId
|
||||||
|
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %d", originTask.ChannelId)
|
||||||
|
}
|
||||||
|
midjRequest.Prompt = originTask.Prompt
|
||||||
|
|
||||||
|
//if channelType == common.ChannelTypeMidjourneyPlus {
|
||||||
|
// // plus
|
||||||
|
//} else {
|
||||||
|
// // 普通版渠道
|
||||||
|
//
|
||||||
|
//}
|
||||||
|
}
|
||||||
|
|
||||||
|
if midjRequest.Action == provider.MjActionInPaint || midjRequest.Action == provider.MjActionCustomZoom {
|
||||||
|
consumeQuota = false
|
||||||
|
}
|
||||||
|
|
||||||
|
mjProvider, errWithMJ := getMJProvider(c, relayMode, channelId, &midjRequest)
|
||||||
|
if errWithMJ != nil {
|
||||||
|
return errWithMJ
|
||||||
|
}
|
||||||
|
|
||||||
|
//baseURL := common.ChannelBaseURLs[channelType]
|
||||||
|
requestURL := getMjRequestPath(c.Request.URL.String())
|
||||||
|
|
||||||
|
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
||||||
|
|
||||||
|
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
|
||||||
|
if errWithOA != nil {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: errWithOA.Message,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
midjResponseWithStatus, responseBody, err := mjProvider.Send(60, requestURL)
|
||||||
|
if err != nil {
|
||||||
|
quotaInstance.Undo(c)
|
||||||
|
return &midjResponseWithStatus.Response
|
||||||
|
}
|
||||||
|
|
||||||
|
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||||
|
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1, TotalTokens: 1})
|
||||||
|
} else {
|
||||||
|
quotaInstance.Undo(c)
|
||||||
|
}
|
||||||
|
quota := int(quotaInstance.GetInputRatio() * 1000)
|
||||||
|
|
||||||
|
midjResponse := &midjResponseWithStatus.Response
|
||||||
|
|
||||||
|
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
|
||||||
|
//1-提交成功
|
||||||
|
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
|
||||||
|
// 22-排队中 {"code":22,"description":"排队中,前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}}
|
||||||
|
// 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}}
|
||||||
|
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
|
||||||
|
// other: 提交错误,description为错误描述
|
||||||
|
midjourneyTask := &model.Midjourney{
|
||||||
|
UserId: userId,
|
||||||
|
Code: midjResponse.Code,
|
||||||
|
Action: midjRequest.Action,
|
||||||
|
MjId: midjResponse.Result,
|
||||||
|
Prompt: midjRequest.Prompt,
|
||||||
|
PromptEn: "",
|
||||||
|
Description: midjResponse.Description,
|
||||||
|
State: "",
|
||||||
|
SubmitTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||||||
|
StartTime: 0,
|
||||||
|
FinishTime: 0,
|
||||||
|
ImageUrl: "",
|
||||||
|
Status: "",
|
||||||
|
Progress: "0%",
|
||||||
|
FailReason: "",
|
||||||
|
ChannelId: c.GetInt("channel_id"),
|
||||||
|
Quota: quota,
|
||||||
|
}
|
||||||
|
|
||||||
|
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
|
||||||
|
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
|
||||||
|
midjourneyTask.FailReason = midjResponse.Description
|
||||||
|
consumeQuota = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了)
|
||||||
|
// 将 properties 转换为一个 map
|
||||||
|
properties, ok := midjResponse.Properties.(map[string]interface{})
|
||||||
|
if ok {
|
||||||
|
imageUrl, ok1 := properties["imageUrl"].(string)
|
||||||
|
status, ok2 := properties["status"].(string)
|
||||||
|
if ok1 && ok2 {
|
||||||
|
midjourneyTask.ImageUrl = imageUrl
|
||||||
|
midjourneyTask.Status = status
|
||||||
|
if status == "SUCCESS" {
|
||||||
|
midjourneyTask.Progress = "100%"
|
||||||
|
midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond)
|
||||||
|
midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond)
|
||||||
|
midjResponse.Code = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//修改返回值
|
||||||
|
if midjRequest.Action != provider.MjActionInPaint && midjRequest.Action != provider.MjActionCustomZoom {
|
||||||
|
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
|
||||||
|
responseBody = []byte(newBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = midjourneyTask.Insert()
|
||||||
|
if err != nil {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: "insert_midjourney_task_failed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 开始激活任务
|
||||||
|
controller.ActivateUpdateMidjourneyTaskBulk()
|
||||||
|
|
||||||
|
if midjResponse.Code == 22 { //22-排队中,说明任务已存在
|
||||||
|
//修改返回值
|
||||||
|
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
|
||||||
|
responseBody = []byte(newBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
|
||||||
|
//for k, v := range resp.Header {
|
||||||
|
// c.Writer.Header().Set(k, v[0])
|
||||||
|
//}
|
||||||
|
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
||||||
|
|
||||||
|
_, err = io.Copy(c.Writer, bodyReader)
|
||||||
|
if err != nil {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: "copy_response_body_failed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = bodyReader.Close()
|
||||||
|
if err != nil {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: "close_response_body_failed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMjRequestPath(path string) string {
|
||||||
|
requestURL := path
|
||||||
|
if strings.Contains(requestURL, "/mj-") {
|
||||||
|
urls := strings.Split(requestURL, "/mj/")
|
||||||
|
if len(urls) < 2 {
|
||||||
|
return requestURL
|
||||||
|
}
|
||||||
|
requestURL = "/mj/" + urls[1]
|
||||||
|
}
|
||||||
|
return requestURL
|
||||||
|
}
|
||||||
|
|
||||||
|
func getQuota(c *gin.Context, modelName string) (*util.Quota, *types.OpenAIErrorWithStatusCode) {
|
||||||
|
// modelName = CoverActionToModelName(modelName)
|
||||||
|
|
||||||
|
return util.NewQuota(c, modelName, 1000)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMJProvider(c *gin.Context, relayMode, channel_id int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
|
||||||
|
var baseProvider providersBase.ProviderInterface
|
||||||
|
modelName := ""
|
||||||
|
if channel_id > 0 {
|
||||||
|
c.Set("specific_channel_id", channel_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if request != nil {
|
||||||
|
midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request)
|
||||||
|
if mjErr != nil {
|
||||||
|
return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description)
|
||||||
|
}
|
||||||
|
if midjourneyModel == "" {
|
||||||
|
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型")
|
||||||
|
}
|
||||||
|
|
||||||
|
modelName = midjourneyModel
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
baseProvider, _, err = relay.GetProvider(c, modelName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无法获取provider:"+err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
mjProvider, ok := baseProvider.(*provider.MidjourneyProvider)
|
||||||
|
if !ok {
|
||||||
|
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法获取midjourney provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
return mjProvider, nil
|
||||||
|
}
|
95
relay/midjourney/relay.go
Normal file
95
relay/midjourney/relay.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
// Author: Calcium-Ion
|
||||||
|
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||||
|
// Path: controller/relay.go
|
||||||
|
package midjourney
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
provider "one-api/providers/midjourney"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RelayMidjourney(c *gin.Context) {
|
||||||
|
relayMode := Path2RelayModeMidjourney(c.Request.URL.Path)
|
||||||
|
var err *provider.MidjourneyResponse
|
||||||
|
switch relayMode {
|
||||||
|
case provider.RelayModeMidjourneyNotify:
|
||||||
|
err = RelayMidjourneyNotify(c)
|
||||||
|
case provider.RelayModeMidjourneyTaskFetch, provider.RelayModeMidjourneyTaskFetchByCondition:
|
||||||
|
err = RelayMidjourneyTask(c, relayMode)
|
||||||
|
case provider.RelayModeMidjourneyTaskImageSeed:
|
||||||
|
err = RelayMidjourneyTaskImageSeed(c)
|
||||||
|
case provider.RelayModeMidjourneySwapFace:
|
||||||
|
err = RelaySwapFace(c)
|
||||||
|
default:
|
||||||
|
err = RelayMidjourneySubmit(c, relayMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
statusCode := http.StatusBadRequest
|
||||||
|
if err.Code == 30 {
|
||||||
|
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||||
|
statusCode = http.StatusTooManyRequests
|
||||||
|
}
|
||||||
|
|
||||||
|
typeMsg := "upstream_error"
|
||||||
|
if err.Type != "" {
|
||||||
|
typeMsg = err.Type
|
||||||
|
}
|
||||||
|
c.JSON(statusCode, gin.H{
|
||||||
|
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
|
||||||
|
"type": typeMsg,
|
||||||
|
"code": err.Code,
|
||||||
|
})
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MidjourneyErrorFromInternal(code int, description string) *provider.MidjourneyResponse {
|
||||||
|
return &provider.MidjourneyResponse{
|
||||||
|
Code: code,
|
||||||
|
Description: description,
|
||||||
|
Type: "internal_error",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Path2RelayModeMidjourney(path string) int {
|
||||||
|
relayMode := provider.RelayModeUnknown
|
||||||
|
if strings.HasSuffix(path, "/mj/submit/action") {
|
||||||
|
// midjourney plus
|
||||||
|
relayMode = provider.RelayModeMidjourneyAction
|
||||||
|
} else if strings.HasSuffix(path, "/mj/submit/modal") {
|
||||||
|
// midjourney plus
|
||||||
|
relayMode = provider.RelayModeMidjourneyModal
|
||||||
|
} else if strings.HasSuffix(path, "/mj/submit/shorten") {
|
||||||
|
// midjourney plus
|
||||||
|
relayMode = provider.RelayModeMidjourneyShorten
|
||||||
|
} else if strings.HasSuffix(path, "/mj/insight-face/swap") {
|
||||||
|
// midjourney plus
|
||||||
|
relayMode = provider.RelayModeMidjourneySwapFace
|
||||||
|
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
|
||||||
|
relayMode = provider.RelayModeMidjourneyImagine
|
||||||
|
} else if strings.HasSuffix(path, "/mj/submit/blend") {
|
||||||
|
relayMode = provider.RelayModeMidjourneyBlend
|
||||||
|
} else if strings.HasSuffix(path, "/mj/submit/describe") {
|
||||||
|
relayMode = provider.RelayModeMidjourneyDescribe
|
||||||
|
} else if strings.HasSuffix(path, "/mj/notify") {
|
||||||
|
relayMode = provider.RelayModeMidjourneyNotify
|
||||||
|
} else if strings.HasSuffix(path, "/mj/submit/change") {
|
||||||
|
relayMode = provider.RelayModeMidjourneyChange
|
||||||
|
} else if strings.HasSuffix(path, "/mj/submit/simple-change") {
|
||||||
|
relayMode = provider.RelayModeMidjourneyChange
|
||||||
|
} else if strings.HasSuffix(path, "/fetch") {
|
||||||
|
relayMode = provider.RelayModeMidjourneyTaskFetch
|
||||||
|
} else if strings.HasSuffix(path, "/image-seed") {
|
||||||
|
relayMode = provider.RelayModeMidjourneyTaskImageSeed
|
||||||
|
} else if strings.HasSuffix(path, "/list-by-condition") {
|
||||||
|
relayMode = provider.RelayModeMidjourneyTaskFetchByCondition
|
||||||
|
}
|
||||||
|
return relayMode
|
||||||
|
}
|
148
relay/midjourney/service.go
Normal file
148
relay/midjourney/service.go
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
// Author: Calcium-Ion
|
||||||
|
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||||
|
// Path: service/midjourney.go
|
||||||
|
package midjourney
|
||||||
|
|
||||||
|
import (
|
||||||
|
mjProvider "one-api/providers/midjourney"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CoverActionToModelName(mjAction string) string {
|
||||||
|
modelName := "mj_" + strings.ToLower(mjAction)
|
||||||
|
if mjAction == mjProvider.MjActionSwapFace {
|
||||||
|
modelName = "swap_face"
|
||||||
|
}
|
||||||
|
return modelName
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMjRequestModel(relayMode int, midjRequest *mjProvider.MidjourneyRequest) (string, *mjProvider.MidjourneyResponse, bool) {
|
||||||
|
action := ""
|
||||||
|
if relayMode == mjProvider.RelayModeMidjourneyAction {
|
||||||
|
// plus request
|
||||||
|
err := CoverPlusActionToNormalAction(midjRequest)
|
||||||
|
if err != nil {
|
||||||
|
return "", err, false
|
||||||
|
}
|
||||||
|
action = midjRequest.Action
|
||||||
|
} else {
|
||||||
|
switch relayMode {
|
||||||
|
case mjProvider.RelayModeMidjourneyImagine:
|
||||||
|
action = mjProvider.MjActionImagine
|
||||||
|
case mjProvider.RelayModeMidjourneyDescribe:
|
||||||
|
action = mjProvider.MjActionDescribe
|
||||||
|
case mjProvider.RelayModeMidjourneyBlend:
|
||||||
|
action = mjProvider.MjActionBlend
|
||||||
|
case mjProvider.RelayModeMidjourneyShorten:
|
||||||
|
action = mjProvider.MjActionShorten
|
||||||
|
case mjProvider.RelayModeMidjourneyChange:
|
||||||
|
action = midjRequest.Action
|
||||||
|
case mjProvider.RelayModeMidjourneyModal:
|
||||||
|
action = mjProvider.MjActionModal
|
||||||
|
case mjProvider.RelayModeMidjourneySwapFace:
|
||||||
|
action = mjProvider.MjActionSwapFace
|
||||||
|
case mjProvider.RelayModeMidjourneySimpleChange:
|
||||||
|
params := ConvertSimpleChangeParams(midjRequest.Content)
|
||||||
|
if params == nil {
|
||||||
|
return "", mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "invalid_request"), false
|
||||||
|
}
|
||||||
|
action = params.Action
|
||||||
|
case mjProvider.RelayModeMidjourneyTaskFetch, mjProvider.RelayModeMidjourneyTaskFetchByCondition, mjProvider.RelayModeMidjourneyNotify:
|
||||||
|
return "", nil, true
|
||||||
|
default:
|
||||||
|
return "", mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_relay_action"), false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
modelName := CoverActionToModelName(action)
|
||||||
|
return modelName, nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func CoverPlusActionToNormalAction(midjRequest *mjProvider.MidjourneyRequest) *mjProvider.MidjourneyResponse {
|
||||||
|
// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
|
||||||
|
customId := midjRequest.CustomId
|
||||||
|
if customId == "" {
|
||||||
|
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "custom_id_is_required")
|
||||||
|
}
|
||||||
|
splits := strings.Split(customId, "::")
|
||||||
|
var action string
|
||||||
|
if splits[1] == "JOB" {
|
||||||
|
action = splits[2]
|
||||||
|
} else {
|
||||||
|
action = splits[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if action == "" {
|
||||||
|
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_action")
|
||||||
|
}
|
||||||
|
if strings.Contains(action, "upsample") {
|
||||||
|
index, err := strconv.Atoi(splits[3])
|
||||||
|
if err != nil {
|
||||||
|
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "index_parse_failed")
|
||||||
|
}
|
||||||
|
midjRequest.Index = index
|
||||||
|
midjRequest.Action = mjProvider.MjActionUpscale
|
||||||
|
} else if strings.Contains(action, "variation") {
|
||||||
|
midjRequest.Index = 1
|
||||||
|
if action == "variation" {
|
||||||
|
index, err := strconv.Atoi(splits[3])
|
||||||
|
if err != nil {
|
||||||
|
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "index_parse_failed")
|
||||||
|
}
|
||||||
|
midjRequest.Index = index
|
||||||
|
midjRequest.Action = mjProvider.MjActionVariation
|
||||||
|
} else if action == "low_variation" {
|
||||||
|
midjRequest.Action = mjProvider.MjActionLowVariation
|
||||||
|
} else if action == "high_variation" {
|
||||||
|
midjRequest.Action = mjProvider.MjActionHighVariation
|
||||||
|
}
|
||||||
|
} else if strings.Contains(action, "pan") {
|
||||||
|
midjRequest.Action = mjProvider.MjActionPan
|
||||||
|
midjRequest.Index = 1
|
||||||
|
} else if strings.Contains(action, "reroll") {
|
||||||
|
midjRequest.Action = mjProvider.MjActionReRoll
|
||||||
|
midjRequest.Index = 1
|
||||||
|
} else if action == "Outpaint" {
|
||||||
|
midjRequest.Action = mjProvider.MjActionZoom
|
||||||
|
midjRequest.Index = 1
|
||||||
|
} else if action == "CustomZoom" {
|
||||||
|
midjRequest.Action = mjProvider.MjActionCustomZoom
|
||||||
|
midjRequest.Index = 1
|
||||||
|
} else if action == "Inpaint" {
|
||||||
|
midjRequest.Action = mjProvider.MjActionInPaint
|
||||||
|
midjRequest.Index = 1
|
||||||
|
} else {
|
||||||
|
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_action:"+customId)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertSimpleChangeParams(content string) *mjProvider.MidjourneyRequest {
|
||||||
|
split := strings.Split(content, " ")
|
||||||
|
if len(split) != 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
action := strings.ToLower(split[1])
|
||||||
|
changeParams := &mjProvider.MidjourneyRequest{}
|
||||||
|
changeParams.TaskId = split[0]
|
||||||
|
|
||||||
|
if action[0] == 'u' {
|
||||||
|
changeParams.Action = "UPSCALE"
|
||||||
|
} else if action[0] == 'v' {
|
||||||
|
changeParams.Action = "VARIATION"
|
||||||
|
} else if action == "r" {
|
||||||
|
changeParams.Action = "REROLL"
|
||||||
|
return changeParams
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
index, err := strconv.Atoi(action[1:2])
|
||||||
|
if err != nil || index < 1 || index > 4 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
changeParams.Index = index
|
||||||
|
return changeParams
|
||||||
|
}
|
@ -170,3 +170,7 @@ func (q *Quota) Consume(c *gin.Context, usage *types.Usage) {
|
|||||||
}
|
}
|
||||||
}(c.Request.Context())
|
}(c.Request.Context())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *Quota) GetInputRatio() float64 {
|
||||||
|
return q.inputRatio
|
||||||
|
}
|
||||||
|
@ -7,22 +7,23 @@ var ModelOwnedBy map[int]string
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
ModelOwnedBy = map[int]string{
|
ModelOwnedBy = map[int]string{
|
||||||
common.ChannelTypeOpenAI: "OpenAI",
|
common.ChannelTypeOpenAI: "OpenAI",
|
||||||
common.ChannelTypeAnthropic: "Anthropic",
|
common.ChannelTypeAnthropic: "Anthropic",
|
||||||
common.ChannelTypeBaidu: "Baidu",
|
common.ChannelTypeBaidu: "Baidu",
|
||||||
common.ChannelTypePaLM: "Google PaLM",
|
common.ChannelTypePaLM: "Google PaLM",
|
||||||
common.ChannelTypeGemini: "Google Gemini",
|
common.ChannelTypeGemini: "Google Gemini",
|
||||||
common.ChannelTypeZhipu: "Zhipu",
|
common.ChannelTypeZhipu: "Zhipu",
|
||||||
common.ChannelTypeAli: "Ali",
|
common.ChannelTypeAli: "Ali",
|
||||||
common.ChannelTypeXunfei: "Xunfei",
|
common.ChannelTypeXunfei: "Xunfei",
|
||||||
common.ChannelType360: "360",
|
common.ChannelType360: "360",
|
||||||
common.ChannelTypeTencent: "Tencent",
|
common.ChannelTypeTencent: "Tencent",
|
||||||
common.ChannelTypeBaichuan: "Baichuan",
|
common.ChannelTypeBaichuan: "Baichuan",
|
||||||
common.ChannelTypeMiniMax: "MiniMax",
|
common.ChannelTypeMiniMax: "MiniMax",
|
||||||
common.ChannelTypeDeepseek: "Deepseek",
|
common.ChannelTypeDeepseek: "Deepseek",
|
||||||
common.ChannelTypeMoonshot: "Moonshot",
|
common.ChannelTypeMoonshot: "Moonshot",
|
||||||
common.ChannelTypeMistral: "Mistral",
|
common.ChannelTypeMistral: "Mistral",
|
||||||
common.ChannelTypeGroq: "Groq",
|
common.ChannelTypeGroq: "Groq",
|
||||||
common.ChannelTypeLingyi: "Lingyiwanwu",
|
common.ChannelTypeLingyi: "Lingyiwanwu",
|
||||||
|
common.ChannelTypeMidjourney: "Midjourney",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -145,6 +145,9 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mjRoute := apiRouter.Group("/mj")
|
||||||
|
mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney)
|
||||||
|
mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-contrib/gzip"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/gzip"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetDashboardRouter(router *gin.Engine) {
|
func SetDashboardRouter(router *gin.Engine) {
|
||||||
@ -12,7 +13,7 @@ func SetDashboardRouter(router *gin.Engine) {
|
|||||||
apiRouter := router.Group("/")
|
apiRouter := router.Group("/")
|
||||||
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
|
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||||
apiRouter.Use(middleware.GlobalAPIRateLimit())
|
apiRouter.Use(middleware.GlobalAPIRateLimit())
|
||||||
apiRouter.Use(middleware.TokenAuth())
|
apiRouter.Use(middleware.OpenaiAuth())
|
||||||
{
|
{
|
||||||
apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription)
|
apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription)
|
||||||
apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription)
|
apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription)
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
|
"one-api/relay/midjourney"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@ -11,14 +12,19 @@ import (
|
|||||||
func SetRelayRouter(router *gin.Engine) {
|
func SetRelayRouter(router *gin.Engine) {
|
||||||
router.Use(middleware.CORS())
|
router.Use(middleware.CORS())
|
||||||
// https://platform.openai.com/docs/api-reference/introduction
|
// https://platform.openai.com/docs/api-reference/introduction
|
||||||
|
setOpenAIRouter(router)
|
||||||
|
setMJRouter(router)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setOpenAIRouter(router *gin.Engine) {
|
||||||
modelsRouter := router.Group("/v1/models")
|
modelsRouter := router.Group("/v1/models")
|
||||||
modelsRouter.Use(middleware.TokenAuth(), middleware.Distribute())
|
modelsRouter.Use(middleware.OpenaiAuth(), middleware.Distribute())
|
||||||
{
|
{
|
||||||
modelsRouter.GET("", relay.ListModels)
|
modelsRouter.GET("", relay.ListModels)
|
||||||
modelsRouter.GET("/:model", relay.RetrieveModel)
|
modelsRouter.GET("/:model", relay.RetrieveModel)
|
||||||
}
|
}
|
||||||
relayV1Router := router.Group("/v1")
|
relayV1Router := router.Group("/v1")
|
||||||
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
|
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.OpenaiAuth(), middleware.Distribute())
|
||||||
{
|
{
|
||||||
relayV1Router.POST("/completions", relay.Relay)
|
relayV1Router.POST("/completions", relay.Relay)
|
||||||
relayV1Router.POST("/chat/completions", relay.Relay)
|
relayV1Router.POST("/chat/completions", relay.Relay)
|
||||||
@ -71,3 +77,34 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented)
|
relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setMJRouter(router *gin.Engine) {
|
||||||
|
relayMjRouter := router.Group("/mj")
|
||||||
|
registerMjRouterGroup(relayMjRouter)
|
||||||
|
|
||||||
|
relayMjModeRouter := router.Group("/:mode/mj")
|
||||||
|
registerMjRouterGroup(relayMjModeRouter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Author: Calcium-Ion
|
||||||
|
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||||
|
// Path: router/relay-router.go
|
||||||
|
func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
|
||||||
|
relayMjRouter.GET("/image/:id", midjourney.RelayMidjourneyImage)
|
||||||
|
relayMjRouter.Use(middleware.MjAuth(), middleware.Distribute())
|
||||||
|
{
|
||||||
|
relayMjRouter.POST("/submit/action", midjourney.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/submit/shorten", midjourney.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/submit/modal", midjourney.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/submit/imagine", midjourney.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/submit/change", midjourney.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/submit/simple-change", midjourney.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/submit/describe", midjourney.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/submit/blend", midjourney.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/notify", midjourney.RelayMidjourney)
|
||||||
|
relayMjRouter.GET("/task/:id/fetch", midjourney.RelayMidjourney)
|
||||||
|
relayMjRouter.GET("/task/:id/image-seed", midjourney.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/task/list-by-condition", midjourney.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/insight-face/swap", midjourney.RelayMidjourney)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
使用了以下开源项目作为我们项目的一部分:
|
使用了以下开源项目作为我们项目的一部分:
|
||||||
|
|
||||||
- [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template)
|
- [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template)
|
||||||
- [minimal-ui-kit](minimal-ui-kit)
|
- [minimal-ui-kit](https://github.com/minimal-ui-kit/material-kit-react)
|
||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
|
|
||||||
|
@ -132,6 +132,13 @@ export const CHANNEL_OPTIONS = {
|
|||||||
color: 'primary',
|
color: 'primary',
|
||||||
url: 'https://platform.lingyiwanwu.com/details'
|
url: 'https://platform.lingyiwanwu.com/details'
|
||||||
},
|
},
|
||||||
|
34: {
|
||||||
|
key: 34,
|
||||||
|
text: 'Midjourney',
|
||||||
|
value: 34,
|
||||||
|
color: 'orange',
|
||||||
|
url: ''
|
||||||
|
},
|
||||||
24: {
|
24: {
|
||||||
key: 24,
|
key: 24,
|
||||||
text: 'Azure Speech',
|
text: 'Azure Speech',
|
||||||
|
@ -11,7 +11,8 @@ import {
|
|||||||
IconUserScan,
|
IconUserScan,
|
||||||
IconActivity,
|
IconActivity,
|
||||||
IconBrandTelegram,
|
IconBrandTelegram,
|
||||||
IconReceipt2
|
IconReceipt2,
|
||||||
|
IconBrush
|
||||||
} from '@tabler/icons-react';
|
} from '@tabler/icons-react';
|
||||||
|
|
||||||
// constant
|
// constant
|
||||||
@ -27,7 +28,8 @@ const icons = {
|
|||||||
IconUserScan,
|
IconUserScan,
|
||||||
IconActivity,
|
IconActivity,
|
||||||
IconBrandTelegram,
|
IconBrandTelegram,
|
||||||
IconReceipt2
|
IconReceipt2,
|
||||||
|
IconBrush
|
||||||
};
|
};
|
||||||
|
|
||||||
// ==============================|| DASHBOARD MENU ITEMS ||============================== //
|
// ==============================|| DASHBOARD MENU ITEMS ||============================== //
|
||||||
@ -96,6 +98,14 @@ const panel = {
|
|||||||
icon: icons.IconGardenCart,
|
icon: icons.IconGardenCart,
|
||||||
breadcrumbs: false
|
breadcrumbs: false
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
id: 'midjourney',
|
||||||
|
title: 'Midjourney',
|
||||||
|
type: 'item',
|
||||||
|
url: '/panel/midjourney',
|
||||||
|
icon: icons.IconBrush,
|
||||||
|
breadcrumbs: false
|
||||||
|
},
|
||||||
{
|
{
|
||||||
id: 'user',
|
id: 'user',
|
||||||
title: '用户',
|
title: '用户',
|
||||||
|
@ -16,6 +16,7 @@ const NotFoundView = Loadable(lazy(() => import('views/Error')));
|
|||||||
const Analytics = Loadable(lazy(() => import('views/Analytics')));
|
const Analytics = Loadable(lazy(() => import('views/Analytics')));
|
||||||
const Telegram = Loadable(lazy(() => import('views/Telegram')));
|
const Telegram = Loadable(lazy(() => import('views/Telegram')));
|
||||||
const Pricing = Loadable(lazy(() => import('views/Pricing')));
|
const Pricing = Loadable(lazy(() => import('views/Pricing')));
|
||||||
|
const Midjourney = Loadable(lazy(() => import('views/Midjourney')));
|
||||||
|
|
||||||
// dashboard routing
|
// dashboard routing
|
||||||
const Dashboard = Loadable(lazy(() => import('views/Dashboard')));
|
const Dashboard = Loadable(lazy(() => import('views/Dashboard')));
|
||||||
@ -81,6 +82,10 @@ const MainRoutes = {
|
|||||||
{
|
{
|
||||||
path: 'pricing',
|
path: 'pricing',
|
||||||
element: <Pricing />
|
element: <Pricing />
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: 'midjourney',
|
||||||
|
element: <Midjourney />
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
};
|
};
|
||||||
|
@ -12,15 +12,7 @@ export default function componentStyleOverrides(theme) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
MuiMenuItem: {
|
//MuiAutocomplete-popper MuiPopover-root
|
||||||
styleOverrides: {
|
|
||||||
root: {
|
|
||||||
'&:hover': {
|
|
||||||
backgroundColor: theme.colors?.grey100
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, //MuiAutocomplete-popper MuiPopover-root
|
|
||||||
MuiAutocomplete: {
|
MuiAutocomplete: {
|
||||||
styleOverrides: {
|
styleOverrides: {
|
||||||
popper: {
|
popper: {
|
||||||
@ -247,7 +239,7 @@ export default function componentStyleOverrides(theme) {
|
|||||||
MuiTooltip: {
|
MuiTooltip: {
|
||||||
styleOverrides: {
|
styleOverrides: {
|
||||||
tooltip: {
|
tooltip: {
|
||||||
color: theme.paper,
|
color: theme.colors.paper,
|
||||||
background: theme.colors?.grey700
|
background: theme.colors?.grey700
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -266,6 +258,9 @@ export default function componentStyleOverrides(theme) {
|
|||||||
.apexcharts-menu {
|
.apexcharts-menu {
|
||||||
background: ${theme.backgroundDefault} !important
|
background: ${theme.backgroundDefault} !important
|
||||||
}
|
}
|
||||||
|
.apexcharts-gridline, .apexcharts-xaxistooltip-background, .apexcharts-yaxistooltip-background {
|
||||||
|
stroke: ${theme.divider} !important;
|
||||||
|
}
|
||||||
`
|
`
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -19,14 +19,14 @@ const Footer = () => {
|
|||||||
{siteInfo.system_name} {process.env.REACT_APP_VERSION}{' '}
|
{siteInfo.system_name} {process.env.REACT_APP_VERSION}{' '}
|
||||||
</Link>
|
</Link>
|
||||||
由{' '}
|
由{' '}
|
||||||
<Link href="https://github.com/songquanpeng" target="_blank">
|
|
||||||
JustSong
|
|
||||||
</Link>{' '}
|
|
||||||
构建,
|
|
||||||
<Link href="https://github.com/MartialBE" target="_blank">
|
<Link href="https://github.com/MartialBE" target="_blank">
|
||||||
MartialBE
|
MartialBE
|
||||||
</Link>
|
</Link>
|
||||||
修改,源代码遵循
|
开发,基于
|
||||||
|
<Link href="https://github.com/songquanpeng" target="_blank">
|
||||||
|
JustSong
|
||||||
|
</Link>{' '}
|
||||||
|
One API,源代码遵循
|
||||||
<Link href="https://opensource.org/licenses/mit-license.php"> MIT 协议</Link>
|
<Link href="https://opensource.org/licenses/mit-license.php"> MIT 协议</Link>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
@ -234,6 +234,33 @@ const typeConfig = {
|
|||||||
test_model: 'yi-34b-chat-0205'
|
test_model: 'yi-34b-chat-0205'
|
||||||
},
|
},
|
||||||
modelGroup: 'Lingyiwanwu'
|
modelGroup: 'Lingyiwanwu'
|
||||||
|
},
|
||||||
|
34: {
|
||||||
|
input: {
|
||||||
|
models: [
|
||||||
|
'mj_imagine',
|
||||||
|
'mj_variation',
|
||||||
|
'mj_reroll',
|
||||||
|
'mj_blend',
|
||||||
|
'mj_modal',
|
||||||
|
'mj_zoom',
|
||||||
|
'mj_shorten',
|
||||||
|
'mj_high_variation',
|
||||||
|
'mj_low_variation',
|
||||||
|
'mj_pan',
|
||||||
|
'mj_inpaint',
|
||||||
|
'mj_custom_zoom',
|
||||||
|
'mj_describe',
|
||||||
|
'mj_upscale',
|
||||||
|
'swap_face'
|
||||||
|
]
|
||||||
|
},
|
||||||
|
prompt: {
|
||||||
|
key: '密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填',
|
||||||
|
base_url: '地址填写midjourney-proxy部署的地址',
|
||||||
|
test_model: ''
|
||||||
|
},
|
||||||
|
modelGroup: 'Midjourney'
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
174
web/src/views/Midjourney/component/TableRow.js
Normal file
174
web/src/views/Midjourney/component/TableRow.js
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
import PropTypes from 'prop-types';
|
||||||
|
|
||||||
|
import { useState } from 'react';
|
||||||
|
import {
|
||||||
|
TableRow,
|
||||||
|
TableCell,
|
||||||
|
Button,
|
||||||
|
Dialog,
|
||||||
|
DialogActions,
|
||||||
|
DialogContent,
|
||||||
|
ButtonGroup,
|
||||||
|
Popover,
|
||||||
|
MenuItem,
|
||||||
|
MenuList,
|
||||||
|
Tooltip
|
||||||
|
} from '@mui/material';
|
||||||
|
|
||||||
|
import { timestamp2string, copy } from 'utils/common';
|
||||||
|
import Label from 'ui-component/Label';
|
||||||
|
import { ACTION_TYPE, CODE_TYPE, STATUS_TYPE } from '../type/Type';
|
||||||
|
import { IconCaretDownFilled, IconCopy, IconDownload, IconExternalLink } from '@tabler/icons-react';
|
||||||
|
|
||||||
|
function renderType(types, type) {
|
||||||
|
const typeOption = types[type];
|
||||||
|
if (typeOption) {
|
||||||
|
return (
|
||||||
|
<Label variant="filled" color={typeOption.color}>
|
||||||
|
{' '}
|
||||||
|
{typeOption.text}{' '}
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return (
|
||||||
|
<Label variant="filled" color="error">
|
||||||
|
{' '}
|
||||||
|
未知{' '}
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async function downloadImage(url, filename) {
|
||||||
|
const response = await fetch(url);
|
||||||
|
const blob = await response.blob();
|
||||||
|
const blobUrl = URL.createObjectURL(blob);
|
||||||
|
const link = document.createElement('a');
|
||||||
|
link.href = blobUrl;
|
||||||
|
link.download = filename;
|
||||||
|
link.click();
|
||||||
|
URL.revokeObjectURL(blobUrl);
|
||||||
|
}
|
||||||
|
|
||||||
|
function TruncatedText(text) {
|
||||||
|
const truncatedText = text.length > 30 ? text.substring(0, 100) + '...' : text;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Tooltip
|
||||||
|
placement="top"
|
||||||
|
title={text}
|
||||||
|
onClick={() => {
|
||||||
|
copy(text, '');
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<span>{truncatedText}</span>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function LogTableRow({ item, userIsAdmin }) {
|
||||||
|
const [open, setOpen] = useState(false);
|
||||||
|
const [menuOpen, setMenuOpen] = useState(null);
|
||||||
|
const handleClickOpen = () => {
|
||||||
|
setOpen(true);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleClose = () => {
|
||||||
|
setOpen(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleOpenMenu = (event) => {
|
||||||
|
setMenuOpen(event.currentTarget);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCloseMenu = () => {
|
||||||
|
setMenuOpen(null);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<TableRow tabIndex={item.id}>
|
||||||
|
<TableCell>{item.mj_id}</TableCell>
|
||||||
|
<TableCell>{timestamp2string(item.submit_time / 1000)}</TableCell>
|
||||||
|
|
||||||
|
{userIsAdmin && <TableCell>{item.channel_id || ''}</TableCell>}
|
||||||
|
{userIsAdmin && <TableCell>{item.user_id || ''}</TableCell>}
|
||||||
|
|
||||||
|
<TableCell>{renderType(ACTION_TYPE, item.action)}</TableCell>
|
||||||
|
{userIsAdmin && <TableCell>{renderType(CODE_TYPE, item.code)}</TableCell>}
|
||||||
|
{userIsAdmin && <TableCell>{renderType(STATUS_TYPE, item.status)}</TableCell>}
|
||||||
|
<TableCell>{item.progress}</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
{item.image_url == '' ? (
|
||||||
|
'无'
|
||||||
|
) : (
|
||||||
|
<ButtonGroup size="small" aria-label="split button">
|
||||||
|
<Button color="primary" onClick={handleClickOpen}>
|
||||||
|
显示
|
||||||
|
</Button>
|
||||||
|
<Button onClick={handleOpenMenu}>
|
||||||
|
<IconCaretDownFilled size={'16px'} />
|
||||||
|
</Button>
|
||||||
|
</ButtonGroup>
|
||||||
|
)}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell>{TruncatedText(item.prompt)}</TableCell>
|
||||||
|
<TableCell>{TruncatedText(item.prompt_en)}</TableCell>
|
||||||
|
<TableCell>{TruncatedText(item.fail_reason)}</TableCell>
|
||||||
|
</TableRow>
|
||||||
|
<Dialog open={open} onClose={handleClose}>
|
||||||
|
<DialogContent>
|
||||||
|
<img src={item.image_url} alt="item" style={{ maxWidth: '100%', maxHeight: '100%' }} />
|
||||||
|
</DialogContent>
|
||||||
|
<DialogActions>
|
||||||
|
<Button onClick={handleClose} color="primary">
|
||||||
|
关闭
|
||||||
|
</Button>
|
||||||
|
</DialogActions>
|
||||||
|
</Dialog>
|
||||||
|
|
||||||
|
<Popover
|
||||||
|
open={!!menuOpen}
|
||||||
|
anchorEl={menuOpen}
|
||||||
|
onClose={handleCloseMenu}
|
||||||
|
anchorOrigin={{ vertical: 'top', horizontal: 'left' }}
|
||||||
|
transformOrigin={{ vertical: 'top', horizontal: 'right' }}
|
||||||
|
PaperProps={{
|
||||||
|
sx: { width: 140 }
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<MenuList>
|
||||||
|
<MenuItem
|
||||||
|
onClick={() => {
|
||||||
|
handleCloseMenu();
|
||||||
|
copy(item.image_url, '图片地址');
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<IconCopy style={{ marginRight: '16px' }} />
|
||||||
|
复制地址
|
||||||
|
</MenuItem>
|
||||||
|
|
||||||
|
<MenuItem
|
||||||
|
onClick={async () => {
|
||||||
|
handleCloseMenu();
|
||||||
|
await downloadImage(item.image_url, item.mj_id + '.png');
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<IconDownload style={{ marginRight: '16px' }} /> 下载图片{' '}
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem
|
||||||
|
onClick={() => {
|
||||||
|
handleCloseMenu();
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<IconExternalLink style={{ marginRight: '16px' }} /> 新窗口打开{' '}
|
||||||
|
</MenuItem>
|
||||||
|
</MenuList>
|
||||||
|
</Popover>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogTableRow.propTypes = {
|
||||||
|
item: PropTypes.object,
|
||||||
|
userIsAdmin: PropTypes.bool
|
||||||
|
};
|
113
web/src/views/Midjourney/component/TableToolBar.js
Normal file
113
web/src/views/Midjourney/component/TableToolBar.js
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import PropTypes from 'prop-types';
|
||||||
|
import { useTheme } from '@mui/material/styles';
|
||||||
|
import { IconBroadcast, IconCalendarEvent } from '@tabler/icons-react';
|
||||||
|
import { InputAdornment, OutlinedInput, Stack, FormControl, InputLabel } from '@mui/material';
|
||||||
|
import { LocalizationProvider, DateTimePicker } from '@mui/x-date-pickers';
|
||||||
|
import { AdapterDayjs } from '@mui/x-date-pickers/AdapterDayjs';
|
||||||
|
import dayjs from 'dayjs';
|
||||||
|
require('dayjs/locale/zh-cn');
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
export default function TableToolBar({ filterName, handleFilterName, userIsAdmin }) {
|
||||||
|
const theme = useTheme();
|
||||||
|
const grey500 = theme.palette.grey[500];
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Stack direction={{ xs: 'column', sm: 'row' }} spacing={{ xs: 3, sm: 2, md: 4 }} padding={'24px'} paddingBottom={'0px'}>
|
||||||
|
{userIsAdmin && (
|
||||||
|
<FormControl>
|
||||||
|
<InputLabel htmlFor="channel-channel_id-label">渠道ID</InputLabel>
|
||||||
|
<OutlinedInput
|
||||||
|
id="channel_id"
|
||||||
|
name="channel_id"
|
||||||
|
sx={{
|
||||||
|
minWidth: '100%'
|
||||||
|
}}
|
||||||
|
label="渠道ID"
|
||||||
|
value={filterName.channel_id}
|
||||||
|
onChange={handleFilterName}
|
||||||
|
placeholder="渠道ID"
|
||||||
|
startAdornment={
|
||||||
|
<InputAdornment position="start">
|
||||||
|
<IconBroadcast stroke={1.5} size="20px" color={grey500} />
|
||||||
|
</InputAdornment>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
<FormControl>
|
||||||
|
<InputLabel htmlFor="channel-mj_id-label">任务ID</InputLabel>
|
||||||
|
<OutlinedInput
|
||||||
|
id="mj_id"
|
||||||
|
name="mj_id"
|
||||||
|
sx={{
|
||||||
|
minWidth: '100%'
|
||||||
|
}}
|
||||||
|
label="任务ID"
|
||||||
|
value={filterName.mj_id}
|
||||||
|
onChange={handleFilterName}
|
||||||
|
placeholder="任务ID"
|
||||||
|
startAdornment={
|
||||||
|
<InputAdornment position="start">
|
||||||
|
<IconCalendarEvent stroke={1.5} size="20px" color={grey500} />
|
||||||
|
</InputAdornment>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
<FormControl>
|
||||||
|
<LocalizationProvider dateAdapter={AdapterDayjs} adapterLocale={'zh-cn'}>
|
||||||
|
<DateTimePicker
|
||||||
|
label="起始时间"
|
||||||
|
ampm={false}
|
||||||
|
name="start_timestamp"
|
||||||
|
value={filterName.start_timestamp === 0 ? null : dayjs.unix(filterName.start_timestamp / 1000)}
|
||||||
|
onChange={(value) => {
|
||||||
|
if (value === null) {
|
||||||
|
handleFilterName({ target: { name: 'start_timestamp', value: 0 } });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
handleFilterName({ target: { name: 'start_timestamp', value: value.unix() * 1000 } });
|
||||||
|
}}
|
||||||
|
slotProps={{
|
||||||
|
actionBar: {
|
||||||
|
actions: ['clear', 'today', 'accept']
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</LocalizationProvider>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
<FormControl>
|
||||||
|
<LocalizationProvider dateAdapter={AdapterDayjs} adapterLocale={'zh-cn'}>
|
||||||
|
<DateTimePicker
|
||||||
|
label="结束时间"
|
||||||
|
name="end_timestamp"
|
||||||
|
ampm={false}
|
||||||
|
value={filterName.end_timestamp === 0 ? null : dayjs.unix(filterName.end_timestamp / 1000)}
|
||||||
|
onChange={(value) => {
|
||||||
|
if (value === null) {
|
||||||
|
handleFilterName({ target: { name: 'end_timestamp', value: 0 } });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
handleFilterName({ target: { name: 'end_timestamp', value: value.unix() * 1000 } });
|
||||||
|
}}
|
||||||
|
slotProps={{
|
||||||
|
actionBar: {
|
||||||
|
actions: ['clear', 'today', 'accept']
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</LocalizationProvider>
|
||||||
|
</FormControl>
|
||||||
|
</Stack>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
TableToolBar.propTypes = {
|
||||||
|
filterName: PropTypes.object,
|
||||||
|
handleFilterName: PropTypes.func,
|
||||||
|
userIsAdmin: PropTypes.bool
|
||||||
|
};
|
247
web/src/views/Midjourney/index.js
Normal file
247
web/src/views/Midjourney/index.js
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
import { useState, useEffect, useCallback } from 'react';
|
||||||
|
import { showError } from 'utils/common';
|
||||||
|
|
||||||
|
import Table from '@mui/material/Table';
|
||||||
|
import TableBody from '@mui/material/TableBody';
|
||||||
|
import TableContainer from '@mui/material/TableContainer';
|
||||||
|
import PerfectScrollbar from 'react-perfect-scrollbar';
|
||||||
|
import TablePagination from '@mui/material/TablePagination';
|
||||||
|
import LinearProgress from '@mui/material/LinearProgress';
|
||||||
|
import ButtonGroup from '@mui/material/ButtonGroup';
|
||||||
|
import Toolbar from '@mui/material/Toolbar';
|
||||||
|
|
||||||
|
import { Button, Card, Stack, Container, Typography, Box } from '@mui/material';
|
||||||
|
import LogTableRow from './component/TableRow';
|
||||||
|
import KeywordTableHead from 'ui-component/TableHead';
|
||||||
|
import TableToolBar from './component/TableToolBar';
|
||||||
|
import { API } from 'utils/api';
|
||||||
|
import { isAdmin } from 'utils/common';
|
||||||
|
import { ITEMS_PER_PAGE } from 'constants';
|
||||||
|
import { IconRefresh, IconSearch } from '@tabler/icons-react';
|
||||||
|
import dayjs from 'dayjs';
|
||||||
|
|
||||||
|
export default function Log() {
|
||||||
|
const originalKeyword = {
|
||||||
|
p: 0,
|
||||||
|
channel_id: '',
|
||||||
|
mj_id: '',
|
||||||
|
start_timestamp: 0,
|
||||||
|
end_timestamp: dayjs().unix() * 1000 + 3600
|
||||||
|
};
|
||||||
|
|
||||||
|
const [page, setPage] = useState(0);
|
||||||
|
const [order, setOrder] = useState('desc');
|
||||||
|
const [orderBy, setOrderBy] = useState('id');
|
||||||
|
const [rowsPerPage, setRowsPerPage] = useState(ITEMS_PER_PAGE);
|
||||||
|
const [listCount, setListCount] = useState(0);
|
||||||
|
const [searching, setSearching] = useState(false);
|
||||||
|
const [toolBarValue, setToolBarValue] = useState(originalKeyword);
|
||||||
|
const [searchKeyword, setSearchKeyword] = useState(originalKeyword);
|
||||||
|
const [refreshFlag, setRefreshFlag] = useState(false);
|
||||||
|
|
||||||
|
const [logs, setLogs] = useState([]);
|
||||||
|
const userIsAdmin = isAdmin();
|
||||||
|
|
||||||
|
const handleSort = (event, id) => {
|
||||||
|
const isAsc = orderBy === id && order === 'asc';
|
||||||
|
if (id !== '') {
|
||||||
|
setOrder(isAsc ? 'desc' : 'asc');
|
||||||
|
setOrderBy(id);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleChangePage = (event, newPage) => {
|
||||||
|
setPage(newPage);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleChangeRowsPerPage = (event) => {
|
||||||
|
setPage(0);
|
||||||
|
setRowsPerPage(parseInt(event.target.value, 10));
|
||||||
|
};
|
||||||
|
|
||||||
|
const searchLogs = async () => {
|
||||||
|
setPage(0);
|
||||||
|
setSearchKeyword(toolBarValue);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleToolBarValue = (event) => {
|
||||||
|
setToolBarValue({ ...toolBarValue, [event.target.name]: event.target.value });
|
||||||
|
};
|
||||||
|
|
||||||
|
const fetchData = useCallback(
|
||||||
|
async (page, rowsPerPage, keyword, order, orderBy) => {
|
||||||
|
setSearching(true);
|
||||||
|
try {
|
||||||
|
if (orderBy) {
|
||||||
|
orderBy = order === 'desc' ? '-' + orderBy : orderBy;
|
||||||
|
}
|
||||||
|
const url = userIsAdmin ? '/api/mj/' : '/api/mj/self/';
|
||||||
|
if (!userIsAdmin) {
|
||||||
|
delete keyword.channel_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
const res = await API.get(url, {
|
||||||
|
params: {
|
||||||
|
page: page + 1,
|
||||||
|
size: rowsPerPage,
|
||||||
|
order: orderBy,
|
||||||
|
...keyword
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const { success, message, data } = res.data;
|
||||||
|
if (success) {
|
||||||
|
setListCount(data.total_count);
|
||||||
|
setLogs(data.data);
|
||||||
|
} else {
|
||||||
|
showError(message);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(error);
|
||||||
|
}
|
||||||
|
setSearching(false);
|
||||||
|
},
|
||||||
|
[userIsAdmin]
|
||||||
|
);
|
||||||
|
|
||||||
|
// 处理刷新
|
||||||
|
const handleRefresh = async () => {
|
||||||
|
setOrderBy('id');
|
||||||
|
setOrder('desc');
|
||||||
|
setToolBarValue(originalKeyword);
|
||||||
|
setSearchKeyword(originalKeyword);
|
||||||
|
setRefreshFlag(!refreshFlag);
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchData(page, rowsPerPage, searchKeyword, order, orderBy);
|
||||||
|
}, [page, rowsPerPage, searchKeyword, order, orderBy, fetchData, refreshFlag]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Stack direction="row" alignItems="center" justifyContent="space-between" mb={5}>
|
||||||
|
<Typography variant="h4">Midjourney</Typography>
|
||||||
|
</Stack>
|
||||||
|
<Card>
|
||||||
|
<Box component="form" noValidate>
|
||||||
|
<TableToolBar filterName={toolBarValue} handleFilterName={handleToolBarValue} userIsAdmin={userIsAdmin} />
|
||||||
|
</Box>
|
||||||
|
<Toolbar
|
||||||
|
sx={{
|
||||||
|
textAlign: 'right',
|
||||||
|
height: 50,
|
||||||
|
display: 'flex',
|
||||||
|
justifyContent: 'space-between',
|
||||||
|
p: (theme) => theme.spacing(0, 1, 0, 3)
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Container>
|
||||||
|
<ButtonGroup variant="outlined" aria-label="outlined small primary button group">
|
||||||
|
<Button onClick={handleRefresh} startIcon={<IconRefresh width={'18px'} />}>
|
||||||
|
刷新/清除搜索条件
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
<Button onClick={searchLogs} startIcon={<IconSearch width={'18px'} />}>
|
||||||
|
搜索
|
||||||
|
</Button>
|
||||||
|
</ButtonGroup>
|
||||||
|
</Container>
|
||||||
|
</Toolbar>
|
||||||
|
{searching && <LinearProgress />}
|
||||||
|
<PerfectScrollbar component="div">
|
||||||
|
<TableContainer sx={{ overflow: 'unset' }}>
|
||||||
|
<Table sx={{ minWidth: 800 }}>
|
||||||
|
<KeywordTableHead
|
||||||
|
order={order}
|
||||||
|
orderBy={orderBy}
|
||||||
|
onRequestSort={handleSort}
|
||||||
|
headLabel={[
|
||||||
|
{
|
||||||
|
id: 'mj_id',
|
||||||
|
label: '任务ID',
|
||||||
|
disableSort: false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'submit_time',
|
||||||
|
label: '提交时间',
|
||||||
|
disableSort: false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'channel_id',
|
||||||
|
label: '渠道',
|
||||||
|
disableSort: false,
|
||||||
|
hide: !userIsAdmin
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'user_id',
|
||||||
|
label: '用户',
|
||||||
|
disableSort: false,
|
||||||
|
hide: !userIsAdmin
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'action',
|
||||||
|
label: '类型',
|
||||||
|
disableSort: false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'code',
|
||||||
|
label: '提交结果',
|
||||||
|
disableSort: false,
|
||||||
|
hide: !userIsAdmin
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'status',
|
||||||
|
label: '任务状态',
|
||||||
|
disableSort: false,
|
||||||
|
hide: !userIsAdmin
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'progress',
|
||||||
|
label: '进度',
|
||||||
|
disableSort: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'image_url',
|
||||||
|
label: '结果图片',
|
||||||
|
disableSort: true,
|
||||||
|
width: '120px'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'prompt',
|
||||||
|
label: 'Prompt',
|
||||||
|
disableSort: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'prompt_en',
|
||||||
|
label: 'PromptEn',
|
||||||
|
disableSort: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'fail_reason',
|
||||||
|
label: '失败原因',
|
||||||
|
disableSort: true
|
||||||
|
}
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
<TableBody>
|
||||||
|
{logs.map((row, index) => (
|
||||||
|
<LogTableRow item={row} key={`${row.id}_${index}`} userIsAdmin={userIsAdmin} />
|
||||||
|
))}
|
||||||
|
</TableBody>
|
||||||
|
</Table>
|
||||||
|
</TableContainer>
|
||||||
|
</PerfectScrollbar>
|
||||||
|
<TablePagination
|
||||||
|
page={page}
|
||||||
|
component="div"
|
||||||
|
count={listCount}
|
||||||
|
rowsPerPage={rowsPerPage}
|
||||||
|
onPageChange={handleChangePage}
|
||||||
|
rowsPerPageOptions={[10, 25, 30]}
|
||||||
|
onRowsPerPageChange={handleChangeRowsPerPage}
|
||||||
|
showFirstButton
|
||||||
|
showLastButton
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
33
web/src/views/Midjourney/type/Type.js
Normal file
33
web/src/views/Midjourney/type/Type.js
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
export const ACTION_TYPE = {
|
||||||
|
IMAGINE: { value: 'IMAGINE', text: '绘图', color: 'primary' },
|
||||||
|
UPSCALE: { value: 'UPSCALE', text: '放大', color: 'orange' },
|
||||||
|
VARIATION: { value: 'VARIATION', text: '变换', color: 'default' },
|
||||||
|
HIGH_VARIATION: { value: 'HIGH_VARIATION', text: '强变换', color: 'default' },
|
||||||
|
LOW_VARIATION: { value: 'LOW_VARIATION', text: '弱变换', color: 'default' },
|
||||||
|
PAN: { value: 'PAN', text: '平移', color: 'secondary' },
|
||||||
|
DESCRIBE: { value: 'DESCRIBE', text: '图生文', color: 'secondary' },
|
||||||
|
BLEND: { value: 'BLEND', text: '图混合', color: 'secondary' },
|
||||||
|
SHORTEN: { value: 'SHORTEN', text: '缩词', color: 'secondary' },
|
||||||
|
REROLL: { value: 'REROLL', text: '重绘', color: 'secondary' },
|
||||||
|
INPAINT: { value: 'INPAINT', text: '局部重绘-提交', color: 'secondary' },
|
||||||
|
ZOOM: { value: 'ZOOM', text: '变焦', color: 'secondary' },
|
||||||
|
CUSTOM_ZOOM: { value: 'CUSTOM_ZOOM', text: '自定义变焦-提交', color: 'secondary' },
|
||||||
|
MODAL: { value: 'MODAL', text: '窗口处理', color: 'secondary' },
|
||||||
|
SWAP_FACE: { value: 'SWAP_FACE', text: '换脸', color: 'secondary' }
|
||||||
|
};
|
||||||
|
|
||||||
|
export const CODE_TYPE = {
|
||||||
|
1: { value: 1, text: '已提交', color: 'primary' },
|
||||||
|
21: { value: 21, text: '等待中', color: 'orange' },
|
||||||
|
22: { value: 22, text: '重复提交', color: 'default' },
|
||||||
|
0: { value: 0, text: '未提交', color: 'default' }
|
||||||
|
};
|
||||||
|
|
||||||
|
export const STATUS_TYPE = {
|
||||||
|
SUCCESS: { value: 'SUCCESS', text: '成功', color: 'success' },
|
||||||
|
NOT_START: { value: 'NOT_START', text: '未启动', color: 'default' },
|
||||||
|
SUBMITTED: { value: 'SUBMITTED', text: '队列中', color: 'secondary' },
|
||||||
|
IN_PROGRESS: { value: 'IN_PROGRESS', text: '执行中', color: 'primary' },
|
||||||
|
FAILURE: { value: 'FAILURE', text: '失败', color: 'orange' },
|
||||||
|
MODAL: { value: 'MODAL', text: '窗口等待', color: 'default' }
|
||||||
|
};
|
@ -29,7 +29,8 @@ const OperationSetting = () => {
|
|||||||
DisplayTokenStatEnabled: '',
|
DisplayTokenStatEnabled: '',
|
||||||
ApproximateTokenEnabled: '',
|
ApproximateTokenEnabled: '',
|
||||||
RetryTimes: 0,
|
RetryTimes: 0,
|
||||||
RetryCooldownSeconds: 0
|
RetryCooldownSeconds: 0,
|
||||||
|
MjNotifyEnabled: ''
|
||||||
});
|
});
|
||||||
const [originInputs, setOriginInputs] = useState({});
|
const [originInputs, setOriginInputs] = useState({});
|
||||||
let [loading, setLoading] = useState(false);
|
let [loading, setLoading] = useState(false);
|
||||||
@ -278,6 +279,22 @@ const OperationSetting = () => {
|
|||||||
</Button>
|
</Button>
|
||||||
</Stack>
|
</Stack>
|
||||||
</SubCard>
|
</SubCard>
|
||||||
|
<SubCard title="其他设置">
|
||||||
|
<Stack justifyContent="flex-start" alignItems="flex-start" spacing={2}>
|
||||||
|
<Stack
|
||||||
|
direction={{ sm: 'column', md: 'row' }}
|
||||||
|
spacing={{ xs: 3, sm: 2, md: 4 }}
|
||||||
|
justifyContent="flex-start"
|
||||||
|
alignItems="flex-start"
|
||||||
|
>
|
||||||
|
<FormControlLabel
|
||||||
|
sx={{ marginLeft: '0px' }}
|
||||||
|
label="Midjourney 允许回调(会泄露服务器ip地址)"
|
||||||
|
control={<Checkbox checked={inputs.MjNotifyEnabled === 'true'} onChange={handleInputChange} name="MjNotifyEnabled" />}
|
||||||
|
/>
|
||||||
|
</Stack>
|
||||||
|
</Stack>
|
||||||
|
</SubCard>
|
||||||
<SubCard title="日志设置">
|
<SubCard title="日志设置">
|
||||||
<Stack direction="column" justifyContent="flex-start" alignItems="flex-start" spacing={2}>
|
<Stack direction="column" justifyContent="flex-start" alignItems="flex-start" spacing={2}>
|
||||||
<FormControlLabel
|
<FormControlLabel
|
||||||
|
Loading…
Reference in New Issue
Block a user