feat: add Midjourney (#138)

* 🚧 stash

*  feat: add Midjourney

* 📝 doc: update readme
This commit is contained in:
Buer 2024-04-05 04:03:46 +08:00 committed by GitHub
parent 87bfecf3e9
commit c1fc32add7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 2479 additions and 84 deletions

View File

@ -58,6 +58,16 @@ _本项目是基于[one-api](https://github.com/songquanpeng/one-api)二次开
请查看[文档](https://github.com/MartialBE/one-api/wiki)
## 感谢
- 本程序使用了以下开源项目
- [one-api](https://github.com/songquanpeng/one-api)为本项目的基础
- [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template)为本项目的前端界面
- [minimal-ui-kit](https://github.com/minimal-ui-kit/material-kit-react),使用了其中的部分样式
- [new api](https://github.com/Calcium-Ion/new-api)Midjourney 模块的代码来源于此
感谢以上项目的作者和贡献者
## 其他
<a href="https://next.ossinsight.io/widgets/official/analyze-repo-stars-history?repo_id=689214770" target="_blank" style="display: block" align="center">

View File

@ -37,6 +37,9 @@ var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
// mj
var MjNotifyEnabled = false
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
@ -161,6 +164,7 @@ const (
ChannelTypeGroq = 31
ChannelTypeBedrock = 32
ChannelTypeLingyi = 33
ChannelTypeMidjourney = 34
)
var ChannelBaseURLs = []string{
@ -198,6 +202,7 @@ var ChannelBaseURLs = []string{
"https://api.groq.com/openai", //31
"", //32
"https://api.lingyiwanwu.com", //33
"", //34
}
const (

32
common/go-channel.go Normal file
View File

@ -0,0 +1,32 @@
package common
import (
"fmt"
"runtime/debug"
)
func SafeGoroutine(f func()) {
go func() {
defer func() {
if r := recover(); r != nil {
SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
}
}()
f()
}()
}
func SafeSend(ch chan bool, value bool) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = true
}
}()
// This will panic if the channel is closed.
ch <- value
// If the code reaches here, then the channel was not closed.
return false
}

View File

@ -106,7 +106,10 @@ func logHelper(ctx context.Context, level string, msg string) {
if level == loggerINFO {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
id, ok := ctx.Value(RequestIdKey).(string)
if !ok {
id = "unknown"
}
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here

View File

@ -23,6 +23,7 @@ type HTTPRequester struct {
CreateFormBuilder func(io.Writer) FormBuilder
ErrorHandler HttpErrorHandler
proxyAddr string
Context context.Context
}
// NewHTTPRequester 创建一个新的 HTTPRequester 实例。
@ -37,6 +38,7 @@ func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequ
},
ErrorHandler: errorHandler,
proxyAddr: proxyAddr,
Context: context.Background(),
}
}
@ -47,18 +49,18 @@ type requestOptions struct {
type requestOption func(*requestOptions)
func (r *HTTPRequester) getContext() context.Context {
func (r *HTTPRequester) setProxy() context.Context {
if r.proxyAddr == "" {
return context.Background()
return r.Context
}
// 如果是以 socks5:// 开头的地址,那么使用 socks5 代理
if strings.HasPrefix(r.proxyAddr, "socks5://") {
return context.WithValue(context.Background(), ProxySock5AddrKey, r.proxyAddr)
return context.WithValue(r.Context, ProxySock5AddrKey, r.proxyAddr)
}
// 否则使用 http 代理
return context.WithValue(context.Background(), ProxyHTTPAddrKey, r.proxyAddr)
return context.WithValue(r.Context, ProxyHTTPAddrKey, r.proxyAddr)
}
@ -71,7 +73,7 @@ func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption)
for _, setter := range setters {
setter(args)
}
req, err := r.requestBuilder.Build(r.getContext(), method, url, args.body, args.header)
req, err := r.requestBuilder.Build(r.setProxy(), method, url, args.body, args.header)
if err != nil {
return nil, err
}

285
controller/midjourney.go Normal file
View File

@ -0,0 +1,285 @@
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: controller/midjourney.go
package controller
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/model"
provider "one-api/providers/midjourney"
"time"
"github.com/gin-gonic/gin"
)
var activeMidjourneyTask = make(chan bool, 1)
func InitMidjourneyTask() {
common.SafeGoroutine(func() {
midjourneyTask()
})
ActivateUpdateMidjourneyTaskBulk()
}
func midjourneyTask() {
for {
select {
case <-activeMidjourneyTask:
UpdateMidjourneyTaskBulk()
}
}
}
func ActivateUpdateMidjourneyTaskBulk() {
if len(activeMidjourneyTask) == 0 {
activeMidjourneyTask <- true
}
}
func UpdateMidjourneyTaskBulk() {
ctx := context.WithValue(context.Background(), common.RequestIdKey, "MidjourneyTask")
for {
common.LogInfo(ctx, "running")
tasks := model.GetAllUnFinishTasks()
// 如果没有未完成的任务,则等待
if len(tasks) == 0 {
for len(activeMidjourneyTask) > 0 {
<-activeMidjourneyTask
}
common.LogInfo(ctx, "no tasks, waiting...")
return
}
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney)
nullTaskIds := make([]int, 0)
for _, task := range tasks {
if task.MjId == "" {
// 统计失败的未完成任务
nullTaskIds = append(nullTaskIds, task.Id)
continue
}
taskM[task.MjId] = task
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
}
if len(nullTaskIds) > 0 {
err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
} else {
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 {
continue
}
for channelId, taskIds := range taskChannelM {
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
continue
}
midjourneyChannel := model.ChannelGroup.GetChannel(channelId)
if midjourneyChannel == nil {
err := model.MjBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", channelId),
"status": "FAILURE",
"progress": "100%",
})
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
body, _ := json.Marshal(map[string]any{
"ids": taskIds,
})
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := requester.HTTPClient.Do(req)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []provider.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
continue
}
resp.Body.Close()
req.Body.Close()
cancel()
for _, responseItem := range responseItems {
task := taskM[responseItem.MjId]
useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
// 如果时间超过一小时且进度不是100%,则认为任务失败
if useTime > 3600000 && task.Progress != "100%" {
responseItem.FailReason = "上游任务超时超过1小时"
responseItem.Status = "FAILURE"
}
if !checkMjTaskNeedUpdate(task, responseItem) {
continue
}
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
task.State = responseItem.State
task.SubmitTime = responseItem.SubmitTime
task.StartTime = responseItem.StartTime
task.FinishTime = responseItem.FinishTime
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if responseItem.Properties != nil {
propertiesStr, _ := json.Marshal(responseItem.Properties)
task.Properties = string(propertiesStr)
}
if responseItem.Buttons != nil {
buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr)
}
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
err = task.Update()
if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
}
}
}
time.Sleep(time.Duration(15) * time.Second)
}
}
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask provider.MidjourneyDto) bool {
if oldTask.Code != 1 {
return true
}
if oldTask.Progress != newTask.Progress {
return true
}
if oldTask.PromptEn != newTask.PromptEn {
return true
}
if oldTask.State != newTask.State {
return true
}
if oldTask.SubmitTime != newTask.SubmitTime {
return true
}
if oldTask.StartTime != newTask.StartTime {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.ImageUrl != newTask.ImageUrl {
return true
}
if oldTask.Status != newTask.Status {
return true
}
if oldTask.FailReason != newTask.FailReason {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.Progress != "100%" && newTask.FailReason != "" {
return true
}
return false
}
func GetAllMidjourney(c *gin.Context) {
var params model.TaskQueryParams
if err := c.ShouldBindQuery(&params); err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
midjourneys, err := model.GetAllTasks(&params)
if err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": midjourneys,
})
}
func GetUserMidjourney(c *gin.Context) {
userId := c.GetInt("id")
var params model.TaskQueryParams
if err := c.ShouldBindQuery(&params); err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
midjourneys, err := model.GetAllUserTask(userId, &params)
if err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": midjourneys,
})
}

View File

@ -40,6 +40,7 @@ func GetStatus(c *gin.Context) {
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"telegram_bot": telegram_bot,
"mj_notify_enabled": common.MjNotifyEnabled,
},
})
}

View File

@ -45,6 +45,8 @@ func main() {
// Initialize Telegram bot
telegram.InitTelegramBot()
controller.InitMidjourneyTask()
initHttpServer()
}

View File

@ -83,43 +83,54 @@ func RootAuth() func(c *gin.Context) {
}
}
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
key := c.Request.Header.Get("Authorization")
key = strings.TrimPrefix(key, "Bearer ")
key = strings.TrimPrefix(key, "sk-")
parts := strings.Split(key, "-")
key = parts[0]
token, err := model.ValidateUserToken(key)
if err != nil {
abortWithMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
channelId := common.String2Int(parts[1])
if channelId == 0 {
abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id")
return
}
c.Set("specific_channel_id", channelId)
} else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
func tokenAuth(c *gin.Context, key string) {
key = strings.TrimPrefix(key, "Bearer ")
key = strings.TrimPrefix(key, "sk-")
parts := strings.Split(key, "-")
key = parts[0]
token, err := model.ValidateUserToken(key)
if err != nil {
abortWithMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
channelId := common.String2Int(parts[1])
if channelId == 0 {
abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id")
return
}
c.Set("specific_channel_id", channelId)
} else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
}
c.Next()
}
c.Next()
}
func OpenaiAuth() func(c *gin.Context) {
return func(c *gin.Context) {
key := c.Request.Header.Get("Authorization")
tokenAuth(c, key)
}
}
func MjAuth() func(c *gin.Context) {
return func(c *gin.Context) {
key := c.Request.Header.Get("mj-api-secret")
tokenAuth(c, key)
}
}

View File

@ -114,6 +114,17 @@ func (cc *ChannelsChooser) GetGroupModels(group string) ([]string, error) {
return models, nil
}
func (cc *ChannelsChooser) GetChannel(channelId int) *Channel {
cc.RLock()
defer cc.RUnlock()
if choice, ok := cc.Channels[channelId]; ok {
return choice.Channel
}
return nil
}
var ChannelGroup = ChannelsChooser{}
func (cc *ChannelsChooser) Load() {

View File

@ -139,6 +139,10 @@ func InitDB() (err error) {
if err != nil {
return err
}
err = db.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err

182
model/midjourney.go Normal file
View File

@ -0,0 +1,182 @@
// Copyright (c) 2024 Calcium-Ion
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
//
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
package model
type Midjourney struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action" gorm:"type:varchar(40);index"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time" gorm:"index"`
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
ImageUrl string `json:"image_url"`
Status string `json:"status" gorm:"type:varchar(20);index"`
Progress string `json:"progress" gorm:"type:varchar(30);index"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
Quota int `json:"quota"`
Buttons string `json:"buttons"`
Properties string `json:"properties"`
}
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
type TaskQueryParams struct {
ChannelID int `form:"channel_id"`
MjID string `form:"mj_id"`
StartTimestamp int `form:"start_timestamp"`
EndTimestamp int `form:"end_timestamp"`
PaginationParams
}
var allowedMidjourneyOrderFields = map[string]bool{
"id": true,
"user_id": true,
"code": true,
"action": true,
"mj_id": true,
"submit_time": true,
"start_time": true,
"finish_time": true,
"status": true,
"channel_id": true,
}
func GetAllUserTask(userId int, params *TaskQueryParams) (*DataResult[Midjourney], error) {
var tasks []*Midjourney
// 初始化查询构建器
query := DB.Where("user_id = ?", userId)
if params.MjID != "" {
query = query.Where("mj_id = ?", params.MjID)
}
if params.StartTimestamp != 0 {
// 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
query = query.Where("submit_time >= ?", params.StartTimestamp)
}
if params.EndTimestamp != 0 {
query = query.Where("submit_time <= ?", params.EndTimestamp)
}
return PaginateAndOrder(query, &params.PaginationParams, &tasks, allowedMidjourneyOrderFields)
}
func GetAllTasks(params *TaskQueryParams) (*DataResult[Midjourney], error) {
var tasks []*Midjourney
// 初始化查询构建器
query := DB
// 添加过滤条件
if params.ChannelID != 0 {
query = query.Where("channel_id = ?", params.ChannelID)
}
if params.MjID != "" {
query = query.Where("mj_id = ?", params.MjID)
}
if params.StartTimestamp != 0 {
query = query.Where("submit_time >= ?", params.StartTimestamp)
}
if params.EndTimestamp != 0 {
query = query.Where("submit_time <= ?", params.EndTimestamp)
}
return PaginateAndOrder(query, &params.PaginationParams, &tasks, allowedMidjourneyOrderFields)
}
func GetAllUnFinishTasks() []*Midjourney {
var tasks []*Midjourney
// get all tasks progress is not 100%
err := DB.Where("progress != ?", "100%").Find(&tasks).Error
if err != nil {
return nil
}
return tasks
}
func GetByOnlyMJId(mjId string) *Midjourney {
var mj *Midjourney
err := DB.Where("mj_id = ?", mjId).First(&mj).Error
if err != nil {
return nil
}
return mj
}
func GetByMJId(userId int, mjId string) *Midjourney {
var mj *Midjourney
err := DB.Where("user_id = ? and mj_id = ?", userId, mjId).First(&mj).Error
if err != nil {
return nil
}
return mj
}
func GetByMJIds(userId int, mjIds []string) []*Midjourney {
var mj []*Midjourney
err := DB.Where("user_id = ? and mj_id in (?)", userId, mjIds).Find(&mj).Error
if err != nil {
return nil
}
return mj
}
func GetMjByuId(id int) *Midjourney {
var mj *Midjourney
err := DB.Where("id = ?", id).First(&mj).Error
if err != nil {
return nil
}
return mj
}
func UpdateProgress(id int, progress string) error {
return DB.Model(&Midjourney{}).Where("id = ?", id).Update("progress", progress).Error
}
func (midjourney *Midjourney) Insert() error {
return DB.Create(midjourney).Error
}
func (midjourney *Midjourney) Update() error {
return DB.Save(midjourney).Error
}
func MjBulkUpdate(mjIds []string, params map[string]any) error {
return DB.Model(&Midjourney{}).
Where("mj_id in (?)", mjIds).
Updates(params).Error
}
func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
return DB.Model(&Midjourney{}).
Where("id in (?)", taskIDs).
Updates(params).Error
}

View File

@ -74,6 +74,8 @@ func InitOptionMap() {
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@ -138,6 +140,7 @@ var optionBoolMap = map[string]*bool{
"LogConsumeEnabled": &common.LogConsumeEnabled,
"DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled,
"DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled,
"MjNotifyEnabled": &common.MjNotifyEnabled,
}
var optionStringMap = map[string]*string{

View File

@ -301,5 +301,33 @@ func GetDefaultPrice() []*Price {
})
}
var DefaultMJPrice = map[string]float64{
"mj_imagine": 50,
"mj_variation": 50,
"mj_reroll": 50,
"mj_blend": 50,
"mj_modal": 50,
"mj_zoom": 50,
"mj_shorten": 50,
"mj_high_variation": 50,
"mj_low_variation": 50,
"mj_pan": 50,
"mj_inpaint": 0,
"mj_custom_zoom": 0,
"mj_describe": 25,
"mj_upscale": 25,
"swap_face": 25,
}
for model, mjPrice := range DefaultMJPrice {
prices = append(prices, &Price{
Model: model,
Type: TimesPriceType,
ChannelType: common.ChannelTypeMidjourney,
Input: mjPrice,
Output: mjPrice,
})
}
return prices
}

View File

@ -0,0 +1,121 @@
package midjourney
import (
"bytes"
"context"
"encoding/json"
"io"
"log"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"time"
)
// 定义供应商工厂
type MidjourneyProviderFactory struct{}
// 创建 MidjourneyProvider
func (f MidjourneyProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &MidjourneyProvider{
BaseProvider: base.BaseProvider{
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, nil),
},
}
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "",
}
}
type MidjourneyProvider struct {
base.BaseProvider
}
func (p *MidjourneyProvider) Send(timeout int, requestURL string) (*MidjourneyResponseWithStatusCode, []byte, error) {
var nullBytes []byte
var mapResult map[string]interface{}
if p.Context.Request.Method != "GET" {
err := json.NewDecoder(p.Context.Request.Body).Decode(&mapResult)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
delete(mapResult, "accountFilter")
if !common.MjNotifyEnabled {
delete(mapResult, "notifyHook")
}
}
reqBody, err := json.Marshal(mapResult)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
fullRequestURL := p.GetFullRequestURL(requestURL, "")
var cancel context.CancelFunc
p.Requester.Context, cancel = context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
headers := p.GetRequestHeaders()
defer cancel()
req, err := p.Requester.NewRequest(p.Context.Request.Method, fullRequestURL, p.Requester.WithBody(bytes.NewBuffer(reqBody)), p.Requester.WithHeader(headers))
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
}
resp, errWith := p.Requester.SendRequestRaw(req)
if errWith != nil {
common.SysError("do request failed: " + errWith.Error())
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
}
statusCode := resp.StatusCode
err = req.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
}
err = p.Context.Request.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
}
var midjResponse MidjourneyResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
}
err = resp.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
}
respStr := string(responseBody)
log.Printf("responseBody: %s", respStr)
if respStr == "" {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
} else {
err = json.Unmarshal(responseBody, &midjResponse)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
}
}
return &MidjourneyResponseWithStatusCode{
StatusCode: statusCode,
Response: midjResponse,
}, responseBody, nil
}
func (p *MidjourneyProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
headers["mj-api-secret"] = p.Channel.Key
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
return headers
}

View File

@ -0,0 +1,69 @@
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: relay/constant/relay_mode.go
package midjourney
const (
RelayModeUnknown = iota
RelayModeMidjourneyImagine
RelayModeMidjourneyDescribe
RelayModeMidjourneyBlend
RelayModeMidjourneyChange
RelayModeMidjourneySimpleChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
RelayModeMidjourneyTaskImageSeed
RelayModeMidjourneyTaskFetchByCondition
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
RelayModeMidjourneyAction
RelayModeMidjourneyModal
RelayModeMidjourneyShorten
RelayModeMidjourneySwapFace
)
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: constant/midjourney.go
const (
MjErrorUnknown = 5
MjRequestError = 4
)
const (
MjActionImagine = "IMAGINE"
MjActionDescribe = "DESCRIBE"
MjActionBlend = "BLEND"
MjActionUpscale = "UPSCALE"
MjActionVariation = "VARIATION"
MjActionReRoll = "REROLL"
MjActionInPaint = "INPAINT"
MjActionModal = "MODAL"
MjActionZoom = "ZOOM"
MjActionCustomZoom = "CUSTOM_ZOOM"
MjActionShorten = "SHORTEN"
MjActionHighVariation = "HIGH_VARIATION"
MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE"
)
var MidjourneyModel2Action = map[string]string{
"mj_imagine": MjActionImagine,
"mj_describe": MjActionDescribe,
"mj_blend": MjActionBlend,
"mj_upscale": MjActionUpscale,
"mj_variation": MjActionVariation,
"mj_reroll": MjActionReRoll,
"mj_modal": MjActionModal,
"mj_inpaint": MjActionInPaint,
"mj_zoom": MjActionZoom,
"mj_custom_zoom": MjActionCustomZoom,
"mj_shorten": MjActionShorten,
"mj_high_variation": MjActionHighVariation,
"mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan,
"swap_face": MjActionSwapFace,
}

View File

@ -0,0 +1,18 @@
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: service/error.go
package midjourney
func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *MidjourneyResponseWithStatusCode {
return &MidjourneyResponseWithStatusCode{
StatusCode: statusCode,
Response: *MidjourneyErrorWrapper(code, desc),
}
}
func MidjourneyErrorWrapper(code int, desc string) *MidjourneyResponse {
return &MidjourneyResponse{
Code: code,
Description: desc,
}
}

View File

@ -0,0 +1,92 @@
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: dto/midjourney.go
package midjourney
type SwapFaceRequest struct {
SourceBase64 string `json:"sourceBase64"`
TargetBase64 string `json:"targetBase64"`
}
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
CustomId string `json:"customId"`
BotType string `json:"botType"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
MaskBase64 string `json:"maskBase64"`
}
type MidjourneyResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Properties interface{} `json:"properties"`
Result string `json:"result"`
Type string `json:"type,omitempty"`
}
type MidjourneyResponseWithStatusCode struct {
StatusCode int `json:"statusCode"`
Response MidjourneyResponse
}
type MidjourneyDto struct {
MjId string `json:"id"`
Action string `json:"action"`
CustomId string `json:"customId"`
BotType string `json:"botType"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
Buttons any `json:"buttons"`
MaskBase64 string `json:"maskBase64"`
Properties *Properties `json:"properties"`
}
type MidjourneyStatus struct {
Status int `json:"status"`
}
type MidjourneyWithoutStatus struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
ImageUrl string `json:"image_url"`
Progress string `json:"progress"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
}
type ActionButton struct {
CustomId any `json:"customId"`
Emoji any `json:"emoji"`
Label any `json:"label"`
Type any `json:"type"`
Style any `json:"style"`
}
type Properties struct {
FinalPrompt string `json:"finalPrompt"`
FinalZhPrompt string `json:"finalZhPrompt"`
}

View File

@ -14,6 +14,7 @@ import (
"one-api/providers/deepseek"
"one-api/providers/gemini"
"one-api/providers/groq"
"one-api/providers/midjourney"
"one-api/providers/minimax"
"one-api/providers/mistral"
"one-api/providers/openai"
@ -52,6 +53,7 @@ func init() {
providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{}
providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{}
providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
}

View File

@ -27,7 +27,7 @@ type RelayBaseInterface interface {
}
func (r *relayBase) setProvider(modelName string) error {
provider, modelName, fail := getProvider(r.c, modelName)
provider, modelName, fail := GetProvider(r.c, modelName)
if fail != nil {
return fail
}

View File

@ -45,7 +45,7 @@ func Path2Relay(c *gin.Context, path string) RelayBaseInterface {
return nil
}
func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
func GetProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
channel, fail := fetchChannel(c, modeName)
if fail != nil {
return

19
relay/midjourney/LICENSE Normal file
View File

@ -0,0 +1,19 @@
Copyright (c) 2024 Calcium-Ion
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,578 @@
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: relay/relay-mj.go
package midjourney
import (
"bytes"
"encoding/json"
"io"
"log"
"net/http"
"one-api/common"
"one-api/controller"
"one-api/model"
providersBase "one-api/providers/base"
provider "one-api/providers/midjourney"
"one-api/relay"
"one-api/relay/util"
"one-api/types"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func RelayMidjourneyImage(c *gin.Context) {
taskId := c.Param("id")
midjourneyTask := model.GetByOnlyMJId(taskId)
if midjourneyTask == nil {
c.JSON(400, gin.H{
"error": "midjourney_task_not_found",
})
return
}
resp, err := http.Get(midjourneyTask.ImageUrl)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "http_get_image_failed",
})
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
c.JSON(resp.StatusCode, gin.H{
"error": string(responseBody),
})
return
}
// 从Content-Type头获取MIME类型
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
// 如果无法确定内容类型则默认为jpeg
contentType = "image/jpeg"
}
// 设置响应的内容类型
c.Writer.Header().Set("Content-Type", contentType)
// 将图片流式传输到响应体
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
log.Println("Failed to stream image:", err)
}
}
func RelayMidjourneyNotify(c *gin.Context) *provider.MidjourneyResponse {
var midjRequest provider.MidjourneyDto
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "bind_request_body_failed",
Properties: nil,
Result: "",
}
}
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
if midjourneyTask == nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "midjourney_task_not_found",
Properties: nil,
Result: "",
}
}
midjourneyTask.Progress = midjRequest.Progress
midjourneyTask.PromptEn = midjRequest.PromptEn
midjourneyTask.State = midjRequest.State
midjourneyTask.SubmitTime = midjRequest.SubmitTime
midjourneyTask.StartTime = midjRequest.StartTime
midjourneyTask.FinishTime = midjRequest.FinishTime
midjourneyTask.ImageUrl = midjRequest.ImageUrl
midjourneyTask.Status = midjRequest.Status
midjourneyTask.FailReason = midjRequest.FailReason
err = midjourneyTask.Update()
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "update_midjourney_task_failed",
}
}
return nil
}
func coverMidjourneyTaskDto(originTask *model.Midjourney) (midjourneyTask provider.MidjourneyDto) {
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
midjourneyTask.State = originTask.State
midjourneyTask.SubmitTime = originTask.SubmitTime
midjourneyTask.StartTime = originTask.StartTime
midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" {
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
}
}
midjourneyTask.Status = originTask.Status
midjourneyTask.FailReason = originTask.FailReason
midjourneyTask.Action = originTask.Action
midjourneyTask.Description = originTask.Description
midjourneyTask.Prompt = originTask.Prompt
if originTask.Buttons != "" {
var buttons []provider.ActionButton
err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
if err == nil {
midjourneyTask.Buttons = buttons
}
}
if originTask.Properties != "" {
var properties provider.Properties
err := json.Unmarshal([]byte(originTask.Properties), &properties)
if err == nil {
midjourneyTask.Properties = &properties
}
}
return
}
func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse {
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneySwapFace, 0, nil)
if errWithMJ != nil {
return errWithMJ
}
startTime := time.Now().UnixNano() / int64(time.Millisecond)
userId := c.GetInt("id")
var swapFaceRequest provider.SwapFaceRequest
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
}
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "sour_base64_and_target_base64_is_required")
}
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
if errWithOA != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: errWithOA.Message,
}
}
requestURL := getMjRequestPath(c.Request.URL.String())
mjResp, _, err := mjProvider.Send(60, requestURL)
if err != nil {
quotaInstance.Undo(c)
return &mjResp.Response
}
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1000, TotalTokens: 1000})
} else {
quotaInstance.Undo(c)
}
quota := int(quotaInstance.GetInputRatio() * 1000)
midjResponse := &mjResp.Response
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
Action: provider.MjActionSwapFace,
MjId: midjResponse.Result,
Prompt: "InsightFace",
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: startTime,
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
FinishTime: 0,
ImageUrl: "",
Status: "",
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
Quota: quota,
}
err = midjourneyTask.Insert()
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "insert_midjourney_task_failed")
}
// 开始激活任务
controller.ActivateUpdateMidjourneyTaskBulk()
c.Writer.WriteHeader(mjResp.StatusCode)
respBody, err := json.Marshal(midjResponse)
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "unmarshal_response_body_failed")
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "copy_response_body_failed")
}
return nil
}
func RelayMidjourneyTaskImageSeed(c *gin.Context) *provider.MidjourneyResponse {
taskId := c.Param("id")
userId := c.GetInt("id")
originTask := model.GetByMJId(userId, taskId)
if originTask == nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_no_found")
}
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneyTaskImageSeed, originTask.ChannelId, nil)
if errWithMJ != nil {
return errWithMJ
}
requestURL := getMjRequestPath(c.Request.URL.String())
midjResponseWithStatus, _, err := mjProvider.Send(30, requestURL)
if err != nil {
return &midjResponseWithStatus.Response
}
midjResponse := &midjResponseWithStatus.Response
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
respBody, err := json.Marshal(midjResponse)
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "unmarshal_response_body_failed")
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "copy_response_body_failed")
}
return nil
}
func RelayMidjourneyTask(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
userId := c.GetInt("id")
var err error
var respBody []byte
switch relayMode {
case provider.RelayModeMidjourneyTaskFetch:
taskId := c.Param("id")
originTask := model.GetByMJId(userId, taskId)
if originTask == nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
}
midjourneyTask := coverMidjourneyTaskDto(originTask)
respBody, err = json.Marshal(midjourneyTask)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
case provider.RelayModeMidjourneyTaskFetchByCondition:
var condition = struct {
IDs []string `json:"ids"`
}{}
err = c.BindJSON(&condition)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "do_request_failed",
}
}
var tasks []provider.MidjourneyDto
if len(condition.IDs) != 0 {
originTasks := model.GetByMJIds(userId, condition.IDs)
for _, originTask := range originTasks {
midjourneyTask := coverMidjourneyTaskDto(originTask)
tasks = append(tasks, midjourneyTask)
}
}
if tasks == nil {
tasks = make([]provider.MidjourneyDto, 0)
}
respBody, err = json.Marshal(tasks)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
}
c.Writer.Header().Set("Content-Type", "application/json")
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
return nil
}
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
channelId := 0
userId := c.GetInt("id")
consumeQuota := true
var midjRequest provider.MidjourneyRequest
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
}
if relayMode == provider.RelayModeMidjourneyAction { // midjourney plus需要从customId中获取任务信息
mjErr := CoverPlusActionToNormalAction(&midjRequest)
if mjErr != nil {
return mjErr
}
relayMode = provider.RelayModeMidjourneyChange
}
if relayMode == provider.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "prompt_is_required")
}
midjRequest.Action = provider.MjActionImagine
} else if relayMode == provider.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
midjRequest.Action = provider.MjActionDescribe
} else if relayMode == provider.RelayModeMidjourneyShorten { //缩短任务此类任务可重复plus only
midjRequest.Action = provider.MjActionShorten
} else if relayMode == provider.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = provider.MjActionBlend
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
mjId := ""
if relayMode == provider.RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_id_is_required")
} else if midjRequest.Action == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "action_is_required")
} else if midjRequest.Index == 0 {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "index_is_required")
}
//action = midjRequest.Action
mjId = midjRequest.TaskId
} else if relayMode == provider.RelayModeMidjourneySimpleChange {
if midjRequest.Content == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "content_is_required")
}
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "content_parse_failed")
}
mjId = params.TaskId
midjRequest.Action = params.Action
} else if relayMode == provider.RelayModeMidjourneyModal {
//if midjRequest.MaskBase64 == "" {
// return provider.MidjourneyErrorWrapper(provider.MjRequestError, "mask_base64_is_required")
//}
mjId = midjRequest.TaskId
midjRequest.Action = provider.MjActionModal
}
originTask := model.GetByMJId(userId, mjId)
if originTask == nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_not_found")
} else if originTask.Status != "SUCCESS" && relayMode != provider.RelayModeMidjourneyModal {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_status_not_success")
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
channelId = originTask.ChannelId
log.Printf("检测到此操作为放大、变换、重绘获取原channel信息: %d", originTask.ChannelId)
}
midjRequest.Prompt = originTask.Prompt
//if channelType == common.ChannelTypeMidjourneyPlus {
// // plus
//} else {
// // 普通版渠道
//
//}
}
if midjRequest.Action == provider.MjActionInPaint || midjRequest.Action == provider.MjActionCustomZoom {
consumeQuota = false
}
mjProvider, errWithMJ := getMJProvider(c, relayMode, channelId, &midjRequest)
if errWithMJ != nil {
return errWithMJ
}
//baseURL := common.ChannelBaseURLs[channelType]
requestURL := getMjRequestPath(c.Request.URL.String())
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
if errWithOA != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: errWithOA.Message,
}
}
midjResponseWithStatus, responseBody, err := mjProvider.Send(60, requestURL)
if err != nil {
quotaInstance.Undo(c)
return &midjResponseWithStatus.Response
}
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1, TotalTokens: 1})
} else {
quotaInstance.Undo(c)
}
quota := int(quotaInstance.GetInputRatio() * 1000)
midjResponse := &midjResponseWithStatus.Response
// 文档https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
// 22-排队中 {"code":22,"description":"排队中前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}}
// 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}}
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
// other: 提交错误description为错误描述
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
Action: midjRequest.Action,
MjId: midjResponse.Result,
Prompt: midjRequest.Prompt,
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: time.Now().UnixNano() / int64(time.Millisecond),
StartTime: 0,
FinishTime: 0,
ImageUrl: "",
Status: "",
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
Quota: quota,
}
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
midjourneyTask.FailReason = midjResponse.Description
consumeQuota = false
}
if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了)
// 将 properties 转换为一个 map
properties, ok := midjResponse.Properties.(map[string]interface{})
if ok {
imageUrl, ok1 := properties["imageUrl"].(string)
status, ok2 := properties["status"].(string)
if ok1 && ok2 {
midjourneyTask.ImageUrl = imageUrl
midjourneyTask.Status = status
if status == "SUCCESS" {
midjourneyTask.Progress = "100%"
midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond)
midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond)
midjResponse.Code = 1
}
}
}
//修改返回值
if midjRequest.Action != provider.MjActionInPaint && midjRequest.Action != provider.MjActionCustomZoom {
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
responseBody = []byte(newBody)
}
}
err = midjourneyTask.Insert()
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "insert_midjourney_task_failed",
}
}
// 开始激活任务
controller.ActivateUpdateMidjourneyTaskBulk()
if midjResponse.Code == 22 { //22-排队中,说明任务已存在
//修改返回值
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
responseBody = []byte(newBody)
}
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
//for k, v := range resp.Header {
// c.Writer.Header().Set(k, v[0])
//}
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
_, err = io.Copy(c.Writer, bodyReader)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
err = bodyReader.Close()
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "close_response_body_failed",
}
}
return nil
}
func getMjRequestPath(path string) string {
requestURL := path
if strings.Contains(requestURL, "/mj-") {
urls := strings.Split(requestURL, "/mj/")
if len(urls) < 2 {
return requestURL
}
requestURL = "/mj/" + urls[1]
}
return requestURL
}
func getQuota(c *gin.Context, modelName string) (*util.Quota, *types.OpenAIErrorWithStatusCode) {
// modelName = CoverActionToModelName(modelName)
return util.NewQuota(c, modelName, 1000)
}
func getMJProvider(c *gin.Context, relayMode, channel_id int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
var baseProvider providersBase.ProviderInterface
modelName := ""
if channel_id > 0 {
c.Set("specific_channel_id", channel_id)
}
if request != nil {
midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request)
if mjErr != nil {
return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description)
}
if midjourneyModel == "" {
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型")
}
modelName = midjourneyModel
}
var err error
baseProvider, _, err = relay.GetProvider(c, modelName)
if err != nil {
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无法获取provider:"+err.Error())
}
mjProvider, ok := baseProvider.(*provider.MidjourneyProvider)
if !ok {
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法获取midjourney provider")
}
return mjProvider, nil
}

95
relay/midjourney/relay.go Normal file
View File

@ -0,0 +1,95 @@
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: controller/relay.go
package midjourney
import (
"fmt"
"net/http"
"one-api/common"
provider "one-api/providers/midjourney"
"strings"
"github.com/gin-gonic/gin"
)
func RelayMidjourney(c *gin.Context) {
relayMode := Path2RelayModeMidjourney(c.Request.URL.Path)
var err *provider.MidjourneyResponse
switch relayMode {
case provider.RelayModeMidjourneyNotify:
err = RelayMidjourneyNotify(c)
case provider.RelayModeMidjourneyTaskFetch, provider.RelayModeMidjourneyTaskFetchByCondition:
err = RelayMidjourneyTask(c, relayMode)
case provider.RelayModeMidjourneyTaskImageSeed:
err = RelayMidjourneyTaskImageSeed(c)
case provider.RelayModeMidjourneySwapFace:
err = RelaySwapFace(c)
default:
err = RelayMidjourneySubmit(c, relayMode)
}
if err != nil {
statusCode := http.StatusBadRequest
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
statusCode = http.StatusTooManyRequests
}
typeMsg := "upstream_error"
if err.Type != "" {
typeMsg = err.Type
}
c.JSON(statusCode, gin.H{
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
"type": typeMsg,
"code": err.Code,
})
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
}
}
func MidjourneyErrorFromInternal(code int, description string) *provider.MidjourneyResponse {
return &provider.MidjourneyResponse{
Code: code,
Description: description,
Type: "internal_error",
}
}
func Path2RelayModeMidjourney(path string) int {
relayMode := provider.RelayModeUnknown
if strings.HasSuffix(path, "/mj/submit/action") {
// midjourney plus
relayMode = provider.RelayModeMidjourneyAction
} else if strings.HasSuffix(path, "/mj/submit/modal") {
// midjourney plus
relayMode = provider.RelayModeMidjourneyModal
} else if strings.HasSuffix(path, "/mj/submit/shorten") {
// midjourney plus
relayMode = provider.RelayModeMidjourneyShorten
} else if strings.HasSuffix(path, "/mj/insight-face/swap") {
// midjourney plus
relayMode = provider.RelayModeMidjourneySwapFace
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
relayMode = provider.RelayModeMidjourneyImagine
} else if strings.HasSuffix(path, "/mj/submit/blend") {
relayMode = provider.RelayModeMidjourneyBlend
} else if strings.HasSuffix(path, "/mj/submit/describe") {
relayMode = provider.RelayModeMidjourneyDescribe
} else if strings.HasSuffix(path, "/mj/notify") {
relayMode = provider.RelayModeMidjourneyNotify
} else if strings.HasSuffix(path, "/mj/submit/change") {
relayMode = provider.RelayModeMidjourneyChange
} else if strings.HasSuffix(path, "/mj/submit/simple-change") {
relayMode = provider.RelayModeMidjourneyChange
} else if strings.HasSuffix(path, "/fetch") {
relayMode = provider.RelayModeMidjourneyTaskFetch
} else if strings.HasSuffix(path, "/image-seed") {
relayMode = provider.RelayModeMidjourneyTaskImageSeed
} else if strings.HasSuffix(path, "/list-by-condition") {
relayMode = provider.RelayModeMidjourneyTaskFetchByCondition
}
return relayMode
}

148
relay/midjourney/service.go Normal file
View File

@ -0,0 +1,148 @@
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: service/midjourney.go
package midjourney
import (
mjProvider "one-api/providers/midjourney"
"strconv"
"strings"
)
func CoverActionToModelName(mjAction string) string {
modelName := "mj_" + strings.ToLower(mjAction)
if mjAction == mjProvider.MjActionSwapFace {
modelName = "swap_face"
}
return modelName
}
func GetMjRequestModel(relayMode int, midjRequest *mjProvider.MidjourneyRequest) (string, *mjProvider.MidjourneyResponse, bool) {
action := ""
if relayMode == mjProvider.RelayModeMidjourneyAction {
// plus request
err := CoverPlusActionToNormalAction(midjRequest)
if err != nil {
return "", err, false
}
action = midjRequest.Action
} else {
switch relayMode {
case mjProvider.RelayModeMidjourneyImagine:
action = mjProvider.MjActionImagine
case mjProvider.RelayModeMidjourneyDescribe:
action = mjProvider.MjActionDescribe
case mjProvider.RelayModeMidjourneyBlend:
action = mjProvider.MjActionBlend
case mjProvider.RelayModeMidjourneyShorten:
action = mjProvider.MjActionShorten
case mjProvider.RelayModeMidjourneyChange:
action = midjRequest.Action
case mjProvider.RelayModeMidjourneyModal:
action = mjProvider.MjActionModal
case mjProvider.RelayModeMidjourneySwapFace:
action = mjProvider.MjActionSwapFace
case mjProvider.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
return "", mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "invalid_request"), false
}
action = params.Action
case mjProvider.RelayModeMidjourneyTaskFetch, mjProvider.RelayModeMidjourneyTaskFetchByCondition, mjProvider.RelayModeMidjourneyNotify:
return "", nil, true
default:
return "", mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_relay_action"), false
}
}
modelName := CoverActionToModelName(action)
return modelName, nil, true
}
func CoverPlusActionToNormalAction(midjRequest *mjProvider.MidjourneyRequest) *mjProvider.MidjourneyResponse {
// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
customId := midjRequest.CustomId
if customId == "" {
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "custom_id_is_required")
}
splits := strings.Split(customId, "::")
var action string
if splits[1] == "JOB" {
action = splits[2]
} else {
action = splits[1]
}
if action == "" {
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_action")
}
if strings.Contains(action, "upsample") {
index, err := strconv.Atoi(splits[3])
if err != nil {
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "index_parse_failed")
}
midjRequest.Index = index
midjRequest.Action = mjProvider.MjActionUpscale
} else if strings.Contains(action, "variation") {
midjRequest.Index = 1
if action == "variation" {
index, err := strconv.Atoi(splits[3])
if err != nil {
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "index_parse_failed")
}
midjRequest.Index = index
midjRequest.Action = mjProvider.MjActionVariation
} else if action == "low_variation" {
midjRequest.Action = mjProvider.MjActionLowVariation
} else if action == "high_variation" {
midjRequest.Action = mjProvider.MjActionHighVariation
}
} else if strings.Contains(action, "pan") {
midjRequest.Action = mjProvider.MjActionPan
midjRequest.Index = 1
} else if strings.Contains(action, "reroll") {
midjRequest.Action = mjProvider.MjActionReRoll
midjRequest.Index = 1
} else if action == "Outpaint" {
midjRequest.Action = mjProvider.MjActionZoom
midjRequest.Index = 1
} else if action == "CustomZoom" {
midjRequest.Action = mjProvider.MjActionCustomZoom
midjRequest.Index = 1
} else if action == "Inpaint" {
midjRequest.Action = mjProvider.MjActionInPaint
midjRequest.Index = 1
} else {
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_action:"+customId)
}
return nil
}
func ConvertSimpleChangeParams(content string) *mjProvider.MidjourneyRequest {
split := strings.Split(content, " ")
if len(split) != 2 {
return nil
}
action := strings.ToLower(split[1])
changeParams := &mjProvider.MidjourneyRequest{}
changeParams.TaskId = split[0]
if action[0] == 'u' {
changeParams.Action = "UPSCALE"
} else if action[0] == 'v' {
changeParams.Action = "VARIATION"
} else if action == "r" {
changeParams.Action = "REROLL"
return changeParams
} else {
return nil
}
index, err := strconv.Atoi(action[1:2])
if err != nil || index < 1 || index > 4 {
return nil
}
changeParams.Index = index
return changeParams
}

View File

@ -170,3 +170,7 @@ func (q *Quota) Consume(c *gin.Context, usage *types.Usage) {
}
}(c.Request.Context())
}
func (q *Quota) GetInputRatio() float64 {
return q.inputRatio
}

View File

@ -7,22 +7,23 @@ var ModelOwnedBy map[int]string
func init() {
ModelOwnedBy = map[int]string{
common.ChannelTypeOpenAI: "OpenAI",
common.ChannelTypeAnthropic: "Anthropic",
common.ChannelTypeBaidu: "Baidu",
common.ChannelTypePaLM: "Google PaLM",
common.ChannelTypeGemini: "Google Gemini",
common.ChannelTypeZhipu: "Zhipu",
common.ChannelTypeAli: "Ali",
common.ChannelTypeXunfei: "Xunfei",
common.ChannelType360: "360",
common.ChannelTypeTencent: "Tencent",
common.ChannelTypeBaichuan: "Baichuan",
common.ChannelTypeMiniMax: "MiniMax",
common.ChannelTypeDeepseek: "Deepseek",
common.ChannelTypeMoonshot: "Moonshot",
common.ChannelTypeMistral: "Mistral",
common.ChannelTypeGroq: "Groq",
common.ChannelTypeLingyi: "Lingyiwanwu",
common.ChannelTypeOpenAI: "OpenAI",
common.ChannelTypeAnthropic: "Anthropic",
common.ChannelTypeBaidu: "Baidu",
common.ChannelTypePaLM: "Google PaLM",
common.ChannelTypeGemini: "Google Gemini",
common.ChannelTypeZhipu: "Zhipu",
common.ChannelTypeAli: "Ali",
common.ChannelTypeXunfei: "Xunfei",
common.ChannelType360: "360",
common.ChannelTypeTencent: "Tencent",
common.ChannelTypeBaichuan: "Baichuan",
common.ChannelTypeMiniMax: "MiniMax",
common.ChannelTypeDeepseek: "Deepseek",
common.ChannelTypeMoonshot: "Moonshot",
common.ChannelTypeMistral: "Mistral",
common.ChannelTypeGroq: "Groq",
common.ChannelTypeLingyi: "Lingyiwanwu",
common.ChannelTypeMidjourney: "Midjourney",
}
}

View File

@ -145,6 +145,9 @@ func SetApiRouter(router *gin.Engine) {
}
mjRoute := apiRouter.Group("/mj")
mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney)
mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney)
}
}

View File

@ -1,10 +1,11 @@
package router
import (
"github.com/gin-contrib/gzip"
"github.com/gin-gonic/gin"
"one-api/controller"
"one-api/middleware"
"github.com/gin-contrib/gzip"
"github.com/gin-gonic/gin"
)
func SetDashboardRouter(router *gin.Engine) {
@ -12,7 +13,7 @@ func SetDashboardRouter(router *gin.Engine) {
apiRouter := router.Group("/")
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
apiRouter.Use(middleware.GlobalAPIRateLimit())
apiRouter.Use(middleware.TokenAuth())
apiRouter.Use(middleware.OpenaiAuth())
{
apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription)
apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription)

View File

@ -4,6 +4,7 @@ import (
"one-api/controller"
"one-api/middleware"
"one-api/relay"
"one-api/relay/midjourney"
"github.com/gin-gonic/gin"
)
@ -11,14 +12,19 @@ import (
func SetRelayRouter(router *gin.Engine) {
router.Use(middleware.CORS())
// https://platform.openai.com/docs/api-reference/introduction
setOpenAIRouter(router)
setMJRouter(router)
}
func setOpenAIRouter(router *gin.Engine) {
modelsRouter := router.Group("/v1/models")
modelsRouter.Use(middleware.TokenAuth(), middleware.Distribute())
modelsRouter.Use(middleware.OpenaiAuth(), middleware.Distribute())
{
modelsRouter.GET("", relay.ListModels)
modelsRouter.GET("/:model", relay.RetrieveModel)
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.OpenaiAuth(), middleware.Distribute())
{
relayV1Router.POST("/completions", relay.Relay)
relayV1Router.POST("/chat/completions", relay.Relay)
@ -71,3 +77,34 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented)
}
}
func setMJRouter(router *gin.Engine) {
relayMjRouter := router.Group("/mj")
registerMjRouterGroup(relayMjRouter)
relayMjModeRouter := router.Group("/:mode/mj")
registerMjRouterGroup(relayMjModeRouter)
}
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: router/relay-router.go
func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
relayMjRouter.GET("/image/:id", midjourney.RelayMidjourneyImage)
relayMjRouter.Use(middleware.MjAuth(), middleware.Distribute())
{
relayMjRouter.POST("/submit/action", midjourney.RelayMidjourney)
relayMjRouter.POST("/submit/shorten", midjourney.RelayMidjourney)
relayMjRouter.POST("/submit/modal", midjourney.RelayMidjourney)
relayMjRouter.POST("/submit/imagine", midjourney.RelayMidjourney)
relayMjRouter.POST("/submit/change", midjourney.RelayMidjourney)
relayMjRouter.POST("/submit/simple-change", midjourney.RelayMidjourney)
relayMjRouter.POST("/submit/describe", midjourney.RelayMidjourney)
relayMjRouter.POST("/submit/blend", midjourney.RelayMidjourney)
relayMjRouter.POST("/notify", midjourney.RelayMidjourney)
relayMjRouter.GET("/task/:id/fetch", midjourney.RelayMidjourney)
relayMjRouter.GET("/task/:id/image-seed", midjourney.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", midjourney.RelayMidjourney)
relayMjRouter.POST("/insight-face/swap", midjourney.RelayMidjourney)
}
}

View File

@ -7,7 +7,7 @@
使用了以下开源项目作为我们项目的一部分:
- [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template)
- [minimal-ui-kit](minimal-ui-kit)
- [minimal-ui-kit](https://github.com/minimal-ui-kit/material-kit-react)
## 许可证

View File

@ -132,6 +132,13 @@ export const CHANNEL_OPTIONS = {
color: 'primary',
url: 'https://platform.lingyiwanwu.com/details'
},
34: {
key: 34,
text: 'Midjourney',
value: 34,
color: 'orange',
url: ''
},
24: {
key: 24,
text: 'Azure Speech',

View File

@ -11,7 +11,8 @@ import {
IconUserScan,
IconActivity,
IconBrandTelegram,
IconReceipt2
IconReceipt2,
IconBrush
} from '@tabler/icons-react';
// constant
@ -27,7 +28,8 @@ const icons = {
IconUserScan,
IconActivity,
IconBrandTelegram,
IconReceipt2
IconReceipt2,
IconBrush
};
// ==============================|| DASHBOARD MENU ITEMS ||============================== //
@ -96,6 +98,14 @@ const panel = {
icon: icons.IconGardenCart,
breadcrumbs: false
},
{
id: 'midjourney',
title: 'Midjourney',
type: 'item',
url: '/panel/midjourney',
icon: icons.IconBrush,
breadcrumbs: false
},
{
id: 'user',
title: '用户',

View File

@ -16,6 +16,7 @@ const NotFoundView = Loadable(lazy(() => import('views/Error')));
const Analytics = Loadable(lazy(() => import('views/Analytics')));
const Telegram = Loadable(lazy(() => import('views/Telegram')));
const Pricing = Loadable(lazy(() => import('views/Pricing')));
const Midjourney = Loadable(lazy(() => import('views/Midjourney')));
// dashboard routing
const Dashboard = Loadable(lazy(() => import('views/Dashboard')));
@ -81,6 +82,10 @@ const MainRoutes = {
{
path: 'pricing',
element: <Pricing />
},
{
path: 'midjourney',
element: <Midjourney />
}
]
};

View File

@ -12,15 +12,7 @@ export default function componentStyleOverrides(theme) {
}
}
},
MuiMenuItem: {
styleOverrides: {
root: {
'&:hover': {
backgroundColor: theme.colors?.grey100
}
}
}
}, //MuiAutocomplete-popper MuiPopover-root
//MuiAutocomplete-popper MuiPopover-root
MuiAutocomplete: {
styleOverrides: {
popper: {
@ -247,7 +239,7 @@ export default function componentStyleOverrides(theme) {
MuiTooltip: {
styleOverrides: {
tooltip: {
color: theme.paper,
color: theme.colors.paper,
background: theme.colors?.grey700
}
}
@ -266,6 +258,9 @@ export default function componentStyleOverrides(theme) {
.apexcharts-menu {
background: ${theme.backgroundDefault} !important
}
.apexcharts-gridline, .apexcharts-xaxistooltip-background, .apexcharts-yaxistooltip-background {
stroke: ${theme.divider} !important;
}
`
}
};

View File

@ -19,14 +19,14 @@ const Footer = () => {
{siteInfo.system_name} {process.env.REACT_APP_VERSION}{' '}
</Link>
{' '}
<Link href="https://github.com/songquanpeng" target="_blank">
JustSong
</Link>{' '}
构建
<Link href="https://github.com/MartialBE" target="_blank">
MartialBE
</Link>
修改源代码遵循
开发基于
<Link href="https://github.com/songquanpeng" target="_blank">
JustSong
</Link>{' '}
One API源代码遵循
<Link href="https://opensource.org/licenses/mit-license.php"> MIT 协议</Link>
</>
)}

View File

@ -234,6 +234,33 @@ const typeConfig = {
test_model: 'yi-34b-chat-0205'
},
modelGroup: 'Lingyiwanwu'
},
34: {
input: {
models: [
'mj_imagine',
'mj_variation',
'mj_reroll',
'mj_blend',
'mj_modal',
'mj_zoom',
'mj_shorten',
'mj_high_variation',
'mj_low_variation',
'mj_pan',
'mj_inpaint',
'mj_custom_zoom',
'mj_describe',
'mj_upscale',
'swap_face'
]
},
prompt: {
key: '密钥填写midjourney-proxy的密钥如果没有设置密钥可以随便填',
base_url: '地址填写midjourney-proxy部署的地址',
test_model: ''
},
modelGroup: 'Midjourney'
}
};

View File

@ -0,0 +1,174 @@
import PropTypes from 'prop-types';
import { useState } from 'react';
import {
TableRow,
TableCell,
Button,
Dialog,
DialogActions,
DialogContent,
ButtonGroup,
Popover,
MenuItem,
MenuList,
Tooltip
} from '@mui/material';
import { timestamp2string, copy } from 'utils/common';
import Label from 'ui-component/Label';
import { ACTION_TYPE, CODE_TYPE, STATUS_TYPE } from '../type/Type';
import { IconCaretDownFilled, IconCopy, IconDownload, IconExternalLink } from '@tabler/icons-react';
function renderType(types, type) {
const typeOption = types[type];
if (typeOption) {
return (
<Label variant="filled" color={typeOption.color}>
{' '}
{typeOption.text}{' '}
</Label>
);
} else {
return (
<Label variant="filled" color="error">
{' '}
未知{' '}
</Label>
);
}
}
async function downloadImage(url, filename) {
const response = await fetch(url);
const blob = await response.blob();
const blobUrl = URL.createObjectURL(blob);
const link = document.createElement('a');
link.href = blobUrl;
link.download = filename;
link.click();
URL.revokeObjectURL(blobUrl);
}
function TruncatedText(text) {
const truncatedText = text.length > 30 ? text.substring(0, 100) + '...' : text;
return (
<Tooltip
placement="top"
title={text}
onClick={() => {
copy(text, '');
}}
>
<span>{truncatedText}</span>
</Tooltip>
);
}
export default function LogTableRow({ item, userIsAdmin }) {
const [open, setOpen] = useState(false);
const [menuOpen, setMenuOpen] = useState(null);
const handleClickOpen = () => {
setOpen(true);
};
const handleClose = () => {
setOpen(false);
};
const handleOpenMenu = (event) => {
setMenuOpen(event.currentTarget);
};
const handleCloseMenu = () => {
setMenuOpen(null);
};
return (
<>
<TableRow tabIndex={item.id}>
<TableCell>{item.mj_id}</TableCell>
<TableCell>{timestamp2string(item.submit_time / 1000)}</TableCell>
{userIsAdmin && <TableCell>{item.channel_id || ''}</TableCell>}
{userIsAdmin && <TableCell>{item.user_id || ''}</TableCell>}
<TableCell>{renderType(ACTION_TYPE, item.action)}</TableCell>
{userIsAdmin && <TableCell>{renderType(CODE_TYPE, item.code)}</TableCell>}
{userIsAdmin && <TableCell>{renderType(STATUS_TYPE, item.status)}</TableCell>}
<TableCell>{item.progress}</TableCell>
<TableCell>
{item.image_url == '' ? (
'无'
) : (
<ButtonGroup size="small" aria-label="split button">
<Button color="primary" onClick={handleClickOpen}>
显示
</Button>
<Button onClick={handleOpenMenu}>
<IconCaretDownFilled size={'16px'} />
</Button>
</ButtonGroup>
)}
</TableCell>
<TableCell>{TruncatedText(item.prompt)}</TableCell>
<TableCell>{TruncatedText(item.prompt_en)}</TableCell>
<TableCell>{TruncatedText(item.fail_reason)}</TableCell>
</TableRow>
<Dialog open={open} onClose={handleClose}>
<DialogContent>
<img src={item.image_url} alt="item" style={{ maxWidth: '100%', maxHeight: '100%' }} />
</DialogContent>
<DialogActions>
<Button onClick={handleClose} color="primary">
关闭
</Button>
</DialogActions>
</Dialog>
<Popover
open={!!menuOpen}
anchorEl={menuOpen}
onClose={handleCloseMenu}
anchorOrigin={{ vertical: 'top', horizontal: 'left' }}
transformOrigin={{ vertical: 'top', horizontal: 'right' }}
PaperProps={{
sx: { width: 140 }
}}
>
<MenuList>
<MenuItem
onClick={() => {
handleCloseMenu();
copy(item.image_url, '图片地址');
}}
>
<IconCopy style={{ marginRight: '16px' }} />
复制地址
</MenuItem>
<MenuItem
onClick={async () => {
handleCloseMenu();
await downloadImage(item.image_url, item.mj_id + '.png');
}}
>
<IconDownload style={{ marginRight: '16px' }} /> 下载图片{' '}
</MenuItem>
<MenuItem
onClick={() => {
handleCloseMenu();
}}
>
<IconExternalLink style={{ marginRight: '16px' }} /> 新窗口打开{' '}
</MenuItem>
</MenuList>
</Popover>
</>
);
}
LogTableRow.propTypes = {
item: PropTypes.object,
userIsAdmin: PropTypes.bool
};

View File

@ -0,0 +1,113 @@
import PropTypes from 'prop-types';
import { useTheme } from '@mui/material/styles';
import { IconBroadcast, IconCalendarEvent } from '@tabler/icons-react';
import { InputAdornment, OutlinedInput, Stack, FormControl, InputLabel } from '@mui/material';
import { LocalizationProvider, DateTimePicker } from '@mui/x-date-pickers';
import { AdapterDayjs } from '@mui/x-date-pickers/AdapterDayjs';
import dayjs from 'dayjs';
require('dayjs/locale/zh-cn');
// ----------------------------------------------------------------------
export default function TableToolBar({ filterName, handleFilterName, userIsAdmin }) {
const theme = useTheme();
const grey500 = theme.palette.grey[500];
return (
<>
<Stack direction={{ xs: 'column', sm: 'row' }} spacing={{ xs: 3, sm: 2, md: 4 }} padding={'24px'} paddingBottom={'0px'}>
{userIsAdmin && (
<FormControl>
<InputLabel htmlFor="channel-channel_id-label">渠道ID</InputLabel>
<OutlinedInput
id="channel_id"
name="channel_id"
sx={{
minWidth: '100%'
}}
label="渠道ID"
value={filterName.channel_id}
onChange={handleFilterName}
placeholder="渠道ID"
startAdornment={
<InputAdornment position="start">
<IconBroadcast stroke={1.5} size="20px" color={grey500} />
</InputAdornment>
}
/>
</FormControl>
)}
<FormControl>
<InputLabel htmlFor="channel-mj_id-label">任务ID</InputLabel>
<OutlinedInput
id="mj_id"
name="mj_id"
sx={{
minWidth: '100%'
}}
label="任务ID"
value={filterName.mj_id}
onChange={handleFilterName}
placeholder="任务ID"
startAdornment={
<InputAdornment position="start">
<IconCalendarEvent stroke={1.5} size="20px" color={grey500} />
</InputAdornment>
}
/>
</FormControl>
<FormControl>
<LocalizationProvider dateAdapter={AdapterDayjs} adapterLocale={'zh-cn'}>
<DateTimePicker
label="起始时间"
ampm={false}
name="start_timestamp"
value={filterName.start_timestamp === 0 ? null : dayjs.unix(filterName.start_timestamp / 1000)}
onChange={(value) => {
if (value === null) {
handleFilterName({ target: { name: 'start_timestamp', value: 0 } });
return;
}
handleFilterName({ target: { name: 'start_timestamp', value: value.unix() * 1000 } });
}}
slotProps={{
actionBar: {
actions: ['clear', 'today', 'accept']
}
}}
/>
</LocalizationProvider>
</FormControl>
<FormControl>
<LocalizationProvider dateAdapter={AdapterDayjs} adapterLocale={'zh-cn'}>
<DateTimePicker
label="结束时间"
name="end_timestamp"
ampm={false}
value={filterName.end_timestamp === 0 ? null : dayjs.unix(filterName.end_timestamp / 1000)}
onChange={(value) => {
if (value === null) {
handleFilterName({ target: { name: 'end_timestamp', value: 0 } });
return;
}
handleFilterName({ target: { name: 'end_timestamp', value: value.unix() * 1000 } });
}}
slotProps={{
actionBar: {
actions: ['clear', 'today', 'accept']
}
}}
/>
</LocalizationProvider>
</FormControl>
</Stack>
</>
);
}
TableToolBar.propTypes = {
filterName: PropTypes.object,
handleFilterName: PropTypes.func,
userIsAdmin: PropTypes.bool
};

View File

@ -0,0 +1,247 @@
import { useState, useEffect, useCallback } from 'react';
import { showError } from 'utils/common';
import Table from '@mui/material/Table';
import TableBody from '@mui/material/TableBody';
import TableContainer from '@mui/material/TableContainer';
import PerfectScrollbar from 'react-perfect-scrollbar';
import TablePagination from '@mui/material/TablePagination';
import LinearProgress from '@mui/material/LinearProgress';
import ButtonGroup from '@mui/material/ButtonGroup';
import Toolbar from '@mui/material/Toolbar';
import { Button, Card, Stack, Container, Typography, Box } from '@mui/material';
import LogTableRow from './component/TableRow';
import KeywordTableHead from 'ui-component/TableHead';
import TableToolBar from './component/TableToolBar';
import { API } from 'utils/api';
import { isAdmin } from 'utils/common';
import { ITEMS_PER_PAGE } from 'constants';
import { IconRefresh, IconSearch } from '@tabler/icons-react';
import dayjs from 'dayjs';
export default function Log() {
const originalKeyword = {
p: 0,
channel_id: '',
mj_id: '',
start_timestamp: 0,
end_timestamp: dayjs().unix() * 1000 + 3600
};
const [page, setPage] = useState(0);
const [order, setOrder] = useState('desc');
const [orderBy, setOrderBy] = useState('id');
const [rowsPerPage, setRowsPerPage] = useState(ITEMS_PER_PAGE);
const [listCount, setListCount] = useState(0);
const [searching, setSearching] = useState(false);
const [toolBarValue, setToolBarValue] = useState(originalKeyword);
const [searchKeyword, setSearchKeyword] = useState(originalKeyword);
const [refreshFlag, setRefreshFlag] = useState(false);
const [logs, setLogs] = useState([]);
const userIsAdmin = isAdmin();
const handleSort = (event, id) => {
const isAsc = orderBy === id && order === 'asc';
if (id !== '') {
setOrder(isAsc ? 'desc' : 'asc');
setOrderBy(id);
}
};
const handleChangePage = (event, newPage) => {
setPage(newPage);
};
const handleChangeRowsPerPage = (event) => {
setPage(0);
setRowsPerPage(parseInt(event.target.value, 10));
};
const searchLogs = async () => {
setPage(0);
setSearchKeyword(toolBarValue);
};
const handleToolBarValue = (event) => {
setToolBarValue({ ...toolBarValue, [event.target.name]: event.target.value });
};
const fetchData = useCallback(
async (page, rowsPerPage, keyword, order, orderBy) => {
setSearching(true);
try {
if (orderBy) {
orderBy = order === 'desc' ? '-' + orderBy : orderBy;
}
const url = userIsAdmin ? '/api/mj/' : '/api/mj/self/';
if (!userIsAdmin) {
delete keyword.channel_id;
}
const res = await API.get(url, {
params: {
page: page + 1,
size: rowsPerPage,
order: orderBy,
...keyword
}
});
const { success, message, data } = res.data;
if (success) {
setListCount(data.total_count);
setLogs(data.data);
} else {
showError(message);
}
} catch (error) {
console.error(error);
}
setSearching(false);
},
[userIsAdmin]
);
// 处理刷新
const handleRefresh = async () => {
setOrderBy('id');
setOrder('desc');
setToolBarValue(originalKeyword);
setSearchKeyword(originalKeyword);
setRefreshFlag(!refreshFlag);
};
useEffect(() => {
fetchData(page, rowsPerPage, searchKeyword, order, orderBy);
}, [page, rowsPerPage, searchKeyword, order, orderBy, fetchData, refreshFlag]);
return (
<>
<Stack direction="row" alignItems="center" justifyContent="space-between" mb={5}>
<Typography variant="h4">Midjourney</Typography>
</Stack>
<Card>
<Box component="form" noValidate>
<TableToolBar filterName={toolBarValue} handleFilterName={handleToolBarValue} userIsAdmin={userIsAdmin} />
</Box>
<Toolbar
sx={{
textAlign: 'right',
height: 50,
display: 'flex',
justifyContent: 'space-between',
p: (theme) => theme.spacing(0, 1, 0, 3)
}}
>
<Container>
<ButtonGroup variant="outlined" aria-label="outlined small primary button group">
<Button onClick={handleRefresh} startIcon={<IconRefresh width={'18px'} />}>
刷新/清除搜索条件
</Button>
<Button onClick={searchLogs} startIcon={<IconSearch width={'18px'} />}>
搜索
</Button>
</ButtonGroup>
</Container>
</Toolbar>
{searching && <LinearProgress />}
<PerfectScrollbar component="div">
<TableContainer sx={{ overflow: 'unset' }}>
<Table sx={{ minWidth: 800 }}>
<KeywordTableHead
order={order}
orderBy={orderBy}
onRequestSort={handleSort}
headLabel={[
{
id: 'mj_id',
label: '任务ID',
disableSort: false
},
{
id: 'submit_time',
label: '提交时间',
disableSort: false
},
{
id: 'channel_id',
label: '渠道',
disableSort: false,
hide: !userIsAdmin
},
{
id: 'user_id',
label: '用户',
disableSort: false,
hide: !userIsAdmin
},
{
id: 'action',
label: '类型',
disableSort: false
},
{
id: 'code',
label: '提交结果',
disableSort: false,
hide: !userIsAdmin
},
{
id: 'status',
label: '任务状态',
disableSort: false,
hide: !userIsAdmin
},
{
id: 'progress',
label: '进度',
disableSort: true
},
{
id: 'image_url',
label: '结果图片',
disableSort: true,
width: '120px'
},
{
id: 'prompt',
label: 'Prompt',
disableSort: true
},
{
id: 'prompt_en',
label: 'PromptEn',
disableSort: true
},
{
id: 'fail_reason',
label: '失败原因',
disableSort: true
}
]}
/>
<TableBody>
{logs.map((row, index) => (
<LogTableRow item={row} key={`${row.id}_${index}`} userIsAdmin={userIsAdmin} />
))}
</TableBody>
</Table>
</TableContainer>
</PerfectScrollbar>
<TablePagination
page={page}
component="div"
count={listCount}
rowsPerPage={rowsPerPage}
onPageChange={handleChangePage}
rowsPerPageOptions={[10, 25, 30]}
onRowsPerPageChange={handleChangeRowsPerPage}
showFirstButton
showLastButton
/>
</Card>
</>
);
}

View File

@ -0,0 +1,33 @@
export const ACTION_TYPE = {
IMAGINE: { value: 'IMAGINE', text: '绘图', color: 'primary' },
UPSCALE: { value: 'UPSCALE', text: '放大', color: 'orange' },
VARIATION: { value: 'VARIATION', text: '变换', color: 'default' },
HIGH_VARIATION: { value: 'HIGH_VARIATION', text: '强变换', color: 'default' },
LOW_VARIATION: { value: 'LOW_VARIATION', text: '弱变换', color: 'default' },
PAN: { value: 'PAN', text: '平移', color: 'secondary' },
DESCRIBE: { value: 'DESCRIBE', text: '图生文', color: 'secondary' },
BLEND: { value: 'BLEND', text: '图混合', color: 'secondary' },
SHORTEN: { value: 'SHORTEN', text: '缩词', color: 'secondary' },
REROLL: { value: 'REROLL', text: '重绘', color: 'secondary' },
INPAINT: { value: 'INPAINT', text: '局部重绘-提交', color: 'secondary' },
ZOOM: { value: 'ZOOM', text: '变焦', color: 'secondary' },
CUSTOM_ZOOM: { value: 'CUSTOM_ZOOM', text: '自定义变焦-提交', color: 'secondary' },
MODAL: { value: 'MODAL', text: '窗口处理', color: 'secondary' },
SWAP_FACE: { value: 'SWAP_FACE', text: '换脸', color: 'secondary' }
};
export const CODE_TYPE = {
1: { value: 1, text: '已提交', color: 'primary' },
21: { value: 21, text: '等待中', color: 'orange' },
22: { value: 22, text: '重复提交', color: 'default' },
0: { value: 0, text: '未提交', color: 'default' }
};
export const STATUS_TYPE = {
SUCCESS: { value: 'SUCCESS', text: '成功', color: 'success' },
NOT_START: { value: 'NOT_START', text: '未启动', color: 'default' },
SUBMITTED: { value: 'SUBMITTED', text: '队列中', color: 'secondary' },
IN_PROGRESS: { value: 'IN_PROGRESS', text: '执行中', color: 'primary' },
FAILURE: { value: 'FAILURE', text: '失败', color: 'orange' },
MODAL: { value: 'MODAL', text: '窗口等待', color: 'default' }
};

View File

@ -29,7 +29,8 @@ const OperationSetting = () => {
DisplayTokenStatEnabled: '',
ApproximateTokenEnabled: '',
RetryTimes: 0,
RetryCooldownSeconds: 0
RetryCooldownSeconds: 0,
MjNotifyEnabled: ''
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
@ -278,6 +279,22 @@ const OperationSetting = () => {
</Button>
</Stack>
</SubCard>
<SubCard title="其他设置">
<Stack justifyContent="flex-start" alignItems="flex-start" spacing={2}>
<Stack
direction={{ sm: 'column', md: 'row' }}
spacing={{ xs: 3, sm: 2, md: 4 }}
justifyContent="flex-start"
alignItems="flex-start"
>
<FormControlLabel
sx={{ marginLeft: '0px' }}
label="Midjourney 允许回调会泄露服务器ip地址"
control={<Checkbox checked={inputs.MjNotifyEnabled === 'true'} onChange={handleInputChange} name="MjNotifyEnabled" />}
/>
</Stack>
</Stack>
</SubCard>
<SubCard title="日志设置">
<Stack direction="column" justifyContent="flex-start" alignItems="flex-start" spacing={2}>
<FormControlLabel