diff --git a/Midjourney.md b/Midjourney.md
new file mode 100644
index 00000000..8141d448
--- /dev/null
+++ b/Midjourney.md
@@ -0,0 +1,254 @@
+# Midjourney Proxy API文档
+
+
+**简介**:Midjourney Proxy API文档
+
+
+**HOST**:https://api.nekoedu.com
+
+
+**Version**:v2.3.5
+
+
+[TOC]
+
+
+
+
+
+
+# 任务提交
+
+
+## 绘图变化
+
+
+**接口地址**:`/mj/submit/change`
+
+
+**请求方式**:`POST`
+
+
+**请求数据类型**:`application/json`
+
+
+**响应数据类型**:`*/*`
+
+
+**接口描述**:
+
+
+**请求示例**:
+
+
+```javascript
+{
+ "action": "UPSCALE",
+ "index": 1,
+ "notifyHook": "",
+ "state": "",
+ "taskId": "1320098173412546"
+}
+```
+
+
+**请求参数**:
+
+
+| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
+| -------- | -------- | ----- | -------- | -------- | ------ |
+|changeDTO|changeDTO|body|true|变化任务提交参数|变化任务提交参数|
+| action|UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL||true|string||
+| index|序号(1~4), action为UPSCALE,VARIATION时必传||false|integer(int32)||
+| notifyHook|回调地址, 为空时使用全局notifyHook||false|string||
+| state|自定义参数||false|string||
+| taskId|任务ID||true|string||
+
+
+**响应状态**:
+
+
+| 状态码 | 说明 | schema |
+| -------- | -------- | ----- |
+|200|OK|提交结果|
+|201|Created||
+|401|Unauthorized||
+|403|Forbidden||
+|404|Not Found||
+
+
+**响应参数**:
+
+
+| 参数名称 | 参数说明 | 类型 | schema |
+| -------- | -------- | ----- |----- |
+|code|状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)|integer(int32)|integer(int32)|
+|description|描述|string||
+|properties|扩展字段|object||
+|result|任务ID|string||
+
+
+**响应示例**:
+```javascript
+{
+ "code": 1,
+ "description": "提交成功",
+ "properties": {},
+ "result": 1320098173412546
+}
+```
+
+## 提交Imagine任务
+
+
+**接口地址**:`/mj/submit/imagine`
+
+
+**请求方式**:`POST`
+
+
+**请求数据类型**:`application/json`
+
+
+**响应数据类型**:`*/*`
+
+
+**接口描述**:
+
+
+**请求示例**:
+
+
+```javascript
+{
+ "base64": "",
+ "notifyHook": "",
+ "prompt": "Cat",
+ "state": ""
+}
+```
+
+
+**请求参数**:
+
+
+| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
+| -------- | -------- | ----- | -------- | -------- | ------ |
+|imagineDTO|imagineDTO|body|true|Imagine提交参数|Imagine提交参数|
+| base64|垫图base64||false|string||
+| notifyHook|回调地址, 为空时使用全局notifyHook||false|string||
+| prompt|提示词||true|string||
+| state|自定义参数||false|string||
+
+
+**响应状态**:
+
+
+| 状态码 | 说明 | schema |
+| -------- | -------- | ----- |
+|200|OK|提交结果|
+|201|Created||
+|401|Unauthorized||
+|403|Forbidden||
+|404|Not Found||
+
+
+**响应参数**:
+
+
+| 参数名称 | 参数说明 | 类型 | schema |
+| -------- | -------- | ----- |----- |
+|code|状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)|integer(int32)|integer(int32)|
+|description|描述|string||
+|properties|扩展字段|object||
+|result|任务ID|string||
+
+
+**响应示例**:
+```javascript
+{
+ "code": 1,
+ "description": "提交成功",
+ "properties": {},
+ "result": 1320098173412546
+}
+```
+
+
+# 任务查询
+
+## 指定ID获取任务
+
+
+**接口地址**:`/mj/task/{id}/fetch`
+
+
+**请求方式**:`GET`
+
+
+**请求数据类型**:`application/x-www-form-urlencoded`
+
+
+**响应数据类型**:`*/*`
+
+
+**接口描述**:
+
+
+**请求参数**:
+
+
+| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
+| -------- | -------- | ----- | -------- | -------- | ------ |
+|id|任务ID|path|false|string||
+
+
+**响应状态**:
+
+
+| 状态码 | 说明 | schema |
+| -------- | -------- | ----- |
+|200|OK|任务|
+|401|Unauthorized||
+|403|Forbidden||
+|404|Not Found||
+
+
+**响应参数**:
+
+
+| 参数名称 | 参数说明 | 类型 | schema |
+| -------- | -------- | ----- |----- |
+|action|可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND|string||
+|description|任务描述|string||
+|failReason|失败原因|string||
+|finishTime|结束时间|integer(int64)|integer(int64)|
+|id|任务ID|string||
+|imageUrl|图片url|string||
+|progress|任务进度|string||
+|prompt|提示词|string||
+|promptEn|提示词-英文|string||
+|startTime|开始执行时间|integer(int64)|integer(int64)|
+|state|自定义参数|string||
+|status|任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS|string||
+|submitTime|提交时间|integer(int64)|integer(int64)|
+
+
+**响应示例**:
+```javascript
+{
+ "action": "",
+ "description": "",
+ "failReason": "",
+ "finishTime": 0,
+ "id": "",
+ "imageUrl": "",
+ "progress": "",
+ "prompt": "",
+ "promptEn": "",
+ "startTime": 0,
+ "state": "",
+ "status": "",
+ "submitTime": 0
+}
+```
\ No newline at end of file
diff --git a/common/constants.go b/common/constants.go
index 81f98163..95f5b44e 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -79,6 +79,10 @@ var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
+var NormalPrice = 1.5
+var StablePrice = 6.0
+var BasePrice = 1.5
+
const (
RoleGuestUser = 0
RoleCommonUser = 1
diff --git a/controller/log.go b/controller/log.go
index ba043349..1ff10519 100644
--- a/controller/log.go
+++ b/controller/log.go
@@ -94,6 +94,23 @@ func SearchUserLogs(c *gin.Context) {
})
}
+func GetLogByKey(c *gin.Context) {
+ key := c.Query("key")
+ logs, err := model.GetLogByKey(key)
+ if err != nil {
+ c.JSON(200, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(200, gin.H{
+ "success": true,
+ "message": "",
+ "data": logs,
+ })
+}
+
func GetLogsStat(c *gin.Context) {
logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
diff --git a/controller/midjourney.go b/controller/midjourney.go
new file mode 100644
index 00000000..fa2c4d9b
--- /dev/null
+++ b/controller/midjourney.go
@@ -0,0 +1,133 @@
+package controller
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "log"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "strconv"
+ "time"
+)
+
+func UpdateMidjourneyTask() {
+ imageModel := "midjourney"
+ for {
+ time.Sleep(time.Duration(15) * time.Second)
+ tasks := model.GetAllUnFinishTasks()
+ if len(tasks) != 0 {
+ //log.Printf("UpdateMidjourneyTask: %v", time.Now())
+ ids := make([]string, 0)
+ for _, task := range tasks {
+ ids = append(ids, task.MjId)
+ }
+ requestUrl := "http://107.173.171.147:8080/mj/task/list-by-condition"
+ requestBody := map[string]interface{}{
+ "ids": ids,
+ }
+ jsonStr, err := json.Marshal(requestBody)
+ if err != nil {
+ log.Printf("UpdateMidjourneyTask: %v", err)
+ }
+ req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(jsonStr))
+ if err != nil {
+ log.Printf("UpdateMidjourneyTask: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("mj-api-secret", "uhiftyuwadbkjshbiklahcuitguasguzhxliawodawdu")
+ resp, err := httpClient.Do(req)
+ if err != nil {
+ log.Printf("UpdateMidjourneyTask: %v", err)
+ }
+ defer resp.Body.Close()
+ var response []Midjourney
+ err = json.NewDecoder(resp.Body).Decode(&response)
+ if err != nil {
+ log.Printf("UpdateMidjourneyTask: %v", err)
+ }
+ for _, responseItem := range response {
+ var midjourneyTask *model.Midjourney
+ for _, mj := range tasks {
+ mj.MjId = responseItem.MjId
+ midjourneyTask = model.GetMjByuId(mj.Id)
+ }
+ if midjourneyTask != nil {
+ midjourneyTask.Code = 1
+ midjourneyTask.Progress = responseItem.Progress
+ midjourneyTask.PromptEn = responseItem.PromptEn
+ midjourneyTask.State = responseItem.State
+ midjourneyTask.SubmitTime = responseItem.SubmitTime
+ midjourneyTask.StartTime = responseItem.StartTime
+ midjourneyTask.FinishTime = responseItem.FinishTime
+ midjourneyTask.ImageUrl = responseItem.ImageUrl
+ midjourneyTask.Status = responseItem.Status
+ midjourneyTask.FailReason = responseItem.FailReason
+ if midjourneyTask.Progress != "100%" && responseItem.FailReason != "" {
+ log.Println(midjourneyTask.MjId + " 构建失败," + midjourneyTask.FailReason)
+ midjourneyTask.Progress = "100%"
+ err = model.CacheUpdateUserQuota(midjourneyTask.UserId)
+ if err != nil {
+ log.Println("error update user quota cache: " + err.Error())
+ } else {
+ modelRatio := common.GetModelRatio(imageModel)
+ groupRatio := common.GetGroupRatio("default")
+ ratio := modelRatio * groupRatio
+ quota := int(ratio * 1 * 1000)
+ if quota != 0 {
+ err := model.IncreaseUserQuota(midjourneyTask.UserId, quota)
+ if err != nil {
+ log.Println("fail to increase user quota")
+ }
+ logContent := fmt.Sprintf("%s 构图失败,补偿 %s", midjourneyTask.MjId, common.LogQuota(quota))
+ model.RecordLog(midjourneyTask.UserId, 1, logContent)
+ }
+ }
+ }
+
+ err = midjourneyTask.Update()
+ if err != nil {
+ log.Printf("UpdateMidjourneyTaskFail: %v", err)
+ }
+ log.Printf("UpdateMidjourneyTask: %v", midjourneyTask)
+ }
+ }
+ }
+ }
+}
+
+func GetAllMidjourney(c *gin.Context) {
+ p, _ := strconv.Atoi(c.Query("p"))
+ if p < 0 {
+ p = 0
+ }
+ logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage)
+ if logs == nil {
+ logs = make([]*model.Midjourney, 0)
+ }
+ c.JSON(200, gin.H{
+ "success": true,
+ "message": "",
+ "data": logs,
+ })
+}
+
+func GetUserMidjourney(c *gin.Context) {
+ p, _ := strconv.Atoi(c.Query("p"))
+ if p < 0 {
+ p = 0
+ }
+ userId := c.GetInt("id")
+ log.Printf("userId = %d \n", userId)
+ logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage)
+ if logs == nil {
+ logs = make([]*model.Midjourney, 0)
+ }
+ c.JSON(200, gin.H{
+ "success": true,
+ "message": "",
+ "data": logs,
+ })
+}
diff --git a/controller/misc.go b/controller/misc.go
index 958a3716..5cbbb9d2 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -31,6 +31,9 @@ func GetStatus(c *gin.Context) {
"chat_link": common.ChatLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
+ "normal_price": common.NormalPrice,
+ "stable_price": common.StablePrice,
+ "base_price": common.BasePrice,
},
})
return
@@ -58,6 +61,17 @@ func GetAbout(c *gin.Context) {
return
}
+func GetMidjourney(c *gin.Context) {
+ common.OptionMapRWMutex.RLock()
+ defer common.OptionMapRWMutex.RUnlock()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": common.OptionMap["Midjourney"],
+ })
+ return
+}
+
func GetHomePageContent(c *gin.Context) {
common.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock()
diff --git a/controller/relay-image.go b/controller/relay-image.go
index de623288..e66a1800 100644
--- a/controller/relay-image.go
+++ b/controller/relay-image.go
@@ -137,7 +137,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
+ model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
diff --git a/controller/relay-mj.go b/controller/relay-mj.go
new file mode 100644
index 00000000..8989f036
--- /dev/null
+++ b/controller/relay-mj.go
@@ -0,0 +1,388 @@
+package controller
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Midjourney struct {
+ MjId string `json:"id"`
+ Action string `json:"action"`
+ Prompt string `json:"prompt"`
+ PromptEn string `json:"promptEn"`
+ Description string `json:"description"`
+ State string `json:"state"`
+ SubmitTime int64 `json:"submitTime"`
+ StartTime int64 `json:"startTime"`
+ FinishTime int64 `json:"finishTime"`
+ ImageUrl string `json:"imageUrl"`
+ Status string `json:"status"`
+ Progress string `json:"progress"`
+ FailReason string `json:"failReason"`
+}
+
+func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
+ var midjRequest Midjourney
+ err := common.UnmarshalBodyReusable(c, &midjRequest)
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "bind_request_body_failed",
+ Properties: nil,
+ Result: "",
+ }
+ }
+ midjourneyTask := model.GetByMJId(midjRequest.MjId)
+ if midjourneyTask == nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "midjourney_task_not_found",
+ Properties: nil,
+ Result: "",
+ }
+ }
+ midjourneyTask.Progress = midjRequest.Progress
+ midjourneyTask.PromptEn = midjRequest.PromptEn
+ midjourneyTask.State = midjRequest.State
+ midjourneyTask.SubmitTime = midjRequest.SubmitTime
+ midjourneyTask.StartTime = midjRequest.StartTime
+ midjourneyTask.FinishTime = midjRequest.FinishTime
+ midjourneyTask.ImageUrl = midjRequest.ImageUrl
+ midjourneyTask.Status = midjRequest.Status
+ midjourneyTask.FailReason = midjRequest.FailReason
+ err = midjourneyTask.Update()
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "update_midjourney_task_failed",
+ }
+ }
+
+ return nil
+}
+
+func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
+ taskId := c.Param("id")
+ originTask := model.GetByMJId(taskId)
+ if originTask == nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "task_no_found",
+ }
+ }
+ var midjourneyTask Midjourney
+ midjourneyTask.MjId = originTask.MjId
+ midjourneyTask.Progress = originTask.Progress
+ midjourneyTask.PromptEn = originTask.PromptEn
+ midjourneyTask.State = originTask.State
+ midjourneyTask.SubmitTime = originTask.SubmitTime
+ midjourneyTask.StartTime = originTask.StartTime
+ midjourneyTask.FinishTime = originTask.FinishTime
+ midjourneyTask.ImageUrl = originTask.ImageUrl
+ midjourneyTask.Status = originTask.Status
+ midjourneyTask.FailReason = originTask.FailReason
+ midjourneyTask.Action = originTask.Action
+ midjourneyTask.Description = originTask.Description
+ midjourneyTask.Prompt = originTask.Prompt
+ jsonMap, err := json.Marshal(midjourneyTask)
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "unmarshal_response_body_failed",
+ }
+ }
+ _, err = io.Copy(c.Writer, bytes.NewBuffer(jsonMap))
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "copy_response_body_failed",
+ }
+ }
+ return nil
+}
+
+func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
+ imageModel := "midjourney"
+
+ tokenId := c.GetInt("token_id")
+ channelType := c.GetInt("channel")
+ userId := c.GetInt("id")
+ consumeQuota := c.GetBool("consume_quota")
+ group := c.GetString("group")
+
+ var midjRequest MidjourneyRequest
+ if consumeQuota {
+ err := common.UnmarshalBodyReusable(c, &midjRequest)
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "bind_request_body_failed",
+ }
+ }
+ }
+ if relayMode == RelayModeMidjourneyImagine {
+ if midjRequest.Prompt == "" {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "prompt_is_required",
+ }
+ }
+ midjRequest.Action = "IMAGINE"
+ } else if midjRequest.TaskId != "" {
+ originTask := model.GetByMJId(midjRequest.TaskId)
+ if originTask == nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "task_no_found",
+ }
+ } else if originTask.Action == "UPSCALE" {
+ //return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "upscale_task_can_not_be_change",
+ }
+ } else if originTask.Status != "SUCCESS" {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "task_status_is_not_success",
+ }
+
+ }
+ midjRequest.Prompt = originTask.Prompt
+ } else if relayMode == RelayModeMidjourneyChange {
+ if midjRequest.TaskId == "" {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "taskId_is_required",
+ }
+ } else if midjRequest.Action == "" {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "action_is_required",
+ }
+ } else if midjRequest.Index == 0 {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "index_can_only_be_1_2_3_4",
+ }
+ }
+ }
+
+ // map model name
+ modelMapping := c.GetString("model_mapping")
+ isModelMapped := false
+ if modelMapping != "" {
+ modelMap := make(map[string]string)
+ err := json.Unmarshal([]byte(modelMapping), &modelMap)
+ if err != nil {
+ //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "unmarshal_model_mapping_failed",
+ }
+ }
+ if modelMap[imageModel] != "" {
+ imageModel = modelMap[imageModel]
+ isModelMapped = true
+ }
+ }
+
+ baseURL := common.ChannelBaseURLs[channelType]
+ requestURL := c.Request.URL.String()
+
+ if c.GetString("base_url") != "" {
+ baseURL = c.GetString("base_url")
+ }
+
+ //midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
+
+ fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+
+ var requestBody io.Reader
+ if isModelMapped {
+ jsonStr, err := json.Marshal(midjRequest)
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "marshal_text_request_failed",
+ }
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ } else {
+ requestBody = c.Request.Body
+ }
+
+ modelRatio := common.GetModelRatio(imageModel)
+ groupRatio := common.GetGroupRatio(group)
+ ratio := modelRatio * groupRatio
+ userQuota, err := model.CacheGetUserQuota(userId)
+
+ sizeRatio := 1.0
+
+ quota := int(ratio * sizeRatio * 1000)
+
+ if consumeQuota && userQuota-quota < 0 {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "quota_not_enough",
+ }
+ }
+
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "create_request_failed",
+ }
+ }
+ //req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+ req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+ req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
+ // print request header
+ log.Printf("request header: %s", req.Header)
+ log.Printf("request body: %s", midjRequest.Prompt)
+
+ resp, err := httpClient.Do(req)
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "do_request_failed",
+ }
+ }
+
+ err = req.Body.Close()
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "close_request_body_failed",
+ }
+ }
+ err = c.Request.Body.Close()
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "close_request_body_failed",
+ }
+ }
+ var midjResponse MidjourneyResponse
+
+ defer func() {
+ if consumeQuota {
+ err := model.PostConsumeTokenQuota(tokenId, quota)
+ if err != nil {
+ common.SysError("error consuming token remain quota: " + err.Error())
+ }
+ err = model.CacheUpdateUserQuota(userId)
+ if err != nil {
+ common.SysError("error update user quota cache: " + err.Error())
+ }
+ if quota != 0 {
+ tokenName := c.GetString("token_name")
+ logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+ model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent, tokenId)
+ model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+ channelId := c.GetInt("channel_id")
+ model.UpdateChannelUsedQuota(channelId, quota)
+ }
+ }
+ }()
+
+ //if consumeQuota {
+ //
+ //}
+ responseBody, err := io.ReadAll(resp.Body)
+
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "read_response_body_failed",
+ }
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "close_response_body_failed",
+ }
+ }
+
+ err = json.Unmarshal(responseBody, &midjResponse)
+ log.Printf("responseBody: %s", string(responseBody))
+ log.Printf("midjResponse: %v", midjResponse)
+ if resp.StatusCode != 200 {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "fail_to_fetch_midjourney",
+ }
+ }
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "unmarshal_response_body_failed",
+ }
+ }
+ if midjResponse.Code == 24 || midjResponse.Code == 21 || midjResponse.Code == 4 {
+ consumeQuota = false
+ }
+
+ midjourneyTask := &model.Midjourney{
+ UserId: userId,
+ Code: midjResponse.Code,
+ Action: midjRequest.Action,
+ MjId: midjResponse.Result,
+ Prompt: midjRequest.Prompt,
+ PromptEn: "",
+ Description: midjResponse.Description,
+ State: "",
+ SubmitTime: 0,
+ StartTime: 0,
+ FinishTime: 0,
+ ImageUrl: "",
+ Status: "",
+ Progress: "0%",
+ FailReason: "",
+ }
+ if midjResponse.Code == 4 {
+ midjourneyTask.FailReason = midjResponse.Description
+ }
+ err = midjourneyTask.Insert()
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "insert_midjourney_task_failed",
+ }
+ }
+ resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+
+ for k, v := range resp.Header {
+ c.Writer.Header().Set(k, v[0])
+ }
+ c.Writer.WriteHeader(resp.StatusCode)
+
+ _, err = io.Copy(c.Writer, resp.Body)
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "copy_response_body_failed",
+ }
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return &MidjourneyResponse{
+ Code: 4,
+ Description: "close_response_body_failed",
+ }
+ }
+ return nil
+}
diff --git a/controller/relay-text.go b/controller/relay-text.go
index 52e10f2b..621ae336 100644
--- a/controller/relay-text.go
+++ b/controller/relay-text.go
@@ -95,6 +95,33 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
}
+ isStable := c.GetBool("stable")
+
+ //if common.NormalPrice == -1 && strings.HasPrefix(textRequest.Model, "gpt-4") {
+ // nowUser, err := model.GetUserById(userId, false)
+ // if err != nil {
+ // return errorWrapper(err, "get_user_info_failed", http.StatusInternalServerError)
+ // }
+ // if nowUser.StableMode {
+ // group = "svip"
+ // isStable = true
+ // ////stableRatio = (common.StablePrice / common.BasePrice) * modelRatio
+ // //userMaxPrice, _ := strconv.ParseFloat(nowUser.MaxPrice, 64)
+ // //if userMaxPrice < common.StablePrice {
+ // // return errorWrapper(errors.New("当前低价通道不可用,稳定渠道价格为"+strconv.FormatFloat(common.StablePrice, 'f', -1, 64)+"R/刀"), "当前低价通道不可用", http.StatusInternalServerError)
+ // //}
+ // //
+ // ////ratio = stableRatio * groupRatio
+ // //channel, err := model.CacheGetRandomSatisfiedChannel("svip", textRequest.Model)
+ // //if err != nil {
+ // // message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", "svip", textRequest.Model)
+ // // return errorWrapper(errors.New(message), "no_available_channel", http.StatusInternalServerError)
+ // //}
+ // //channelType = channel.Type
+ // } else {
+ // return errorWrapper(errors.New("当前低价通道不可用,请稍后再试,或者在后台开启稳定模式"), "当前低价通道不可用", http.StatusInternalServerError)
+ // }
+ //}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
@@ -168,11 +195,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
+ //stableRatio := common.GetStableRatio(textRequest.Model)
modelRatio := common.GetModelRatio(textRequest.Model)
+ stableRatio := modelRatio
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId)
+ if isStable {
+ stableRatio = (common.StablePrice / common.BasePrice) * modelRatio
+ ratio = stableRatio * groupRatio
+ }
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
@@ -301,6 +334,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
+ //if strings.HasPrefix(textRequest.Model, "gpt-4") {
+ // if quota < 5000 && quota != 0 {
+ // quota = 5000
+ // }
+ //}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
@@ -312,8 +350,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
if quota != 0 {
tokenName := c.GetString("token_name")
- logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
+ var logContent string
+ if isStable {
+ logContent = fmt.Sprintf("(稳定模式)模型倍率 %.2f,分组倍率 %.2f", stableRatio, groupRatio)
+ } else {
+ logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+ }
+ model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
diff --git a/controller/relay.go b/controller/relay.go
index 9cfa5c4f..96abf6ee 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -2,6 +2,7 @@ package controller
import (
"fmt"
+ "log"
"net/http"
"one-api/common"
"strconv"
@@ -24,6 +25,10 @@ const (
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
+ RelayModeMidjourneyImagine
+ RelayModeMidjourneyChange
+ RelayModeMidjourneyNotify
+ RelayModeMidjourneyTaskFetch
)
// https://platform.openai.com/docs/api-reference/chat
@@ -128,6 +133,23 @@ type CompletionsStreamResponse struct {
} `json:"choices"`
}
+type MidjourneyRequest struct {
+ Prompt string `json:"prompt"`
+ NotifyHook string `json:"notifyHook"`
+ Action string `json:"action"`
+ Index int `json:"index"`
+ State string `json:"state"`
+ TaskId string `json:"taskId"`
+ Base64Array []string `json:"base64Array"`
+}
+
+type MidjourneyResponse struct {
+ Code int `json:"code"`
+ Description string `json:"description"`
+ Properties interface{} `json:"properties"`
+ Result string `json:"result"`
+}
+
func Relay(c *gin.Context) {
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
@@ -179,6 +201,54 @@ func Relay(c *gin.Context) {
}
}
+func RelayMidjourney(c *gin.Context) {
+ relayMode := RelayModeUnknown
+ if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
+ relayMode = RelayModeMidjourneyImagine
+ } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
+ relayMode = RelayModeMidjourneyNotify
+ } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
+ relayMode = RelayModeMidjourneyChange
+ } else if strings.HasPrefix(c.Request.URL.Path, "/mj/task") {
+ relayMode = RelayModeMidjourneyTaskFetch
+ }
+ var err *MidjourneyResponse
+ switch relayMode {
+ case RelayModeMidjourneyNotify:
+ err = relayMidjourneyNotify(c)
+ case RelayModeMidjourneyTaskFetch:
+ err = relayMidjourneyTask(c, relayMode)
+ default:
+ err = relayMidjourneySubmit(c, relayMode)
+ }
+ //err = relayMidjourneySubmit(c, relayMode)
+ log.Println(err)
+ if err != nil {
+ retryTimesStr := c.Query("retry")
+ retryTimes, _ := strconv.Atoi(retryTimesStr)
+ if retryTimesStr == "" {
+ retryTimes = common.RetryTimes
+ }
+ if retryTimes > 0 {
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
+ } else {
+ if err.Code == 30 {
+ err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
+ }
+ c.JSON(400, gin.H{
+ "error": err.Result,
+ })
+ }
+ channelId := c.GetInt("channel_id")
+ common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Result))
+ //if shouldDisableChannel(&err.OpenAIError) {
+ // channelId := c.GetInt("channel_id")
+ // channelName := c.GetString("channel_name")
+ // disableChannel(channelId, channelName, err.Result)
+ //}
+ }
+}
+
func RelayNotImplemented(c *gin.Context) {
err := OpenAIError{
Message: "API not implemented",
diff --git a/controller/topup.go b/controller/topup.go
new file mode 100644
index 00000000..9f6b3a82
--- /dev/null
+++ b/controller/topup.go
@@ -0,0 +1,173 @@
+package controller
+
+import (
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/samber/lo"
+ epay "github.com/star-horizon/go-epay"
+ "log"
+ "net/url"
+ "one-api/common"
+ "one-api/model"
+ "strconv"
+ "time"
+)
+
+type EpayRequest struct {
+ Amount int `json:"amount"`
+ PaymentMethod string `json:"payment_method"`
+ TopUpCode string `json:"top_up_code"`
+}
+
+type AmountRequest struct {
+ Amount int `json:"amount"`
+ TopUpCode string `json:"top_up_code"`
+}
+
+var client, _ = epay.NewClientWithUrl(&epay.Config{
+ PartnerID: "1096",
+ Key: "n08V9LpE8JffA3NPP893689u8p39NV9J",
+}, "https://api.lempay.org")
+
+func GetAmount(id int, count float64, topUpCode string) float64 {
+ amount := count * 1.5
+ if topUpCode != "" {
+ if topUpCode == "nekoapi" {
+ if id == 89 {
+ amount = count * 1
+ } else if id == 98 || id == 105 || id == 107 {
+ amount = count * 1.2
+ } else if id == 1 {
+ amount = count * 1
+ }
+ }
+ }
+ return amount
+}
+
+func RequestEpay(c *gin.Context) {
+ var req EpayRequest
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(200, gin.H{"message": err.Error(), "data": 10})
+ return
+ }
+ id := c.GetInt("id")
+ amount := GetAmount(id, float64(req.Amount), req.TopUpCode)
+ if id != 1 {
+ if req.Amount < 10 {
+ c.JSON(200, gin.H{"message": "最小充值10元", "data": amount, "count": 10})
+ return
+ }
+ }
+ if req.PaymentMethod == "zfb" {
+ if amount > 400 {
+ c.JSON(200, gin.H{"message": "支付宝最大充值400元", "data": amount, "count": 400})
+ return
+ }
+ req.PaymentMethod = "alipay"
+ }
+ if req.PaymentMethod == "wx" {
+ if amount > 600 {
+ c.JSON(200, gin.H{"message": "微信最大充值600元", "data": amount, "count": 600})
+ return
+ }
+ req.PaymentMethod = "wxpay"
+ }
+
+ returnUrl, _ := url.Parse("https://nekoapi.com/log")
+ notifyUrl, _ := url.Parse("https://nekoapi.com/api/user/epay/notify")
+ tradeNo := strconv.FormatInt(time.Now().Unix(), 10)
+ uri, params, err := client.Purchase(&epay.PurchaseArgs{
+ Type: epay.PurchaseType(req.PaymentMethod),
+ ServiceTradeNo: "A" + tradeNo,
+ Name: "B" + tradeNo,
+ Money: strconv.FormatFloat(amount*0.99, 'f', 2, 64),
+ Device: epay.PC,
+ NotifyUrl: notifyUrl,
+ ReturnUrl: returnUrl,
+ })
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
+ return
+ }
+ topUp := &model.TopUp{
+ UserId: id,
+ Amount: req.Amount,
+ Money: int(amount),
+ TradeNo: "A" + tradeNo,
+ CreateTime: time.Now().Unix(),
+ Status: "pending",
+ }
+ err = topUp.Insert()
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
+ return
+ }
+ c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
+}
+
+func EpayNotify(c *gin.Context) {
+ params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
+ r[t] = c.Request.URL.Query().Get(t)
+ return r
+ }, map[string]string{})
+ verifyInfo, err := client.Verify(params)
+ if err == nil && verifyInfo.VerifyStatus {
+ _, err := c.Writer.Write([]byte("success"))
+ if err != nil {
+ log.Println("易支付回调写入失败")
+ }
+ } else {
+ _, err := c.Writer.Write([]byte("fail"))
+ if err != nil {
+ log.Println("易支付回调写入失败")
+ }
+ }
+
+ if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
+ log.Println(verifyInfo)
+ topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
+ if topUp.Status == "pending" {
+ topUp.Status = "success"
+ err := topUp.Update()
+ if err != nil {
+ log.Printf("易支付回调更新订单失败: %v", topUp)
+ return
+ }
+ //user, _ := model.GetUserById(topUp.UserId, false)
+ //user.Quota += topUp.Amount * 500000
+ err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*500000)
+ if err != nil {
+ log.Printf("易支付回调更新用户失败: %v", topUp)
+ return
+ }
+ log.Printf("易支付回调更新用户成功 %v", topUp)
+ model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v", common.LogQuota(topUp.Amount*500000)))
+ }
+ } else {
+ log.Printf("易支付异常回调: %v", verifyInfo)
+ }
+}
+
+func RequestAmount(c *gin.Context) {
+ var req AmountRequest
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(200, gin.H{"message": err.Error(), "data": 10})
+ return
+ }
+ id := c.GetInt("id")
+ if id != 1 {
+ if req.Amount < 10 {
+ c.JSON(200, gin.H{"message": "最小充值10刀", "data": GetAmount(id, 10, req.TopUpCode), "count": 10})
+ return
+ }
+ if req.Amount > 400 {
+ c.JSON(200, gin.H{"message": "最大充值400刀", "data": GetAmount(id, 400, req.TopUpCode), "count": 400})
+ return
+ }
+ }
+
+ c.JSON(200, gin.H{"message": "success", "data": GetAmount(id, float64(req.Amount), req.TopUpCode)})
+}
diff --git a/controller/user.go b/controller/user.go
index 8fd10b82..a072b54e 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -3,6 +3,7 @@ package controller
import (
"encoding/json"
"fmt"
+ "log"
"net/http"
"one-api/common"
"one-api/model"
@@ -79,6 +80,8 @@ func setupLogin(user *model.User, c *gin.Context) {
DisplayName: user.DisplayName,
Role: user.Role,
Status: user.Status,
+ StableMode: user.StableMode,
+ MaxPrice: user.MaxPrice,
}
c.JSON(http.StatusOK, gin.H{
"message": "",
@@ -158,6 +161,8 @@ func Register(c *gin.Context) {
Password: user.Password,
DisplayName: user.Username,
InviterId: inviterId,
+ StableMode: user.StableMode,
+ MaxPrice: user.MaxPrice,
}
if common.EmailVerificationEnabled {
cleanUser.Email = user.Email
@@ -420,6 +425,8 @@ func UpdateSelf(c *gin.Context) {
Username: user.Username,
Password: user.Password,
DisplayName: user.DisplayName,
+ StableMode: user.StableMode,
+ MaxPrice: user.MaxPrice,
}
if user.Password == "$I_LOVE_U" {
user.Password = "" // rollback to what it should be
@@ -741,3 +748,52 @@ func TopUp(c *gin.Context) {
})
return
}
+
+type StableModeRequest struct {
+ StableMode bool `json:"stableMode"`
+ MaxPrice string `json:"maxPrice"`
+}
+
+func SetTableMode(c *gin.Context) {
+ req := &StableModeRequest{}
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ log.Println(req)
+ id := c.GetInt("id")
+ user := model.User{
+ Id: id,
+ }
+ err = user.FillUserById()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ user.StableMode = req.StableMode
+ if !req.StableMode {
+ req.MaxPrice = "0"
+ }
+ user.MaxPrice = req.MaxPrice
+ err = user.Update(false)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": "设置成功",
+ })
+ return
+}
diff --git a/go.mod b/go.mod
index 2e0cf017..ca4a8bf1 100644
--- a/go.mod
+++ b/go.mod
@@ -14,6 +14,7 @@ require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
github.com/pkoukk/tiktoken-go v0.1.1
+ github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
golang.org/x/crypto v0.9.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/sqlite v1.4.3
@@ -42,12 +43,15 @@ require (
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
+ github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
+ github.com/samber/lo v1.37.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
+ golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect
diff --git a/go.sum b/go.sum
index 7287206a..ba518d99 100644
--- a/go.sum
+++ b/go.sum
@@ -97,6 +97,8 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
+github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
+github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -117,6 +119,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
+github.com/samber/lo v1.37.0 h1:XjVcB8g6tgUp8rsPsJ2CvhClfImrpL04YpQHXeHPhRw=
+github.com/samber/lo v1.37.0/go.mod h1:9vaz2O4o8oOnK23pd2TrXufcbdbJIa3b6cstBWKpopA=
+github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI=
+github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2/go.mod h1:SiffGCWGGMVwujne2dUQbJ5zUVD1V1Yj0hDuTfqFNEo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -144,6 +150,8 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
+golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
+golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
@@ -163,8 +171,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
diff --git a/main.go b/main.go
index d6d0c75b..c7182299 100644
--- a/main.go
+++ b/main.go
@@ -74,6 +74,7 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
+ go controller.UpdateMidjourneyTask()
// Initialize HTTP server
server := gin.Default()
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 91c00e1a..ab296903 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -2,6 +2,7 @@ package middleware
import (
"fmt"
+ "log"
"net/http"
"one-api/common"
"one-api/model"
@@ -21,6 +22,7 @@ func Distribute() func(c *gin.Context) {
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
var channel *model.Channel
+ var err error
channelId, ok := c.Get("channelId")
if ok {
id, err := strconv.Atoi(channelId.(string))
@@ -56,19 +58,28 @@ func Distribute() func(c *gin.Context) {
return
}
} else {
+
// Select a channel for the user
var modelRequest ModelRequest
- err := common.UnmarshalBodyReusable(c, &modelRequest)
- if err != nil {
- c.JSON(http.StatusBadRequest, gin.H{
- "error": gin.H{
- "message": "无效的请求",
- "type": "one_api_error",
- },
- })
- c.Abort()
- return
+ if strings.HasPrefix(c.Request.URL.Path, "/mj") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = "midjourney"
+ }
+ } else {
+ err := common.UnmarshalBodyReusable(c, &modelRequest)
+ if err != nil {
+ log.Println(err)
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": gin.H{
+ "message": "无效的请求",
+ "type": "one_api_error",
+ },
+ })
+ c.Abort()
+ return
+ }
}
+
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
@@ -84,21 +95,51 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "dall-e"
}
}
+ isStable := false
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
+ c.Set("stable", false)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
- if channel != nil {
- common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
- message = "数据库一致性已被破坏,请联系管理员"
+ if strings.HasPrefix(modelRequest.Model, "gpt-4") {
+ common.SysLog("GPT-4低价渠道宕机,正在尝试转换")
+ nowUser, err := model.GetUserById(userId, false)
+ if err == nil {
+ if nowUser.StableMode {
+ userGroup = "svip"
+ //stableRatio = (common.StablePrice / common.BasePrice) * modelRatio
+ userMaxPrice, _ := strconv.ParseFloat(nowUser.MaxPrice, 64)
+ if userMaxPrice < common.StablePrice {
+ message = "当前低价通道不可用,稳定渠道价格为" + strconv.FormatFloat(common.StablePrice, 'f', -1, 64) + "R/刀"
+ } else {
+ //common.SysLog(fmt.Sprintf("用户 %s 使用稳定渠道", nowUser.Username))
+ channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
+ if err != nil {
+ message = "稳定渠道已经宕机,请联系管理员"
+ }
+ isStable = true
+ common.SysLog(fmt.Sprintf("用户 %s 使用稳定渠道 %v", nowUser.Username, channel))
+ c.Set("stable", true)
+ }
+
+ } else {
+ message = "当前低价通道不可用,请稍后再试,或者在后台开启稳定渠道模式"
+ }
+ }
+ }
+ //if channel == nil {
+ // common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
+ // message = "数据库一致性已被破坏,请联系管理员"
+ //}
+ if !isStable {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "error": gin.H{
+ "message": message,
+ "type": "one_api_error",
+ },
+ })
+ c.Abort()
+ return
}
- c.JSON(http.StatusServiceUnavailable, gin.H{
- "error": gin.H{
- "message": message,
- "type": "one_api_error",
- },
- })
- c.Abort()
- return
}
}
c.Set("channel", channel.Type)
diff --git a/model/log.go b/model/log.go
index b0d6409a..e79c881f 100644
--- a/model/log.go
+++ b/model/log.go
@@ -3,6 +3,7 @@ package model
import (
"gorm.io/gorm"
"one-api/common"
+ "strings"
)
type Log struct {
@@ -17,6 +18,7 @@ type Log struct {
Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
+ TokenId int `json:"token_id" gorm:"default:0;index"`
}
const (
@@ -27,6 +29,11 @@ const (
LogTypeSystem
)
+func GetLogByKey(key string) (logs []*Log, err error) {
+ err = DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.Split(key, "-")[1]).Find(&logs).Error
+ return logs, err
+}
+
func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !common.LogConsumeEnabled {
return
@@ -44,7 +51,7 @@ func RecordLog(userId int, logType int, content string) {
}
}
-func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
+func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int) {
if !common.LogConsumeEnabled {
return
}
@@ -59,6 +66,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
+ TokenId: tokenId,
}
err := DB.Create(log).Error
if err != nil {
diff --git a/model/main.go b/model/main.go
index 5bc5ce19..3c80426f 100644
--- a/model/main.go
+++ b/model/main.go
@@ -88,6 +88,14 @@ func InitDB() (err error) {
if err != nil {
return err
}
+ err = db.AutoMigrate(&Midjourney{})
+ if err != nil {
+ return err
+ }
+ err = db.AutoMigrate(&TopUp{})
+ if err != nil {
+ return err
+ }
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
diff --git a/model/midjourney.go b/model/midjourney.go
new file mode 100644
index 00000000..bb723e8c
--- /dev/null
+++ b/model/midjourney.go
@@ -0,0 +1,87 @@
+package model
+
+type Midjourney struct {
+ Id int `json:"id"`
+ Code int `json:"code"`
+ UserId int `json:"user_id" gorm:"index"`
+ Action string `json:"action"`
+ MjId string `json:"mj_id" gorm:"index"`
+ Prompt string `json:"prompt"`
+ PromptEn string `json:"prompt_en"`
+ Description string `json:"description"`
+ State string `json:"state"`
+ SubmitTime int64 `json:"submit_time"`
+ StartTime int64 `json:"start_time"`
+ FinishTime int64 `json:"finish_time"`
+ ImageUrl string `json:"image_url"`
+ Status string `json:"status"`
+ Progress string `json:"progress"`
+ FailReason string `json:"fail_reason"`
+}
+
+func GetAllUserTask(userId int, startIdx int, num int) []*Midjourney {
+ var tasks []*Midjourney
+ var err error
+ err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
+ if err != nil {
+ return nil
+ }
+ return tasks
+}
+
+func GetAllTasks(startIdx int, num int) []*Midjourney {
+ var tasks []*Midjourney
+ var err error
+ err = DB.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
+ if err != nil {
+ return nil
+ }
+ return tasks
+}
+
+func GetAllUnFinishTasks() []*Midjourney {
+ var tasks []*Midjourney
+ var err error
+ // get all tasks progress is not 100%
+ err = DB.Where("progress != ?", "100%").Find(&tasks).Error
+ if err != nil {
+ return nil
+ }
+ return tasks
+}
+
+func GetByMJId(mjId string) *Midjourney {
+ var mj *Midjourney
+ var err error
+ err = DB.Where("mj_id = ?", mjId).First(&mj).Error
+ if err != nil {
+ return nil
+ }
+ return mj
+}
+
+func GetMjByuId(id int) *Midjourney {
+ var mj *Midjourney
+ var err error
+ err = DB.Where("id = ?", id).First(&mj).Error
+ if err != nil {
+ return nil
+ }
+ return mj
+}
+
+func UpdateProgress(id int, progress string) error {
+ return DB.Model(&Midjourney{}).Where("id = ?", id).Update("progress", progress).Error
+}
+
+func (midjourney *Midjourney) Insert() error {
+ var err error
+ err = DB.Create(midjourney).Error
+ return err
+}
+
+func (midjourney *Midjourney) Update() error {
+ var err error
+ err = DB.Save(midjourney).Error
+ return err
+}
diff --git a/model/option.go b/model/option.go
index e7bc6806..aa4da949 100644
--- a/model/option.go
+++ b/model/option.go
@@ -69,6 +69,10 @@ func InitOptionMap() {
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
+ common.OptionMap["NormalPrice"] = strconv.FormatFloat(common.NormalPrice, 'f', -1, 64)
+ common.OptionMap["StablePrice"] = strconv.FormatFloat(common.StablePrice, 'f', -1, 64)
+ common.OptionMap["BasePrice"] = strconv.FormatFloat(common.BasePrice, 'f', -1, 64)
+
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@@ -207,6 +211,12 @@ func updateOptionMap(key string, value string) (err error) {
common.TopUpLink = value
case "ChatLink":
common.ChatLink = value
+ case "NormalPrice":
+ common.NormalPrice, _ = strconv.ParseFloat(value, 64)
+ case "BasePrice":
+ common.BasePrice, _ = strconv.ParseFloat(value, 64)
+ case "StablePrice":
+ common.StablePrice, _ = strconv.ParseFloat(value, 64)
case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit":
diff --git a/model/topup.go b/model/topup.go
new file mode 100644
index 00000000..876cd230
--- /dev/null
+++ b/model/topup.go
@@ -0,0 +1,43 @@
+package model
+
+type TopUp struct {
+ Id int `json:"id"`
+ UserId int `json:"user_id" gorm:"index"`
+ Amount int `json:"amount"`
+ Money int `json:"money"`
+ TradeNo string `json:"trade_no"`
+ CreateTime int64 `json:"create_time"`
+ Status string `json:"status"`
+}
+
+func (topUp *TopUp) Insert() error {
+ var err error
+ err = DB.Create(topUp).Error
+ return err
+}
+
+func (topUp *TopUp) Update() error {
+ var err error
+ err = DB.Save(topUp).Error
+ return err
+}
+
+func GetTopUpById(id int) *TopUp {
+ var topUp *TopUp
+ var err error
+ err = DB.Where("id = ?", id).First(&topUp).Error
+ if err != nil {
+ return nil
+ }
+ return topUp
+}
+
+func GetTopUpByTradeNo(tradeNo string) *TopUp {
+ var topUp *TopUp
+ var err error
+ err = DB.Where("trade_no = ?", tradeNo).First(&topUp).Error
+ if err != nil {
+ return nil
+ }
+ return topUp
+}
diff --git a/model/user.go b/model/user.go
index 7c771840..b96b955b 100644
--- a/model/user.go
+++ b/model/user.go
@@ -28,6 +28,8 @@ type User struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
+ StableMode bool `json:"stable_mode" gorm:"type:tinyint;default:0;column:stable_mode"`
+ MaxPrice string `json:"max_price" gorm:"type:varchar(32);default:'7'"`
}
func GetMaxUserId() int {
@@ -116,7 +118,14 @@ func (user *User) Update(updatePassword bool) error {
return err
}
}
- err = DB.Model(user).Updates(user).Error
+ newUser := *user
+ err = DB.Model(user).UpdateColumns(map[string]interface{}{
+ "stable_mode": user.StableMode,
+ "max_price": user.MaxPrice,
+ }).Error
+
+ DB.First(&user, user.Id)
+ err = DB.Model(user).Updates(newUser).Error
return err
}
diff --git a/router/api-router.go b/router/api-router.go
index 383133fa..eae6a7e6 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -16,6 +16,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/status", controller.GetStatus)
apiRouter.GET("/notice", controller.GetNotice)
apiRouter.GET("/about", controller.GetAbout)
+ apiRouter.GET("/midjourney", controller.GetMidjourney)
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
@@ -29,7 +30,9 @@ func SetApiRouter(router *gin.Engine) {
{
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
userRoute.POST("/login", middleware.CriticalRateLimit(), controller.Login)
+ //userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
userRoute.GET("/logout", controller.Logout)
+ userRoute.GET("/epay/notify", controller.EpayNotify)
selfRoute := userRoute.Group("/")
selfRoute.Use(middleware.UserAuth())
@@ -40,6 +43,9 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.POST("/topup", controller.TopUp)
+ selfRoute.POST("/pay", controller.RequestEpay)
+ selfRoute.POST("/amount", controller.RequestAmount)
+ selfRoute.POST("/set_stable_mode", controller.SetTableMode)
}
adminRoute := userRoute.Group("/")
@@ -102,10 +108,14 @@ func SetApiRouter(router *gin.Engine) {
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
+ logRoute.GET("/token", controller.GetLogByKey)
groupRoute := apiRouter.Group("/group")
groupRoute.Use(middleware.AdminAuth())
{
groupRoute.GET("/", controller.GetGroups)
}
+ mjRoute := apiRouter.Group("/mj")
+ mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney)
+ mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney)
}
}
diff --git a/router/relay-router.go b/router/relay-router.go
index c3c84d8b..ca8e1cc8 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -41,4 +41,12 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.Relay)
}
+ relayMjRouter := router.Group("/mj")
+ relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
+ {
+ relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
+ relayMjRouter.POST("/notify", controller.RelayMidjourney)
+ relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
+ }
}
diff --git a/web/src/App.js b/web/src/App.js
index c967ce2c..422b1522 100644
--- a/web/src/App.js
+++ b/web/src/App.js
@@ -1,20 +1,20 @@
-import React, { lazy, Suspense, useContext, useEffect } from 'react';
-import { Route, Routes } from 'react-router-dom';
+import React, {lazy, Suspense, useContext, useEffect} from 'react';
+import {Route, Routes} from 'react-router-dom';
import Loading from './components/Loading';
import User from './pages/User';
-import { PrivateRoute } from './components/PrivateRoute';
+import {PrivateRoute} from './components/PrivateRoute';
import RegisterForm from './components/RegisterForm';
import LoginForm from './components/LoginForm';
import NotFound from './pages/NotFound';
import Setting from './pages/Setting';
import EditUser from './pages/User/EditUser';
import AddUser from './pages/User/AddUser';
-import { API, getLogo, getSystemName, showError, showNotice } from './helpers';
+import {API, getLogo, getSystemName, showError, showNotice} from './helpers';
import PasswordResetForm from './components/PasswordResetForm';
import GitHubOAuth from './components/GitHubOAuth';
import PasswordResetConfirm from './components/PasswordResetConfirm';
-import { UserContext } from './context/User';
-import { StatusContext } from './context/Status';
+import {UserContext} from './context/User';
+import {StatusContext} from './context/Status';
import Channel from './pages/Channel';
import Token from './pages/Token';
import EditToken from './pages/Token/EditToken';
@@ -24,268 +24,295 @@ import EditRedemption from './pages/Redemption/EditRedemption';
import TopUp from './pages/TopUp';
import Log from './pages/Log';
import Chat from './pages/Chat';
+import Midjourney from './pages/Midjourney';
const Home = lazy(() => import('./pages/Home'));
const About = lazy(() => import('./pages/About'));
function App() {
- const [userState, userDispatch] = useContext(UserContext);
- const [statusState, statusDispatch] = useContext(StatusContext);
+ const [userState, userDispatch] = useContext(UserContext);
+ const [statusState, statusDispatch] = useContext(StatusContext);
- const loadUser = () => {
- let user = localStorage.getItem('user');
- if (user) {
- let data = JSON.parse(user);
- userDispatch({ type: 'login', payload: data });
- }
- };
- const loadStatus = async () => {
- const res = await API.get('/api/status');
- const { success, data } = res.data;
- if (success) {
- localStorage.setItem('status', JSON.stringify(data));
- statusDispatch({ type: 'set', payload: data });
- localStorage.setItem('system_name', data.system_name);
- localStorage.setItem('logo', data.logo);
- localStorage.setItem('footer_html', data.footer_html);
- localStorage.setItem('quota_per_unit', data.quota_per_unit);
- localStorage.setItem('display_in_currency', data.display_in_currency);
- if (data.chat_link) {
- localStorage.setItem('chat_link', data.chat_link);
- } else {
- localStorage.removeItem('chat_link');
- }
- if (
- data.version !== process.env.REACT_APP_VERSION &&
- data.version !== 'v0.0.0' &&
- process.env.REACT_APP_VERSION !== ''
- ) {
- showNotice(
- `新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面`
- );
- }
- } else {
- showError('无法正常连接至服务器!');
- }
- };
+ const loadUser = () => {
+ let user = localStorage.getItem('user');
+ if (user) {
+ let data = JSON.parse(user);
+ userDispatch({type: 'login', payload: data});
+ }
+ };
+ const loadStatus = async () => {
+ const res = await API.get('/api/status');
+ const {success, data} = res.data;
+ if (success) {
+ localStorage.setItem('status', JSON.stringify(data));
+ statusDispatch({type: 'set', payload: data});
+ localStorage.setItem('system_name', data.system_name);
+ localStorage.setItem('logo', data.logo);
+ localStorage.setItem('footer_html', data.footer_html);
+ localStorage.setItem('quota_per_unit', data.quota_per_unit);
+ localStorage.setItem('display_in_currency', data.display_in_currency);
+ if (data.chat_link) {
+ localStorage.setItem('chat_link', data.chat_link);
+ } else {
+ localStorage.removeItem('chat_link');
+ }
+ if (
+ data.version !== process.env.REACT_APP_VERSION &&
+ data.version !== 'v0.0.0' &&
+ process.env.REACT_APP_VERSION !== ''
+ ) {
+ showNotice(
+ `新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面`
+ );
+ }
+ } else {
+ showError('无法正常连接至服务器!');
+ }
+ };
- useEffect(() => {
- loadUser();
- loadStatus().then();
- let systemName = getSystemName();
- if (systemName) {
- document.title = systemName;
- }
- let logo = getLogo();
- if (logo) {
- let linkElement = document.querySelector("link[rel~='icon']");
- if (linkElement) {
- linkElement.href = logo;
- }
- }
- }, []);
+ // const getOptions = async () => {
+ // const res = await API.get('/api/option/');
+ // const {success, message, data} = res.data;
+ // if (success) {
+ // let newInputs = {};
+ // data.forEach((item) => {
+ // if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
+ // item.value = JSON.stringify(JSON.parse(item.value), null, 2);
+ // }
+ // newInputs[item.key] = item.value;
+ // });
+ // setInputs(newInputs);
+ // setOriginInputs(newInputs);
+ // } else {
+ // showError(message);
+ // }
+ // };
- return (
-
- 微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) -
-+ 微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) +
+通道:官方通道
+状态:存活
+价格:{statusState?.status?.base_price}R / 刀
+通道:官方通道|低价通道
++ 状态: + {statusState?.status?.stable_price===-1? + 不 可 用 + : + 可 用 + } + | + {statusState?.status?.normal_price===-1? + 不 可 用 + : + 可 用 + } +
++ 价格:{statusState?.status?.stable_price}R / 刀|{statusState?.status?.normal_price}R / 刀 +
+名称:{statusState?.status?.system_name}
-版本:{statusState?.status?.version}
-- 源码: - - https://github.com/songquanpeng/one-api - -
-启动时间:{getStartTimeString()}
-- 邮箱验证: - {statusState?.status?.email_verification === true - ? '已启用' - : '未启用'} -
-- GitHub 身份验证: - {statusState?.status?.github_oauth === true - ? '已启用' - : '未启用'} -
-- 微信身份验证: - {statusState?.status?.wechat_login === true - ? '已启用' - : '未启用'} -
-- Turnstile 用户校验: - {statusState?.status?.turnstile_check === true - ? '已启用' - : '未启用'} -
-