This commit is contained in:
CaIon 2023-08-14 22:16:32 +08:00
parent c134604cee
commit 8f2119e410
33 changed files with 3224 additions and 1138 deletions

254
Midjourney.md Normal file
View File

@ -0,0 +1,254 @@
# Midjourney Proxy API文档
**简介**:Midjourney Proxy API文档
**HOST**:https://api.nekoedu.com
**Version**:v2.3.5
[TOC]
# 任务提交
## 绘图变化
**接口地址**:`/mj/submit/change`
**请求方式**:`POST`
**请求数据类型**:`application/json`
**响应数据类型**:`*/*`
**接口描述**:
**请求示例**:
```javascript
{
"action": "UPSCALE",
"index": 1,
"notifyHook": "",
"state": "",
"taskId": "1320098173412546"
}
```
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
| -------- | -------- | ----- | -------- | -------- | ------ |
|changeDTO|changeDTO|body|true|变化任务提交参数|变化任务提交参数|
|  action|UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL||true|string||
|  index|序号(1~4), action为UPSCALE,VARIATION时必传||false|integer(int32)||
|  notifyHook|回调地址, 为空时使用全局notifyHook||false|string||
|  state|自定义参数||false|string||
|  taskId|任务ID||true|string||
**响应状态**:
| 状态码 | 说明 | schema |
| -------- | -------- | ----- |
|200|OK|提交结果|
|201|Created||
|401|Unauthorized||
|403|Forbidden||
|404|Not Found||
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
| -------- | -------- | ----- |----- |
|code|状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)|integer(int32)|integer(int32)|
|description|描述|string||
|properties|扩展字段|object||
|result|任务ID|string||
**响应示例**:
```javascript
{
"code": 1,
"description": "提交成功",
"properties": {},
"result": 1320098173412546
}
```
## 提交Imagine任务
**接口地址**:`/mj/submit/imagine`
**请求方式**:`POST`
**请求数据类型**:`application/json`
**响应数据类型**:`*/*`
**接口描述**:
**请求示例**:
```javascript
{
"base64": "",
"notifyHook": "",
"prompt": "Cat",
"state": ""
}
```
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
| -------- | -------- | ----- | -------- | -------- | ------ |
|imagineDTO|imagineDTO|body|true|Imagine提交参数|Imagine提交参数|
|  base64|垫图base64||false|string||
|  notifyHook|回调地址, 为空时使用全局notifyHook||false|string||
|  prompt|提示词||true|string||
|  state|自定义参数||false|string||
**响应状态**:
| 状态码 | 说明 | schema |
| -------- | -------- | ----- |
|200|OK|提交结果|
|201|Created||
|401|Unauthorized||
|403|Forbidden||
|404|Not Found||
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
| -------- | -------- | ----- |----- |
|code|状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)|integer(int32)|integer(int32)|
|description|描述|string||
|properties|扩展字段|object||
|result|任务ID|string||
**响应示例**:
```javascript
{
"code": 1,
"description": "提交成功",
"properties": {},
"result": 1320098173412546
}
```
# 任务查询
## 指定ID获取任务
**接口地址**:`/mj/task/{id}/fetch`
**请求方式**:`GET`
**请求数据类型**:`application/x-www-form-urlencoded`
**响应数据类型**:`*/*`
**接口描述**:
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
| -------- | -------- | ----- | -------- | -------- | ------ |
|id|任务ID|path|false|string||
**响应状态**:
| 状态码 | 说明 | schema |
| -------- | -------- | ----- |
|200|OK|任务|
|401|Unauthorized||
|403|Forbidden||
|404|Not Found||
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
| -------- | -------- | ----- |----- |
|action|可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND|string||
|description|任务描述|string||
|failReason|失败原因|string||
|finishTime|结束时间|integer(int64)|integer(int64)|
|id|任务ID|string||
|imageUrl|图片url|string||
|progress|任务进度|string||
|prompt|提示词|string||
|promptEn|提示词-英文|string||
|startTime|开始执行时间|integer(int64)|integer(int64)|
|state|自定义参数|string||
|status|任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS|string||
|submitTime|提交时间|integer(int64)|integer(int64)|
**响应示例**:
```javascript
{
"action": "",
"description": "",
"failReason": "",
"finishTime": 0,
"id": "",
"imageUrl": "",
"progress": "",
"prompt": "",
"promptEn": "",
"startTime": 0,
"state": "",
"status": "",
"submitTime": 0
}
```

View File

@ -79,6 +79,10 @@ var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
var NormalPrice = 1.5
var StablePrice = 6.0
var BasePrice = 1.5
const (
RoleGuestUser = 0
RoleCommonUser = 1

View File

@ -94,6 +94,23 @@ func SearchUserLogs(c *gin.Context) {
})
}
func GetLogByKey(c *gin.Context) {
key := c.Query("key")
logs, err := model.GetLogByKey(key)
if err != nil {
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}
func GetLogsStat(c *gin.Context) {
logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)

133
controller/midjourney.go Normal file
View File

@ -0,0 +1,133 @@
package controller
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
func UpdateMidjourneyTask() {
imageModel := "midjourney"
for {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) != 0 {
//log.Printf("UpdateMidjourneyTask: %v", time.Now())
ids := make([]string, 0)
for _, task := range tasks {
ids = append(ids, task.MjId)
}
requestUrl := "http://107.173.171.147:8080/mj/task/list-by-condition"
requestBody := map[string]interface{}{
"ids": ids,
}
jsonStr, err := json.Marshal(requestBody)
if err != nil {
log.Printf("UpdateMidjourneyTask: %v", err)
}
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(jsonStr))
if err != nil {
log.Printf("UpdateMidjourneyTask: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", "uhiftyuwadbkjshbiklahcuitguasguzhxliawodawdu")
resp, err := httpClient.Do(req)
if err != nil {
log.Printf("UpdateMidjourneyTask: %v", err)
}
defer resp.Body.Close()
var response []Midjourney
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
log.Printf("UpdateMidjourneyTask: %v", err)
}
for _, responseItem := range response {
var midjourneyTask *model.Midjourney
for _, mj := range tasks {
mj.MjId = responseItem.MjId
midjourneyTask = model.GetMjByuId(mj.Id)
}
if midjourneyTask != nil {
midjourneyTask.Code = 1
midjourneyTask.Progress = responseItem.Progress
midjourneyTask.PromptEn = responseItem.PromptEn
midjourneyTask.State = responseItem.State
midjourneyTask.SubmitTime = responseItem.SubmitTime
midjourneyTask.StartTime = responseItem.StartTime
midjourneyTask.FinishTime = responseItem.FinishTime
midjourneyTask.ImageUrl = responseItem.ImageUrl
midjourneyTask.Status = responseItem.Status
midjourneyTask.FailReason = responseItem.FailReason
if midjourneyTask.Progress != "100%" && responseItem.FailReason != "" {
log.Println(midjourneyTask.MjId + " 构建失败," + midjourneyTask.FailReason)
midjourneyTask.Progress = "100%"
err = model.CacheUpdateUserQuota(midjourneyTask.UserId)
if err != nil {
log.Println("error update user quota cache: " + err.Error())
} else {
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio("default")
ratio := modelRatio * groupRatio
quota := int(ratio * 1 * 1000)
if quota != 0 {
err := model.IncreaseUserQuota(midjourneyTask.UserId, quota)
if err != nil {
log.Println("fail to increase user quota")
}
logContent := fmt.Sprintf("%s 构图失败,补偿 %s", midjourneyTask.MjId, common.LogQuota(quota))
model.RecordLog(midjourneyTask.UserId, 1, logContent)
}
}
}
err = midjourneyTask.Update()
if err != nil {
log.Printf("UpdateMidjourneyTaskFail: %v", err)
}
log.Printf("UpdateMidjourneyTask: %v", midjourneyTask)
}
}
}
}
}
func GetAllMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage)
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}
func GetUserMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
userId := c.GetInt("id")
log.Printf("userId = %d \n", userId)
logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage)
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}

View File

@ -31,6 +31,9 @@ func GetStatus(c *gin.Context) {
"chat_link": common.ChatLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"normal_price": common.NormalPrice,
"stable_price": common.StablePrice,
"base_price": common.BasePrice,
},
})
return
@ -58,6 +61,17 @@ func GetAbout(c *gin.Context) {
return
}
func GetMidjourney(c *gin.Context) {
common.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": common.OptionMap["Midjourney"],
})
return
}
func GetHomePageContent(c *gin.Context) {
common.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock()

View File

@ -137,7 +137,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)

388
controller/relay-mj.go Normal file
View File

@ -0,0 +1,388 @@
package controller
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
)
type Midjourney struct {
MjId string `json:"id"`
Action string `json:"action"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
}
func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
var midjRequest Midjourney
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "bind_request_body_failed",
Properties: nil,
Result: "",
}
}
midjourneyTask := model.GetByMJId(midjRequest.MjId)
if midjourneyTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "midjourney_task_not_found",
Properties: nil,
Result: "",
}
}
midjourneyTask.Progress = midjRequest.Progress
midjourneyTask.PromptEn = midjRequest.PromptEn
midjourneyTask.State = midjRequest.State
midjourneyTask.SubmitTime = midjRequest.SubmitTime
midjourneyTask.StartTime = midjRequest.StartTime
midjourneyTask.FinishTime = midjRequest.FinishTime
midjourneyTask.ImageUrl = midjRequest.ImageUrl
midjourneyTask.Status = midjRequest.Status
midjourneyTask.FailReason = midjRequest.FailReason
err = midjourneyTask.Update()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "update_midjourney_task_failed",
}
}
return nil
}
func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
taskId := c.Param("id")
originTask := model.GetByMJId(taskId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
}
var midjourneyTask Midjourney
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
midjourneyTask.State = originTask.State
midjourneyTask.SubmitTime = originTask.SubmitTime
midjourneyTask.StartTime = originTask.StartTime
midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = originTask.ImageUrl
midjourneyTask.Status = originTask.Status
midjourneyTask.FailReason = originTask.FailReason
midjourneyTask.Action = originTask.Action
midjourneyTask.Description = originTask.Description
midjourneyTask.Prompt = originTask.Prompt
jsonMap, err := json.Marshal(midjourneyTask)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(jsonMap))
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
return nil
}
func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
imageModel := "midjourney"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var midjRequest MidjourneyRequest
if consumeQuota {
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "bind_request_body_failed",
}
}
}
if relayMode == RelayModeMidjourneyImagine {
if midjRequest.Prompt == "" {
return &MidjourneyResponse{
Code: 4,
Description: "prompt_is_required",
}
}
midjRequest.Action = "IMAGINE"
} else if midjRequest.TaskId != "" {
originTask := model.GetByMJId(midjRequest.TaskId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
} else if originTask.Action == "UPSCALE" {
//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
return &MidjourneyResponse{
Code: 4,
Description: "upscale_task_can_not_be_change",
}
} else if originTask.Status != "SUCCESS" {
return &MidjourneyResponse{
Code: 4,
Description: "task_status_is_not_success",
}
}
midjRequest.Prompt = originTask.Prompt
} else if relayMode == RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return &MidjourneyResponse{
Code: 4,
Description: "taskId_is_required",
}
} else if midjRequest.Action == "" {
return &MidjourneyResponse{
Code: 4,
Description: "action_is_required",
}
} else if midjRequest.Index == 0 {
return &MidjourneyResponse{
Code: 4,
Description: "index_can_only_be_1_2_3_4",
}
}
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_model_mapping_failed",
}
}
if modelMap[imageModel] != "" {
imageModel = modelMap[imageModel]
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(midjRequest)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "marshal_text_request_failed",
}
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0
quota := int(ratio * sizeRatio * 1000)
if consumeQuota && userQuota-quota < 0 {
return &MidjourneyResponse{
Code: 4,
Description: "quota_not_enough",
}
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "create_request_failed",
}
}
//req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
// print request header
log.Printf("request header: %s", req.Header)
log.Printf("request body: %s", midjRequest.Prompt)
resp, err := httpClient.Do(req)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "do_request_failed",
}
}
err = req.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_request_body_failed",
}
}
err = c.Request.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_request_body_failed",
}
}
var midjResponse MidjourneyResponse
defer func() {
if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}()
//if consumeQuota {
//
//}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "read_response_body_failed",
}
}
err = resp.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_response_body_failed",
}
}
err = json.Unmarshal(responseBody, &midjResponse)
log.Printf("responseBody: %s", string(responseBody))
log.Printf("midjResponse: %v", midjResponse)
if resp.StatusCode != 200 {
return &MidjourneyResponse{
Code: 4,
Description: "fail_to_fetch_midjourney",
}
}
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
if midjResponse.Code == 24 || midjResponse.Code == 21 || midjResponse.Code == 4 {
consumeQuota = false
}
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
Action: midjRequest.Action,
MjId: midjResponse.Result,
Prompt: midjRequest.Prompt,
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: 0,
StartTime: 0,
FinishTime: 0,
ImageUrl: "",
Status: "",
Progress: "0%",
FailReason: "",
}
if midjResponse.Code == 4 {
midjourneyTask.FailReason = midjResponse.Description
}
err = midjourneyTask.Insert()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "insert_midjourney_task_failed",
}
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
err = resp.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_response_body_failed",
}
}
return nil
}

View File

@ -95,6 +95,33 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
}
isStable := c.GetBool("stable")
//if common.NormalPrice == -1 && strings.HasPrefix(textRequest.Model, "gpt-4") {
// nowUser, err := model.GetUserById(userId, false)
// if err != nil {
// return errorWrapper(err, "get_user_info_failed", http.StatusInternalServerError)
// }
// if nowUser.StableMode {
// group = "svip"
// isStable = true
// ////stableRatio = (common.StablePrice / common.BasePrice) * modelRatio
// //userMaxPrice, _ := strconv.ParseFloat(nowUser.MaxPrice, 64)
// //if userMaxPrice < common.StablePrice {
// // return errorWrapper(errors.New("当前低价通道不可用,稳定渠道价格为"+strconv.FormatFloat(common.StablePrice, 'f', -1, 64)+"R/刀"), "当前低价通道不可用", http.StatusInternalServerError)
// //}
// //
// ////ratio = stableRatio * groupRatio
// //channel, err := model.CacheGetRandomSatisfiedChannel("svip", textRequest.Model)
// //if err != nil {
// // message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", "svip", textRequest.Model)
// // return errorWrapper(errors.New(message), "no_available_channel", http.StatusInternalServerError)
// //}
// //channelType = channel.Type
// } else {
// return errorWrapper(errors.New("当前低价通道不可用,请稍后再试,或者在后台开启稳定模式"), "当前低价通道不可用", http.StatusInternalServerError)
// }
//}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
@ -168,11 +195,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
//stableRatio := common.GetStableRatio(textRequest.Model)
modelRatio := common.GetModelRatio(textRequest.Model)
stableRatio := modelRatio
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId)
if isStable {
stableRatio = (common.StablePrice / common.BasePrice) * modelRatio
ratio = stableRatio * groupRatio
}
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
@ -301,6 +334,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
//if strings.HasPrefix(textRequest.Model, "gpt-4") {
// if quota < 5000 && quota != 0 {
// quota = 5000
// }
//}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
@ -312,8 +350,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
var logContent string
if isStable {
logContent = fmt.Sprintf("(稳定模式)模型倍率 %.2f,分组倍率 %.2f", stableRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
}
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)

View File

@ -2,6 +2,7 @@ package controller
import (
"fmt"
"log"
"net/http"
"one-api/common"
"strconv"
@ -24,6 +25,10 @@ const (
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeMidjourneyImagine
RelayModeMidjourneyChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
)
// https://platform.openai.com/docs/api-reference/chat
@ -128,6 +133,23 @@ type CompletionsStreamResponse struct {
} `json:"choices"`
}
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
}
type MidjourneyResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Properties interface{} `json:"properties"`
Result string `json:"result"`
}
func Relay(c *gin.Context) {
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
@ -179,6 +201,54 @@ func Relay(c *gin.Context) {
}
}
func RelayMidjourney(c *gin.Context) {
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
relayMode = RelayModeMidjourneyNotify
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
relayMode = RelayModeMidjourneyChange
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/task") {
relayMode = RelayModeMidjourneyTaskFetch
}
var err *MidjourneyResponse
switch relayMode {
case RelayModeMidjourneyNotify:
err = relayMidjourneyNotify(c)
case RelayModeMidjourneyTaskFetch:
err = relayMidjourneyTask(c, relayMode)
default:
err = relayMidjourneySubmit(c, relayMode)
}
//err = relayMidjourneySubmit(c, relayMode)
log.Println(err)
if err != nil {
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
}
c.JSON(400, gin.H{
"error": err.Result,
})
}
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Result))
//if shouldDisableChannel(&err.OpenAIError) {
// channelId := c.GetInt("channel_id")
// channelName := c.GetString("channel_name")
// disableChannel(channelId, channelName, err.Result)
//}
}
}
func RelayNotImplemented(c *gin.Context) {
err := OpenAIError{
Message: "API not implemented",

173
controller/topup.go Normal file
View File

@ -0,0 +1,173 @@
package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
epay "github.com/star-horizon/go-epay"
"log"
"net/url"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
type EpayRequest struct {
Amount int `json:"amount"`
PaymentMethod string `json:"payment_method"`
TopUpCode string `json:"top_up_code"`
}
type AmountRequest struct {
Amount int `json:"amount"`
TopUpCode string `json:"top_up_code"`
}
var client, _ = epay.NewClientWithUrl(&epay.Config{
PartnerID: "1096",
Key: "n08V9LpE8JffA3NPP893689u8p39NV9J",
}, "https://api.lempay.org")
func GetAmount(id int, count float64, topUpCode string) float64 {
amount := count * 1.5
if topUpCode != "" {
if topUpCode == "nekoapi" {
if id == 89 {
amount = count * 1
} else if id == 98 || id == 105 || id == 107 {
amount = count * 1.2
} else if id == 1 {
amount = count * 1
}
}
}
return amount
}
func RequestEpay(c *gin.Context) {
var req EpayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
return
}
id := c.GetInt("id")
amount := GetAmount(id, float64(req.Amount), req.TopUpCode)
if id != 1 {
if req.Amount < 10 {
c.JSON(200, gin.H{"message": "最小充值10元", "data": amount, "count": 10})
return
}
}
if req.PaymentMethod == "zfb" {
if amount > 400 {
c.JSON(200, gin.H{"message": "支付宝最大充值400元", "data": amount, "count": 400})
return
}
req.PaymentMethod = "alipay"
}
if req.PaymentMethod == "wx" {
if amount > 600 {
c.JSON(200, gin.H{"message": "微信最大充值600元", "data": amount, "count": 600})
return
}
req.PaymentMethod = "wxpay"
}
returnUrl, _ := url.Parse("https://nekoapi.com/log")
notifyUrl, _ := url.Parse("https://nekoapi.com/api/user/epay/notify")
tradeNo := strconv.FormatInt(time.Now().Unix(), 10)
uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: epay.PurchaseType(req.PaymentMethod),
ServiceTradeNo: "A" + tradeNo,
Name: "B" + tradeNo,
Money: strconv.FormatFloat(amount*0.99, 'f', 2, 64),
Device: epay.PC,
NotifyUrl: notifyUrl,
ReturnUrl: returnUrl,
})
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
topUp := &model.TopUp{
UserId: id,
Amount: req.Amount,
Money: int(amount),
TradeNo: "A" + tradeNo,
CreateTime: time.Now().Unix(),
Status: "pending",
}
err = topUp.Insert()
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
return
}
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
}
func EpayNotify(c *gin.Context) {
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.URL.Query().Get(t)
return r
}, map[string]string{})
verifyInfo, err := client.Verify(params)
if err == nil && verifyInfo.VerifyStatus {
_, err := c.Writer.Write([]byte("success"))
if err != nil {
log.Println("易支付回调写入失败")
}
} else {
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
}
}
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
log.Println(verifyInfo)
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
if topUp.Status == "pending" {
topUp.Status = "success"
err := topUp.Update()
if err != nil {
log.Printf("易支付回调更新订单失败: %v", topUp)
return
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*500000)
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v", common.LogQuota(topUp.Amount*500000)))
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
}
}
func RequestAmount(c *gin.Context) {
var req AmountRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
return
}
id := c.GetInt("id")
if id != 1 {
if req.Amount < 10 {
c.JSON(200, gin.H{"message": "最小充值10刀", "data": GetAmount(id, 10, req.TopUpCode), "count": 10})
return
}
if req.Amount > 400 {
c.JSON(200, gin.H{"message": "最大充值400刀", "data": GetAmount(id, 400, req.TopUpCode), "count": 400})
return
}
}
c.JSON(200, gin.H{"message": "success", "data": GetAmount(id, float64(req.Amount), req.TopUpCode)})
}

View File

@ -3,6 +3,7 @@ package controller
import (
"encoding/json"
"fmt"
"log"
"net/http"
"one-api/common"
"one-api/model"
@ -79,6 +80,8 @@ func setupLogin(user *model.User, c *gin.Context) {
DisplayName: user.DisplayName,
Role: user.Role,
Status: user.Status,
StableMode: user.StableMode,
MaxPrice: user.MaxPrice,
}
c.JSON(http.StatusOK, gin.H{
"message": "",
@ -158,6 +161,8 @@ func Register(c *gin.Context) {
Password: user.Password,
DisplayName: user.Username,
InviterId: inviterId,
StableMode: user.StableMode,
MaxPrice: user.MaxPrice,
}
if common.EmailVerificationEnabled {
cleanUser.Email = user.Email
@ -420,6 +425,8 @@ func UpdateSelf(c *gin.Context) {
Username: user.Username,
Password: user.Password,
DisplayName: user.DisplayName,
StableMode: user.StableMode,
MaxPrice: user.MaxPrice,
}
if user.Password == "$I_LOVE_U" {
user.Password = "" // rollback to what it should be
@ -741,3 +748,52 @@ func TopUp(c *gin.Context) {
})
return
}
type StableModeRequest struct {
StableMode bool `json:"stableMode"`
MaxPrice string `json:"maxPrice"`
}
func SetTableMode(c *gin.Context) {
req := &StableModeRequest{}
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
log.Println(req)
id := c.GetInt("id")
user := model.User{
Id: id,
}
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.StableMode = req.StableMode
if !req.StableMode {
req.MaxPrice = "0"
}
user.MaxPrice = req.MaxPrice
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": "设置成功",
})
return
}

4
go.mod
View File

@ -14,6 +14,7 @@ require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
github.com/pkoukk/tiktoken-go v0.1.1
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
golang.org/x/crypto v0.9.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/sqlite v1.4.3
@ -42,12 +43,15 @@ require (
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/samber/lo v1.37.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect

10
go.sum
View File

@ -97,6 +97,8 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -117,6 +119,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/samber/lo v1.37.0 h1:XjVcB8g6tgUp8rsPsJ2CvhClfImrpL04YpQHXeHPhRw=
github.com/samber/lo v1.37.0/go.mod h1:9vaz2O4o8oOnK23pd2TrXufcbdbJIa3b6cstBWKpopA=
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI=
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2/go.mod h1:SiffGCWGGMVwujne2dUQbJ5zUVD1V1Yj0hDuTfqFNEo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@ -144,6 +150,8 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
@ -163,8 +171,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=

View File

@ -74,6 +74,7 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
go controller.UpdateMidjourneyTask()
// Initialize HTTP server
server := gin.Default()

View File

@ -2,6 +2,7 @@ package middleware
import (
"fmt"
"log"
"net/http"
"one-api/common"
"one-api/model"
@ -21,6 +22,7 @@ func Distribute() func(c *gin.Context) {
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
var channel *model.Channel
var err error
channelId, ok := c.Get("channelId")
if ok {
id, err := strconv.Atoi(channelId.(string))
@ -56,10 +58,17 @@ func Distribute() func(c *gin.Context) {
return
}
} else {
// Select a channel for the user
var modelRequest ModelRequest
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
if modelRequest.Model == "" {
modelRequest.Model = "midjourney"
}
} else {
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
log.Println(err)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的请求",
@ -69,6 +78,8 @@ func Distribute() func(c *gin.Context) {
c.Abort()
return
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
@ -84,14 +95,43 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "dall-e"
}
}
isStable := false
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
c.Set("stable", false)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
if strings.HasPrefix(modelRequest.Model, "gpt-4") {
common.SysLog("GPT-4低价渠道宕机正在尝试转换")
nowUser, err := model.GetUserById(userId, false)
if err == nil {
if nowUser.StableMode {
userGroup = "svip"
//stableRatio = (common.StablePrice / common.BasePrice) * modelRatio
userMaxPrice, _ := strconv.ParseFloat(nowUser.MaxPrice, 64)
if userMaxPrice < common.StablePrice {
message = "当前低价通道不可用,稳定渠道价格为" + strconv.FormatFloat(common.StablePrice, 'f', -1, 64) + "R/刀"
} else {
//common.SysLog(fmt.Sprintf("用户 %s 使用稳定渠道", nowUser.Username))
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message = "稳定渠道已经宕机,请联系管理员"
}
c.JSON(http.StatusServiceUnavailable, gin.H{
isStable = true
common.SysLog(fmt.Sprintf("用户 %s 使用稳定渠道 %v", nowUser.Username, channel))
c.Set("stable", true)
}
} else {
message = "当前低价通道不可用,请稍后再试,或者在后台开启稳定渠道模式"
}
}
}
//if channel == nil {
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
// message = "数据库一致性已被破坏,请联系管理员"
//}
if !isStable {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": message,
"type": "one_api_error",
@ -101,6 +141,7 @@ func Distribute() func(c *gin.Context) {
return
}
}
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)

View File

@ -3,6 +3,7 @@ package model
import (
"gorm.io/gorm"
"one-api/common"
"strings"
)
type Log struct {
@ -17,6 +18,7 @@ type Log struct {
Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
TokenId int `json:"token_id" gorm:"default:0;index"`
}
const (
@ -27,6 +29,11 @@ const (
LogTypeSystem
)
func GetLogByKey(key string) (logs []*Log, err error) {
err = DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.Split(key, "-")[1]).Find(&logs).Error
return logs, err
}
func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !common.LogConsumeEnabled {
return
@ -44,7 +51,7 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int) {
if !common.LogConsumeEnabled {
return
}
@ -59,6 +66,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
TokenId: tokenId,
}
err := DB.Create(log).Error
if err != nil {

View File

@ -88,6 +88,14 @@ func InitDB() (err error) {
if err != nil {
return err
}
err = db.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
err = db.AutoMigrate(&TopUp{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err

87
model/midjourney.go Normal file
View File

@ -0,0 +1,87 @@
package model
type Midjourney struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
ImageUrl string `json:"image_url"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"fail_reason"`
}
func GetAllUserTask(userId int, startIdx int, num int) []*Midjourney {
var tasks []*Midjourney
var err error
err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
if err != nil {
return nil
}
return tasks
}
func GetAllTasks(startIdx int, num int) []*Midjourney {
var tasks []*Midjourney
var err error
err = DB.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
if err != nil {
return nil
}
return tasks
}
func GetAllUnFinishTasks() []*Midjourney {
var tasks []*Midjourney
var err error
// get all tasks progress is not 100%
err = DB.Where("progress != ?", "100%").Find(&tasks).Error
if err != nil {
return nil
}
return tasks
}
func GetByMJId(mjId string) *Midjourney {
var mj *Midjourney
var err error
err = DB.Where("mj_id = ?", mjId).First(&mj).Error
if err != nil {
return nil
}
return mj
}
func GetMjByuId(id int) *Midjourney {
var mj *Midjourney
var err error
err = DB.Where("id = ?", id).First(&mj).Error
if err != nil {
return nil
}
return mj
}
func UpdateProgress(id int, progress string) error {
return DB.Model(&Midjourney{}).Where("id = ?", id).Update("progress", progress).Error
}
func (midjourney *Midjourney) Insert() error {
var err error
err = DB.Create(midjourney).Error
return err
}
func (midjourney *Midjourney) Update() error {
var err error
err = DB.Save(midjourney).Error
return err
}

View File

@ -69,6 +69,10 @@ func InitOptionMap() {
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["NormalPrice"] = strconv.FormatFloat(common.NormalPrice, 'f', -1, 64)
common.OptionMap["StablePrice"] = strconv.FormatFloat(common.StablePrice, 'f', -1, 64)
common.OptionMap["BasePrice"] = strconv.FormatFloat(common.BasePrice, 'f', -1, 64)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@ -207,6 +211,12 @@ func updateOptionMap(key string, value string) (err error) {
common.TopUpLink = value
case "ChatLink":
common.ChatLink = value
case "NormalPrice":
common.NormalPrice, _ = strconv.ParseFloat(value, 64)
case "BasePrice":
common.BasePrice, _ = strconv.ParseFloat(value, 64)
case "StablePrice":
common.StablePrice, _ = strconv.ParseFloat(value, 64)
case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit":

43
model/topup.go Normal file
View File

@ -0,0 +1,43 @@
package model
type TopUp struct {
Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"`
Amount int `json:"amount"`
Money int `json:"money"`
TradeNo string `json:"trade_no"`
CreateTime int64 `json:"create_time"`
Status string `json:"status"`
}
func (topUp *TopUp) Insert() error {
var err error
err = DB.Create(topUp).Error
return err
}
func (topUp *TopUp) Update() error {
var err error
err = DB.Save(topUp).Error
return err
}
func GetTopUpById(id int) *TopUp {
var topUp *TopUp
var err error
err = DB.Where("id = ?", id).First(&topUp).Error
if err != nil {
return nil
}
return topUp
}
func GetTopUpByTradeNo(tradeNo string) *TopUp {
var topUp *TopUp
var err error
err = DB.Where("trade_no = ?", tradeNo).First(&topUp).Error
if err != nil {
return nil
}
return topUp
}

View File

@ -28,6 +28,8 @@ type User struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
StableMode bool `json:"stable_mode" gorm:"type:tinyint;default:0;column:stable_mode"`
MaxPrice string `json:"max_price" gorm:"type:varchar(32);default:'7'"`
}
func GetMaxUserId() int {
@ -116,7 +118,14 @@ func (user *User) Update(updatePassword bool) error {
return err
}
}
err = DB.Model(user).Updates(user).Error
newUser := *user
err = DB.Model(user).UpdateColumns(map[string]interface{}{
"stable_mode": user.StableMode,
"max_price": user.MaxPrice,
}).Error
DB.First(&user, user.Id)
err = DB.Model(user).Updates(newUser).Error
return err
}

View File

@ -16,6 +16,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/status", controller.GetStatus)
apiRouter.GET("/notice", controller.GetNotice)
apiRouter.GET("/about", controller.GetAbout)
apiRouter.GET("/midjourney", controller.GetMidjourney)
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
@ -29,7 +30,9 @@ func SetApiRouter(router *gin.Engine) {
{
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
userRoute.POST("/login", middleware.CriticalRateLimit(), controller.Login)
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
userRoute.GET("/logout", controller.Logout)
userRoute.GET("/epay/notify", controller.EpayNotify)
selfRoute := userRoute.Group("/")
selfRoute.Use(middleware.UserAuth())
@ -40,6 +43,9 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.POST("/topup", controller.TopUp)
selfRoute.POST("/pay", controller.RequestEpay)
selfRoute.POST("/amount", controller.RequestAmount)
selfRoute.POST("/set_stable_mode", controller.SetTableMode)
}
adminRoute := userRoute.Group("/")
@ -102,10 +108,14 @@ func SetApiRouter(router *gin.Engine) {
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
logRoute.GET("/token", controller.GetLogByKey)
groupRoute := apiRouter.Group("/group")
groupRoute.Use(middleware.AdminAuth())
{
groupRoute.GET("/", controller.GetGroups)
}
mjRoute := apiRouter.Group("/mj")
mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney)
mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney)
}
}

View File

@ -41,4 +41,12 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.Relay)
}
relayMjRouter := router.Group("/mj")
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
{
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
relayMjRouter.POST("/notify", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
}
}

View File

@ -1,20 +1,20 @@
import React, { lazy, Suspense, useContext, useEffect } from 'react';
import { Route, Routes } from 'react-router-dom';
import React, {lazy, Suspense, useContext, useEffect} from 'react';
import {Route, Routes} from 'react-router-dom';
import Loading from './components/Loading';
import User from './pages/User';
import { PrivateRoute } from './components/PrivateRoute';
import {PrivateRoute} from './components/PrivateRoute';
import RegisterForm from './components/RegisterForm';
import LoginForm from './components/LoginForm';
import NotFound from './pages/NotFound';
import Setting from './pages/Setting';
import EditUser from './pages/User/EditUser';
import AddUser from './pages/User/AddUser';
import { API, getLogo, getSystemName, showError, showNotice } from './helpers';
import {API, getLogo, getSystemName, showError, showNotice} from './helpers';
import PasswordResetForm from './components/PasswordResetForm';
import GitHubOAuth from './components/GitHubOAuth';
import PasswordResetConfirm from './components/PasswordResetConfirm';
import { UserContext } from './context/User';
import { StatusContext } from './context/Status';
import {UserContext} from './context/User';
import {StatusContext} from './context/Status';
import Channel from './pages/Channel';
import Token from './pages/Token';
import EditToken from './pages/Token/EditToken';
@ -24,6 +24,7 @@ import EditRedemption from './pages/Redemption/EditRedemption';
import TopUp from './pages/TopUp';
import Log from './pages/Log';
import Chat from './pages/Chat';
import Midjourney from './pages/Midjourney';
const Home = lazy(() => import('./pages/Home'));
const About = lazy(() => import('./pages/About'));
@ -36,15 +37,15 @@ function App() {
let user = localStorage.getItem('user');
if (user) {
let data = JSON.parse(user);
userDispatch({ type: 'login', payload: data });
userDispatch({type: 'login', payload: data});
}
};
const loadStatus = async () => {
const res = await API.get('/api/status');
const { success, data } = res.data;
const {success, data} = res.data;
if (success) {
localStorage.setItem('status', JSON.stringify(data));
statusDispatch({ type: 'set', payload: data });
statusDispatch({type: 'set', payload: data});
localStorage.setItem('system_name', data.system_name);
localStorage.setItem('logo', data.logo);
localStorage.setItem('footer_html', data.footer_html);
@ -69,6 +70,24 @@ function App() {
}
};
// const getOptions = async () => {
// const res = await API.get('/api/option/');
// const {success, message, data} = res.data;
// if (success) {
// let newInputs = {};
// data.forEach((item) => {
// if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
// item.value = JSON.stringify(JSON.parse(item.value), null, 2);
// }
// newInputs[item.key] = item.value;
// });
// setInputs(newInputs);
// setOriginInputs(newInputs);
// } else {
// showError(message);
// }
// };
useEffect(() => {
loadUser();
loadStatus().then();
@ -91,7 +110,7 @@ function App() {
path='/'
element={
<Suspense fallback={<Loading></Loading>}>
<Home />
<Home/>
</Suspense>
}
/>
@ -99,7 +118,7 @@ function App() {
path='/channel'
element={
<PrivateRoute>
<Channel />
<Channel/>
</PrivateRoute>
}
/>
@ -107,7 +126,7 @@ function App() {
path='/channel/edit/:id'
element={
<Suspense fallback={<Loading></Loading>}>
<EditChannel />
<EditChannel/>
</Suspense>
}
/>
@ -115,7 +134,7 @@ function App() {
path='/channel/add'
element={
<Suspense fallback={<Loading></Loading>}>
<EditChannel />
<EditChannel/>
</Suspense>
}
/>
@ -123,7 +142,7 @@ function App() {
path='/token'
element={
<PrivateRoute>
<Token />
<Token/>
</PrivateRoute>
}
/>
@ -131,7 +150,7 @@ function App() {
path='/token/edit/:id'
element={
<Suspense fallback={<Loading></Loading>}>
<EditToken />
<EditToken/>
</Suspense>
}
/>
@ -139,7 +158,7 @@ function App() {
path='/token/add'
element={
<Suspense fallback={<Loading></Loading>}>
<EditToken />
<EditToken/>
</Suspense>
}
/>
@ -147,7 +166,7 @@ function App() {
path='/redemption'
element={
<PrivateRoute>
<Redemption />
<Redemption/>
</PrivateRoute>
}
/>
@ -155,7 +174,7 @@ function App() {
path='/redemption/edit/:id'
element={
<Suspense fallback={<Loading></Loading>}>
<EditRedemption />
<EditRedemption/>
</Suspense>
}
/>
@ -163,7 +182,7 @@ function App() {
path='/redemption/add'
element={
<Suspense fallback={<Loading></Loading>}>
<EditRedemption />
<EditRedemption/>
</Suspense>
}
/>
@ -171,7 +190,7 @@ function App() {
path='/user'
element={
<PrivateRoute>
<User />
<User/>
</PrivateRoute>
}
/>
@ -179,7 +198,7 @@ function App() {
path='/user/edit/:id'
element={
<Suspense fallback={<Loading></Loading>}>
<EditUser />
<EditUser/>
</Suspense>
}
/>
@ -187,7 +206,7 @@ function App() {
path='/user/edit'
element={
<Suspense fallback={<Loading></Loading>}>
<EditUser />
<EditUser/>
</Suspense>
}
/>
@ -195,7 +214,7 @@ function App() {
path='/user/add'
element={
<Suspense fallback={<Loading></Loading>}>
<AddUser />
<AddUser/>
</Suspense>
}
/>
@ -203,7 +222,7 @@ function App() {
path='/user/reset'
element={
<Suspense fallback={<Loading></Loading>}>
<PasswordResetConfirm />
<PasswordResetConfirm/>
</Suspense>
}
/>
@ -211,7 +230,7 @@ function App() {
path='/login'
element={
<Suspense fallback={<Loading></Loading>}>
<LoginForm />
<LoginForm/>
</Suspense>
}
/>
@ -219,7 +238,7 @@ function App() {
path='/register'
element={
<Suspense fallback={<Loading></Loading>}>
<RegisterForm />
<RegisterForm/>
</Suspense>
}
/>
@ -227,7 +246,7 @@ function App() {
path='/reset'
element={
<Suspense fallback={<Loading></Loading>}>
<PasswordResetForm />
<PasswordResetForm/>
</Suspense>
}
/>
@ -235,7 +254,7 @@ function App() {
path='/oauth/github'
element={
<Suspense fallback={<Loading></Loading>}>
<GitHubOAuth />
<GitHubOAuth/>
</Suspense>
}
/>
@ -244,7 +263,7 @@ function App() {
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<Setting />
<Setting/>
</Suspense>
</PrivateRoute>
}
@ -254,7 +273,7 @@ function App() {
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<TopUp />
<TopUp/>
</Suspense>
</PrivateRoute>
}
@ -263,7 +282,15 @@ function App() {
path='/log'
element={
<PrivateRoute>
<Log />
<Log/>
</PrivateRoute>
}
/>
<Route
path='/midjourney'
element={
<PrivateRoute>
<Midjourney/>
</PrivateRoute>
}
/>
@ -271,7 +298,7 @@ function App() {
path='/about'
element={
<Suspense fallback={<Loading></Loading>}>
<About />
<About/>
</Suspense>
}
/>
@ -279,11 +306,11 @@ function App() {
path='/chat'
element={
<Suspense fallback={<Loading></Loading>}>
<Chat />
<Chat/>
</Suspense>
}
/>
<Route path='*' element={NotFound} />
<Route path='*' element={NotFound}/>
</Routes>
);
}

View File

@ -4,7 +4,7 @@ import { Link } from 'react-router-dom';
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
import { renderGroup, renderNumber } from '../helpers/render';
import {renderGroup, renderNumber, renderQuota} from '../helpers/render';
function renderTimestamp(timestamp) {
return (
@ -299,6 +299,7 @@ const ChannelsTable = () => {
onClick={() => {
sortChannel('group');
}}
width={1}
>
分组
</Table.HeaderCell>
@ -307,6 +308,7 @@ const ChannelsTable = () => {
onClick={() => {
sortChannel('type');
}}
width={2}
>
类型
</Table.HeaderCell>
@ -315,6 +317,7 @@ const ChannelsTable = () => {
onClick={() => {
sortChannel('status');
}}
width={2}
>
状态
</Table.HeaderCell>
@ -326,6 +329,15 @@ const ChannelsTable = () => {
>
响应时间
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortChannel('used_quota');
}}
width={1}
>
已使用
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
@ -361,6 +373,7 @@ const ChannelsTable = () => {
basic
/>
</Table.Cell>
<Table.Cell>{renderQuota(channel.used_quota)}</Table.Cell>
<Table.Cell>
<Popup
content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}

View File

@ -46,6 +46,11 @@ let headerButtons = [
to: '/log',
icon: 'book'
},
{
name: 'Midjourney',
to: '/midjourney',
icon: 'images outline'
},
{
name: '设置',
to: '/setting',

View File

@ -0,0 +1,385 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Label, Pagination, Segment, Select, Table } from 'semantic-ui-react';
import { API, isAdmin, showError, timestamp2string } from '../helpers';
import { ITEMS_PER_PAGE } from '../constants';
import { renderQuota } from '../helpers/render';
import {Link} from "react-router-dom";
function renderTimestamp(timestamp) {
return (
<>
{timestamp2string(timestamp)}
</>
);
}
const MODE_OPTIONS = [
{ key: 'all', text: '全部用户', value: 'all' },
{ key: 'self', text: '当前用户', value: 'self' }
];
const LOG_OPTIONS = [
{ key: '0', text: '全部', value: 0 },
// { key: '1', text: '绘图', value: 1 },
// { key: '2', text: '放大', value: 2 },
// { key: '3', text: '变换', value: 3 },
// { key: '4', text: '图生文', value: 4 },
// { key: '5', text: '图片混合', value: 5 }
];
function renderType(type) {
switch (type) {
case 'IMAGINE':
return <Label basic color='blue'> 绘图 </Label>;
case 'UPSCALE':
return <Label basic color='orange'> 放大 </Label>;
case 'VARIATION':
return <Label basic color='purple'> 变换 </Label>;
case 'DESCRIBE':
return <Label basic color='yellow'> 图生文 </Label>;
case 'BLEAND':
return <Label basic color='olive'> 图混合 </Label>;
default:
return <Label basic color='black'> 未知 </Label>;
}
}
function renderCode(type) {
switch (type) {
case 1:
return <Label basic color='green'> 已提交 </Label>;
case 21:
return <Label basic color='olive'> 排队中 </Label>;
case 22:
return <Label basic color='orange'> 重复提交 </Label>;
default:
return <Label basic color='black'> 未知 </Label>;
}
}
function renderStatus(type) {
switch (type) {
case 'SUCCESS':
return <Label basic color='green'> 成功 </Label>;
case 'NOT_START':
return <Label basic color='black'> 未启动 </Label>;
case 'SUBMITTED':
return <Label basic color='yellow'> 队列中 </Label>;
case 'IN_PROGRESS':
return <Label basic color='blue'> 执行中 </Label>;
case 'FAILURE':
return <Label basic color='red'> 失败 </Label>;
default:
return <Label basic color='black'> 未知 </Label>;
}
}
const LogsTable = () => {
const [logs, setLogs] = useState([
]);
const [loading, setLoading] = useState(true);
const [activePage, setActivePage] = useState(1);
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
const [logType, setLogType] = useState(0);
const isAdminUser = isAdmin();
let now = new Date();
const [inputs, setInputs] = useState({
username: '',
token_name: '',
model_name: '',
start_timestamp: timestamp2string(0),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
});
const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs;
const [stat, setStat] = useState({
quota: 0,
token: 0
});
const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
const getLogSelfStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
const { success, message, data } = res.data;
if (success) {
setStat(data);
} else {
showError(message);
}
};
const getLogStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
const { success, message, data } = res.data;
if (success) {
setStat(data);
} else {
showError(message);
}
};
const loadLogs = async (startIdx) => {
let url = '';
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
if (isAdminUser) {
url = `/api/mj/?p=${startIdx}&username=${username}&token_name=${token_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
} else {
url = `/api/mj/self/?p=${startIdx}&token_name=${token_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
}
const res = await API.get(url);
const { success, message, data } = res.data;
if (success) {
if (startIdx === 0) {
setLogs(data);
} else {
let newLogs = [...logs];
newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
setLogs(newLogs);
}
} else {
showError(message);
}
setLoading(false);
};
const onPaginationChange = (e, { activePage }) => {
(async () => {
if (activePage === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) {
// In this case we have to load more data and then append them.
await loadLogs(activePage - 1);
}
setActivePage(activePage);
})();
};
const refresh = async () => {
setLoading(true);
setActivePage(1)
await loadLogs(0);
// if (isAdminUser) {
// getLogStat().then();
// } else {
// getLogSelfStat().then();
// }
};
useEffect(() => {
refresh().then();
}, [logType]);
const searchLogs = async () => {
if (searchKeyword === '') {
// if keyword is blank, load files instead.
await loadLogs(0);
setActivePage(1);
return;
}
setSearching(true);
const res = await API.get(`/api/log/self/search?keyword=${searchKeyword}`);
const { success, message, data } = res.data;
if (success) {
setLogs(data);
setActivePage(1);
} else {
showError(message);
}
setSearching(false);
};
const handleKeywordChange = async (e, { value }) => {
setSearchKeyword(value.trim());
};
const sortLog = (key) => {
if (logs.length === 0) return;
setLoading(true);
let sortedLogs = [...logs];
if (typeof sortedLogs[0][key] === 'string'){
sortedLogs.sort((a, b) => {
return ('' + a[key]).localeCompare(b[key]);
});
} else {
sortedLogs.sort((a, b) => {
if (a[key] === b[key]) return 0;
if (a[key] > b[key]) return -1;
if (a[key] < b[key]) return 1;
});
}
if (sortedLogs[0].id === logs[0].id) {
sortedLogs.reverse();
}
setLogs(sortedLogs);
setLoading(false);
};
return (
<>
<Segment>
<Table basic compact size='small'>
<Table.Header>
<Table.Row>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('submit_time');
}}
width={2}
>
提交时间
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('action');
}}
width={1}
>
类型
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('mj_id');
}}
width={2}
>
任务ID
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('code');
}}
width={1}
>
提交结果
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('status');
}}
width={1}
>
任务状态
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('progress');
}}
width={1}
>
进度
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('image_url');
}}
width={1}
>
结果图片
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('prompt');
}}
width={3}
>
Prompt
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('fail_reason');
}}
width={1}
>
失败原因
</Table.HeaderCell>
</Table.Row>
</Table.Header>
<Table.Body>
{logs
.slice(
(activePage - 1) * ITEMS_PER_PAGE,
activePage * ITEMS_PER_PAGE
)
.map((log, idx) => {
if (log.deleted) return <></>;
return (
<Table.Row key={log.created_at}>
<Table.Cell>{renderTimestamp(log.submit_time/1000)}</Table.Cell>
{/*{*/}
{/* isAdminUser && (*/}
{/* <Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>*/}
{/* )*/}
{/*}*/}
<Table.Cell>{renderType(log.action)}</Table.Cell>
<Table.Cell>{log.mj_id}</Table.Cell>
<Table.Cell>{renderCode(log.code)}</Table.Cell>
<Table.Cell>{renderStatus(log.status)}</Table.Cell>
<Table.Cell>{log.progress ? <Label basic>{log.progress}</Label> : ''}</Table.Cell>
<Table.Cell>
{
log.image_url ? (
<Link to={log.image_url} target='_blank'>点击查看</Link>
) : '暂未生成图片'
}
</Table.Cell>
<Table.Cell>{log.prompt}</Table.Cell>
<Table.Cell>{log.fail_reason ? log.fail_reason : '无'}</Table.Cell>
</Table.Row>
);
})}
</Table.Body>
<Table.Footer>
<Table.Row>
<Table.HeaderCell colSpan={'9'}>
<Select
placeholder='选择明细分类'
options={LOG_OPTIONS}
style={{ marginRight: '8px' }}
name='logType'
value={logType}
onChange={(e, { name, value }) => {
setLogType(value);
}}
/>
<Button size='small' onClick={refresh} loading={loading}>刷新</Button>
<Pagination
floated='right'
activePage={activePage}
onPageChange={onPaginationChange}
size='small'
siblingRange={1}
totalPages={
Math.ceil(logs.length / ITEMS_PER_PAGE) +
(logs.length % ITEMS_PER_PAGE === 0 ? 1 : 0)
}
/>
</Table.HeaderCell>
</Table.Row>
</Table.Footer>
</Table>
</Segment>
</>
);
};
export default LogsTable;

View File

@ -1,6 +1,6 @@
import React, { useEffect, useState } from 'react';
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
import { API, showError, verifyJSON } from '../helpers';
import React, {useEffect, useState} from 'react';
import {Divider, Form, Grid, Header} from 'semantic-ui-react';
import {API, showError, verifyJSON} from '../helpers';
const OperationSetting = () => {
let [inputs, setInputs] = useState({
@ -21,13 +21,16 @@ const OperationSetting = () => {
DisplayTokenStatEnabled: '',
ApproximateTokenEnabled: '',
RetryTimes: 0,
StablePrice: 6,
NormalPrice: 1.5,
BasePrice: 1.5,
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
const {success, message, data} = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
@ -56,20 +59,20 @@ const OperationSetting = () => {
key,
value
});
const { success, message } = res.data;
const {success, message} = res.data;
if (success) {
setInputs((inputs) => ({ ...inputs, [key]: value }));
setInputs((inputs) => ({...inputs, [key]: value}));
} else {
showError(message);
}
setLoading(false);
};
const handleInputChange = async (e, { name, value }) => {
const handleInputChange = async (e, {name, value}) => {
if (name.endsWith('Enabled')) {
await updateOption(name, value);
} else {
setInputs((inputs) => ({ ...inputs, [name]: value }));
setInputs((inputs) => ({...inputs, [name]: value}));
}
};
@ -83,6 +86,14 @@ const OperationSetting = () => {
await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
}
break;
case 'stable':
await updateOption('StablePrice', inputs.StablePrice);
await updateOption('NormalPrice', inputs.NormalPrice);
await updateOption('BasePrice', inputs.BasePrice);
localStorage.setItem('stable_price', inputs.StablePrice);
localStorage.setItem('normal_price', inputs.NormalPrice);
localStorage.setItem('base_price', inputs.BasePrice);
break;
case 'ratio':
if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
if (!verifyJSON(inputs.ModelRatio)) {
@ -207,7 +218,7 @@ const OperationSetting = () => {
<Form.Button onClick={() => {
submitConfig('general').then();
}}>保存通用设置</Form.Button>
<Divider />
<Divider/>
<Header as='h3'>
监控设置
</Header>
@ -244,7 +255,36 @@ const OperationSetting = () => {
<Form.Button onClick={() => {
submitConfig('monitor').then();
}}>保存监控设置</Form.Button>
<Divider />
<Divider/>
<Header as='h3'>
通道设置
</Header>
<Form.Group widths={3}>
<Form.Input
label='普通渠道价格'
name='NormalPrice'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.NormalPrice}
type='number'
// min='1.5'
placeholder='n元/刀'
/>
<Form.Input
label='稳定渠道价格'
name='StablePrice'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.StablePrice}
type='number'
// min='1.5'
placeholder='n元/刀'
/>
</Form.Group>
<Form.Button onClick={() => {
submitConfig('stable').then();
}}>保存通道设置</Form.Button>
<Divider/>
<Header as='h3'>
额度设置
</Header>
@ -293,7 +333,7 @@ const OperationSetting = () => {
<Form.Button onClick={() => {
submitConfig('quota').then();
}}>保存额度设置</Form.Button>
<Divider />
<Divider/>
<Header as='h3'>
倍率设置
</Header>
@ -302,7 +342,7 @@ const OperationSetting = () => {
label='模型倍率'
name='ModelRatio'
onChange={handleInputChange}
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
style={{minHeight: 250, fontFamily: 'JetBrains Mono, Consolas'}}
autoComplete='new-password'
value={inputs.ModelRatio}
placeholder='为一个 JSON 文本,键为模型名称,值为倍率'
@ -313,7 +353,7 @@ const OperationSetting = () => {
label='分组倍率'
name='GroupRatio'
onChange={handleInputChange}
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
style={{minHeight: 250, fontFamily: 'JetBrains Mono, Consolas'}}
autoComplete='new-password'
value={inputs.GroupRatio}
placeholder='为一个 JSON 文本,键为分组名称,值为倍率'
@ -325,7 +365,8 @@ const OperationSetting = () => {
</Form>
</Grid.Column>
</Grid>
);
)
;
};
export default OperationSetting;

View File

@ -1,9 +1,9 @@
import React, { useContext, useEffect, useState } from 'react';
import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react';
import { Link, useNavigate } from 'react-router-dom';
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
import React, {useContext, useEffect, useState} from 'react';
import {Button, Checkbox, Divider, Form, Header, Image, Input, Message, Modal} from 'semantic-ui-react';
import {Link, useNavigate} from 'react-router-dom';
import {API, copy, showError, showInfo, showNotice, showSuccess} from '../helpers';
import Turnstile from 'react-turnstile';
import { UserContext } from '../context/User';
import {UserContext} from '../context/User';
const PersonalSetting = () => {
const [userState, userDispatch] = useContext(UserContext);
@ -15,6 +15,10 @@ const PersonalSetting = () => {
email: '',
self_account_deletion_confirmation: ''
});
const [stableMode, setStableMode] = useState({
stableMode: false,
maxPrice: 7,
});
const [status, setStatus] = useState({});
const [showWeChatBindModal, setShowWeChatBindModal] = useState(false);
const [showEmailBindModal, setShowEmailBindModal] = useState(false);
@ -26,6 +30,10 @@ const PersonalSetting = () => {
const [disableButton, setDisableButton] = useState(false);
const [countdown, setCountdown] = useState(30);
// setStableMode(userState.user.stableMode, userState.user.maxPrice);
console.log(userState.user)
useEffect(() => {
let status = localStorage.getItem('status');
if (status) {
@ -36,6 +44,9 @@ const PersonalSetting = () => {
setTurnstileSiteKey(status.turnstile_site_key);
}
}
// if (userState.user !== undefined) {
// setStableMode(userState.user.stable_mode, userState.user.max_price);
// }
}, []);
useEffect(() => {
@ -51,13 +62,27 @@ const PersonalSetting = () => {
return () => clearInterval(countdownInterval); // Clean up on unmount
}, [disableButton, countdown]);
const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
useEffect(() => {
if (userState.user !== undefined) {
setStableMode({
stableMode: userState.user.stable_mode,
maxPrice: userState.user.max_price
})
// if (stableMode.localMaxPrice !== userState.user.max_price) {
// setStableMode({
// localMaxPrice: userState.user.max_price
// })
// }
}
}, [userState]);
const handleInputChange = (e, {name, value}) => {
setInputs((inputs) => ({...inputs, [name]: value}));
};
const generateAccessToken = async () => {
const res = await API.get('/api/user/token');
const { success, message, data } = res.data;
const {success, message, data} = res.data;
if (success) {
await copy(data);
showSuccess(`令牌已重置并已复制到剪贴板:${data}`);
@ -68,7 +93,7 @@ const PersonalSetting = () => {
const getAffLink = async () => {
const res = await API.get('/api/user/aff');
const { success, message, data } = res.data;
const {success, message, data} = res.data;
if (success) {
let link = `${window.location.origin}/register?aff=${data}`;
await copy(link);
@ -85,12 +110,12 @@ const PersonalSetting = () => {
}
const res = await API.delete('/api/user/self');
const { success, message } = res.data;
const {success, message} = res.data;
if (success) {
showSuccess('账户已删除!');
await API.get('/api/user/logout');
userDispatch({ type: 'logout' });
userDispatch({type: 'logout'});
localStorage.removeItem('user');
navigate('/login');
} else {
@ -103,7 +128,7 @@ const PersonalSetting = () => {
const res = await API.get(
`/api/oauth/wechat/bind?code=${inputs.wechat_verification_code}`
);
const { success, message } = res.data;
const {success, message} = res.data;
if (success) {
showSuccess('微信账户绑定成功!');
setShowWeChatBindModal(false);
@ -129,7 +154,7 @@ const PersonalSetting = () => {
const res = await API.get(
`/api/verification?email=${inputs.email}&turnstile=${turnstileToken}`
);
const { success, message } = res.data;
const {success, message} = res.data;
if (success) {
showSuccess('验证码发送成功,请检查邮箱!');
} else {
@ -144,7 +169,7 @@ const PersonalSetting = () => {
const res = await API.get(
`/api/oauth/email/bind?email=${inputs.email}&code=${inputs.email_verification_code}`
);
const { success, message } = res.data;
const {success, message} = res.data;
if (success) {
showSuccess('邮箱账户绑定成功!');
setShowEmailBindModal(false);
@ -154,8 +179,10 @@ const PersonalSetting = () => {
setLoading(false);
};
// const setStableMod = ;
return (
<div style={{ lineHeight: '40px' }}>
<div style={{lineHeight: '40px'}}>
<Header as='h3'>通用设置</Header>
<Message>
注意此处生成的令牌用于系统管理而非用于请求 OpenAI 相关的服务请知悉
@ -168,7 +195,67 @@ const PersonalSetting = () => {
<Button onClick={() => {
setShowAccountDeleteModal(true);
}}>删除个人账户</Button>
<Divider />
<Divider/>
<Header as='h3'>GPT-4消费设置</Header>
<Form>
<Form.Field>
<Checkbox label="启用稳定模式当低价渠道宕机时自动选择已开启的渠道以保证稳定性仅影响GPT-4"
checked={stableMode.stableMode}
onChange={
(e, data) => {
setStableMode({
...stableMode,
stableMode: data.checked
})
}
}
></Checkbox>
</Form.Field>
<Form.Field
control={Input}
label='最高接受价格n元/刀)'
placeholder='7'
type={'number'}
value={stableMode.maxPrice}
onChange={
(e, data) => {
setStableMode({
...stableMode,
maxPrice: data.value
})
}
}
>
{/*<label></label>*/}
{/*<input placeholder='7' value= {stableMode.maxPrice}/>*/}
</Form.Field>
<Button type='submit' onClick={
async (e, data) => {
if (stableMode.localMaxPrice === '') return;
// console.log(data)
// post to /api/user/set_stable_mode
const res = await API.post(`/api/user/set_stable_mode`, stableMode)
const {success, message} = res.data;
if (success) {
// userDispatch({type: 'stable_mode', payload: stableMode})
userState.user.stable_mode = stableMode.stableMode
userState.user.max_price = stableMode.maxPrice
localStorage.setItem('user', JSON.stringify(userState.user));
showSuccess('设置成功!');
} else {
showError(message);
}
}
}>保存消费设置</Button>
</Form>
{/*<Checkbox label="启用稳定模式(当低价渠道宕机时,自动选择已开启的渠道,以保证稳定性)" onChange={*/}
{/* (e, data) => {*/}
{/* // if (inputs.email_verification_code === '') return;*/}
{/* console.log(data)*/}
{/* }*/}
{/*}></Checkbox>*/}
{/*<Input label="最高接受价格n元/刀)" type="integer"></Input>*/}
<Divider/>
<Header as='h3'>账号绑定</Header>
{
status.wechat_login && (
@ -189,8 +276,8 @@ const PersonalSetting = () => {
>
<Modal.Content>
<Modal.Description>
<Image src={status.wechat_qrcode} fluid />
<div style={{ textAlign: 'center' }}>
<Image src={status.wechat_qrcode} fluid/>
<div style={{textAlign: 'center'}}>
<p>
微信扫码关注公众号输入验证码获取验证码三分钟内有效
</p>
@ -227,7 +314,7 @@ const PersonalSetting = () => {
onOpen={() => setShowEmailBindModal(true)}
open={showEmailBindModal}
size={'tiny'}
style={{ maxWidth: '450px' }}
style={{maxWidth: '450px'}}
>
<Modal.Header>绑定邮箱地址</Modal.Header>
<Modal.Content>
@ -280,7 +367,7 @@ const PersonalSetting = () => {
onOpen={() => setShowAccountDeleteModal(true)}
open={showAccountDeleteModal}
size={'tiny'}
style={{ maxWidth: '450px' }}
style={{maxWidth: '450px'}}
>
<Modal.Header>确认删除自己的帐户</Modal.Header>
<Modal.Content>

View File

@ -235,6 +235,8 @@ const TokensTable = () => {
<Table.Cell>{token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}</Table.Cell>
<Table.Cell>
<div>
<Popup
trigger={
<Button
size={'small'}
positive
@ -250,6 +252,10 @@ const TokensTable = () => {
>
复制
</Button>
}
on={'hover'}
content={"sk-" + token.key}
/>
<Popup
trigger={
<Button size='small' negative>

View File

@ -1,8 +1,8 @@
import React, { useContext, useEffect, useState } from 'react';
import { Card, Grid, Header, Segment } from 'semantic-ui-react';
import { API, showError, showNotice, timestamp2string } from '../../helpers';
import { StatusContext } from '../../context/Status';
import { marked } from 'marked';
import React, {useContext, useEffect, useState} from 'react';
import {Card, Grid, Header, Segment} from 'semantic-ui-react';
import {API, showError, showNotice, timestamp2string} from '../../helpers';
import {StatusContext} from '../../context/Status';
import {marked} from 'marked';
const Home = () => {
const [statusState, statusDispatch] = useContext(StatusContext);
@ -11,7 +11,7 @@ const Home = () => {
const displayNotice = async () => {
const res = await API.get('/api/notice');
const { success, message, data } = res.data;
const {success, message, data} = res.data;
if (success) {
let oldNotice = localStorage.getItem('notice');
if (data !== oldNotice && data !== '') {
@ -26,7 +26,7 @@ const Home = () => {
const displayHomePageContent = async () => {
setHomePageContent(localStorage.getItem('home_page_content') || '');
const res = await API.get('/api/home_page_content');
const { success, message, data } = res.data;
const {success, message, data} = res.data;
if (success) {
let content = data;
if (!data.startsWith('https://')) {
@ -53,28 +53,20 @@ const Home = () => {
return (
<>
{
homePageContentLoaded && homePageContent === '' ? <>
// homePageContentLoaded && homePageContent === '' ?
<>
<Segment>
<Header as='h3'>系统状况</Header>
<Header as='h3'>当前状态</Header>
<Grid columns={2} stackable>
<Grid.Column>
<Card fluid>
<Card.Content>
<Card.Header>系统信息</Card.Header>
<Card.Meta>系统信息总览</Card.Meta>
<Card.Header>GPT-3.5</Card.Header>
<Card.Meta>信息总览</Card.Meta>
<Card.Description>
<p>名称{statusState?.status?.system_name}</p>
<p>版本{statusState?.status?.version}</p>
<p>
源码
<a
href='https://github.com/songquanpeng/one-api'
target='_blank'
>
https://github.com/songquanpeng/one-api
</a>
</p>
<p>启动时间{getStartTimeString()}</p>
<p>通道官方通道</p>
<p>状态存活</p>
<p>价格{statusState?.status?.base_price}R&nbsp;/&nbsp;</p>
</Card.Description>
</Card.Content>
</Card>
@ -82,32 +74,26 @@ const Home = () => {
<Grid.Column>
<Card fluid>
<Card.Content>
<Card.Header>系统配置</Card.Header>
<Card.Meta>系统配置总览</Card.Meta>
<Card.Header>GPT-4</Card.Header>
<Card.Meta>信息总览</Card.Meta>
<Card.Description>
<p>通道官方通道低价通道</p>
<p>
邮箱验证
{statusState?.status?.email_verification === true
? '已启用'
: '未启用'}
状态
{statusState?.status?.stable_price===-1?
<span style={{color:'red'}}>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</span>
:
<span style={{color:'green'}}>&emsp;&emsp;</span>
}
{statusState?.status?.normal_price===-1?
<span style={{color:'red'}}>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</span>
:
<span style={{color:'green'}}>&emsp;&emsp;</span>
}
</p>
<p>
GitHub 身份验证
{statusState?.status?.github_oauth === true
? '已启用'
: '未启用'}
</p>
<p>
微信身份验证
{statusState?.status?.wechat_login === true
? '已启用'
: '未启用'}
</p>
<p>
Turnstile 用户校验
{statusState?.status?.turnstile_check === true
? '已启用'
: '未启用'}
价格{statusState?.status?.stable_price}R&nbsp;/&nbsp;刀|{statusState?.status?.normal_price}R&nbsp;/&nbsp;
</p>
</Card.Description>
</Card.Content>
@ -115,7 +101,6 @@ const Home = () => {
</Grid.Column>
</Grid>
</Segment>
</> : <>
{
homePageContent.startsWith('https://') ? <iframe
src={homePageContent}
@ -123,6 +108,10 @@ const Home = () => {
/> : <div style={{ fontSize: 'larger' }} dangerouslySetInnerHTML={{ __html: homePageContent }}></div>
}
</>
// :
// <>
// </>
}
</>

View File

@ -0,0 +1,11 @@
import React from 'react';
import { Header, Segment } from 'semantic-ui-react';
import MjLogsTable from '../../components/MjLogsTable';
const Midjourney = () => (
<>
<MjLogsTable />
</>
);
export default Midjourney;

View File

@ -1,10 +1,13 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Grid, Header, Segment, Statistic } from 'semantic-ui-react';
import { API, showError, showInfo, showSuccess } from '../../helpers';
import { renderQuota } from '../../helpers/render';
import React, {useEffect, useState} from 'react';
import {Button, Form, Grid, Header, Segment, Statistic} from 'semantic-ui-react';
import {API, showError, showInfo, showSuccess} from '../../helpers';
import {renderNumber, renderQuota} from '../../helpers/render';
const TopUp = () => {
const [redemptionCode, setRedemptionCode] = useState('');
const [topUpCode, setTopUpCode] = useState('');
const [topUpCount, setTopUpCount] = useState(10);
const [amount, setAmount] = useState(0);
const [topUpLink, setTopUpLink] = useState('');
const [userQuota, setUserQuota] = useState(0);
const [isSubmitting, setIsSubmitting] = useState(false);
@ -19,7 +22,7 @@ const TopUp = () => {
const res = await API.post('/api/user/topup', {
key: redemptionCode
});
const { success, message, data } = res.data;
const {success, message, data} = res.data;
if (success) {
showSuccess('充值成功!');
setUserQuota((quota) => {
@ -44,7 +47,51 @@ const TopUp = () => {
window.open(topUpLink, '_blank');
};
const getUserQuota = async ()=>{
const onlineTopUp = async (payment) => {
if (amount === 0) {
await getAmount();
}
try {
const res = await API.post('/api/user/pay', {
amount: parseInt(topUpCount),
top_up_code: topUpCode,
PaymentMethod: payment
});
if (res !== undefined) {
const {message, data} = res.data;
// showInfo(message);
if (message === 'success') {
let params = data
let url = res.data.url
let form = document.createElement('form')
form.action = url
form.method = 'POST'
form.target = '_blank'
for (let key in params) {
let input = document.createElement('input')
input.type = 'hidden'
input.name = key
input.value = params[key]
form.appendChild(input)
}
document.body.appendChild(form)
form.submit()
document.body.removeChild(form)
} else {
showError(message);
// setTopUpCount(parseInt(res.data.count));
setAmount(parseInt(data));
}
} else {
showError(res);
}
} catch (err) {
console.log(err);
} finally {
}
}
const getUserQuota = async () => {
let res = await API.get(`/api/user/self`);
const {success, message, data} = res.data;
if (success) {
@ -65,7 +112,41 @@ const TopUp = () => {
getUserQuota().then();
}, []);
const renderAmount = () => {
console.log(amount);
return amount + '元';
}
const getAmount = async (value) => {
if (value === undefined) {
value = topUpCount;
}
try {
const res = await API.post('/api/user/amount', {
amount: parseFloat(value),
top_up_code: topUpCode
});
if (res !== undefined) {
const {message, data} = res.data;
// showInfo(message);
if (message === 'success') {
setAmount(parseInt(data));
} else {
showError(message);
// setTopUpCount(parseInt(res.data.count));
setAmount(parseInt(data));
}
} else {
showError(res);
}
} catch (err) {
console.log(err);
} finally {
}
}
return (
<div>
<Segment>
<Header as='h3'>充值额度</Header>
<Grid columns={2} stackable>
@ -79,9 +160,9 @@ const TopUp = () => {
setRedemptionCode(e.target.value);
}}
/>
<Button color='green' onClick={openTopUpLink}>
获取兑换码
</Button>
{/*<Button color='green' onClick={openTopUpLink}>*/}
{/* 获取兑换码*/}
{/*</Button>*/}
<Button color='yellow' onClick={topUp} disabled={isSubmitting}>
{isSubmitting ? '兑换中...' : '兑换'}
</Button>
@ -97,6 +178,58 @@ const TopUp = () => {
</Grid.Column>
</Grid>
</Segment>
<Segment>
<Header as='h3'>在线充值最小10刀</Header>
<Grid columns={2} stackable>
<Grid.Column>
<Form>
<Form.Input
placeholder='充值金额最低10,最高400'
name='redemptionCount'
type={'number'}
value={topUpCount}
autoComplete={'off'}
onChange={async (e) => {
setTopUpCount(e.target.value);
await getAmount(e.target.value);
}}
/>
<Form.Input
placeholder='充值码,如果你没有充值码,可不填写'
name='redemptionCount'
value={topUpCode}
onChange={(e) => {
setTopUpCode(e.target.value);
}}
/>
<Button color='blue' onClick={
async () => {
onlineTopUp('zfb')
}
}>
支付宝最大400元
</Button>
<Button color='green' onClick={
async () => {
onlineTopUp('wx')
}
}>
微信最大600元
</Button>
</Form>
</Grid.Column>
<Grid.Column>
<Statistic.Group widths='one'>
<Statistic>
<Statistic.Value>{renderAmount()}</Statistic.Value>
<Statistic.Label>支付金额</Statistic.Label>
</Statistic>
</Statistic.Group>
</Grid.Column>
</Grid>
</Segment>
</div>
);
};