diff --git a/Midjourney.md b/Midjourney.md new file mode 100644 index 00000000..8141d448 --- /dev/null +++ b/Midjourney.md @@ -0,0 +1,254 @@ +# Midjourney Proxy API文档 + + +**简介**:Midjourney Proxy API文档 + + +**HOST**:https://api.nekoedu.com + + +**Version**:v2.3.5 + + +[TOC] + + + + + + +# 任务提交 + + +## 绘图变化 + + +**接口地址**:`/mj/submit/change` + + +**请求方式**:`POST` + + +**请求数据类型**:`application/json` + + +**响应数据类型**:`*/*` + + +**接口描述**: + + +**请求示例**: + + +```javascript +{ + "action": "UPSCALE", + "index": 1, + "notifyHook": "", + "state": "", + "taskId": "1320098173412546" +} +``` + + +**请求参数**: + + +| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema | +| -------- | -------- | ----- | -------- | -------- | ------ | +|changeDTO|changeDTO|body|true|变化任务提交参数|变化任务提交参数| +|  action|UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL||true|string|| +|  index|序号(1~4), action为UPSCALE,VARIATION时必传||false|integer(int32)|| +|  notifyHook|回调地址, 为空时使用全局notifyHook||false|string|| +|  state|自定义参数||false|string|| +|  taskId|任务ID||true|string|| + + +**响应状态**: + + +| 状态码 | 说明 | schema | +| -------- | -------- | ----- | +|200|OK|提交结果| +|201|Created|| +|401|Unauthorized|| +|403|Forbidden|| +|404|Not Found|| + + +**响应参数**: + + +| 参数名称 | 参数说明 | 类型 | schema | +| -------- | -------- | ----- |----- | +|code|状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)|integer(int32)|integer(int32)| +|description|描述|string|| +|properties|扩展字段|object|| +|result|任务ID|string|| + + +**响应示例**: +```javascript +{ + "code": 1, + "description": "提交成功", + "properties": {}, + "result": 1320098173412546 +} +``` + +## 提交Imagine任务 + + +**接口地址**:`/mj/submit/imagine` + + +**请求方式**:`POST` + + +**请求数据类型**:`application/json` + + +**响应数据类型**:`*/*` + + +**接口描述**: + + +**请求示例**: + + +```javascript +{ + "base64": "", + "notifyHook": "", + "prompt": "Cat", + "state": "" +} +``` + + +**请求参数**: + + +| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema | +| -------- | -------- | ----- | -------- | -------- | ------ | +|imagineDTO|imagineDTO|body|true|Imagine提交参数|Imagine提交参数| +|  base64|垫图base64||false|string|| +|  notifyHook|回调地址, 为空时使用全局notifyHook||false|string|| +|  prompt|提示词||true|string|| +|  state|自定义参数||false|string|| + + +**响应状态**: + + +| 状态码 | 说明 | schema | +| -------- | -------- | ----- | +|200|OK|提交结果| +|201|Created|| +|401|Unauthorized|| +|403|Forbidden|| +|404|Not Found|| + + +**响应参数**: + + +| 参数名称 | 参数说明 | 类型 | schema | +| -------- | -------- | ----- |----- | +|code|状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)|integer(int32)|integer(int32)| +|description|描述|string|| +|properties|扩展字段|object|| +|result|任务ID|string|| + + +**响应示例**: +```javascript +{ + "code": 1, + "description": "提交成功", + "properties": {}, + "result": 1320098173412546 +} +``` + + +# 任务查询 + +## 指定ID获取任务 + + +**接口地址**:`/mj/task/{id}/fetch` + + +**请求方式**:`GET` + + +**请求数据类型**:`application/x-www-form-urlencoded` + + +**响应数据类型**:`*/*` + + +**接口描述**: + + +**请求参数**: + + +| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema | +| -------- | -------- | ----- | -------- | -------- | ------ | +|id|任务ID|path|false|string|| + + +**响应状态**: + + +| 状态码 | 说明 | schema | +| -------- | -------- | ----- | +|200|OK|任务| +|401|Unauthorized|| +|403|Forbidden|| +|404|Not Found|| + + +**响应参数**: + + +| 参数名称 | 参数说明 | 类型 | schema | +| -------- | -------- | ----- |----- | +|action|可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND|string|| +|description|任务描述|string|| +|failReason|失败原因|string|| +|finishTime|结束时间|integer(int64)|integer(int64)| +|id|任务ID|string|| +|imageUrl|图片url|string|| +|progress|任务进度|string|| +|prompt|提示词|string|| +|promptEn|提示词-英文|string|| +|startTime|开始执行时间|integer(int64)|integer(int64)| +|state|自定义参数|string|| +|status|任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS|string|| +|submitTime|提交时间|integer(int64)|integer(int64)| + + +**响应示例**: +```javascript +{ + "action": "", + "description": "", + "failReason": "", + "finishTime": 0, + "id": "", + "imageUrl": "", + "progress": "", + "prompt": "", + "promptEn": "", + "startTime": 0, + "state": "", + "status": "", + "submitTime": 0 +} +``` \ No newline at end of file diff --git a/common/constants.go b/common/constants.go index 81f98163..95f5b44e 100644 --- a/common/constants.go +++ b/common/constants.go @@ -79,6 +79,10 @@ var RequestInterval = time.Duration(requestInterval) * time.Second var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY +var NormalPrice = 1.5 +var StablePrice = 6.0 +var BasePrice = 1.5 + const ( RoleGuestUser = 0 RoleCommonUser = 1 diff --git a/controller/log.go b/controller/log.go index ba043349..1ff10519 100644 --- a/controller/log.go +++ b/controller/log.go @@ -94,6 +94,23 @@ func SearchUserLogs(c *gin.Context) { }) } +func GetLogByKey(c *gin.Context) { + key := c.Query("key") + logs, err := model.GetLogByKey(key) + if err != nil { + c.JSON(200, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(200, gin.H{ + "success": true, + "message": "", + "data": logs, + }) +} + func GetLogsStat(c *gin.Context) { logType, _ := strconv.Atoi(c.Query("type")) startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) diff --git a/controller/midjourney.go b/controller/midjourney.go new file mode 100644 index 00000000..fa2c4d9b --- /dev/null +++ b/controller/midjourney.go @@ -0,0 +1,133 @@ +package controller + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "log" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + "time" +) + +func UpdateMidjourneyTask() { + imageModel := "midjourney" + for { + time.Sleep(time.Duration(15) * time.Second) + tasks := model.GetAllUnFinishTasks() + if len(tasks) != 0 { + //log.Printf("UpdateMidjourneyTask: %v", time.Now()) + ids := make([]string, 0) + for _, task := range tasks { + ids = append(ids, task.MjId) + } + requestUrl := "http://107.173.171.147:8080/mj/task/list-by-condition" + requestBody := map[string]interface{}{ + "ids": ids, + } + jsonStr, err := json.Marshal(requestBody) + if err != nil { + log.Printf("UpdateMidjourneyTask: %v", err) + } + req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(jsonStr)) + if err != nil { + log.Printf("UpdateMidjourneyTask: %v", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("mj-api-secret", "uhiftyuwadbkjshbiklahcuitguasguzhxliawodawdu") + resp, err := httpClient.Do(req) + if err != nil { + log.Printf("UpdateMidjourneyTask: %v", err) + } + defer resp.Body.Close() + var response []Midjourney + err = json.NewDecoder(resp.Body).Decode(&response) + if err != nil { + log.Printf("UpdateMidjourneyTask: %v", err) + } + for _, responseItem := range response { + var midjourneyTask *model.Midjourney + for _, mj := range tasks { + mj.MjId = responseItem.MjId + midjourneyTask = model.GetMjByuId(mj.Id) + } + if midjourneyTask != nil { + midjourneyTask.Code = 1 + midjourneyTask.Progress = responseItem.Progress + midjourneyTask.PromptEn = responseItem.PromptEn + midjourneyTask.State = responseItem.State + midjourneyTask.SubmitTime = responseItem.SubmitTime + midjourneyTask.StartTime = responseItem.StartTime + midjourneyTask.FinishTime = responseItem.FinishTime + midjourneyTask.ImageUrl = responseItem.ImageUrl + midjourneyTask.Status = responseItem.Status + midjourneyTask.FailReason = responseItem.FailReason + if midjourneyTask.Progress != "100%" && responseItem.FailReason != "" { + log.Println(midjourneyTask.MjId + " 构建失败," + midjourneyTask.FailReason) + midjourneyTask.Progress = "100%" + err = model.CacheUpdateUserQuota(midjourneyTask.UserId) + if err != nil { + log.Println("error update user quota cache: " + err.Error()) + } else { + modelRatio := common.GetModelRatio(imageModel) + groupRatio := common.GetGroupRatio("default") + ratio := modelRatio * groupRatio + quota := int(ratio * 1 * 1000) + if quota != 0 { + err := model.IncreaseUserQuota(midjourneyTask.UserId, quota) + if err != nil { + log.Println("fail to increase user quota") + } + logContent := fmt.Sprintf("%s 构图失败,补偿 %s", midjourneyTask.MjId, common.LogQuota(quota)) + model.RecordLog(midjourneyTask.UserId, 1, logContent) + } + } + } + + err = midjourneyTask.Update() + if err != nil { + log.Printf("UpdateMidjourneyTaskFail: %v", err) + } + log.Printf("UpdateMidjourneyTask: %v", midjourneyTask) + } + } + } + } +} + +func GetAllMidjourney(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage) + if logs == nil { + logs = make([]*model.Midjourney, 0) + } + c.JSON(200, gin.H{ + "success": true, + "message": "", + "data": logs, + }) +} + +func GetUserMidjourney(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + userId := c.GetInt("id") + log.Printf("userId = %d \n", userId) + logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage) + if logs == nil { + logs = make([]*model.Midjourney, 0) + } + c.JSON(200, gin.H{ + "success": true, + "message": "", + "data": logs, + }) +} diff --git a/controller/misc.go b/controller/misc.go index 958a3716..5cbbb9d2 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -31,6 +31,9 @@ func GetStatus(c *gin.Context) { "chat_link": common.ChatLink, "quota_per_unit": common.QuotaPerUnit, "display_in_currency": common.DisplayInCurrencyEnabled, + "normal_price": common.NormalPrice, + "stable_price": common.StablePrice, + "base_price": common.BasePrice, }, }) return @@ -58,6 +61,17 @@ func GetAbout(c *gin.Context) { return } +func GetMidjourney(c *gin.Context) { + common.OptionMapRWMutex.RLock() + defer common.OptionMapRWMutex.RUnlock() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": common.OptionMap["Midjourney"], + }) + return +} + func GetHomePageContent(c *gin.Context) { common.OptionMapRWMutex.RLock() defer common.OptionMapRWMutex.RUnlock() diff --git a/controller/relay-image.go b/controller/relay-image.go index de623288..e66a1800 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -137,7 +137,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent) + model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent, tokenId) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) diff --git a/controller/relay-mj.go b/controller/relay-mj.go new file mode 100644 index 00000000..8989f036 --- /dev/null +++ b/controller/relay-mj.go @@ -0,0 +1,388 @@ +package controller + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "one-api/common" + "one-api/model" + "strings" + + "github.com/gin-gonic/gin" +) + +type Midjourney struct { + MjId string `json:"id"` + Action string `json:"action"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submitTime"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` + ImageUrl string `json:"imageUrl"` + Status string `json:"status"` + Progress string `json:"progress"` + FailReason string `json:"failReason"` +} + +func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse { + var midjRequest Midjourney + err := common.UnmarshalBodyReusable(c, &midjRequest) + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "bind_request_body_failed", + Properties: nil, + Result: "", + } + } + midjourneyTask := model.GetByMJId(midjRequest.MjId) + if midjourneyTask == nil { + return &MidjourneyResponse{ + Code: 4, + Description: "midjourney_task_not_found", + Properties: nil, + Result: "", + } + } + midjourneyTask.Progress = midjRequest.Progress + midjourneyTask.PromptEn = midjRequest.PromptEn + midjourneyTask.State = midjRequest.State + midjourneyTask.SubmitTime = midjRequest.SubmitTime + midjourneyTask.StartTime = midjRequest.StartTime + midjourneyTask.FinishTime = midjRequest.FinishTime + midjourneyTask.ImageUrl = midjRequest.ImageUrl + midjourneyTask.Status = midjRequest.Status + midjourneyTask.FailReason = midjRequest.FailReason + err = midjourneyTask.Update() + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "update_midjourney_task_failed", + } + } + + return nil +} + +func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { + taskId := c.Param("id") + originTask := model.GetByMJId(taskId) + if originTask == nil { + return &MidjourneyResponse{ + Code: 4, + Description: "task_no_found", + } + } + var midjourneyTask Midjourney + midjourneyTask.MjId = originTask.MjId + midjourneyTask.Progress = originTask.Progress + midjourneyTask.PromptEn = originTask.PromptEn + midjourneyTask.State = originTask.State + midjourneyTask.SubmitTime = originTask.SubmitTime + midjourneyTask.StartTime = originTask.StartTime + midjourneyTask.FinishTime = originTask.FinishTime + midjourneyTask.ImageUrl = originTask.ImageUrl + midjourneyTask.Status = originTask.Status + midjourneyTask.FailReason = originTask.FailReason + midjourneyTask.Action = originTask.Action + midjourneyTask.Description = originTask.Description + midjourneyTask.Prompt = originTask.Prompt + jsonMap, err := json.Marshal(midjourneyTask) + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "unmarshal_response_body_failed", + } + } + _, err = io.Copy(c.Writer, bytes.NewBuffer(jsonMap)) + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "copy_response_body_failed", + } + } + return nil +} + +func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { + imageModel := "midjourney" + + tokenId := c.GetInt("token_id") + channelType := c.GetInt("channel") + userId := c.GetInt("id") + consumeQuota := c.GetBool("consume_quota") + group := c.GetString("group") + + var midjRequest MidjourneyRequest + if consumeQuota { + err := common.UnmarshalBodyReusable(c, &midjRequest) + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "bind_request_body_failed", + } + } + } + if relayMode == RelayModeMidjourneyImagine { + if midjRequest.Prompt == "" { + return &MidjourneyResponse{ + Code: 4, + Description: "prompt_is_required", + } + } + midjRequest.Action = "IMAGINE" + } else if midjRequest.TaskId != "" { + originTask := model.GetByMJId(midjRequest.TaskId) + if originTask == nil { + return &MidjourneyResponse{ + Code: 4, + Description: "task_no_found", + } + } else if originTask.Action == "UPSCALE" { + //return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest). + return &MidjourneyResponse{ + Code: 4, + Description: "upscale_task_can_not_be_change", + } + } else if originTask.Status != "SUCCESS" { + return &MidjourneyResponse{ + Code: 4, + Description: "task_status_is_not_success", + } + + } + midjRequest.Prompt = originTask.Prompt + } else if relayMode == RelayModeMidjourneyChange { + if midjRequest.TaskId == "" { + return &MidjourneyResponse{ + Code: 4, + Description: "taskId_is_required", + } + } else if midjRequest.Action == "" { + return &MidjourneyResponse{ + Code: 4, + Description: "action_is_required", + } + } else if midjRequest.Index == 0 { + return &MidjourneyResponse{ + Code: 4, + Description: "index_can_only_be_1_2_3_4", + } + } + } + + // map model name + modelMapping := c.GetString("model_mapping") + isModelMapped := false + if modelMapping != "" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + return &MidjourneyResponse{ + Code: 4, + Description: "unmarshal_model_mapping_failed", + } + } + if modelMap[imageModel] != "" { + imageModel = modelMap[imageModel] + isModelMapped = true + } + } + + baseURL := common.ChannelBaseURLs[channelType] + requestURL := c.Request.URL.String() + + if c.GetString("base_url") != "" { + baseURL = c.GetString("base_url") + } + + //midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify" + + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + + var requestBody io.Reader + if isModelMapped { + jsonStr, err := json.Marshal(midjRequest) + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "marshal_text_request_failed", + } + } + requestBody = bytes.NewBuffer(jsonStr) + } else { + requestBody = c.Request.Body + } + + modelRatio := common.GetModelRatio(imageModel) + groupRatio := common.GetGroupRatio(group) + ratio := modelRatio * groupRatio + userQuota, err := model.CacheGetUserQuota(userId) + + sizeRatio := 1.0 + + quota := int(ratio * sizeRatio * 1000) + + if consumeQuota && userQuota-quota < 0 { + return &MidjourneyResponse{ + Code: 4, + Description: "quota_not_enough", + } + } + + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "create_request_failed", + } + } + //req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) + // print request header + log.Printf("request header: %s", req.Header) + log.Printf("request body: %s", midjRequest.Prompt) + + resp, err := httpClient.Do(req) + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "do_request_failed", + } + } + + err = req.Body.Close() + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "close_request_body_failed", + } + } + err = c.Request.Body.Close() + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "close_request_body_failed", + } + } + var midjResponse MidjourneyResponse + + defer func() { + if consumeQuota { + err := model.PostConsumeTokenQuota(tokenId, quota) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent, tokenId) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) + } + } + }() + + //if consumeQuota { + // + //} + responseBody, err := io.ReadAll(resp.Body) + + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "read_response_body_failed", + } + } + err = resp.Body.Close() + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "close_response_body_failed", + } + } + + err = json.Unmarshal(responseBody, &midjResponse) + log.Printf("responseBody: %s", string(responseBody)) + log.Printf("midjResponse: %v", midjResponse) + if resp.StatusCode != 200 { + return &MidjourneyResponse{ + Code: 4, + Description: "fail_to_fetch_midjourney", + } + } + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "unmarshal_response_body_failed", + } + } + if midjResponse.Code == 24 || midjResponse.Code == 21 || midjResponse.Code == 4 { + consumeQuota = false + } + + midjourneyTask := &model.Midjourney{ + UserId: userId, + Code: midjResponse.Code, + Action: midjRequest.Action, + MjId: midjResponse.Result, + Prompt: midjRequest.Prompt, + PromptEn: "", + Description: midjResponse.Description, + State: "", + SubmitTime: 0, + StartTime: 0, + FinishTime: 0, + ImageUrl: "", + Status: "", + Progress: "0%", + FailReason: "", + } + if midjResponse.Code == 4 { + midjourneyTask.FailReason = midjResponse.Description + } + err = midjourneyTask.Insert() + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "insert_midjourney_task_failed", + } + } + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "copy_response_body_failed", + } + } + err = resp.Body.Close() + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "close_response_body_failed", + } + } + return nil +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 52e10f2b..621ae336 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -95,6 +95,33 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { case common.ChannelTypeZhipu: apiType = APITypeZhipu } + isStable := c.GetBool("stable") + + //if common.NormalPrice == -1 && strings.HasPrefix(textRequest.Model, "gpt-4") { + // nowUser, err := model.GetUserById(userId, false) + // if err != nil { + // return errorWrapper(err, "get_user_info_failed", http.StatusInternalServerError) + // } + // if nowUser.StableMode { + // group = "svip" + // isStable = true + // ////stableRatio = (common.StablePrice / common.BasePrice) * modelRatio + // //userMaxPrice, _ := strconv.ParseFloat(nowUser.MaxPrice, 64) + // //if userMaxPrice < common.StablePrice { + // // return errorWrapper(errors.New("当前低价通道不可用,稳定渠道价格为"+strconv.FormatFloat(common.StablePrice, 'f', -1, 64)+"R/刀"), "当前低价通道不可用", http.StatusInternalServerError) + // //} + // // + // ////ratio = stableRatio * groupRatio + // //channel, err := model.CacheGetRandomSatisfiedChannel("svip", textRequest.Model) + // //if err != nil { + // // message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", "svip", textRequest.Model) + // // return errorWrapper(errors.New(message), "no_available_channel", http.StatusInternalServerError) + // //} + // //channelType = channel.Type + // } else { + // return errorWrapper(errors.New("当前低价通道不可用,请稍后再试,或者在后台开启稳定模式"), "当前低价通道不可用", http.StatusInternalServerError) + // } + //} baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() if c.GetString("base_url") != "" { @@ -168,11 +195,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.MaxTokens != 0 { preConsumedTokens = promptTokens + textRequest.MaxTokens } + //stableRatio := common.GetStableRatio(textRequest.Model) modelRatio := common.GetModelRatio(textRequest.Model) + stableRatio := modelRatio groupRatio := common.GetGroupRatio(group) ratio := modelRatio * groupRatio preConsumedQuota := int(float64(preConsumedTokens) * ratio) userQuota, err := model.CacheGetUserQuota(userId) + if isStable { + stableRatio = (common.StablePrice / common.BasePrice) * modelRatio + ratio = stableRatio * groupRatio + } if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } @@ -301,6 +334,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { // we cannot just return, because we may have to return the pre-consumed quota quota = 0 } + //if strings.HasPrefix(textRequest.Model, "gpt-4") { + // if quota < 5000 && quota != 0 { + // quota = 5000 + // } + //} quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { @@ -312,8 +350,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } if quota != 0 { tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) + var logContent string + if isStable { + logContent = fmt.Sprintf("(稳定模式)模型倍率 %.2f,分组倍率 %.2f", stableRatio, groupRatio) + } else { + logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + } + model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) diff --git a/controller/relay.go b/controller/relay.go index 9cfa5c4f..96abf6ee 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,6 +2,7 @@ package controller import ( "fmt" + "log" "net/http" "one-api/common" "strconv" @@ -24,6 +25,10 @@ const ( RelayModeModerations RelayModeImagesGenerations RelayModeEdits + RelayModeMidjourneyImagine + RelayModeMidjourneyChange + RelayModeMidjourneyNotify + RelayModeMidjourneyTaskFetch ) // https://platform.openai.com/docs/api-reference/chat @@ -128,6 +133,23 @@ type CompletionsStreamResponse struct { } `json:"choices"` } +type MidjourneyRequest struct { + Prompt string `json:"prompt"` + NotifyHook string `json:"notifyHook"` + Action string `json:"action"` + Index int `json:"index"` + State string `json:"state"` + TaskId string `json:"taskId"` + Base64Array []string `json:"base64Array"` +} + +type MidjourneyResponse struct { + Code int `json:"code"` + Description string `json:"description"` + Properties interface{} `json:"properties"` + Result string `json:"result"` +} + func Relay(c *gin.Context) { relayMode := RelayModeUnknown if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { @@ -179,6 +201,54 @@ func Relay(c *gin.Context) { } } +func RelayMidjourney(c *gin.Context) { + relayMode := RelayModeUnknown + if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") { + relayMode = RelayModeMidjourneyImagine + } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") { + relayMode = RelayModeMidjourneyNotify + } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") { + relayMode = RelayModeMidjourneyChange + } else if strings.HasPrefix(c.Request.URL.Path, "/mj/task") { + relayMode = RelayModeMidjourneyTaskFetch + } + var err *MidjourneyResponse + switch relayMode { + case RelayModeMidjourneyNotify: + err = relayMidjourneyNotify(c) + case RelayModeMidjourneyTaskFetch: + err = relayMidjourneyTask(c, relayMode) + default: + err = relayMidjourneySubmit(c, relayMode) + } + //err = relayMidjourneySubmit(c, relayMode) + log.Println(err) + if err != nil { + retryTimesStr := c.Query("retry") + retryTimes, _ := strconv.Atoi(retryTimesStr) + if retryTimesStr == "" { + retryTimes = common.RetryTimes + } + if retryTimes > 0 { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) + } else { + if err.Code == 30 { + err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" + } + c.JSON(400, gin.H{ + "error": err.Result, + }) + } + channelId := c.GetInt("channel_id") + common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Result)) + //if shouldDisableChannel(&err.OpenAIError) { + // channelId := c.GetInt("channel_id") + // channelName := c.GetString("channel_name") + // disableChannel(channelId, channelName, err.Result) + //} + } +} + func RelayNotImplemented(c *gin.Context) { err := OpenAIError{ Message: "API not implemented", diff --git a/controller/topup.go b/controller/topup.go new file mode 100644 index 00000000..9f6b3a82 --- /dev/null +++ b/controller/topup.go @@ -0,0 +1,173 @@ +package controller + +import ( + "fmt" + "github.com/gin-gonic/gin" + "github.com/samber/lo" + epay "github.com/star-horizon/go-epay" + "log" + "net/url" + "one-api/common" + "one-api/model" + "strconv" + "time" +) + +type EpayRequest struct { + Amount int `json:"amount"` + PaymentMethod string `json:"payment_method"` + TopUpCode string `json:"top_up_code"` +} + +type AmountRequest struct { + Amount int `json:"amount"` + TopUpCode string `json:"top_up_code"` +} + +var client, _ = epay.NewClientWithUrl(&epay.Config{ + PartnerID: "1096", + Key: "n08V9LpE8JffA3NPP893689u8p39NV9J", +}, "https://api.lempay.org") + +func GetAmount(id int, count float64, topUpCode string) float64 { + amount := count * 1.5 + if topUpCode != "" { + if topUpCode == "nekoapi" { + if id == 89 { + amount = count * 1 + } else if id == 98 || id == 105 || id == 107 { + amount = count * 1.2 + } else if id == 1 { + amount = count * 1 + } + } + } + return amount +} + +func RequestEpay(c *gin.Context) { + var req EpayRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.JSON(200, gin.H{"message": err.Error(), "data": 10}) + return + } + id := c.GetInt("id") + amount := GetAmount(id, float64(req.Amount), req.TopUpCode) + if id != 1 { + if req.Amount < 10 { + c.JSON(200, gin.H{"message": "最小充值10元", "data": amount, "count": 10}) + return + } + } + if req.PaymentMethod == "zfb" { + if amount > 400 { + c.JSON(200, gin.H{"message": "支付宝最大充值400元", "data": amount, "count": 400}) + return + } + req.PaymentMethod = "alipay" + } + if req.PaymentMethod == "wx" { + if amount > 600 { + c.JSON(200, gin.H{"message": "微信最大充值600元", "data": amount, "count": 600}) + return + } + req.PaymentMethod = "wxpay" + } + + returnUrl, _ := url.Parse("https://nekoapi.com/log") + notifyUrl, _ := url.Parse("https://nekoapi.com/api/user/epay/notify") + tradeNo := strconv.FormatInt(time.Now().Unix(), 10) + uri, params, err := client.Purchase(&epay.PurchaseArgs{ + Type: epay.PurchaseType(req.PaymentMethod), + ServiceTradeNo: "A" + tradeNo, + Name: "B" + tradeNo, + Money: strconv.FormatFloat(amount*0.99, 'f', 2, 64), + Device: epay.PC, + NotifyUrl: notifyUrl, + ReturnUrl: returnUrl, + }) + if err != nil { + c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) + return + } + topUp := &model.TopUp{ + UserId: id, + Amount: req.Amount, + Money: int(amount), + TradeNo: "A" + tradeNo, + CreateTime: time.Now().Unix(), + Status: "pending", + } + err = topUp.Insert() + if err != nil { + c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"}) + return + } + c.JSON(200, gin.H{"message": "success", "data": params, "url": uri}) +} + +func EpayNotify(c *gin.Context) { + params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string { + r[t] = c.Request.URL.Query().Get(t) + return r + }, map[string]string{}) + verifyInfo, err := client.Verify(params) + if err == nil && verifyInfo.VerifyStatus { + _, err := c.Writer.Write([]byte("success")) + if err != nil { + log.Println("易支付回调写入失败") + } + } else { + _, err := c.Writer.Write([]byte("fail")) + if err != nil { + log.Println("易支付回调写入失败") + } + } + + if verifyInfo.TradeStatus == epay.StatusTradeSuccess { + log.Println(verifyInfo) + topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo) + if topUp.Status == "pending" { + topUp.Status = "success" + err := topUp.Update() + if err != nil { + log.Printf("易支付回调更新订单失败: %v", topUp) + return + } + //user, _ := model.GetUserById(topUp.UserId, false) + //user.Quota += topUp.Amount * 500000 + err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*500000) + if err != nil { + log.Printf("易支付回调更新用户失败: %v", topUp) + return + } + log.Printf("易支付回调更新用户成功 %v", topUp) + model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v", common.LogQuota(topUp.Amount*500000))) + } + } else { + log.Printf("易支付异常回调: %v", verifyInfo) + } +} + +func RequestAmount(c *gin.Context) { + var req AmountRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.JSON(200, gin.H{"message": err.Error(), "data": 10}) + return + } + id := c.GetInt("id") + if id != 1 { + if req.Amount < 10 { + c.JSON(200, gin.H{"message": "最小充值10刀", "data": GetAmount(id, 10, req.TopUpCode), "count": 10}) + return + } + if req.Amount > 400 { + c.JSON(200, gin.H{"message": "最大充值400刀", "data": GetAmount(id, 400, req.TopUpCode), "count": 400}) + return + } + } + + c.JSON(200, gin.H{"message": "success", "data": GetAmount(id, float64(req.Amount), req.TopUpCode)}) +} diff --git a/controller/user.go b/controller/user.go index 8fd10b82..a072b54e 100644 --- a/controller/user.go +++ b/controller/user.go @@ -3,6 +3,7 @@ package controller import ( "encoding/json" "fmt" + "log" "net/http" "one-api/common" "one-api/model" @@ -79,6 +80,8 @@ func setupLogin(user *model.User, c *gin.Context) { DisplayName: user.DisplayName, Role: user.Role, Status: user.Status, + StableMode: user.StableMode, + MaxPrice: user.MaxPrice, } c.JSON(http.StatusOK, gin.H{ "message": "", @@ -158,6 +161,8 @@ func Register(c *gin.Context) { Password: user.Password, DisplayName: user.Username, InviterId: inviterId, + StableMode: user.StableMode, + MaxPrice: user.MaxPrice, } if common.EmailVerificationEnabled { cleanUser.Email = user.Email @@ -420,6 +425,8 @@ func UpdateSelf(c *gin.Context) { Username: user.Username, Password: user.Password, DisplayName: user.DisplayName, + StableMode: user.StableMode, + MaxPrice: user.MaxPrice, } if user.Password == "$I_LOVE_U" { user.Password = "" // rollback to what it should be @@ -741,3 +748,52 @@ func TopUp(c *gin.Context) { }) return } + +type StableModeRequest struct { + StableMode bool `json:"stableMode"` + MaxPrice string `json:"maxPrice"` +} + +func SetTableMode(c *gin.Context) { + req := &StableModeRequest{} + err := c.ShouldBindJSON(&req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + log.Println(req) + id := c.GetInt("id") + user := model.User{ + Id: id, + } + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.StableMode = req.StableMode + if !req.StableMode { + req.MaxPrice = "0" + } + user.MaxPrice = req.MaxPrice + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": "设置成功", + }) + return +} diff --git a/go.mod b/go.mod index 2e0cf017..ca4a8bf1 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.3.0 github.com/pkoukk/tiktoken-go v0.1.1 + github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 golang.org/x/crypto v0.9.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/sqlite v1.4.3 @@ -42,12 +43,15 @@ require ( github.com/leodido/go-urn v1.2.4 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/samber/lo v1.37.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect + golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.8.0 // indirect golang.org/x/text v0.9.0 // indirect diff --git a/go.sum b/go.sum index 7287206a..ba518d99 100644 --- a/go.sum +++ b/go.sum @@ -97,6 +97,8 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -117,6 +119,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/samber/lo v1.37.0 h1:XjVcB8g6tgUp8rsPsJ2CvhClfImrpL04YpQHXeHPhRw= +github.com/samber/lo v1.37.0/go.mod h1:9vaz2O4o8oOnK23pd2TrXufcbdbJIa3b6cstBWKpopA= +github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI= +github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2/go.mod h1:SiffGCWGGMVwujne2dUQbJ5zUVD1V1Yj0hDuTfqFNEo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -144,6 +150,8 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= +golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= @@ -163,8 +171,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= diff --git a/main.go b/main.go index d6d0c75b..c7182299 100644 --- a/main.go +++ b/main.go @@ -74,6 +74,7 @@ func main() { } go controller.AutomaticallyTestChannels(frequency) } + go controller.UpdateMidjourneyTask() // Initialize HTTP server server := gin.Default() diff --git a/middleware/distributor.go b/middleware/distributor.go index 91c00e1a..ab296903 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "log" "net/http" "one-api/common" "one-api/model" @@ -21,6 +22,7 @@ func Distribute() func(c *gin.Context) { userGroup, _ := model.CacheGetUserGroup(userId) c.Set("group", userGroup) var channel *model.Channel + var err error channelId, ok := c.Get("channelId") if ok { id, err := strconv.Atoi(channelId.(string)) @@ -56,19 +58,28 @@ func Distribute() func(c *gin.Context) { return } } else { + // Select a channel for the user var modelRequest ModelRequest - err := common.UnmarshalBodyReusable(c, &modelRequest) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "无效的请求", - "type": "one_api_error", - }, - }) - c.Abort() - return + if strings.HasPrefix(c.Request.URL.Path, "/mj") { + if modelRequest.Model == "" { + modelRequest.Model = "midjourney" + } + } else { + err := common.UnmarshalBodyReusable(c, &modelRequest) + if err != nil { + log.Println(err) + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "无效的请求", + "type": "one_api_error", + }, + }) + c.Abort() + return + } } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if modelRequest.Model == "" { modelRequest.Model = "text-moderation-stable" @@ -84,21 +95,51 @@ func Distribute() func(c *gin.Context) { modelRequest.Model = "dall-e" } } + isStable := false channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + c.Set("stable", false) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) - if channel != nil { - common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) - message = "数据库一致性已被破坏,请联系管理员" + if strings.HasPrefix(modelRequest.Model, "gpt-4") { + common.SysLog("GPT-4低价渠道宕机,正在尝试转换") + nowUser, err := model.GetUserById(userId, false) + if err == nil { + if nowUser.StableMode { + userGroup = "svip" + //stableRatio = (common.StablePrice / common.BasePrice) * modelRatio + userMaxPrice, _ := strconv.ParseFloat(nowUser.MaxPrice, 64) + if userMaxPrice < common.StablePrice { + message = "当前低价通道不可用,稳定渠道价格为" + strconv.FormatFloat(common.StablePrice, 'f', -1, 64) + "R/刀" + } else { + //common.SysLog(fmt.Sprintf("用户 %s 使用稳定渠道", nowUser.Username)) + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + if err != nil { + message = "稳定渠道已经宕机,请联系管理员" + } + isStable = true + common.SysLog(fmt.Sprintf("用户 %s 使用稳定渠道 %v", nowUser.Username, channel)) + c.Set("stable", true) + } + + } else { + message = "当前低价通道不可用,请稍后再试,或者在后台开启稳定渠道模式" + } + } + } + //if channel == nil { + // common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) + // message = "数据库一致性已被破坏,请联系管理员" + //} + if !isStable { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": message, + "type": "one_api_error", + }, + }) + c.Abort() + return } - c.JSON(http.StatusServiceUnavailable, gin.H{ - "error": gin.H{ - "message": message, - "type": "one_api_error", - }, - }) - c.Abort() - return } } c.Set("channel", channel.Type) diff --git a/model/log.go b/model/log.go index b0d6409a..e79c881f 100644 --- a/model/log.go +++ b/model/log.go @@ -3,6 +3,7 @@ package model import ( "gorm.io/gorm" "one-api/common" + "strings" ) type Log struct { @@ -17,6 +18,7 @@ type Log struct { Quota int `json:"quota" gorm:"default:0"` PromptTokens int `json:"prompt_tokens" gorm:"default:0"` CompletionTokens int `json:"completion_tokens" gorm:"default:0"` + TokenId int `json:"token_id" gorm:"default:0;index"` } const ( @@ -27,6 +29,11 @@ const ( LogTypeSystem ) +func GetLogByKey(key string) (logs []*Log, err error) { + err = DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.Split(key, "-")[1]).Find(&logs).Error + return logs, err +} + func RecordLog(userId int, logType int, content string) { if logType == LogTypeConsume && !common.LogConsumeEnabled { return @@ -44,7 +51,7 @@ func RecordLog(userId int, logType int, content string) { } } -func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { +func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int) { if !common.LogConsumeEnabled { return } @@ -59,6 +66,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN TokenName: tokenName, ModelName: modelName, Quota: quota, + TokenId: tokenId, } err := DB.Create(log).Error if err != nil { diff --git a/model/main.go b/model/main.go index 5bc5ce19..3c80426f 100644 --- a/model/main.go +++ b/model/main.go @@ -88,6 +88,14 @@ func InitDB() (err error) { if err != nil { return err } + err = db.AutoMigrate(&Midjourney{}) + if err != nil { + return err + } + err = db.AutoMigrate(&TopUp{}) + if err != nil { + return err + } common.SysLog("database migrated") err = createRootAccountIfNeed() return err diff --git a/model/midjourney.go b/model/midjourney.go new file mode 100644 index 00000000..bb723e8c --- /dev/null +++ b/model/midjourney.go @@ -0,0 +1,87 @@ +package model + +type Midjourney struct { + Id int `json:"id"` + Code int `json:"code"` + UserId int `json:"user_id" gorm:"index"` + Action string `json:"action"` + MjId string `json:"mj_id" gorm:"index"` + Prompt string `json:"prompt"` + PromptEn string `json:"prompt_en"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submit_time"` + StartTime int64 `json:"start_time"` + FinishTime int64 `json:"finish_time"` + ImageUrl string `json:"image_url"` + Status string `json:"status"` + Progress string `json:"progress"` + FailReason string `json:"fail_reason"` +} + +func GetAllUserTask(userId int, startIdx int, num int) []*Midjourney { + var tasks []*Midjourney + var err error + err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error + if err != nil { + return nil + } + return tasks +} + +func GetAllTasks(startIdx int, num int) []*Midjourney { + var tasks []*Midjourney + var err error + err = DB.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error + if err != nil { + return nil + } + return tasks +} + +func GetAllUnFinishTasks() []*Midjourney { + var tasks []*Midjourney + var err error + // get all tasks progress is not 100% + err = DB.Where("progress != ?", "100%").Find(&tasks).Error + if err != nil { + return nil + } + return tasks +} + +func GetByMJId(mjId string) *Midjourney { + var mj *Midjourney + var err error + err = DB.Where("mj_id = ?", mjId).First(&mj).Error + if err != nil { + return nil + } + return mj +} + +func GetMjByuId(id int) *Midjourney { + var mj *Midjourney + var err error + err = DB.Where("id = ?", id).First(&mj).Error + if err != nil { + return nil + } + return mj +} + +func UpdateProgress(id int, progress string) error { + return DB.Model(&Midjourney{}).Where("id = ?", id).Update("progress", progress).Error +} + +func (midjourney *Midjourney) Insert() error { + var err error + err = DB.Create(midjourney).Error + return err +} + +func (midjourney *Midjourney) Update() error { + var err error + err = DB.Save(midjourney).Error + return err +} diff --git a/model/option.go b/model/option.go index e7bc6806..aa4da949 100644 --- a/model/option.go +++ b/model/option.go @@ -69,6 +69,10 @@ func InitOptionMap() { common.OptionMap["ChatLink"] = common.ChatLink common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) + common.OptionMap["NormalPrice"] = strconv.FormatFloat(common.NormalPrice, 'f', -1, 64) + common.OptionMap["StablePrice"] = strconv.FormatFloat(common.StablePrice, 'f', -1, 64) + common.OptionMap["BasePrice"] = strconv.FormatFloat(common.BasePrice, 'f', -1, 64) + common.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() } @@ -207,6 +211,12 @@ func updateOptionMap(key string, value string) (err error) { common.TopUpLink = value case "ChatLink": common.ChatLink = value + case "NormalPrice": + common.NormalPrice, _ = strconv.ParseFloat(value, 64) + case "BasePrice": + common.BasePrice, _ = strconv.ParseFloat(value, 64) + case "StablePrice": + common.StablePrice, _ = strconv.ParseFloat(value, 64) case "ChannelDisableThreshold": common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) case "QuotaPerUnit": diff --git a/model/topup.go b/model/topup.go new file mode 100644 index 00000000..876cd230 --- /dev/null +++ b/model/topup.go @@ -0,0 +1,43 @@ +package model + +type TopUp struct { + Id int `json:"id"` + UserId int `json:"user_id" gorm:"index"` + Amount int `json:"amount"` + Money int `json:"money"` + TradeNo string `json:"trade_no"` + CreateTime int64 `json:"create_time"` + Status string `json:"status"` +} + +func (topUp *TopUp) Insert() error { + var err error + err = DB.Create(topUp).Error + return err +} + +func (topUp *TopUp) Update() error { + var err error + err = DB.Save(topUp).Error + return err +} + +func GetTopUpById(id int) *TopUp { + var topUp *TopUp + var err error + err = DB.Where("id = ?", id).First(&topUp).Error + if err != nil { + return nil + } + return topUp +} + +func GetTopUpByTradeNo(tradeNo string) *TopUp { + var topUp *TopUp + var err error + err = DB.Where("trade_no = ?", tradeNo).First(&topUp).Error + if err != nil { + return nil + } + return topUp +} diff --git a/model/user.go b/model/user.go index 7c771840..b96b955b 100644 --- a/model/user.go +++ b/model/user.go @@ -28,6 +28,8 @@ type User struct { Group string `json:"group" gorm:"type:varchar(32);default:'default'"` AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` + StableMode bool `json:"stable_mode" gorm:"type:tinyint;default:0;column:stable_mode"` + MaxPrice string `json:"max_price" gorm:"type:varchar(32);default:'7'"` } func GetMaxUserId() int { @@ -116,7 +118,14 @@ func (user *User) Update(updatePassword bool) error { return err } } - err = DB.Model(user).Updates(user).Error + newUser := *user + err = DB.Model(user).UpdateColumns(map[string]interface{}{ + "stable_mode": user.StableMode, + "max_price": user.MaxPrice, + }).Error + + DB.First(&user, user.Id) + err = DB.Model(user).Updates(newUser).Error return err } diff --git a/router/api-router.go b/router/api-router.go index 383133fa..eae6a7e6 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -16,6 +16,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/status", controller.GetStatus) apiRouter.GET("/notice", controller.GetNotice) apiRouter.GET("/about", controller.GetAbout) + apiRouter.GET("/midjourney", controller.GetMidjourney) apiRouter.GET("/home_page_content", controller.GetHomePageContent) apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) @@ -29,7 +30,9 @@ func SetApiRouter(router *gin.Engine) { { userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register) userRoute.POST("/login", middleware.CriticalRateLimit(), controller.Login) + //userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog) userRoute.GET("/logout", controller.Logout) + userRoute.GET("/epay/notify", controller.EpayNotify) selfRoute := userRoute.Group("/") selfRoute.Use(middleware.UserAuth()) @@ -40,6 +43,9 @@ func SetApiRouter(router *gin.Engine) { selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/aff", controller.GetAffCode) selfRoute.POST("/topup", controller.TopUp) + selfRoute.POST("/pay", controller.RequestEpay) + selfRoute.POST("/amount", controller.RequestAmount) + selfRoute.POST("/set_stable_mode", controller.SetTableMode) } adminRoute := userRoute.Group("/") @@ -102,10 +108,14 @@ func SetApiRouter(router *gin.Engine) { logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs) logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs) + logRoute.GET("/token", controller.GetLogByKey) groupRoute := apiRouter.Group("/group") groupRoute.Use(middleware.AdminAuth()) { groupRoute.GET("/", controller.GetGroups) } + mjRoute := apiRouter.Group("/mj") + mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney) + mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney) } } diff --git a/router/relay-router.go b/router/relay-router.go index c3c84d8b..ca8e1cc8 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -41,4 +41,12 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) relayV1Router.POST("/moderations", controller.Relay) } + relayMjRouter := router.Group("/mj") + relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute()) + { + relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney) + relayMjRouter.POST("/submit/change", controller.RelayMidjourney) + relayMjRouter.POST("/notify", controller.RelayMidjourney) + relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney) + } } diff --git a/web/src/App.js b/web/src/App.js index c967ce2c..422b1522 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -1,20 +1,20 @@ -import React, { lazy, Suspense, useContext, useEffect } from 'react'; -import { Route, Routes } from 'react-router-dom'; +import React, {lazy, Suspense, useContext, useEffect} from 'react'; +import {Route, Routes} from 'react-router-dom'; import Loading from './components/Loading'; import User from './pages/User'; -import { PrivateRoute } from './components/PrivateRoute'; +import {PrivateRoute} from './components/PrivateRoute'; import RegisterForm from './components/RegisterForm'; import LoginForm from './components/LoginForm'; import NotFound from './pages/NotFound'; import Setting from './pages/Setting'; import EditUser from './pages/User/EditUser'; import AddUser from './pages/User/AddUser'; -import { API, getLogo, getSystemName, showError, showNotice } from './helpers'; +import {API, getLogo, getSystemName, showError, showNotice} from './helpers'; import PasswordResetForm from './components/PasswordResetForm'; import GitHubOAuth from './components/GitHubOAuth'; import PasswordResetConfirm from './components/PasswordResetConfirm'; -import { UserContext } from './context/User'; -import { StatusContext } from './context/Status'; +import {UserContext} from './context/User'; +import {StatusContext} from './context/Status'; import Channel from './pages/Channel'; import Token from './pages/Token'; import EditToken from './pages/Token/EditToken'; @@ -24,268 +24,295 @@ import EditRedemption from './pages/Redemption/EditRedemption'; import TopUp from './pages/TopUp'; import Log from './pages/Log'; import Chat from './pages/Chat'; +import Midjourney from './pages/Midjourney'; const Home = lazy(() => import('./pages/Home')); const About = lazy(() => import('./pages/About')); function App() { - const [userState, userDispatch] = useContext(UserContext); - const [statusState, statusDispatch] = useContext(StatusContext); + const [userState, userDispatch] = useContext(UserContext); + const [statusState, statusDispatch] = useContext(StatusContext); - const loadUser = () => { - let user = localStorage.getItem('user'); - if (user) { - let data = JSON.parse(user); - userDispatch({ type: 'login', payload: data }); - } - }; - const loadStatus = async () => { - const res = await API.get('/api/status'); - const { success, data } = res.data; - if (success) { - localStorage.setItem('status', JSON.stringify(data)); - statusDispatch({ type: 'set', payload: data }); - localStorage.setItem('system_name', data.system_name); - localStorage.setItem('logo', data.logo); - localStorage.setItem('footer_html', data.footer_html); - localStorage.setItem('quota_per_unit', data.quota_per_unit); - localStorage.setItem('display_in_currency', data.display_in_currency); - if (data.chat_link) { - localStorage.setItem('chat_link', data.chat_link); - } else { - localStorage.removeItem('chat_link'); - } - if ( - data.version !== process.env.REACT_APP_VERSION && - data.version !== 'v0.0.0' && - process.env.REACT_APP_VERSION !== '' - ) { - showNotice( - `新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面` - ); - } - } else { - showError('无法正常连接至服务器!'); - } - }; + const loadUser = () => { + let user = localStorage.getItem('user'); + if (user) { + let data = JSON.parse(user); + userDispatch({type: 'login', payload: data}); + } + }; + const loadStatus = async () => { + const res = await API.get('/api/status'); + const {success, data} = res.data; + if (success) { + localStorage.setItem('status', JSON.stringify(data)); + statusDispatch({type: 'set', payload: data}); + localStorage.setItem('system_name', data.system_name); + localStorage.setItem('logo', data.logo); + localStorage.setItem('footer_html', data.footer_html); + localStorage.setItem('quota_per_unit', data.quota_per_unit); + localStorage.setItem('display_in_currency', data.display_in_currency); + if (data.chat_link) { + localStorage.setItem('chat_link', data.chat_link); + } else { + localStorage.removeItem('chat_link'); + } + if ( + data.version !== process.env.REACT_APP_VERSION && + data.version !== 'v0.0.0' && + process.env.REACT_APP_VERSION !== '' + ) { + showNotice( + `新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面` + ); + } + } else { + showError('无法正常连接至服务器!'); + } + }; - useEffect(() => { - loadUser(); - loadStatus().then(); - let systemName = getSystemName(); - if (systemName) { - document.title = systemName; - } - let logo = getLogo(); - if (logo) { - let linkElement = document.querySelector("link[rel~='icon']"); - if (linkElement) { - linkElement.href = logo; - } - } - }, []); + // const getOptions = async () => { + // const res = await API.get('/api/option/'); + // const {success, message, data} = res.data; + // if (success) { + // let newInputs = {}; + // data.forEach((item) => { + // if (item.key === 'ModelRatio' || item.key === 'GroupRatio') { + // item.value = JSON.stringify(JSON.parse(item.value), null, 2); + // } + // newInputs[item.key] = item.value; + // }); + // setInputs(newInputs); + // setOriginInputs(newInputs); + // } else { + // showError(message); + // } + // }; - return ( - - }> - - + useEffect(() => { + loadUser(); + loadStatus().then(); + let systemName = getSystemName(); + if (systemName) { + document.title = systemName; } - /> - - - + let logo = getLogo(); + if (logo) { + let linkElement = document.querySelector("link[rel~='icon']"); + if (linkElement) { + linkElement.href = logo; + } } - /> - }> - - - } - /> - }> - - - } - /> - - - - } - /> - }> - - - } - /> - }> - - - } - /> - - - - } - /> - }> - - - } - /> - }> - - - } - /> - - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - - }> - - - - } - /> - - }> - - - - } - /> - - - - } - /> - }> - - - } - /> - }> - - - } - /> - - - ); + }, []); + + return ( + + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + + }> + + + + } + /> + + }> + + + + } + /> + + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + ); } export default App; diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 4ea6965d..aa6d1e25 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -4,7 +4,7 @@ import { Link } from 'react-router-dom'; import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; -import { renderGroup, renderNumber } from '../helpers/render'; +import {renderGroup, renderNumber, renderQuota} from '../helpers/render'; function renderTimestamp(timestamp) { return ( @@ -299,6 +299,7 @@ const ChannelsTable = () => { onClick={() => { sortChannel('group'); }} + width={1} > 分组 @@ -307,6 +308,7 @@ const ChannelsTable = () => { onClick={() => { sortChannel('type'); }} + width={2} > 类型 @@ -315,6 +317,7 @@ const ChannelsTable = () => { onClick={() => { sortChannel('status'); }} + width={2} > 状态 @@ -326,6 +329,15 @@ const ChannelsTable = () => { > 响应时间 + { + sortChannel('used_quota'); + }} + width={1} + > + 已使用 + { @@ -361,6 +373,7 @@ const ChannelsTable = () => { basic /> + {renderQuota(channel.used_quota)} + {timestamp2string(timestamp)} + + ); +} + +const MODE_OPTIONS = [ + { key: 'all', text: '全部用户', value: 'all' }, + { key: 'self', text: '当前用户', value: 'self' } +]; + +const LOG_OPTIONS = [ + { key: '0', text: '全部', value: 0 }, + // { key: '1', text: '绘图', value: 1 }, + // { key: '2', text: '放大', value: 2 }, + // { key: '3', text: '变换', value: 3 }, + // { key: '4', text: '图生文', value: 4 }, + // { key: '5', text: '图片混合', value: 5 } +]; + +function renderType(type) { + switch (type) { + case 'IMAGINE': + return ; + case 'UPSCALE': + return ; + case 'VARIATION': + return ; + case 'DESCRIBE': + return ; + case 'BLEAND': + return ; + default: + return ; + } +} + +function renderCode(type) { + switch (type) { + case 1: + return ; + case 21: + return ; + case 22: + return ; + default: + return ; + } +} + +function renderStatus(type) { + switch (type) { + case 'SUCCESS': + return ; + case 'NOT_START': + return ; + case 'SUBMITTED': + return ; + case 'IN_PROGRESS': + return ; + case 'FAILURE': + return ; + default: + return ; + } +} + +const LogsTable = () => { + const [logs, setLogs] = useState([ + + ]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + const [logType, setLogType] = useState(0); + const isAdminUser = isAdmin(); + let now = new Date(); + const [inputs, setInputs] = useState({ + username: '', + token_name: '', + model_name: '', + start_timestamp: timestamp2string(0), + end_timestamp: timestamp2string(now.getTime() / 1000 + 3600) + }); + const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs; + + const [stat, setStat] = useState({ + quota: 0, + token: 0 + }); + + const handleInputChange = (e, { name, value }) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + const getLogSelfStat = async () => { + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); + const { success, message, data } = res.data; + if (success) { + setStat(data); + } else { + showError(message); + } + }; + + const getLogStat = async () => { + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); + const { success, message, data } = res.data; + if (success) { + setStat(data); + } else { + showError(message); + } + }; + + const loadLogs = async (startIdx) => { + let url = ''; + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + if (isAdminUser) { + url = `/api/mj/?p=${startIdx}&username=${username}&token_name=${token_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } else { + url = `/api/mj/self/?p=${startIdx}&token_name=${token_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } + const res = await API.get(url); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setLogs(data); + } else { + let newLogs = [...logs]; + newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); + setLogs(newLogs); + } + } else { + showError(message); + } + setLoading(false); + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + await loadLogs(activePage - 1); + } + setActivePage(activePage); + })(); + }; + + const refresh = async () => { + setLoading(true); + setActivePage(1) + await loadLogs(0); + // if (isAdminUser) { + // getLogStat().then(); + // } else { + // getLogSelfStat().then(); + // } + }; + + useEffect(() => { + refresh().then(); + }, [logType]); + + const searchLogs = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadLogs(0); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/log/self/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setLogs(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (e, { value }) => { + setSearchKeyword(value.trim()); + }; + + const sortLog = (key) => { + if (logs.length === 0) return; + setLoading(true); + let sortedLogs = [...logs]; + if (typeof sortedLogs[0][key] === 'string'){ + sortedLogs.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + } else { + sortedLogs.sort((a, b) => { + if (a[key] === b[key]) return 0; + if (a[key] > b[key]) return -1; + if (a[key] < b[key]) return 1; + }); + } + if (sortedLogs[0].id === logs[0].id) { + sortedLogs.reverse(); + } + setLogs(sortedLogs); + setLoading(false); + }; + + return ( + <> + + + + + { + sortLog('submit_time'); + }} + width={2} + > + 提交时间 + + { + sortLog('action'); + }} + width={1} + > + 类型 + + { + sortLog('mj_id'); + }} + width={2} + > + 任务ID + + { + sortLog('code'); + }} + width={1} + > + 提交结果 + + { + sortLog('status'); + }} + width={1} + > + 任务状态 + + { + sortLog('progress'); + }} + width={1} + > + 进度 + + { + sortLog('image_url'); + }} + width={1} + > + 结果图片 + + { + sortLog('prompt'); + }} + width={3} + > + Prompt + + { + sortLog('fail_reason'); + }} + width={1} + > + 失败原因 + + + + + + {logs + .slice( + (activePage - 1) * ITEMS_PER_PAGE, + activePage * ITEMS_PER_PAGE + ) + .map((log, idx) => { + if (log.deleted) return <>; + return ( + + {renderTimestamp(log.submit_time/1000)} + {/*{*/} + {/* isAdminUser && (*/} + {/* {log.username ? : ''}*/} + {/* )*/} + {/*}*/} + {renderType(log.action)} + {log.mj_id} + {renderCode(log.code)} + {renderStatus(log.status)} + {log.progress ? : ''} + + { + log.image_url ? ( + 点击查看 + ) : '暂未生成图片' + } + + {log.prompt} + {log.fail_reason ? log.fail_reason : '无'} + + ); + })} + + + + + +
+
+ + ); +}; + +export default LogsTable; diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index 2adc7fa4..4822f69e 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -1,331 +1,372 @@ -import React, { useEffect, useState } from 'react'; -import { Divider, Form, Grid, Header } from 'semantic-ui-react'; -import { API, showError, verifyJSON } from '../helpers'; +import React, {useEffect, useState} from 'react'; +import {Divider, Form, Grid, Header} from 'semantic-ui-react'; +import {API, showError, verifyJSON} from '../helpers'; const OperationSetting = () => { - let [inputs, setInputs] = useState({ - QuotaForNewUser: 0, - QuotaForInviter: 0, - QuotaForInvitee: 0, - QuotaRemindThreshold: 0, - PreConsumedQuota: 0, - ModelRatio: '', - GroupRatio: '', - TopUpLink: '', - ChatLink: '', - QuotaPerUnit: 0, - AutomaticDisableChannelEnabled: '', - ChannelDisableThreshold: 0, - LogConsumeEnabled: '', - DisplayInCurrencyEnabled: '', - DisplayTokenStatEnabled: '', - ApproximateTokenEnabled: '', - RetryTimes: 0, - }); - const [originInputs, setOriginInputs] = useState({}); - let [loading, setLoading] = useState(false); - - const getOptions = async () => { - const res = await API.get('/api/option/'); - const { success, message, data } = res.data; - if (success) { - let newInputs = {}; - data.forEach((item) => { - if (item.key === 'ModelRatio' || item.key === 'GroupRatio') { - item.value = JSON.stringify(JSON.parse(item.value), null, 2); - } - newInputs[item.key] = item.value; - }); - setInputs(newInputs); - setOriginInputs(newInputs); - } else { - showError(message); - } - }; - - useEffect(() => { - getOptions().then(); - }, []); - - const updateOption = async (key, value) => { - setLoading(true); - if (key.endsWith('Enabled')) { - value = inputs[key] === 'true' ? 'false' : 'true'; - } - const res = await API.put('/api/option/', { - key, - value + let [inputs, setInputs] = useState({ + QuotaForNewUser: 0, + QuotaForInviter: 0, + QuotaForInvitee: 0, + QuotaRemindThreshold: 0, + PreConsumedQuota: 0, + ModelRatio: '', + GroupRatio: '', + TopUpLink: '', + ChatLink: '', + QuotaPerUnit: 0, + AutomaticDisableChannelEnabled: '', + ChannelDisableThreshold: 0, + LogConsumeEnabled: '', + DisplayInCurrencyEnabled: '', + DisplayTokenStatEnabled: '', + ApproximateTokenEnabled: '', + RetryTimes: 0, + StablePrice: 6, + NormalPrice: 1.5, + BasePrice: 1.5, }); - const { success, message } = res.data; - if (success) { - setInputs((inputs) => ({ ...inputs, [key]: value })); - } else { - showError(message); - } - setLoading(false); - }; + const [originInputs, setOriginInputs] = useState({}); + let [loading, setLoading] = useState(false); - const handleInputChange = async (e, { name, value }) => { - if (name.endsWith('Enabled')) { - await updateOption(name, value); - } else { - setInputs((inputs) => ({ ...inputs, [name]: value })); - } - }; + const getOptions = async () => { + const res = await API.get('/api/option/'); + const {success, message, data} = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key === 'ModelRatio' || item.key === 'GroupRatio') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + newInputs[item.key] = item.value; + }); + setInputs(newInputs); + setOriginInputs(newInputs); + } else { + showError(message); + } + }; - const submitConfig = async (group) => { - switch (group) { - case 'monitor': - if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) { - await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold); - } - if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) { - await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold); - } - break; - case 'ratio': - if (originInputs['ModelRatio'] !== inputs.ModelRatio) { - if (!verifyJSON(inputs.ModelRatio)) { - showError('模型倍率不是合法的 JSON 字符串'); - return; - } - await updateOption('ModelRatio', inputs.ModelRatio); - } - if (originInputs['GroupRatio'] !== inputs.GroupRatio) { - if (!verifyJSON(inputs.GroupRatio)) { - showError('分组倍率不是合法的 JSON 字符串'); - return; - } - await updateOption('GroupRatio', inputs.GroupRatio); - } - break; - case 'quota': - if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) { - await updateOption('QuotaForNewUser', inputs.QuotaForNewUser); - } - if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) { - await updateOption('QuotaForInvitee', inputs.QuotaForInvitee); - } - if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) { - await updateOption('QuotaForInviter', inputs.QuotaForInviter); - } - if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) { - await updateOption('PreConsumedQuota', inputs.PreConsumedQuota); - } - break; - case 'general': - if (originInputs['TopUpLink'] !== inputs.TopUpLink) { - await updateOption('TopUpLink', inputs.TopUpLink); - } - if (originInputs['ChatLink'] !== inputs.ChatLink) { - await updateOption('ChatLink', inputs.ChatLink); - } - if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) { - await updateOption('QuotaPerUnit', inputs.QuotaPerUnit); - } - if (originInputs['RetryTimes'] !== inputs.RetryTimes) { - await updateOption('RetryTimes', inputs.RetryTimes); - } - break; - } - }; + useEffect(() => { + getOptions().then(); + }, []); - return ( - - -
-
- 通用设置 -
- - - - - - - - - - - - - { - submitConfig('general').then(); - }}>保存通用设置 - -
- 监控设置 -
- - - - - - - - { - submitConfig('monitor').then(); - }}>保存监控设置 - -
- 额度设置 -
- - - - - - - { - submitConfig('quota').then(); - }}>保存额度设置 - -
- 倍率设置 -
- - - - - - - { - submitConfig('ratio').then(); - }}>保存倍率设置 - -
-
- ); + const updateOption = async (key, value) => { + setLoading(true); + if (key.endsWith('Enabled')) { + value = inputs[key] === 'true' ? 'false' : 'true'; + } + const res = await API.put('/api/option/', { + key, + value + }); + const {success, message} = res.data; + if (success) { + setInputs((inputs) => ({...inputs, [key]: value})); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, {name, value}) => { + if (name.endsWith('Enabled')) { + await updateOption(name, value); + } else { + setInputs((inputs) => ({...inputs, [name]: value})); + } + }; + + const submitConfig = async (group) => { + switch (group) { + case 'monitor': + if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) { + await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold); + } + if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) { + await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold); + } + break; + case 'stable': + await updateOption('StablePrice', inputs.StablePrice); + await updateOption('NormalPrice', inputs.NormalPrice); + await updateOption('BasePrice', inputs.BasePrice); + localStorage.setItem('stable_price', inputs.StablePrice); + localStorage.setItem('normal_price', inputs.NormalPrice); + localStorage.setItem('base_price', inputs.BasePrice); + break; + case 'ratio': + if (originInputs['ModelRatio'] !== inputs.ModelRatio) { + if (!verifyJSON(inputs.ModelRatio)) { + showError('模型倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('ModelRatio', inputs.ModelRatio); + } + if (originInputs['GroupRatio'] !== inputs.GroupRatio) { + if (!verifyJSON(inputs.GroupRatio)) { + showError('分组倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('GroupRatio', inputs.GroupRatio); + } + break; + case 'quota': + if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) { + await updateOption('QuotaForNewUser', inputs.QuotaForNewUser); + } + if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) { + await updateOption('QuotaForInvitee', inputs.QuotaForInvitee); + } + if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) { + await updateOption('QuotaForInviter', inputs.QuotaForInviter); + } + if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) { + await updateOption('PreConsumedQuota', inputs.PreConsumedQuota); + } + break; + case 'general': + if (originInputs['TopUpLink'] !== inputs.TopUpLink) { + await updateOption('TopUpLink', inputs.TopUpLink); + } + if (originInputs['ChatLink'] !== inputs.ChatLink) { + await updateOption('ChatLink', inputs.ChatLink); + } + if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) { + await updateOption('QuotaPerUnit', inputs.QuotaPerUnit); + } + if (originInputs['RetryTimes'] !== inputs.RetryTimes) { + await updateOption('RetryTimes', inputs.RetryTimes); + } + break; + } + }; + + return ( + + +
+
+ 通用设置 +
+ + + + + + + + + + + + + { + submitConfig('general').then(); + }}>保存通用设置 + +
+ 监控设置 +
+ + + + + + + + { + submitConfig('monitor').then(); + }}>保存监控设置 + +
+ 通道设置 +
+ + + + + { + submitConfig('stable').then(); + }}>保存通道设置 + +
+ 额度设置 +
+ + + + + + + { + submitConfig('quota').then(); + }}>保存额度设置 + +
+ 倍率设置 +
+ + + + + + + { + submitConfig('ratio').then(); + }}>保存倍率设置 + +
+
+ ) + ; }; export default OperationSetting; diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index 108655d2..2379774a 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -1,323 +1,410 @@ -import React, { useContext, useEffect, useState } from 'react'; -import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react'; -import { Link, useNavigate } from 'react-router-dom'; -import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; +import React, {useContext, useEffect, useState} from 'react'; +import {Button, Checkbox, Divider, Form, Header, Image, Input, Message, Modal} from 'semantic-ui-react'; +import {Link, useNavigate} from 'react-router-dom'; +import {API, copy, showError, showInfo, showNotice, showSuccess} from '../helpers'; import Turnstile from 'react-turnstile'; -import { UserContext } from '../context/User'; +import {UserContext} from '../context/User'; const PersonalSetting = () => { - const [userState, userDispatch] = useContext(UserContext); - let navigate = useNavigate(); + const [userState, userDispatch] = useContext(UserContext); + let navigate = useNavigate(); - const [inputs, setInputs] = useState({ - wechat_verification_code: '', - email_verification_code: '', - email: '', - self_account_deletion_confirmation: '' - }); - const [status, setStatus] = useState({}); - const [showWeChatBindModal, setShowWeChatBindModal] = useState(false); - const [showEmailBindModal, setShowEmailBindModal] = useState(false); - const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false); - const [turnstileEnabled, setTurnstileEnabled] = useState(false); - const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); - const [turnstileToken, setTurnstileToken] = useState(''); - const [loading, setLoading] = useState(false); - const [disableButton, setDisableButton] = useState(false); - const [countdown, setCountdown] = useState(30); + const [inputs, setInputs] = useState({ + wechat_verification_code: '', + email_verification_code: '', + email: '', + self_account_deletion_confirmation: '' + }); + const [stableMode, setStableMode] = useState({ + stableMode: false, + maxPrice: 7, + }); + const [status, setStatus] = useState({}); + const [showWeChatBindModal, setShowWeChatBindModal] = useState(false); + const [showEmailBindModal, setShowEmailBindModal] = useState(false); + const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + const [loading, setLoading] = useState(false); + const [disableButton, setDisableButton] = useState(false); + const [countdown, setCountdown] = useState(30); - useEffect(() => { - let status = localStorage.getItem('status'); - if (status) { - status = JSON.parse(status); - setStatus(status); - if (status.turnstile_check) { - setTurnstileEnabled(true); - setTurnstileSiteKey(status.turnstile_site_key); - } - } - }, []); + // setStableMode(userState.user.stableMode, userState.user.maxPrice); + console.log(userState.user) - useEffect(() => { - let countdownInterval = null; - if (disableButton && countdown > 0) { - countdownInterval = setInterval(() => { - setCountdown(countdown - 1); - }, 1000); - } else if (countdown === 0) { - setDisableButton(false); - setCountdown(30); - } - return () => clearInterval(countdownInterval); // Clean up on unmount - }, [disableButton, countdown]); - const handleInputChange = (e, { name, value }) => { - setInputs((inputs) => ({ ...inputs, [name]: value })); - }; + useEffect(() => { + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + setStatus(status); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + } + // if (userState.user !== undefined) { + // setStableMode(userState.user.stable_mode, userState.user.max_price); + // } + }, []); - const generateAccessToken = async () => { - const res = await API.get('/api/user/token'); - const { success, message, data } = res.data; - if (success) { - await copy(data); - showSuccess(`令牌已重置并已复制到剪贴板:${data}`); - } else { - showError(message); - } - }; + useEffect(() => { + let countdownInterval = null; + if (disableButton && countdown > 0) { + countdownInterval = setInterval(() => { + setCountdown(countdown - 1); + }, 1000); + } else if (countdown === 0) { + setDisableButton(false); + setCountdown(30); + } + return () => clearInterval(countdownInterval); // Clean up on unmount + }, [disableButton, countdown]); - const getAffLink = async () => { - const res = await API.get('/api/user/aff'); - const { success, message, data } = res.data; - if (success) { - let link = `${window.location.origin}/register?aff=${data}`; - await copy(link); - showNotice(`邀请链接已复制到剪切板:${link}`); - } else { - showError(message); - } - }; + useEffect(() => { + if (userState.user !== undefined) { + setStableMode({ + stableMode: userState.user.stable_mode, + maxPrice: userState.user.max_price + }) + // if (stableMode.localMaxPrice !== userState.user.max_price) { + // setStableMode({ + // localMaxPrice: userState.user.max_price + // }) + // } + } + }, [userState]); - const deleteAccount = async () => { - if (inputs.self_account_deletion_confirmation !== userState.user.username) { - showError('请输入你的账户名以确认删除!'); - return; - } + const handleInputChange = (e, {name, value}) => { + setInputs((inputs) => ({...inputs, [name]: value})); + }; - const res = await API.delete('/api/user/self'); - const { success, message } = res.data; + const generateAccessToken = async () => { + const res = await API.get('/api/user/token'); + const {success, message, data} = res.data; + if (success) { + await copy(data); + showSuccess(`令牌已重置并已复制到剪贴板:${data}`); + } else { + showError(message); + } + }; - if (success) { - showSuccess('账户已删除!'); - await API.get('/api/user/logout'); - userDispatch({ type: 'logout' }); - localStorage.removeItem('user'); - navigate('/login'); - } else { - showError(message); - } - }; + const getAffLink = async () => { + const res = await API.get('/api/user/aff'); + const {success, message, data} = res.data; + if (success) { + let link = `${window.location.origin}/register?aff=${data}`; + await copy(link); + showNotice(`邀请链接已复制到剪切板:${link}`); + } else { + showError(message); + } + }; - const bindWeChat = async () => { - if (inputs.wechat_verification_code === '') return; - const res = await API.get( - `/api/oauth/wechat/bind?code=${inputs.wechat_verification_code}` - ); - const { success, message } = res.data; - if (success) { - showSuccess('微信账户绑定成功!'); - setShowWeChatBindModal(false); - } else { - showError(message); - } - }; + const deleteAccount = async () => { + if (inputs.self_account_deletion_confirmation !== userState.user.username) { + showError('请输入你的账户名以确认删除!'); + return; + } - const openGitHubOAuth = () => { - window.open( - `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` - ); - }; + const res = await API.delete('/api/user/self'); + const {success, message} = res.data; - const sendVerificationCode = async () => { - setDisableButton(true); - if (inputs.email === '') return; - if (turnstileEnabled && turnstileToken === '') { - showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); - return; - } - setLoading(true); - const res = await API.get( - `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}` - ); - const { success, message } = res.data; - if (success) { - showSuccess('验证码发送成功,请检查邮箱!'); - } else { - showError(message); - } - setLoading(false); - }; + if (success) { + showSuccess('账户已删除!'); + await API.get('/api/user/logout'); + userDispatch({type: 'logout'}); + localStorage.removeItem('user'); + navigate('/login'); + } else { + showError(message); + } + }; - const bindEmail = async () => { - if (inputs.email_verification_code === '') return; - setLoading(true); - const res = await API.get( - `/api/oauth/email/bind?email=${inputs.email}&code=${inputs.email_verification_code}` - ); - const { success, message } = res.data; - if (success) { - showSuccess('邮箱账户绑定成功!'); - setShowEmailBindModal(false); - } else { - showError(message); - } - setLoading(false); - }; + const bindWeChat = async () => { + if (inputs.wechat_verification_code === '') return; + const res = await API.get( + `/api/oauth/wechat/bind?code=${inputs.wechat_verification_code}` + ); + const {success, message} = res.data; + if (success) { + showSuccess('微信账户绑定成功!'); + setShowWeChatBindModal(false); + } else { + showError(message); + } + }; - return ( -
-
通用设置
- - 注意,此处生成的令牌用于系统管理,而非用于请求 OpenAI 相关的服务,请知悉。 - - - - - - -
账号绑定
- { - status.wechat_login && ( - - ) - } - setShowWeChatBindModal(false)} - onOpen={() => setShowWeChatBindModal(true)} - open={showWeChatBindModal} - size={'mini'} - > - - - -
-

- 微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) -

-
-
- - + const openGitHubOAuth = () => { + window.open( + `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` + ); + }; + + const sendVerificationCode = async () => { + setDisableButton(true); + if (inputs.email === '') return; + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}` + ); + const {success, message} = res.data; + if (success) { + showSuccess('验证码发送成功,请检查邮箱!'); + } else { + showError(message); + } + setLoading(false); + }; + + const bindEmail = async () => { + if (inputs.email_verification_code === '') return; + setLoading(true); + const res = await API.get( + `/api/oauth/email/bind?email=${inputs.email}&code=${inputs.email_verification_code}` + ); + const {success, message} = res.data; + if (success) { + showSuccess('邮箱账户绑定成功!'); + setShowEmailBindModal(false); + } else { + showError(message); + } + setLoading(false); + }; + + // const setStableMod = ; + + return ( +
+
通用设置
+ + 注意,此处生成的令牌用于系统管理,而非用于请求 OpenAI 相关的服务,请知悉。 + + + + + + +
GPT-4消费设置
+ + + { + setStableMode({ + ...stableMode, + stableMode: data.checked + }) + } + } + > + + { + setStableMode({ + ...stableMode, + maxPrice: data.value + }) + } + } + > + {/**/} + {/**/} + + - - - - { - status.github_oauth && ( - - ) - } - - setShowEmailBindModal(false)} - onOpen={() => setShowEmailBindModal(true)} - open={showEmailBindModal} - size={'tiny'} - style={{ maxWidth: '450px' }} - > - 绑定邮箱地址 - - -
- - {disableButton ? `重新发送(${countdown})` : '获取验证码'} - - } - /> - - {turnstileEnabled ? ( - { - setTurnstileToken(token); - }} - /> - ) : ( - <> - )} - - -
-
-
- setShowAccountDeleteModal(false)} - onOpen={() => setShowAccountDeleteModal(true)} - open={showAccountDeleteModal} - size={'tiny'} - style={{ maxWidth: '450px' }} - > - 确认删除自己的帐户 - - -
- - {turnstileEnabled ? ( - { - setTurnstileToken(token); - }} - /> - ) : ( - <> - )} - - -
-
-
-
- ); + {/* {*/} + {/* // if (inputs.email_verification_code === '') return;*/} + {/* console.log(data)*/} + {/* }*/} + {/*}>*/} + {/**/} + +
账号绑定
+ { + status.wechat_login && ( + + ) + } + setShowWeChatBindModal(false)} + onOpen={() => setShowWeChatBindModal(true)} + open={showWeChatBindModal} + size={'mini'} + > + + + +
+

+ 微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) +

+
+
+ + + +
+
+
+ { + status.github_oauth && ( + + ) + } + + setShowEmailBindModal(false)} + onOpen={() => setShowEmailBindModal(true)} + open={showEmailBindModal} + size={'tiny'} + style={{maxWidth: '450px'}} + > + 绑定邮箱地址 + + +
+ + {disableButton ? `重新发送(${countdown})` : '获取验证码'} + + } + /> + + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} + + +
+
+
+ setShowAccountDeleteModal(false)} + onOpen={() => setShowAccountDeleteModal(true)} + open={showAccountDeleteModal} + size={'tiny'} + style={{maxWidth: '450px'}} + > + 确认删除自己的帐户 + + +
+ + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} + + +
+
+
+
+ ); }; export default PersonalSetting; diff --git a/web/src/components/TokensTable.js b/web/src/components/TokensTable.js index b42f7df8..7c62e025 100644 --- a/web/src/components/TokensTable.js +++ b/web/src/components/TokensTable.js @@ -235,21 +235,27 @@ const TokensTable = () => { {token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}
- + { + let key = "sk-" + token.key; + if (await copy(key)) { + showSuccess('已复制到剪贴板!'); + } else { + showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。'); + setSearchKeyword(key); + } + }} + > + 复制 + + } + on={'hover'} + content={"sk-" + token.key} + /> diff --git a/web/src/pages/Home/index.js b/web/src/pages/Home/index.js index 20d42104..67944f91 100644 --- a/web/src/pages/Home/index.js +++ b/web/src/pages/Home/index.js @@ -1,132 +1,121 @@ -import React, { useContext, useEffect, useState } from 'react'; -import { Card, Grid, Header, Segment } from 'semantic-ui-react'; -import { API, showError, showNotice, timestamp2string } from '../../helpers'; -import { StatusContext } from '../../context/Status'; -import { marked } from 'marked'; +import React, {useContext, useEffect, useState} from 'react'; +import {Card, Grid, Header, Segment} from 'semantic-ui-react'; +import {API, showError, showNotice, timestamp2string} from '../../helpers'; +import {StatusContext} from '../../context/Status'; +import {marked} from 'marked'; const Home = () => { - const [statusState, statusDispatch] = useContext(StatusContext); - const [homePageContentLoaded, setHomePageContentLoaded] = useState(false); - const [homePageContent, setHomePageContent] = useState(''); + const [statusState, statusDispatch] = useContext(StatusContext); + const [homePageContentLoaded, setHomePageContentLoaded] = useState(false); + const [homePageContent, setHomePageContent] = useState(''); - const displayNotice = async () => { - const res = await API.get('/api/notice'); - const { success, message, data } = res.data; - if (success) { - let oldNotice = localStorage.getItem('notice'); - if (data !== oldNotice && data !== '') { - showNotice(data); - localStorage.setItem('notice', data); - } - } else { - showError(message); - } - }; + const displayNotice = async () => { + const res = await API.get('/api/notice'); + const {success, message, data} = res.data; + if (success) { + let oldNotice = localStorage.getItem('notice'); + if (data !== oldNotice && data !== '') { + showNotice(data); + localStorage.setItem('notice', data); + } + } else { + showError(message); + } + }; - const displayHomePageContent = async () => { - setHomePageContent(localStorage.getItem('home_page_content') || ''); - const res = await API.get('/api/home_page_content'); - const { success, message, data } = res.data; - if (success) { - let content = data; - if (!data.startsWith('https://')) { - content = marked.parse(data); - } - setHomePageContent(content); - localStorage.setItem('home_page_content', content); - } else { - showError(message); - setHomePageContent('加载首页内容失败...'); - } - setHomePageContentLoaded(true); - }; + const displayHomePageContent = async () => { + setHomePageContent(localStorage.getItem('home_page_content') || ''); + const res = await API.get('/api/home_page_content'); + const {success, message, data} = res.data; + if (success) { + let content = data; + if (!data.startsWith('https://')) { + content = marked.parse(data); + } + setHomePageContent(content); + localStorage.setItem('home_page_content', content); + } else { + showError(message); + setHomePageContent('加载首页内容失败...'); + } + setHomePageContentLoaded(true); + }; - const getStartTimeString = () => { - const timestamp = statusState?.status?.start_time; - return timestamp2string(timestamp); - }; + const getStartTimeString = () => { + const timestamp = statusState?.status?.start_time; + return timestamp2string(timestamp); + }; + + useEffect(() => { + displayNotice().then(); + displayHomePageContent().then(); + }, []); + return ( + <> + { + // homePageContentLoaded && homePageContent === '' ? + <> + +
当前状态
+ + + + + GPT-3.5 + 信息总览 + +

通道:官方通道

+

状态:存活

+

价格:{statusState?.status?.base_price}R / 刀

+
+
+
+
+ + + + GPT-4 + 信息总览 + +

通道:官方通道|低价通道

+

+ 状态: + {statusState?.status?.stable_price===-1? + 不   可   用 + : + 可  用 + } + | + {statusState?.status?.normal_price===-1? + 不   可   用 + : + 可  用 + } +

+

+ 价格:{statusState?.status?.stable_price}R / 刀|{statusState?.status?.normal_price}R / 刀 +

+
+
+
+
+
+
+ { + homePageContent.startsWith('https://') ?