diff --git a/.github/workflows/docker-image-amd64-en.yml b/.github/workflows/docker-image-amd64-en.yml index 44dc0bc0..af488256 100644 --- a/.github/workflows/docker-image-amd64-en.yml +++ b/.github/workflows/docker-image-amd64-en.yml @@ -20,6 +20,13 @@ jobs: - name: Check out the repo uses: actions/checkout@v3 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi + - name: Save version info run: | git describe --tags > VERSION diff --git a/.github/workflows/docker-image-amd64.yml b/.github/workflows/docker-image-amd64.yml index e3b8439a..2079d31f 100644 --- a/.github/workflows/docker-image-amd64.yml +++ b/.github/workflows/docker-image-amd64.yml @@ -20,6 +20,13 @@ jobs: - name: Check out the repo uses: actions/checkout@v3 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi + - name: Save version info run: | git describe --tags > VERSION diff --git a/.github/workflows/docker-image-arm64.yml b/.github/workflows/docker-image-arm64.yml index d6449eb8..39d1a401 100644 --- a/.github/workflows/docker-image-arm64.yml +++ b/.github/workflows/docker-image-arm64.yml @@ -21,6 +21,13 @@ jobs: - name: Check out the repo uses: actions/checkout@v3 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi + - name: Save version info run: | git describe --tags > VERSION diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml index e81ab09f..6f30a1d5 100644 --- a/.github/workflows/linux-release.yml +++ b/.github/workflows/linux-release.yml @@ -20,6 +20,12 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi - uses: actions/setup-node@v3 with: node-version: 16 diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml index 13415276..359c2c92 100644 --- a/.github/workflows/macos-release.yml +++ b/.github/workflows/macos-release.yml @@ -20,6 +20,12 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi - uses: actions/setup-node@v3 with: node-version: 16 diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml index 8b1160b4..4e99b75c 100644 --- a/.github/workflows/windows-release.yml +++ b/.github/workflows/windows-release.yml @@ -23,6 +23,12 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi - uses: actions/setup-node@v3 with: node-version: 16 diff --git a/Dockerfile b/Dockerfile index ec2f9d43..6743b139 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,6 +12,10 @@ WORKDIR /web/berry RUN npm install RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build +WORKDIR /web/air +RUN npm install +RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build + FROM golang AS builder2 ENV GO111MODULE=on \ diff --git a/README.md b/README.md index 69bb10ef..0ba659c4 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,9 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [百川大模型](https://platform.baichuan-ai.com) + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) + [x] [MINIMAX](https://api.minimax.chat/) + + [x] [Groq](https://wow.groq.com/) + + [x] [Ollama](https://github.com/ollama/ollama) + + [x] [零一万物](https://platform.lingyiwanwu.com/) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 @@ -105,6 +108,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [GitHub 开放授权](https://github.com/settings/applications/new)。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 +24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 ## 部署 ### 基于 Docker 进行部署 @@ -374,6 +378,9 @@ graph LR 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 +19. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 +20. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 +21. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/blacklist/main.go b/common/blacklist/main.go new file mode 100644 index 00000000..f84ce6ae --- /dev/null +++ b/common/blacklist/main.go @@ -0,0 +1,29 @@ +package blacklist + +import ( + "fmt" + "sync" +) + +var blackList sync.Map + +func init() { + blackList = sync.Map{} +} + +func userId2Key(id int) string { + return fmt.Sprintf("userid_%d", id) +} + +func BanUser(id int) { + blackList.Store(userId2Key(id), true) +} + +func UnbanUser(id int) { + blackList.Delete(userId2Key(id)) +} + +func IsUserBanned(id int) bool { + _, ok := blackList.Load(userId2Key(id)) + return ok +} diff --git a/common/config/config.go b/common/config/config.go index dd0236b4..a261523d 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -1,7 +1,7 @@ package config import ( - "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/env" "os" "strconv" "sync" @@ -52,6 +52,7 @@ var EmailDomainWhitelist = []string{ } var DebugEnabled = os.Getenv("DEBUG") == "true" +var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true" var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" var LogConsumeEnabled = true @@ -69,17 +70,20 @@ var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" +var MessagePusherAddress = "" +var MessagePusherToken = "" + var TurnstileSiteKey = "" var TurnstileSecretKey = "" -var QuotaForNewUser = 0 -var QuotaForInviter = 0 -var QuotaForInvitee = 0 +var QuotaForNewUser int64 = 0 +var QuotaForInviter int64 = 0 +var QuotaForInvitee int64 = 0 var ChannelDisableThreshold = 5.0 var AutomaticDisableChannelEnabled = false var AutomaticEnableChannelEnabled = false -var QuotaRemindThreshold = 1000 -var PreConsumedQuota = 500 +var QuotaRemindThreshold int64 = 1000 +var PreConsumedQuota int64 = 500 var ApproximateTokenEnabled = false var RetryTimes = 0 @@ -90,28 +94,29 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var RequestInterval = time.Duration(requestInterval) * time.Second -var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second +var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second var BatchUpdateEnabled = false -var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) +var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5) -var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second +var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second -var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") +var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE") -var Theme = helper.GetOrDefaultEnvString("THEME", "default") +var Theme = env.String("THEME", "default") var ValidThemes = map[string]bool{ "default": true, "berry": true, + "air": true, } // All duration's unit is seconds // Shouldn't larger then RateLimitKeyExpirationDuration var ( - GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) + GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitDuration int64 = 3 * 60 - GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) + GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration int64 = 3 * 60 UploadRateLimitNum = 10 @@ -125,3 +130,9 @@ var ( ) var RateLimitKeyExpirationDuration = 20 * time.Minute + +var EnableMetric = env.Bool("ENABLE_METRIC", false) +var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10) +var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) +var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024) +var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) diff --git a/common/constants.go b/common/constants.go index ac901139..849bdce7 100644 --- a/common/constants.go +++ b/common/constants.go @@ -15,6 +15,7 @@ const ( const ( UserStatusEnabled = 1 // don't use 0, 0 is the default value! UserStatusDisabled = 2 // also don't use 0 + UserStatusDeleted = 3 ) const ( @@ -38,35 +39,40 @@ const ( ) const ( - ChannelTypeUnknown = 0 - ChannelTypeOpenAI = 1 - ChannelTypeAPI2D = 2 - ChannelTypeAzure = 3 - ChannelTypeCloseAI = 4 - ChannelTypeOpenAISB = 5 - ChannelTypeOpenAIMax = 6 - ChannelTypeOhMyGPT = 7 - ChannelTypeCustom = 8 - ChannelTypeAILS = 9 - ChannelTypeAIProxy = 10 - ChannelTypePaLM = 11 - ChannelTypeAPI2GPT = 12 - ChannelTypeAIGC2D = 13 - ChannelTypeAnthropic = 14 - ChannelTypeBaidu = 15 - ChannelTypeZhipu = 16 - ChannelTypeAli = 17 - ChannelTypeXunfei = 18 - ChannelType360 = 19 - ChannelTypeOpenRouter = 20 - ChannelTypeAIProxyLibrary = 21 - ChannelTypeFastGPT = 22 - ChannelTypeTencent = 23 - ChannelTypeGemini = 24 - ChannelTypeMoonshot = 25 - ChannelTypeBaichuan = 26 - ChannelTypeMinimax = 27 - ChannelTypeMistral = 28 + ChannelTypeUnknown = iota + ChannelTypeOpenAI + ChannelTypeAPI2D + ChannelTypeAzure + ChannelTypeCloseAI + ChannelTypeOpenAISB + ChannelTypeOpenAIMax + ChannelTypeOhMyGPT + ChannelTypeCustom + ChannelTypeAILS + ChannelTypeAIProxy + ChannelTypePaLM + ChannelTypeAPI2GPT + ChannelTypeAIGC2D + ChannelTypeAnthropic + ChannelTypeBaidu + ChannelTypeZhipu + ChannelTypeAli + ChannelTypeXunfei + ChannelType360 + ChannelTypeOpenRouter + ChannelTypeAIProxyLibrary + ChannelTypeFastGPT + ChannelTypeTencent + ChannelTypeGemini + ChannelTypeMoonshot + ChannelTypeBaichuan + ChannelTypeMinimax + ChannelTypeMistral + ChannelTypeGroq + ChannelTypeOllama + ChannelTypeLingYiWanWu + + ChannelTypeDummy ) var ChannelBaseURLs = []string{ @@ -99,6 +105,9 @@ var ChannelBaseURLs = []string{ "https://api.baichuan-ai.com", // 26 "https://api.minimax.chat", // 27 "https://api.mistral.ai", // 28 + "https://api.groq.com/openai", // 29 + "http://localhost:11434", // 30 + "https://api.lingyiwanwu.com", // 31 } const ( diff --git a/common/database.go b/common/database.go index 9b52a0d5..f2db759f 100644 --- a/common/database.go +++ b/common/database.go @@ -1,9 +1,12 @@ package common -import "github.com/songquanpeng/one-api/common/helper" +import ( + "github.com/songquanpeng/one-api/common/env" +) var UsingSQLite = false var UsingPostgreSQL = false +var UsingMySQL = false var SQLitePath = "one-api.db" -var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) +var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/env/helper.go b/common/env/helper.go new file mode 100644 index 00000000..fdb9f827 --- /dev/null +++ b/common/env/helper.go @@ -0,0 +1,42 @@ +package env + +import ( + "os" + "strconv" +) + +func Bool(env string, defaultValue bool) bool { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) == "true" +} + +func Int(env string, defaultValue int) int { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.Atoi(os.Getenv(env)) + if err != nil { + return defaultValue + } + return num +} + +func Float64(env string, defaultValue float64) float64 { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.ParseFloat(os.Getenv(env), 64) + if err != nil { + return defaultValue + } + return num +} + +func String(env string, defaultValue string) string { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) +} diff --git a/common/helper/helper.go b/common/helper/helper.go index babe422b..db41ac74 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -3,12 +3,10 @@ package helper import ( "fmt" "github.com/google/uuid" - "github.com/songquanpeng/one-api/common/logger" "html/template" "log" "math/rand" "net" - "os" "os/exec" "runtime" "strconv" @@ -187,6 +185,10 @@ func GetTimeString() string { return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) } +func GenRequestID() string { + return GetTimeString() + GetRandomNumberString(8) +} + func Max(a int, b int) int { if a >= b { return a @@ -195,25 +197,6 @@ func Max(a int, b int) int { } } -func GetOrDefaultEnvInt(env string, defaultValue int) int { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - num, err := strconv.Atoi(os.Getenv(env)) - if err != nil { - logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) - return defaultValue - } - return num -} - -func GetOrDefaultEnvString(env string, defaultValue string) string { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - return os.Getenv(env) -} - func AssignOrDefault(value string, defaultValue string) string { if len(value) != 0 { return value diff --git a/common/logger/logger.go b/common/logger/logger.go index 8232b2fc..957d8a11 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" "io" "log" "os" @@ -19,9 +21,6 @@ const ( loggerError = "ERR" ) -const maxLogCount = 1000000 - -var logCount int var setupLogLock sync.Mutex var setupLogWorking bool @@ -57,7 +56,9 @@ func SysError(s string) { } func Debug(ctx context.Context, msg string) { - logHelper(ctx, loggerDEBUG, msg) + if config.DebugEnabled { + logHelper(ctx, loggerDEBUG, msg) + } } func Info(ctx context.Context, msg string) { @@ -94,11 +95,12 @@ func logHelper(ctx context.Context, level string, msg string) { writer = gin.DefaultWriter } id := ctx.Value(RequestIdKey) + if id == nil { + id = helper.GenRequestID() + } now := time.Now() _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) - logCount++ // we don't need accurate count, so no lock here - if logCount > maxLogCount && !setupLogWorking { - logCount = 0 + if !setupLogWorking { setupLogWorking = true go func() { SetupLogger() diff --git a/common/email.go b/common/message/email.go similarity index 96% rename from common/email.go rename to common/message/email.go index 2689da6a..b06782db 100644 --- a/common/email.go +++ b/common/message/email.go @@ -1,4 +1,4 @@ -package common +package message import ( "crypto/rand" @@ -12,6 +12,9 @@ import ( ) func SendEmail(subject string, receiver string, content string) error { + if receiver == "" { + return fmt.Errorf("receiver is empty") + } if config.SMTPFrom == "" { // for compatibility config.SMTPFrom = config.SMTPAccount } diff --git a/common/message/main.go b/common/message/main.go new file mode 100644 index 00000000..5ce82a64 --- /dev/null +++ b/common/message/main.go @@ -0,0 +1,22 @@ +package message + +import ( + "fmt" + "github.com/songquanpeng/one-api/common/config" +) + +const ( + ByAll = "all" + ByEmail = "email" + ByMessagePusher = "message_pusher" +) + +func Notify(by string, title string, description string, content string) error { + if by == ByEmail { + return SendEmail(title, config.RootUserEmail, content) + } + if by == ByMessagePusher { + return SendMessage(title, description, content) + } + return fmt.Errorf("unknown notify method: %s", by) +} diff --git a/common/message/message-pusher.go b/common/message/message-pusher.go new file mode 100644 index 00000000..69949b4b --- /dev/null +++ b/common/message/message-pusher.go @@ -0,0 +1,53 @@ +package message + +import ( + "bytes" + "encoding/json" + "errors" + "github.com/songquanpeng/one-api/common/config" + "net/http" +) + +type request struct { + Title string `json:"title"` + Description string `json:"description"` + Content string `json:"content"` + URL string `json:"url"` + Channel string `json:"channel"` + Token string `json:"token"` +} + +type response struct { + Success bool `json:"success"` + Message string `json:"message"` +} + +func SendMessage(title string, description string, content string) error { + if config.MessagePusherAddress == "" { + return errors.New("message pusher address is not set") + } + req := request{ + Title: title, + Description: description, + Content: content, + Token: config.MessagePusherToken, + } + data, err := json.Marshal(req) + if err != nil { + return err + } + resp, err := http.Post(config.MessagePusherAddress, + "application/json", bytes.NewBuffer(data)) + if err != nil { + return err + } + var res response + err = json.NewDecoder(resp.Body).Decode(&res) + if err != nil { + return err + } + if !res.Success { + return errors.New(res.Message) + } + return nil +} diff --git a/common/model-ratio.go b/common/model-ratio.go index 2e66ac0d..5e7d5729 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -4,7 +4,6 @@ import ( "encoding/json" "github.com/songquanpeng/one-api/common/logger" "strings" - "time" ) const ( @@ -31,7 +30,7 @@ var ModelRatio = map[string]float64{ "gpt-4-0125-preview": 5, // $0.01 / 1K tokens "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens "gpt-4-vision-preview": 5, // $0.01 / 1K tokens - "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens "gpt-3.5-turbo-0301": 0.75, "gpt-3.5-turbo-0613": 0.75, "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens @@ -63,18 +62,24 @@ var ModelRatio = map[string]float64{ "text-search-ada-doc-001": 10, "text-moderation-stable": 0.1, "text-moderation-latest": 0.1, - "dall-e-2": 8, // $0.016 - $0.020 / image - "dall-e-3": 20, // $0.040 - $0.120 / image - "claude-instant-1": 0.815, // $1.63 / 1M tokens - "claude-2": 5.51, // $11.02 / 1M tokens - "claude-2.0": 5.51, // $11.02 / 1M tokens - "claude-2.1": 5.51, // $11.02 / 1M tokens + "dall-e-2": 8, // $0.016 - $0.020 / image + "dall-e-3": 20, // $0.040 - $0.120 / image + // https://www.anthropic.com/api#pricing + "claude-instant-1.2": 0.8 / 1000 * USD, + "claude-2.0": 8.0 / 1000 * USD, + "claude-2.1": 8.0 / 1000 * USD, + "claude-3-haiku-20240307": 0.25 / 1000 * USD, + "claude-3-sonnet-20240229": 3.0 / 1000 * USD, + "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens "ERNIE-Bot-8k": 0.024 * RMB, "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens + "bge-large-zh": 0.002 * RMB, + "bge-large-en": 0.002 * RMB, + "bge-large-8k": 0.002 * RMB, "PaLM-2": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens @@ -122,6 +127,15 @@ var ModelRatio = map[string]float64{ "mistral-medium-latest": 2.7 / 1000 * USD, "mistral-large-latest": 8.0 / 1000 * USD, "mistral-embed": 0.1 / 1000 * USD, + // https://wow.groq.com/ + "llama2-70b-4096": 0.7 / 1000 * USD, + "llama2-7b-2048": 0.1 / 1000 * USD, + "mixtral-8x7b-32768": 0.27 / 1000 * USD, + "gemma-7b-it": 0.1 / 1000 * USD, + // https://platform.lingyiwanwu.com/docs#-计费单元 + "yi-34b-chat-0205": 2.5 / 1000000 * RMB, + "yi-34b-chat-200k": 12.0 / 1000000 * RMB, + "yi-vl-plus": 6.0 / 1000000 * RMB, } var CompletionRatio = map[string]float64{} @@ -140,6 +154,26 @@ func init() { } } +func AddNewMissingRatio(oldRatio string) string { + newRatio := make(map[string]float64) + err := json.Unmarshal([]byte(oldRatio), &newRatio) + if err != nil { + logger.SysError("error unmarshalling old ratio: " + err.Error()) + return oldRatio + } + for k, v := range DefaultModelRatio { + if _, ok := newRatio[k]; !ok { + newRatio[k] = v + } + } + jsonBytes, err := json.Marshal(newRatio) + if err != nil { + logger.SysError("error marshalling new ratio: " + err.Error()) + return oldRatio + } + return string(jsonBytes) +} + func ModelRatio2JSONString() string { jsonBytes, err := json.Marshal(ModelRatio) if err != nil { @@ -189,7 +223,7 @@ func GetCompletionRatio(name string) float64 { return ratio } if strings.HasPrefix(name, "gpt-3.5") { - if strings.HasSuffix(name, "0125") { + if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { // https://openai.com/blog/new-embedding-models-and-api-updates // Updated GPT-3.5 Turbo model and lower pricing return 3 @@ -197,16 +231,7 @@ func GetCompletionRatio(name string) float64 { if strings.HasSuffix(name, "1106") { return 2 } - if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" { - // TODO: clear this after 2023-12-11 - now := time.Now() - // https://platform.openai.com/docs/models/continuous-model-upgrades - // if after 2023-12-11, use 2 - if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) { - return 2 - } - } - return 1.333333 + return 4.0 / 3.0 } if strings.HasPrefix(name, "gpt-4") { if strings.HasSuffix(name, "preview") { @@ -214,14 +239,18 @@ func GetCompletionRatio(name string) float64 { } return 2 } - if strings.HasPrefix(name, "claude-instant-1") { - return 3.38 + if strings.HasPrefix(name, "claude-3") { + return 5 } - if strings.HasPrefix(name, "claude-2") { - return 2.965517 + if strings.HasPrefix(name, "claude-") { + return 3 } if strings.HasPrefix(name, "mistral-") { return 3 } + switch name { + case "llama2-70b-4096": + return 0.8 / 0.7 + } return 1 } diff --git a/common/utils.go b/common/utils.go index 24615225..ecee2c8e 100644 --- a/common/utils.go +++ b/common/utils.go @@ -5,7 +5,7 @@ import ( "github.com/songquanpeng/one-api/common/config" ) -func LogQuota(quota int) string { +func LogQuota(quota int64) string { if config.DisplayInCurrencyEnabled { return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) } else { diff --git a/controller/billing.go b/controller/billing.go index 7317913d..dd518678 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -8,8 +8,8 @@ import ( ) func GetSubscription(c *gin.Context) { - var remainQuota int - var usedQuota int + var remainQuota int64 + var usedQuota int64 var err error var token *model.Token var expiredTime int64 @@ -60,7 +60,7 @@ func GetSubscription(c *gin.Context) { } func GetUsage(c *gin.Context) { - var quota int + var quota int64 var err error var token *model.Token if config.DisplayTokenStatEnabled { diff --git a/controller/channel-billing.go b/controller/channel-billing.go index abeab26a..03c97349 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -8,6 +8,7 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" @@ -295,7 +296,7 @@ func UpdateChannelBalance(c *gin.Context) { } func updateAllChannelsBalance() error { - channels, err := model.GetAllChannels(0, 0, true) + channels, err := model.GetAllChannels(0, 0, "all") if err != nil { return err } @@ -313,7 +314,7 @@ func updateAllChannelsBalance() error { } else { // err is nil & balance <= 0 means quota is used up if balance <= 0 { - disableChannel(channel.Id, channel.Name, "余额不足") + monitor.DisableChannel(channel.Id, channel.Name, "余额不足") } } time.Sleep(config.RequestInterval) @@ -322,15 +323,14 @@ func updateAllChannelsBalance() error { } func UpdateAllChannelsBalance(c *gin.Context) { - // TODO: make it async - err := updateAllChannelsBalance() - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } + //err := updateAllChannelsBalance() + //if err != nil { + // c.JSON(http.StatusOK, gin.H{ + // "success": false, + // "message": err.Error(), + // }) + // return + //} c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/controller/channel-test.go b/controller/channel-test.go index 7007e205..67ac91d0 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -8,8 +8,10 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" relaymodel "github.com/songquanpeng/one-api/relay/model" @@ -28,7 +30,7 @@ import ( func buildTestRequest() *relaymodel.GeneralOpenAIRequest { testRequest := &relaymodel.GeneralOpenAIRequest{ - MaxTokens: 1, + MaxTokens: 2, Stream: false, Model: "gpt-3.5-turbo", } @@ -148,33 +150,7 @@ func TestChannel(c *gin.Context) { var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false -func notifyRootUser(subject string, content string) { - if config.RootUserEmail == "" { - config.RootUserEmail = model.GetRootUserEmail() - } - err := common.SendEmail(subject, config.RootUserEmail, content) - if err != nil { - logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) - } -} - -// disable & notify -func disableChannel(channelId int, channelName string, reason string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - notifyRootUser(subject, content) -} - -// enable & notify -func enableChannel(channelId int, channelName string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - notifyRootUser(subject, content) -} - -func testAllChannels(notify bool) error { +func testChannels(notify bool, scope string) error { if config.RootUserEmail == "" { config.RootUserEmail = model.GetRootUserEmail() } @@ -185,7 +161,7 @@ func testAllChannels(notify bool) error { } testAllChannelsRunning = true testAllChannelsLock.Unlock() - channels, err := model.GetAllChannels(0, 0, true) + channels, err := model.GetAllChannels(0, 0, scope) if err != nil { return err } @@ -202,13 +178,17 @@ func testAllChannels(notify bool) error { milliseconds := tok.Sub(tik).Milliseconds() if isChannelEnabled && milliseconds > disableThreshold { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - disableChannel(channel.Id, channel.Name, err.Error()) + if config.AutomaticDisableChannelEnabled { + monitor.DisableChannel(channel.Id, channel.Name, err.Error()) + } else { + _ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error()) + } } if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { - disableChannel(channel.Id, channel.Name, err.Error()) + monitor.DisableChannel(channel.Id, channel.Name, err.Error()) } if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { - enableChannel(channel.Id, channel.Name) + monitor.EnableChannel(channel.Id, channel.Name) } channel.UpdateResponseTime(milliseconds) time.Sleep(config.RequestInterval) @@ -217,7 +197,7 @@ func testAllChannels(notify bool) error { testAllChannelsRunning = false testAllChannelsLock.Unlock() if notify { - err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") + err := message.Notify(message.ByAll, "通道测试完成", "", "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") if err != nil { logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } @@ -226,8 +206,12 @@ func testAllChannels(notify bool) error { return nil } -func TestAllChannels(c *gin.Context) { - err := testAllChannels(true) +func TestChannels(c *gin.Context) { + scope := c.Query("scope") + if scope == "" { + scope = "all" + } + err := testChannels(true, scope) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -246,7 +230,7 @@ func AutomaticallyTestChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) logger.SysLog("testing all channels") - _ = testAllChannels(false) + _ = testChannels(false, "all") logger.SysLog("channel test finished") } } diff --git a/controller/channel.go b/controller/channel.go index bdfa00d9..37bfb99d 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -15,7 +15,7 @@ func GetAllChannels(c *gin.Context) { if p < 0 { p = 0 } - channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false) + channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, "limited") if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/misc.go b/controller/misc.go index 036bdbd1..f27fdb12 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/message" "github.com/songquanpeng/one-api/model" "net/http" "strings" @@ -110,7 +111,7 @@ func SendEmailVerification(c *gin.Context) { content := fmt.Sprintf("

您好,你正在进行%s邮箱验证。

"+ "

您的验证码为: %s

"+ "

验证码 %d 分钟内有效,如果不是本人操作,请忽略。

", config.SystemName, code, common.VerificationValidMinutes) - err := common.SendEmail(subject, email, content) + err := message.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -149,7 +150,7 @@ func SendPasswordResetEmail(c *gin.Context) { "

点击 此处 进行密码重置。

"+ "

如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s

"+ "

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

", config.SystemName, link, link, common.VerificationValidMinutes) - err := common.SendEmail(subject, email, content) + err := message.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/model.go b/controller/model.go index 0d0d2658..4c5476b4 100644 --- a/controller/model.go +++ b/controller/model.go @@ -3,14 +3,13 @@ package controller import ( "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel/ai360" - "github.com/songquanpeng/one-api/relay/channel/baichuan" - "github.com/songquanpeng/one-api/relay/channel/minimax" - "github.com/songquanpeng/one-api/relay/channel/mistral" - "github.com/songquanpeng/one-api/relay/channel/moonshot" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "net/http" ) // https://platform.openai.com/docs/api-reference/models/list @@ -42,6 +41,7 @@ type OpenAIModels struct { var openAIModels []OpenAIModels var openAIModelsMap map[string]OpenAIModels +var channelId2Models map[int][]string func init() { var permission []OpenAIModelPermission @@ -79,65 +79,44 @@ func init() { }) } } - for _, modelName := range ai360.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "360", - Permission: permission, - Root: modelName, - Parent: nil, - }) - } - for _, modelName := range moonshot.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "moonshot", - Permission: permission, - Root: modelName, - Parent: nil, - }) - } - for _, modelName := range baichuan.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "baichuan", - Permission: permission, - Root: modelName, - Parent: nil, - }) - } - for _, modelName := range minimax.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "minimax", - Permission: permission, - Root: modelName, - Parent: nil, - }) - } - for _, modelName := range mistral.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "mistralai", - Permission: permission, - Root: modelName, - Parent: nil, - }) + for _, channelType := range openai.CompatibleChannels { + if channelType == common.ChannelTypeAzure { + continue + } + channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) + for _, modelName := range channelModelList { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: channelName, + Permission: permission, + Root: modelName, + Parent: nil, + }) + } } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { openAIModelsMap[model.Id] = model } + channelId2Models = make(map[int][]string) + for i := 1; i < common.ChannelTypeDummy; i++ { + adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i)) + meta := &util.RelayMeta{ + ChannelType: i, + } + adaptor.Init(meta) + channelId2Models[i] = adaptor.GetModelList() + } +} + +func DashboardListModels(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": channelId2Models, + }) } func ListModels(c *gin.Context) { diff --git a/controller/relay.go b/controller/relay.go index 9b2d462c..b34768df 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -11,6 +11,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/middleware" dbmodel "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/model" @@ -45,11 +46,12 @@ func Relay(c *gin.Context) { requestBody, _ := common.GetRequestBody(c) logger.Debugf(ctx, "request body: %s", string(requestBody)) } + channelId := c.GetInt("channel_id") bizErr := relay(c, relayMode) if bizErr == nil { + monitor.Emit(channelId, true) return } - channelId := c.GetInt("channel_id") lastFailedChannelId := channelId channelName := c.GetString("channel_name") group := c.GetString("group") @@ -117,7 +119,9 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors if util.ShouldDisableChannel(&err.Error, err.StatusCode) { - disableChannel(channelId, channelName, err.Message) + monitor.DisableChannel(channelId, channelName, err.Message) + } else { + monitor.Emit(channelId, false) } } diff --git a/docker-compose.yml b/docker-compose.yml index fd121ae3..df427f60 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: '3.4' services: one-api: - image: justsong/one-api:latest + image: "${REGISTRY:-docker.io}/justsong/one-api:latest" container_name: one-api restart: always command: --log-dir /app/logs @@ -30,12 +30,12 @@ services: retries: 3 redis: - image: redis:latest + image: "${REGISTRY:-docker.io}/redis:latest" container_name: redis restart: always db: - image: mysql:8.2.0 + image: "${REGISTRY:-docker.io}/mysql:8.2.0" restart: always container_name: mysql volumes: diff --git a/go.mod b/go.mod index 4ab23003..f9ed96d3 100644 --- a/go.mod +++ b/go.mod @@ -42,7 +42,8 @@ require ( github.com/gorilla/sessions v1.2.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.3.1 // indirect + github.com/jackc/pgx/v5 v5.5.4 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -58,8 +59,9 @@ require ( github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/net v0.17.0 // indirect + golang.org/x/sync v0.1.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect + google.golang.org/protobuf v1.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 21bcddc6..9cf056e5 100644 --- a/go.sum +++ b/go.sum @@ -73,8 +73,10 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= -github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= +github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8= +github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= @@ -157,6 +159,8 @@ golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -177,8 +181,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IV golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/main.go b/main.go index 1f43a45f..b20c6daf 100644 --- a/main.go +++ b/main.go @@ -30,11 +30,25 @@ func main() { if config.DebugEnabled { logger.SysLog("running in debug mode") } + var err error // Initialize SQL Database - err := model.InitDB() + model.DB, err = model.InitDB("SQL_DSN") if err != nil { logger.FatalLog("failed to initialize database: " + err.Error()) } + if os.Getenv("LOG_SQL_DSN") != "" { + logger.SysLog("using secondary database for table logs") + model.LOG_DB, err = model.InitDB("LOG_SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize secondary database: " + err.Error()) + } + } else { + model.LOG_DB = model.DB + } + err = model.CreateRootAccountIfNeed() + if err != nil { + logger.FatalLog("database init error: " + err.Error()) + } defer func() { err := model.CloseDB() if err != nil { @@ -64,13 +78,6 @@ func main() { go model.SyncOptions(config.SyncFrequency) go model.SyncChannelCache(config.SyncFrequency) } - if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { - frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) - if err != nil { - logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) - } - go controller.AutomaticallyUpdateChannels(frequency) - } if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) if err != nil { @@ -83,6 +90,9 @@ func main() { logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") model.InitBatchUpdater() } + if config.EnableMetric { + logger.SysLog("metric enabled, will disable channel if too much request failed") + } openai.InitTokenEncoders() // Initialize HTTP server diff --git a/middleware/auth.go b/middleware/auth.go index 9d25f395..30997efd 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -4,6 +4,7 @@ import ( "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/model" "net/http" "strings" @@ -42,11 +43,14 @@ func authHelper(c *gin.Context, minRole int) { return } } - if status.(int) == common.UserStatusDisabled { + if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "用户已被封禁", }) + session := sessions.Default(c) + session.Clear() + _ = session.Save() c.Abort() return } @@ -99,7 +103,7 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusInternalServerError, err.Error()) return } - if !userEnabled { + if !userEnabled || blacklist.IsUserBanned(token.UserId) { abortWithMessage(c, http.StatusForbidden, "用户已被封禁") return } diff --git a/middleware/recover.go b/middleware/recover.go index 02e3e3bb..cfc3f827 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -3,6 +3,7 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "net/http" "runtime/debug" @@ -12,11 +13,15 @@ func RelayPanicRecover() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - logger.SysError(fmt.Sprintf("panic detected: %v", err)) - logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + ctx := c.Request.Context() + logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err)) + logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path)) + body, _ := common.GetRequestBody(c) + logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ - "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), + "message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err), "type": "one_api_panic", }, }) diff --git a/middleware/request-id.go b/middleware/request-id.go index 234a93d8..a4c49ddb 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -9,7 +9,7 @@ import ( func RequestId() func(c *gin.Context) { return func(c *gin.Context) { - id := helper.GetTimeString() + helper.GetRandomNumberString(8) + id := helper.GenRequestID() c.Set(logger.RequestIdKey, id) ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) c.Request = c.Request.WithContext(ctx) diff --git a/model/cache.go b/model/cache.go index 3c3575b8..dd20d857 100644 --- a/model/cache.go +++ b/model/cache.go @@ -1,6 +1,7 @@ package model import ( + "context" "encoding/json" "errors" "fmt" @@ -70,31 +71,42 @@ func CacheGetUserGroup(id int) (group string, err error) { return group, err } -func CacheGetUserQuota(id int) (quota int, err error) { +func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) { + quota, err = GetUserQuota(id) + if err != nil { + return 0, err + } + err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) + if err != nil { + logger.Error(ctx, "Redis set user quota error: "+err.Error()) + } + return +} + +func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) { if !common.RedisEnabled { return GetUserQuota(id) } quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) if err != nil { - quota, err = GetUserQuota(id) - if err != nil { - return 0, err - } - err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) - if err != nil { - logger.SysError("Redis set user quota error: " + err.Error()) - } - return quota, err + return fetchAndUpdateUserQuota(ctx, id) } - quota, err = strconv.Atoi(quotaString) - return quota, err + quota, err = strconv.ParseInt(quotaString, 10, 64) + if err != nil { + return 0, nil + } + if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db + logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id) + return fetchAndUpdateUserQuota(ctx, id) + } + return quota, nil } -func CacheUpdateUserQuota(id int) error { +func CacheUpdateUserQuota(ctx context.Context, id int) error { if !common.RedisEnabled { return nil } - quota, err := CacheGetUserQuota(id) + quota, err := CacheGetUserQuota(ctx, id) if err != nil { return err } @@ -102,7 +114,7 @@ func CacheUpdateUserQuota(id int) error { return err } -func CacheDecreaseUserQuota(id int, quota int) error { +func CacheDecreaseUserQuota(id int, quota int64) error { if !common.RedisEnabled { return nil } diff --git a/model/channel.go b/model/channel.go index 19af2263..fc4905b1 100644 --- a/model/channel.go +++ b/model/channel.go @@ -13,7 +13,7 @@ import ( type Channel struct { Id int `json:"id"` Type int `json:"type" gorm:"default:0"` - Key string `json:"key" gorm:"not null;index"` + Key string `json:"key" gorm:"type:text"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` Weight *uint `json:"weight" gorm:"default:0"` @@ -32,23 +32,22 @@ type Channel struct { Config string `json:"config"` } -func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { +func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { var channels []*Channel var err error - if selectAll { + switch scope { + case "all": err = DB.Order("id desc").Find(&channels).Error - } else { + case "disabled": + err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error + default: err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error } return channels, err } func SearchChannels(keyword string) (channels []*Channel, err error) { - keyCol := "`key`" - if common.UsingPostgreSQL { - keyCol = `"key"` - } - err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error + err = DB.Omit("key").Where("id = ? or name LIKE ?", helper.String2Int(keyword), keyword+"%").Find(&channels).Error return channels, err } @@ -179,7 +178,7 @@ func UpdateChannelStatusById(id int, status int) { } } -func UpdateChannelUsedQuota(id int, quota int) { +func UpdateChannelUsedQuota(id int, quota int64) { if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) return @@ -187,7 +186,7 @@ func UpdateChannelUsedQuota(id int, quota int) { updateChannelUsedQuota(id, quota) } -func updateChannelUsedQuota(id int, quota int) { +func updateChannelUsedQuota(id int, quota int64) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { logger.SysError("failed to update channel used quota: " + err.Error()) diff --git a/model/log.go b/model/log.go index 9615c237..4409f73e 100644 --- a/model/log.go +++ b/model/log.go @@ -45,13 +45,13 @@ func RecordLog(userId int, logType int, content string) { Type: logType, Content: content, } - err := DB.Create(log).Error + err := LOG_DB.Create(log).Error if err != nil { logger.SysError("failed to record log: " + err.Error()) } } -func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { +func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !config.LogConsumeEnabled { return @@ -66,10 +66,10 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke CompletionTokens: completionTokens, TokenName: tokenName, ModelName: modelName, - Quota: quota, + Quota: int(quota), ChannelId: channelId, } - err := DB.Create(log).Error + err := LOG_DB.Create(log).Error if err != nil { logger.Error(ctx, "failed to record log: "+err.Error()) } @@ -78,9 +78,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { - tx = DB + tx = LOG_DB } else { - tx = DB.Where("type = ?", logType) + tx = LOG_DB.Where("type = ?", logType) } if modelName != "" { tx = tx.Where("model_name = ?", modelName) @@ -107,9 +107,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { - tx = DB.Where("user_id = ?", userId) + tx = LOG_DB.Where("user_id = ?", userId) } else { - tx = DB.Where("user_id = ? and type = ?", userId, logType) + tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType) } if modelName != "" { tx = tx.Where("model_name = ?", modelName) @@ -128,17 +128,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int } func SearchAllLogs(keyword string) (logs []*Log, err error) { - err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error + err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error return logs, err } func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { - err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error + err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error return logs, err } -func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { - tx := DB.Table("logs").Select("ifnull(sum(quota),0)") +func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { + tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") if username != "" { tx = tx.Where("username = ?", username) } @@ -162,7 +162,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa } func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { - tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") + tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") if username != "" { tx = tx.Where("username = ?", username) } @@ -183,7 +183,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa } func DeleteOldLog(targetTimestamp int64) (int64, error) { - result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) + result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) return result.RowsAffected, result.Error } @@ -207,7 +207,7 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day" } - err = DB.Raw(` + err = LOG_DB.Raw(` SELECT `+groupSelect+`, model_name, count(1) as request_count, sum(quota) as quota, diff --git a/model/main.go b/model/main.go index 18ed01d0..ca7a35b2 100644 --- a/model/main.go +++ b/model/main.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/env" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/driver/mysql" @@ -16,8 +17,9 @@ import ( ) var DB *gorm.DB +var LOG_DB *gorm.DB -func createRootAccountIfNeed() error { +func CreateRootAccountIfNeed() error { var user User //if user.Status != util.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { @@ -40,9 +42,9 @@ func createRootAccountIfNeed() error { return nil } -func chooseDB() (*gorm.DB, error) { - if os.Getenv("SQL_DSN") != "" { - dsn := os.Getenv("SQL_DSN") +func chooseDB(envName string) (*gorm.DB, error) { + if os.Getenv(envName) != "" { + dsn := os.Getenv(envName) if strings.HasPrefix(dsn, "postgres://") { // Use PostgreSQL logger.SysLog("using PostgreSQL as database") @@ -56,6 +58,7 @@ func chooseDB() (*gorm.DB, error) { } // Use MySQL logger.SysLog("using MySQL as database") + common.UsingMySQL = true return gorm.Open(mysql.Open(dsn), &gorm.Config{ PrepareStmt: true, // precompile SQL }) @@ -69,67 +72,78 @@ func chooseDB() (*gorm.DB, error) { }) } -func InitDB() (err error) { - db, err := chooseDB() +func InitDB(envName string) (db *gorm.DB, err error) { + db, err = chooseDB(envName) if err == nil { - if config.DebugEnabled { + if config.DebugSQLEnabled { db = db.Debug() } - DB = db - sqlDB, err := DB.DB() + sqlDB, err := db.DB() if err != nil { - return err + return nil, err } - sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) + sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) if !config.IsMasterNode { - return nil + return db, err + } + if common.UsingMySQL { + _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded } logger.SysLog("database migration started") err = db.AutoMigrate(&Channel{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Token{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&User{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Option{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Redemption{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Ability{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Log{}) if err != nil { - return err + return nil, err } logger.SysLog("database migrated") - err = createRootAccountIfNeed() - return err + return db, err } else { logger.FatalLog(err) } - return err + return db, err } -func CloseDB() error { - sqlDB, err := DB.DB() +func closeDB(db *gorm.DB) error { + sqlDB, err := db.DB() if err != nil { return err } err = sqlDB.Close() return err } + +func CloseDB() error { + if LOG_DB != DB { + err := closeDB(LOG_DB) + if err != nil { + return err + } + } + return closeDB(DB) +} diff --git a/model/option.go b/model/option.go index 6002c795..1d1c28b4 100644 --- a/model/option.go +++ b/model/option.go @@ -57,13 +57,15 @@ func InitOptionMap() { config.OptionMap["WeChatServerAddress"] = "" config.OptionMap["WeChatServerToken"] = "" config.OptionMap["WeChatAccountQRCodeImageURL"] = "" + config.OptionMap["MessagePusherAddress"] = "" + config.OptionMap["MessagePusherToken"] = "" config.OptionMap["TurnstileSiteKey"] = "" config.OptionMap["TurnstileSecretKey"] = "" - config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) - config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) - config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) - config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) - config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) + config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10) + config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10) + config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) + config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) + config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() @@ -79,6 +81,9 @@ func InitOptionMap() { func loadOptionsFromDatabase() { options, _ := AllOption() for _, option := range options { + if option.Key == "ModelRatio" { + option.Value = common.AddNewMissingRatio(option.Value) + } err := updateOptionMap(option.Key, option.Value) if err != nil { logger.SysError("failed to update option map: " + err.Error()) @@ -179,20 +184,24 @@ func updateOptionMap(key string, value string) (err error) { config.WeChatServerToken = value case "WeChatAccountQRCodeImageURL": config.WeChatAccountQRCodeImageURL = value + case "MessagePusherAddress": + config.MessagePusherAddress = value + case "MessagePusherToken": + config.MessagePusherToken = value case "TurnstileSiteKey": config.TurnstileSiteKey = value case "TurnstileSecretKey": config.TurnstileSecretKey = value case "QuotaForNewUser": - config.QuotaForNewUser, _ = strconv.Atoi(value) + config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64) case "QuotaForInviter": - config.QuotaForInviter, _ = strconv.Atoi(value) + config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64) case "QuotaForInvitee": - config.QuotaForInvitee, _ = strconv.Atoi(value) + config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64) case "QuotaRemindThreshold": - config.QuotaRemindThreshold, _ = strconv.Atoi(value) + config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64) case "PreConsumedQuota": - config.PreConsumedQuota, _ = strconv.Atoi(value) + config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64) case "RetryTimes": config.RetryTimes, _ = strconv.Atoi(value) case "ModelRatio": diff --git a/model/redemption.go b/model/redemption.go index 2c5a4141..e0ae68e2 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -14,7 +14,7 @@ type Redemption struct { Key string `json:"key" gorm:"type:char(32);uniqueIndex"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` - Quota int `json:"quota" gorm:"default:100"` + Quota int64 `json:"quota" gorm:"default:100"` CreatedTime int64 `json:"created_time" gorm:"bigint"` RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"` Count int `json:"count" gorm:"-:all"` // only for api request @@ -42,7 +42,7 @@ func GetRedemptionById(id int) (*Redemption, error) { return &redemption, err } -func Redeem(key string, userId int) (quota int, err error) { +func Redeem(key string, userId int) (quota int64, err error) { if key == "" { return 0, errors.New("未提供兑换码") } diff --git a/model/token.go b/model/token.go index d0a0648a..40d0eb8f 100644 --- a/model/token.go +++ b/model/token.go @@ -7,6 +7,7 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" "gorm.io/gorm" ) @@ -19,9 +20,9 @@ type Token struct { CreatedTime int64 `json:"created_time" gorm:"bigint"` AccessedTime int64 `json:"accessed_time" gorm:"bigint"` ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired - RemainQuota int `json:"remain_quota" gorm:"default:0"` + RemainQuota int64 `json:"remain_quota" gorm:"default:0"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` - UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota + UsedQuota int64 `json:"used_quota" gorm:"default:0"` // used quota } func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { @@ -137,7 +138,7 @@ func DeleteTokenById(id int, userId int) (err error) { return token.Delete() } -func IncreaseTokenQuota(id int, quota int) (err error) { +func IncreaseTokenQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -148,7 +149,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { return increaseTokenQuota(id, quota) } -func increaseTokenQuota(id int, quota int) (err error) { +func increaseTokenQuota(id int, quota int64) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota + ?", quota), @@ -159,7 +160,7 @@ func increaseTokenQuota(id int, quota int) (err error) { return err } -func DecreaseTokenQuota(id int, quota int) (err error) { +func DecreaseTokenQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -170,7 +171,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { return decreaseTokenQuota(id, quota) } -func decreaseTokenQuota(id int, quota int) (err error) { +func decreaseTokenQuota(id int, quota int64) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota - ?", quota), @@ -181,7 +182,7 @@ func decreaseTokenQuota(id int, quota int) (err error) { return err } -func PreConsumeTokenQuota(tokenId int, quota int) (err error) { +func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -213,7 +214,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { } if email != "" { topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) - err = common.SendEmail(prompt, email, + err = message.SendEmail(prompt, email, fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink)) if err != nil { logger.SysError("failed to send email" + err.Error()) @@ -231,7 +232,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { return err } -func PostConsumeTokenQuota(tokenId int, quota int) (err error) { +func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { token, err := GetTokenById(tokenId) if quota > 0 { err = DecreaseUserQuota(token.UserId, quota) diff --git a/model/user.go b/model/user.go index 6979c70b..e325394b 100644 --- a/model/user.go +++ b/model/user.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" @@ -25,8 +26,8 @@ type User struct { WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management - Quota int `json:"quota" gorm:"type:int;default:0"` - UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota + Quota int64 `json:"quota" gorm:"type:int;default:0"` + UsedQuota int64 `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number Group string `json:"group" gorm:"type:varchar(32);default:'default'"` AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` @@ -40,7 +41,7 @@ func GetMaxUserId() int { } func GetAllUsers(startIdx int, num int) (users []*User, err error) { - err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error + err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted).Find(&users).Error return users, err } @@ -123,6 +124,11 @@ func (user *User) Update(updatePassword bool) error { return err } } + if user.Status == common.UserStatusDisabled { + blacklist.BanUser(user.Id) + } else if user.Status == common.UserStatusEnabled { + blacklist.UnbanUser(user.Id) + } err = DB.Model(user).Updates(user).Error return err } @@ -131,7 +137,10 @@ func (user *User) Delete() error { if user.Id == 0 { return errors.New("id 为空!") } - err := DB.Delete(user).Error + blacklist.BanUser(user.Id) + user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID()) + user.Status = common.UserStatusDeleted + err := DB.Model(user).Updates(user).Error return err } @@ -265,12 +274,12 @@ func ValidateAccessToken(token string) (user *User) { return nil } -func GetUserQuota(id int) (quota int, err error) { +func GetUserQuota(id int) (quota int64, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error return quota, err } -func GetUserUsedQuota(id int) (quota int, err error) { +func GetUserUsedQuota(id int) (quota int64, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error return quota, err } @@ -290,7 +299,7 @@ func GetUserGroup(id int) (group string, err error) { return group, err } -func IncreaseUserQuota(id int, quota int) (err error) { +func IncreaseUserQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -301,12 +310,12 @@ func IncreaseUserQuota(id int, quota int) (err error) { return increaseUserQuota(id, quota) } -func increaseUserQuota(id int, quota int) (err error) { +func increaseUserQuota(id int, quota int64) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error return err } -func DecreaseUserQuota(id int, quota int) (err error) { +func DecreaseUserQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -317,7 +326,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { return decreaseUserQuota(id, quota) } -func decreaseUserQuota(id int, quota int) (err error) { +func decreaseUserQuota(id int, quota int64) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error return err } @@ -327,7 +336,7 @@ func GetRootUserEmail() (email string) { return email } -func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { +func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) { if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUsedQuota, id, quota) addNewRecord(BatchUpdateTypeRequestCount, id, 1) @@ -336,7 +345,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { updateUserUsedQuotaAndRequestCount(id, quota, 1) } -func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { +func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), @@ -348,7 +357,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { } } -func updateUserUsedQuota(id int, quota int) { +func updateUserUsedQuota(id int, quota int64) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), diff --git a/model/utils.go b/model/utils.go index d481973a..a55eb4b6 100644 --- a/model/utils.go +++ b/model/utils.go @@ -16,12 +16,12 @@ const ( BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock ) -var batchUpdateStores []map[int]int +var batchUpdateStores []map[int]int64 var batchUpdateLocks []sync.Mutex func init() { for i := 0; i < BatchUpdateTypeCount; i++ { - batchUpdateStores = append(batchUpdateStores, make(map[int]int)) + batchUpdateStores = append(batchUpdateStores, make(map[int]int64)) batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) } } @@ -35,7 +35,7 @@ func InitBatchUpdater() { }() } -func addNewRecord(type_ int, id int, value int) { +func addNewRecord(type_ int, id int, value int64) { batchUpdateLocks[type_].Lock() defer batchUpdateLocks[type_].Unlock() if _, ok := batchUpdateStores[type_][id]; !ok { @@ -50,7 +50,7 @@ func batchUpdate() { for i := 0; i < BatchUpdateTypeCount; i++ { batchUpdateLocks[i].Lock() store := batchUpdateStores[i] - batchUpdateStores[i] = make(map[int]int) + batchUpdateStores[i] = make(map[int]int64) batchUpdateLocks[i].Unlock() // TODO: maybe we can combine updates with same key? for key, value := range store { @@ -68,7 +68,7 @@ func batchUpdate() { case BatchUpdateTypeUsedQuota: updateUserUsedQuota(key, value) case BatchUpdateTypeRequestCount: - updateUserRequestCount(key, value) + updateUserRequestCount(key, int(value)) case BatchUpdateTypeChannelUsedQuota: updateChannelUsedQuota(key, value) } diff --git a/monitor/channel.go b/monitor/channel.go new file mode 100644 index 00000000..597ab11a --- /dev/null +++ b/monitor/channel.go @@ -0,0 +1,55 @@ +package monitor + +import ( + "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" + "github.com/songquanpeng/one-api/model" +) + +func notifyRootUser(subject string, content string) { + if config.MessagePusherAddress != "" { + err := message.SendMessage(subject, content, content) + if err != nil { + logger.SysError(fmt.Sprintf("failed to send message: %s", err.Error())) + } else { + return + } + } + if config.RootUserEmail == "" { + config.RootUserEmail = model.GetRootUserEmail() + } + err := message.SendEmail(subject, config.RootUserEmail, content) + if err != nil { + logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + } +} + +// DisableChannel disable & notify +func DisableChannel(channelId int, channelName string, reason string) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + notifyRootUser(subject, content) +} + +func MetricDisableChannel(channelId int, successRate float64) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) + subject := fmt.Sprintf("通道 #%d 已被禁用", channelId) + content := fmt.Sprintf("该渠道在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", + config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100) + notifyRootUser(subject, content) +} + +// EnableChannel enable & notify +func EnableChannel(channelId int, channelName string) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) + logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) + subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + notifyRootUser(subject, content) +} diff --git a/monitor/metric.go b/monitor/metric.go new file mode 100644 index 00000000..98bc546e --- /dev/null +++ b/monitor/metric.go @@ -0,0 +1,79 @@ +package monitor + +import ( + "github.com/songquanpeng/one-api/common/config" +) + +var store = make(map[int][]bool) +var metricSuccessChan = make(chan int, config.MetricSuccessChanSize) +var metricFailChan = make(chan int, config.MetricFailChanSize) + +func consumeSuccess(channelId int) { + if len(store[channelId]) > config.MetricQueueSize { + store[channelId] = store[channelId][1:] + } + store[channelId] = append(store[channelId], true) +} + +func consumeFail(channelId int) (bool, float64) { + if len(store[channelId]) > config.MetricQueueSize { + store[channelId] = store[channelId][1:] + } + store[channelId] = append(store[channelId], false) + successCount := 0 + for _, success := range store[channelId] { + if success { + successCount++ + } + } + successRate := float64(successCount) / float64(len(store[channelId])) + if len(store[channelId]) < config.MetricQueueSize { + return false, successRate + } + if successRate < config.MetricSuccessRateThreshold { + store[channelId] = make([]bool, 0) + return true, successRate + } + return false, successRate +} + +func metricSuccessConsumer() { + for { + select { + case channelId := <-metricSuccessChan: + consumeSuccess(channelId) + } + } +} + +func metricFailConsumer() { + for { + select { + case channelId := <-metricFailChan: + disable, successRate := consumeFail(channelId) + if disable { + go MetricDisableChannel(channelId, successRate) + } + } + } +} + +func init() { + if config.EnableMetric { + go metricSuccessConsumer() + go metricFailConsumer() + } +} + +func Emit(channelId int, success bool) { + if !config.EnableMetric { + return + } + go func() { + if success { + metricSuccessChan <- channelId + } else { + metricFailChan <- channelId + } + }() +} diff --git a/pull_request_template.md b/pull_request_template.md index a313004f..c6301343 100644 --- a/pull_request_template.md +++ b/pull_request_template.md @@ -1,9 +1,10 @@ [//]: # (请按照以下格式关联 issue) -[//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR,谢谢) +[//]: # (请在提交 PR 前确认所提交的功能可用,需要附上截图,谢谢) [//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解) [//]: # (开发者交流群:910657413) [//]: # (请在提交 PR 之前删除上面的注释) close #issue_number -我已确认该 PR 已自测通过,相关截图如下: \ No newline at end of file +我已确认该 PR 已自测通过,相关截图如下: +(此处放上测试通过的截图,如果不涉及前端改动或从 UI 上无法看出,请放终端启动成功的截图) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 6c6f433e..6a3245ad 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -32,6 +32,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { channel.SetupCommonRequestHeader(c, req, meta) + if meta.IsStream { + req.Header.Set("Accept", "text/event-stream") + } req.Header.Set("Authorization", "Bearer "+meta.APIKey) if meta.IsStream { req.Header.Set("X-DashScope-SSE", "enable") diff --git a/relay/channel/anthropic/adaptor.go b/relay/channel/anthropic/adaptor.go index 4b873715..a165b35c 100644 --- a/relay/channel/anthropic/adaptor.go +++ b/relay/channel/anthropic/adaptor.go @@ -5,7 +5,6 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -20,7 +19,7 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { } func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil + return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { @@ -31,6 +30,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut anthropicVersion = "2023-06-01" } req.Header.Set("anthropic-version", anthropicVersion) + req.Header.Set("anthropic-beta", "messages-2023-12-15") return nil } @@ -47,9 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { - var responseText string - err, responseText = StreamHandler(c, resp) - usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + err, usage = StreamHandler(c, resp) } else { err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) } @@ -61,5 +59,5 @@ func (a *Adaptor) GetModelList() []string { } func (a *Adaptor) GetChannelName() string { - return "authropic" + return "anthropic" } diff --git a/relay/channel/anthropic/constants.go b/relay/channel/anthropic/constants.go index b98c15c2..cadcedc8 100644 --- a/relay/channel/anthropic/constants.go +++ b/relay/channel/anthropic/constants.go @@ -1,5 +1,8 @@ package anthropic var ModelList = []string{ - "claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", + "claude-instant-1.2", "claude-2.0", "claude-2.1", + "claude-3-haiku-20240307", + "claude-3-sonnet-20240229", + "claude-3-opus-20240229", } diff --git a/relay/channel/anthropic/main.go b/relay/channel/anthropic/main.go index e2c575fa..3eeb0b2c 100644 --- a/relay/channel/anthropic/main.go +++ b/relay/channel/anthropic/main.go @@ -7,6 +7,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/model" @@ -15,73 +16,135 @@ import ( "strings" ) -func stopReasonClaude2OpenAI(reason string) string { - switch reason { +func stopReasonClaude2OpenAI(reason *string) string { + if reason == nil { + return "" + } + switch *reason { + case "end_turn": + return "stop" case "stop_sequence": return "stop" case "max_tokens": return "length" default: - return reason + return *reason } } func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { claudeRequest := Request{ - Model: textRequest.Model, - Prompt: "", - MaxTokensToSample: textRequest.MaxTokens, - StopSequences: nil, - Temperature: textRequest.Temperature, - TopP: textRequest.TopP, - Stream: textRequest.Stream, + Model: textRequest.Model, + MaxTokens: textRequest.MaxTokens, + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + Stream: textRequest.Stream, } - if claudeRequest.MaxTokensToSample == 0 { - claudeRequest.MaxTokensToSample = 1000000 + if claudeRequest.MaxTokens == 0 { + claudeRequest.MaxTokens = 4096 + } + // legacy model name mapping + if claudeRequest.Model == "claude-instant-1" { + claudeRequest.Model = "claude-instant-1.1" + } else if claudeRequest.Model == "claude-2" { + claudeRequest.Model = "claude-2.1" } - prompt := "" for _, message := range textRequest.Messages { - if message.Role == "user" { - prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) - } else if message.Role == "assistant" { - prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) - } else if message.Role == "system" { - if prompt == "" { - prompt = message.StringContent() - } + if message.Role == "system" && claudeRequest.System == "" { + claudeRequest.System = message.StringContent() + continue } + claudeMessage := Message{ + Role: message.Role, + } + var content Content + if message.IsStringContent() { + content.Type = "text" + content.Text = message.StringContent() + claudeMessage.Content = append(claudeMessage.Content, content) + claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) + continue + } + var contents []Content + openaiContent := message.ParseContent() + for _, part := range openaiContent { + var content Content + if part.Type == model.ContentTypeText { + content.Type = "text" + content.Text = part.Text + } else if part.Type == model.ContentTypeImageURL { + content.Type = "image" + content.Source = &ImageSource{ + Type: "base64", + } + mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) + content.Source.MediaType = mimeType + content.Source.Data = data + } + contents = append(contents, content) + } + claudeMessage.Content = contents + claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) } - prompt += "\n\nAssistant:" - claudeRequest.Prompt = prompt return &claudeRequest } -func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse { +// https://docs.anthropic.com/claude/reference/messages-streaming +func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { + var response *Response + var responseText string + var stopReason string + switch claudeResponse.Type { + case "message_start": + return nil, claudeResponse.Message + case "content_block_start": + if claudeResponse.ContentBlock != nil { + responseText = claudeResponse.ContentBlock.Text + } + case "content_block_delta": + if claudeResponse.Delta != nil { + responseText = claudeResponse.Delta.Text + } + case "message_delta": + if claudeResponse.Usage != nil { + response = &Response{ + Usage: *claudeResponse.Usage, + } + } + if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil { + stopReason = *claudeResponse.Delta.StopReason + } + } var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = claudeResponse.Completion - finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) + choice.Delta.Content = responseText + choice.Delta.Role = "assistant" + finishReason := stopReasonClaude2OpenAI(&stopReason) if finishReason != "null" { choice.FinishReason = &finishReason } - var response openai.ChatCompletionsStreamResponse - response.Object = "chat.completion.chunk" - response.Model = claudeResponse.Model - response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} - return &response + var openaiResponse openai.ChatCompletionsStreamResponse + openaiResponse.Object = "chat.completion.chunk" + openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + return &openaiResponse, response } func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { + var responseText string + if len(claudeResponse.Content) > 0 { + responseText = claudeResponse.Content[0].Text + } choice := openai.TextResponseChoice{ Index: 0, Message: model.Message{ Role: "assistant", - Content: strings.TrimPrefix(claudeResponse.Completion, " "), + Content: responseText, Name: nil, }, FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), } fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id), + Model: claudeResponse.Model, Object: "chat.completion", Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, @@ -89,17 +152,15 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { return &fullTextResponse } -func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { - responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { createdTime := helper.GetTimestamp() scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil } - if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { - return i + 4, data[0:i], nil + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil } if atEOF { return len(data), data, nil @@ -111,29 +172,45 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC go func() { for scanner.Scan() { data := scanner.Text() - if !strings.HasPrefix(data, "event: completion") { + if len(data) < 6 { continue } - data = strings.TrimPrefix(data, "event: completion\r\ndata: ") + if !strings.HasPrefix(data, "data: ") { + continue + } + data = strings.TrimPrefix(data, "data: ") dataChan <- data } stopChan <- true }() common.SetEventStreamHeaders(c) + var usage model.Usage + var modelName string + var id string c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: // some implementations may add \r at the end of data data = strings.TrimSuffix(data, "\r") - var claudeResponse Response + var claudeResponse StreamResponse err := json.Unmarshal([]byte(data), &claudeResponse) if err != nil { logger.SysError("error unmarshalling stream response: " + err.Error()) return true } - responseText += claudeResponse.Completion - response := streamResponseClaude2OpenAI(&claudeResponse) - response.Id = responseId + response, meta := streamResponseClaude2OpenAI(&claudeResponse) + if meta != nil { + usage.PromptTokens += meta.Usage.InputTokens + usage.CompletionTokens += meta.Usage.OutputTokens + modelName = meta.Model + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + return true + } + if response == nil { + return true + } + response.Id = id + response.Model = modelName response.Created = createdTime jsonStr, err := json.Marshal(response) if err != nil { @@ -147,11 +224,8 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC return false } }) - err := resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" - } - return nil, responseText + _ = resp.Body.Close() + return nil, &usage } func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { @@ -181,11 +255,10 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st } fullTextResponse := responseClaude2OpenAI(&claudeResponse) fullTextResponse.Model = modelName - completionTokens := openai.CountTokenText(claudeResponse.Completion, modelName) usage := model.Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, + PromptTokens: claudeResponse.Usage.InputTokens, + CompletionTokens: claudeResponse.Usage.OutputTokens, + TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, } fullTextResponse.Usage = usage jsonResponse, err := json.Marshal(fullTextResponse) diff --git a/relay/channel/anthropic/model.go b/relay/channel/anthropic/model.go index 70fc9430..32b187cd 100644 --- a/relay/channel/anthropic/model.go +++ b/relay/channel/anthropic/model.go @@ -1,19 +1,44 @@ package anthropic +// https://docs.anthropic.com/claude/reference/messages_post + type Metadata struct { UserId string `json:"user_id"` } +type ImageSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +type Content struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Source *ImageSource `json:"source,omitempty"` +} + +type Message struct { + Role string `json:"role"` + Content []Content `json:"content"` +} + type Request struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - MaxTokensToSample int `json:"max_tokens_to_sample"` - StopSequences []string `json:"stop_sequences,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` + Model string `json:"model"` + Messages []Message `json:"messages"` + System string `json:"system,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` //Metadata `json:"metadata,omitempty"` - Stream bool `json:"stream,omitempty"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` } type Error struct { @@ -22,8 +47,29 @@ type Error struct { } type Response struct { - Completion string `json:"completion"` - StopReason string `json:"stop_reason"` - Model string `json:"model"` - Error Error `json:"error"` + Id string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []Content `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` + Usage Usage `json:"usage"` + Error Error `json:"error"` +} + +type Delta struct { + Type string `json:"type"` + Text string `json:"text"` + StopReason *string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` +} + +type StreamResponse struct { + Type string `json:"type"` + Message *Response `json:"message"` + Index int `json:"index"` + ContentBlock *Content `json:"content_block"` + Delta *Delta `json:"delta"` + Usage *Usage `json:"usage"` } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 066a8107..2d2e24f6 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -2,13 +2,16 @@ package baidu import ( "errors" + "fmt" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" ) type Adaptor struct { @@ -20,25 +23,45 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t - var fullRequestURL string - switch meta.ActualModelName { - case "ERNIE-Bot-4": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" - case "ERNIE-Bot-8K": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k" - case "ERNIE-Bot": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" - case "ERNIE-Speed": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed" - case "ERNIE-Bot-turbo": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" - case "BLOOMZ-7B": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" - case "Embedding-V1": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" - default: - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + meta.ActualModelName + suffix := "chat/" + if strings.HasPrefix(meta.ActualModelName, "Embedding") { + suffix = "embeddings/" } + if strings.HasPrefix(meta.ActualModelName, "bge-large") { + suffix = "embeddings/" + } + if strings.HasPrefix(meta.ActualModelName, "tao-8k") { + suffix = "embeddings/" + } + switch meta.ActualModelName { + case "ERNIE-4.0": + suffix += "completions_pro" + case "ERNIE-Bot-4": + suffix += "completions_pro" + case "ERNIE-3.5-8K": + suffix += "completions" + case "ERNIE-Bot-8K": + suffix += "ernie_bot_8k" + case "ERNIE-Bot": + suffix += "completions" + case "ERNIE-Speed": + suffix += "ernie_speed" + case "ERNIE-Bot-turbo": + suffix += "eb-instant" + case "BLOOMZ-7B": + suffix += "bloomz_7b1" + case "Embedding-V1": + suffix += "embedding-v1" + case "bge-large-zh": + suffix += "bge_large_zh" + case "bge-large-en": + suffix += "bge_large_en" + case "tao-8k": + suffix += "tao_8k" + default: + suffix += meta.ActualModelName + } + fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) var accessToken string var err error if accessToken, err = GetAccessToken(meta.APIKey); err != nil { diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go index 0fa8f2d6..45a4e901 100644 --- a/relay/channel/baidu/constants.go +++ b/relay/channel/baidu/constants.go @@ -7,4 +7,7 @@ var ModelList = []string{ "ERNIE-Speed", "ERNIE-Bot-turbo", "Embedding-V1", + "bge-large-zh", + "bge-large-en", + "tao-8k", } diff --git a/relay/channel/baidu/main.go b/relay/channel/baidu/main.go index 4f2b13fc..9ca9e47d 100644 --- a/relay/channel/baidu/main.go +++ b/relay/channel/baidu/main.go @@ -32,9 +32,16 @@ type Message struct { } type ChatRequest struct { - Messages []Message `json:"messages"` - Stream bool `json:"stream"` - UserId string `json:"user_id,omitempty"` + Messages []Message `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + PenaltyScore float64 `json:"penalty_score,omitempty"` + Stream bool `json:"stream,omitempty"` + System string `json:"system,omitempty"` + DisableSearch bool `json:"disable_search,omitempty"` + EnableCitation bool `json:"enable_citation,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + UserId string `json:"user_id,omitempty"` } type Error struct { @@ -45,28 +52,28 @@ type Error struct { var baiduTokenStore sync.Map func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { - messages := make([]Message, 0, len(request.Messages)) + baiduRequest := ChatRequest{ + Messages: make([]Message, 0, len(request.Messages)), + Temperature: request.Temperature, + TopP: request.TopP, + PenaltyScore: request.FrequencyPenalty, + Stream: request.Stream, + DisableSearch: false, + EnableCitation: false, + MaxOutputTokens: request.MaxTokens, + UserId: request.User, + } for _, message := range request.Messages { if message.Role == "system" { - messages = append(messages, Message{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, Message{ - Role: "assistant", - Content: "Okay", - }) + baiduRequest.System = message.StringContent() } else { - messages = append(messages, Message{ + baiduRequest.Messages = append(baiduRequest.Messages, Message{ Role: message.Role, Content: message.StringContent(), }) } } - return &ChatRequest{ - Messages: messages, - Stream: request.Stream, - } + return &baiduRequest } func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { diff --git a/relay/channel/groq/constants.go b/relay/channel/groq/constants.go new file mode 100644 index 00000000..fc9a9ebd --- /dev/null +++ b/relay/channel/groq/constants.go @@ -0,0 +1,10 @@ +package groq + +// https://console.groq.com/docs/models + +var ModelList = []string{ + "gemma-7b-it", + "llama2-7b-2048", + "llama2-70b-4096", + "mixtral-8x7b-32768", +} diff --git a/relay/channel/lingyiwanwu/constants.go b/relay/channel/lingyiwanwu/constants.go new file mode 100644 index 00000000..30000e9d --- /dev/null +++ b/relay/channel/lingyiwanwu/constants.go @@ -0,0 +1,9 @@ +package lingyiwanwu + +// https://platform.lingyiwanwu.com/docs + +var ModelList = []string{ + "yi-34b-chat-0205", + "yi-34b-chat-200k", + "yi-vl-plus", +} diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go new file mode 100644 index 00000000..06c66101 --- /dev/null +++ b/relay/channel/ollama/adaptor.go @@ -0,0 +1,65 @@ +package ollama + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" + "net/http" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *util.RelayMeta) { + +} + +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + // https://github.com/ollama/ollama/blob/main/docs/api.md + fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channel.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case constant.RelayModeEmbeddings: + return nil, errors.New("not supported") + default: + return ConvertRequest(*request), nil + } +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + return channel.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + err, usage = Handler(c, resp) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "ollama" +} diff --git a/relay/channel/ollama/constants.go b/relay/channel/ollama/constants.go new file mode 100644 index 00000000..32f82b2a --- /dev/null +++ b/relay/channel/ollama/constants.go @@ -0,0 +1,5 @@ +package ollama + +var ModelList = []string{ + "qwen:0.5b-chat", +} diff --git a/relay/channel/ollama/main.go b/relay/channel/ollama/main.go new file mode 100644 index 00000000..7ec646a3 --- /dev/null +++ b/relay/channel/ollama/main.go @@ -0,0 +1,178 @@ +package ollama + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" +) + +func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { + ollamaRequest := ChatRequest{ + Model: request.Model, + Options: &Options{ + Seed: int(request.Seed), + Temperature: request.Temperature, + TopP: request.TopP, + FrequencyPenalty: request.FrequencyPenalty, + PresencePenalty: request.PresencePenalty, + }, + Stream: request.Stream, + } + for _, message := range request.Messages { + ollamaRequest.Messages = append(ollamaRequest.Messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) + } + return &ollamaRequest +} + +func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse { + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: response.Message.Role, + Content: response.Message.Content, + }, + } + if response.Done { + choice.FinishReason = "stop" + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + Usage: model.Usage{ + PromptTokens: response.PromptEvalCount, + CompletionTokens: response.EvalCount, + TotalTokens: response.PromptEvalCount + response.EvalCount, + }, + } + return &fullTextResponse +} + +func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Role = ollamaResponse.Message.Role + choice.Delta.Content = ollamaResponse.Message.Content + if ollamaResponse.Done { + choice.FinishReason = &constant.StopFinishReason + } + response := openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: ollamaResponse.Model, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage model.Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "}\n"); i >= 0 { + return i + 2, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := strings.TrimPrefix(scanner.Text(), "}") + dataChan <- data + "}" + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var ollamaResponse ChatResponse + err := json.Unmarshal([]byte(data), &ollamaResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if ollamaResponse.EvalCount != 0 { + usage.PromptTokens = ollamaResponse.PromptEvalCount + usage.CompletionTokens = ollamaResponse.EvalCount + usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount + } + response := streamResponseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + ctx := context.TODO() + var ollamaResponse ChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + logger.Debugf(ctx, "ollama response: %s", string(responseBody)) + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &ollamaResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if ollamaResponse.Error != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: ollamaResponse.Error, + Type: "ollama_error", + Param: "", + Code: "ollama_error", + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/relay/channel/ollama/model.go b/relay/channel/ollama/model.go new file mode 100644 index 00000000..a8ef1ffc --- /dev/null +++ b/relay/channel/ollama/model.go @@ -0,0 +1,37 @@ +package ollama + +type Options struct { + Seed int `json:"seed,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` +} + +type Message struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Stream bool `json:"stream"` + Options *Options `json:"options,omitempty"` +} + +type ChatResponse struct { + Model string `json:"model,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + Message Message `json:"message,omitempty"` + Response string `json:"response,omitempty"` // for stream response + Done bool `json:"done,omitempty"` + TotalDuration int `json:"total_duration,omitempty"` + LoadDuration int `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration int `json:"eval_duration,omitempty"` + Error string `json:"error,omitempty"` +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 5a04a768..47594030 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -6,11 +6,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/ai360" - "github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/minimax" - "github.com/songquanpeng/one-api/relay/channel/mistral" - "github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -86,37 +82,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel } func (a *Adaptor) GetModelList() []string { - switch a.ChannelType { - case common.ChannelType360: - return ai360.ModelList - case common.ChannelTypeMoonshot: - return moonshot.ModelList - case common.ChannelTypeBaichuan: - return baichuan.ModelList - case common.ChannelTypeMinimax: - return minimax.ModelList - case common.ChannelTypeMistral: - return mistral.ModelList - default: - return ModelList - } + _, modelList := GetCompatibleChannelMeta(a.ChannelType) + return modelList } func (a *Adaptor) GetChannelName() string { - switch a.ChannelType { - case common.ChannelTypeAzure: - return "azure" - case common.ChannelType360: - return "360" - case common.ChannelTypeMoonshot: - return "moonshot" - case common.ChannelTypeBaichuan: - return "baichuan" - case common.ChannelTypeMinimax: - return "minimax" - case common.ChannelTypeMistral: - return "mistralai" - default: - return "openai" - } + channelName, _ := GetCompatibleChannelMeta(a.ChannelType) + return channelName } diff --git a/relay/channel/openai/compatible.go b/relay/channel/openai/compatible.go new file mode 100644 index 00000000..e4951a34 --- /dev/null +++ b/relay/channel/openai/compatible.go @@ -0,0 +1,46 @@ +package openai + +import ( + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/relay/channel/ai360" + "github.com/songquanpeng/one-api/relay/channel/baichuan" + "github.com/songquanpeng/one-api/relay/channel/groq" + "github.com/songquanpeng/one-api/relay/channel/lingyiwanwu" + "github.com/songquanpeng/one-api/relay/channel/minimax" + "github.com/songquanpeng/one-api/relay/channel/mistral" + "github.com/songquanpeng/one-api/relay/channel/moonshot" +) + +var CompatibleChannels = []int{ + common.ChannelTypeAzure, + common.ChannelType360, + common.ChannelTypeMoonshot, + common.ChannelTypeBaichuan, + common.ChannelTypeMinimax, + common.ChannelTypeMistral, + common.ChannelTypeGroq, + common.ChannelTypeLingYiWanWu, +} + +func GetCompatibleChannelMeta(channelType int) (string, []string) { + switch channelType { + case common.ChannelTypeAzure: + return "azure", ModelList + case common.ChannelType360: + return "360", ai360.ModelList + case common.ChannelTypeMoonshot: + return "moonshot", moonshot.ModelList + case common.ChannelTypeBaichuan: + return "baichuan", baichuan.ModelList + case common.ChannelTypeMinimax: + return "minimax", minimax.ModelList + case common.ChannelTypeMistral: + return "mistralai", mistral.ModelList + case common.ChannelTypeGroq: + return "groq", groq.ModelList + case common.ChannelTypeLingYiWanWu: + return "lingyiwanwu", lingyiwanwu.ModelList + default: + return "openai", ModelList + } +} diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go index fa26651b..cfdc0bfd 100644 --- a/relay/channel/tencent/main.go +++ b/relay/channel/tencent/main.go @@ -28,17 +28,6 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] - if message.Role == "system" { - messages = append(messages, Message{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, Message{ - Role: "assistant", - Content: "Okay", - }) - continue - } messages = append(messages, Message{ Content: message.StringContent(), Role: message.Role, diff --git a/relay/channel/xunfei/main.go b/relay/channel/xunfei/main.go index 620e808f..f89aea2b 100644 --- a/relay/channel/xunfei/main.go +++ b/relay/channel/xunfei/main.go @@ -27,21 +27,10 @@ import ( func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { - if message.Role == "system" { - messages = append(messages, Message{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, Message{ - Role: "assistant", - Content: "Okay", - }) - } else { - messages = append(messages, Message{ - Role: message.Role, - Content: message.StringContent(), - }) - } + messages = append(messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) } xunfeiRequest := ChatRequest{} xunfeiRequest.Header.AppId = xunfeiAppId diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 90cc79d3..0ca23d59 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -9,6 +9,7 @@ import ( "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" + "math" "net/http" "strings" ) @@ -52,9 +53,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } - if request.TopP >= 1 { - request.TopP = 0.99 - } + // TopP (0.0, 1.0) + request.TopP = math.Min(0.99, request.TopP) + request.TopP = math.Max(0.01, request.TopP) + + // Temperature (0.0, 1.0) + request.Temperature = math.Min(0.99, request.Temperature) + request.Temperature = math.Max(0.01, request.Temperature) a.SetVersionByModeName(request.Model) if a.APIVersion == "v4" { return request, nil diff --git a/relay/channel/zhipu/main.go b/relay/channel/zhipu/main.go index 7c3e83f3..a46fd537 100644 --- a/relay/channel/zhipu/main.go +++ b/relay/channel/zhipu/main.go @@ -76,21 +76,10 @@ func GetToken(apikey string) string { func ConvertRequest(request model.GeneralOpenAIRequest) *Request { messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { - if message.Role == "system" { - messages = append(messages, Message{ - Role: "system", - Content: message.StringContent(), - }) - messages = append(messages, Message{ - Role: "user", - Content: "Okay", - }) - } else { - messages = append(messages, Message{ - Role: message.Role, - Content: message.StringContent(), - }) - } + messages = append(messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) } return &Request{ Prompt: messages, diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index d2184dac..b249f6a2 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -15,6 +15,7 @@ const ( APITypeAIProxyLibrary APITypeTencent APITypeGemini + APITypeOllama APITypeDummy // this one is only for count, do not add any channel after this ) @@ -40,6 +41,8 @@ func ChannelType2APIType(channelType int) int { apiType = APITypeTencent case common.ChannelTypeGemini: apiType = APITypeGemini + case common.ChannelTypeOllama: + apiType = APITypeOllama } return apiType } diff --git a/relay/controller/audio.go b/relay/controller/audio.go index ee8771c9..155954d2 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -22,6 +22,7 @@ import ( ) func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { + ctx := c.Request.Context() audioModel := "whisper-1" tokenId := c.GetInt("token_id") @@ -49,16 +50,16 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus modelRatio := common.GetModelRatio(audioModel) groupRatio := common.GetGroupRatio(group) ratio := modelRatio * groupRatio - var quota int - var preConsumedQuota int + var quota int64 + var preConsumedQuota int64 switch relayMode { case constant.RelayModeAudioSpeech: - preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) + preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota default: - preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio) + preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio) } - userQuota, err := model.CacheGetUserQuota(userId) + userQuota, err := model.CacheGetUserQuota(ctx, userId) if err != nil { return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } @@ -183,7 +184,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) } - quota = openai.CountTokenText(text, audioModel) + quota = int64(openai.CountTokenText(text, audioModel)) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 89fc69ce..600a8d65 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -107,18 +107,18 @@ func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int return 0 } -func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int { +func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 { preConsumedTokens := config.PreConsumedQuota if textRequest.MaxTokens != 0 { - preConsumedTokens = promptTokens + textRequest.MaxTokens + preConsumedTokens = int64(promptTokens) + int64(textRequest.MaxTokens) } - return int(float64(preConsumedTokens) * ratio) + return int64(float64(preConsumedTokens) * ratio) } -func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) { +func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) { preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) - userQuota, err := model.CacheGetUserQuota(meta.UserId) + userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) if err != nil { return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } @@ -144,16 +144,16 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR return preConsumedQuota, nil } -func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) { +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return } - quota := 0 + var quota int64 completionRatio := common.GetCompletionRatio(textRequest.Model) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens - quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) + quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 { quota = 1 } @@ -168,7 +168,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R if err != nil { logger.Error(ctx, "error consuming token remain quota: "+err.Error()) } - err = model.CacheUpdateUserQuota(meta.UserId) + err = model.CacheUpdateUserQuota(ctx, meta.UserId) if err != nil { logger.Error(ctx, "error update user quota cache: "+err.Error()) } diff --git a/relay/controller/image.go b/relay/controller/image.go index 3ce3809b..20ea0a4c 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -79,9 +79,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus modelRatio := common.GetModelRatio(imageRequest.Model) groupRatio := common.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio - userQuota, err := model.CacheGetUserQuota(meta.UserId) + userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) - quota := int(ratio*imageCostRatio*1000) * imageRequest.N + quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) if userQuota-quota < 0 { return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) @@ -125,7 +125,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) } - err = model.CacheUpdateUserQuota(meta.UserId) + err = model.CacheUpdateUserQuota(ctx, meta.UserId) if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } diff --git a/relay/controller/text.go b/relay/controller/text.go index 59c5f637..ba008713 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -74,6 +74,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { if err != nil { return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) } + logger.Debugf(ctx, "converted request: \n%s", string(jsonData)) requestBody = bytes.NewBuffer(jsonData) } @@ -83,11 +84,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - if resp.StatusCode != http.StatusOK { + errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json") + if errorHappened { util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) return util.RelayErrorHandler(resp) } + meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") // do response usage, respErr := adaptor.DoResponse(c, resp, meta) diff --git a/relay/helper/main.go b/relay/helper/main.go index c2b6e6af..e7342329 100644 --- a/relay/helper/main.go +++ b/relay/helper/main.go @@ -7,6 +7,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel/anthropic" "github.com/songquanpeng/one-api/relay/channel/baidu" "github.com/songquanpeng/one-api/relay/channel/gemini" + "github.com/songquanpeng/one-api/relay/channel/ollama" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/palm" "github.com/songquanpeng/one-api/relay/channel/tencent" @@ -37,6 +38,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &xunfei.Adaptor{} case constant.APITypeZhipu: return &zhipu.Adaptor{} + case constant.APITypeOllama: + return &ollama.Adaptor{} } return nil } diff --git a/relay/util/billing.go b/relay/util/billing.go index 1e2b09ea..495d011e 100644 --- a/relay/util/billing.go +++ b/relay/util/billing.go @@ -6,7 +6,7 @@ import ( "github.com/songquanpeng/one-api/model" ) -func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) { +func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { if preConsumedQuota != 0 { go func(ctx context.Context) { // return pre-consumed quota diff --git a/relay/util/common.go b/relay/util/common.go index 6d993378..535ef680 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -27,7 +27,23 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { if statusCode == http.StatusUnauthorized { return true } - if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + switch err.Type { + case "insufficient_quota": + return true + // https://docs.anthropic.com/claude/reference/errors + case "authentication_error": + return true + case "permission_error": + return true + case "forbidden": + return true + } + if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + return true + } + if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic + return true + } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { return true } return false @@ -101,6 +117,9 @@ func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.Err if err != nil { return } + if config.DebugEnabled { + logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody))) + } err = resp.Body.Close() if err != nil { return @@ -136,20 +155,20 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin return fullRequestURL } -func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { +func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { // quotaDelta is remaining quota to be consumed err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) } - err = model.CacheUpdateUserQuota(userId) + err = model.CacheUpdateUserQuota(ctx, userId) if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } // totalQuota is total quota consumed if totalQuota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) + model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) model.UpdateChannelUsedQuota(channelId, totalQuota) } diff --git a/router/api-router.go b/router/api-router.go index 6d143da7..5b755ede 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -14,6 +14,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.Use(middleware.GlobalAPIRateLimit()) { apiRouter.GET("/status", controller.GetStatus) + apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels) apiRouter.GET("/notice", controller.GetNotice) apiRouter.GET("/about", controller.GetAbout) apiRouter.GET("/home_page_content", controller.GetHomePageContent) @@ -69,7 +70,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.GET("/search", controller.SearchChannels) channelRoute.GET("/models", controller.ListModels) channelRoute.GET("/:id", controller.GetChannel) - channelRoute.GET("/test", controller.TestAllChannels) + channelRoute.GET("/test", controller.TestChannels) channelRoute.GET("/test/:id", controller.TestChannel) channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) diff --git a/router/dashboard.go b/router/dashboard.go index 0b539d44..5952d698 100644 --- a/router/dashboard.go +++ b/router/dashboard.go @@ -9,6 +9,7 @@ import ( func SetDashboardRouter(router *gin.Engine) { apiRouter := router.Group("/") + apiRouter.Use(middleware.CORS()) apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) apiRouter.Use(middleware.GlobalAPIRateLimit()) apiRouter.Use(middleware.TokenAuth()) diff --git a/web/README.md b/web/README.md index 86486085..59d91424 100644 --- a/web/README.md +++ b/web/README.md @@ -33,6 +33,12 @@ |![image](https://github.com/songquanpeng/one-api/assets/42402987/fb2b1c64-ef24-4027-9b80-0cd9d945a47f)|![image](https://github.com/songquanpeng/one-api/assets/42402987/b6b649ec-2888-4324-8b2d-d5e11554eed6)| |![image](https://github.com/songquanpeng/one-api/assets/42402987/6d3b22e0-436b-4e26-8911-bcc993c6a2bd)|![image](https://github.com/songquanpeng/one-api/assets/42402987/eef1e224-7245-44d7-804e-9d1c8fa3f29c)| +### 主题:air +由 [Calon](https://github.com/Calcium-Ion) 开发。 +|![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/1ddb274b-a715-4e81-858b-857d520b6ff4)|![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/163b0b8e-1f73-49cb-b632-3dcb986b56d5)| +|:---:|:---:| + + #### 开发说明 请查看 [web/berry/README.md](https://github.com/songquanpeng/one-api/tree/main/web/berry/README.md) diff --git a/web/THEMES b/web/THEMES index 6b0157cb..149e8698 100644 --- a/web/THEMES +++ b/web/THEMES @@ -1,2 +1,3 @@ default berry +air diff --git a/web/air/.gitignore b/web/air/.gitignore new file mode 100644 index 00000000..2b5bba76 --- /dev/null +++ b/web/air/.gitignore @@ -0,0 +1,26 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.js + +# testing +/coverage + +# production +/build + +# misc +.DS_Store +.env.local +.env.development.local +.env.test.local +.env.production.local + +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.idea +package-lock.json +yarn.lock \ No newline at end of file diff --git a/web/air/README.md b/web/air/README.md new file mode 100644 index 00000000..1b1031a3 --- /dev/null +++ b/web/air/README.md @@ -0,0 +1,21 @@ +# React Template + +## Basic Usages + +```shell +# Runs the app in the development mode +npm start + +# Builds the app for production to the `build` folder +npm run build +``` + +If you want to change the default server, please set `REACT_APP_SERVER` environment variables before build, +for example: `REACT_APP_SERVER=http://your.domain.com`. + +Before you start editing, make sure your `Actions on Save` options have `Optimize imports` & `Run Prettier` enabled. + +## Reference + +1. https://github.com/OIerDb-ng/OIerDb +2. https://github.com/cornflourblue/react-hooks-redux-registration-login-example \ No newline at end of file diff --git a/web/air/package.json b/web/air/package.json new file mode 100644 index 00000000..3bdf3952 --- /dev/null +++ b/web/air/package.json @@ -0,0 +1,60 @@ +{ + "name": "react-template", + "version": "0.1.0", + "private": true, + "dependencies": { + "@douyinfe/semi-icons": "^2.46.1", + "@douyinfe/semi-ui": "^2.46.1", + "@visactor/react-vchart": "~1.8.8", + "@visactor/vchart": "~1.8.8", + "@visactor/vchart-semi-theme": "~1.8.8", + "axios": "^0.27.2", + "history": "^5.3.0", + "marked": "^4.1.1", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "react-dropzone": "^14.2.3", + "react-fireworks": "^1.0.4", + "react-router-dom": "^6.3.0", + "react-scripts": "5.0.1", + "react-telegram-login": "^1.1.2", + "react-toastify": "^9.0.8", + "react-turnstile": "^1.0.5", + "semantic-ui-css": "^2.5.0", + "semantic-ui-react": "^2.1.3", + "usehooks-ts": "^2.9.1" + }, + "scripts": { + "start": "react-scripts start", + "build": "react-scripts build && mv -f build ../build/air", + "test": "react-scripts test", + "eject": "react-scripts eject" + }, + "eslintConfig": { + "extends": [ + "react-app", + "react-app/jest" + ] + }, + "browserslist": { + "production": [ + ">0.2%", + "not dead", + "not op_mini all" + ], + "development": [ + "last 1 chrome version", + "last 1 firefox version", + "last 1 safari version" + ] + }, + "devDependencies": { + "prettier": "2.8.8", + "typescript": "4.4.2" + }, + "prettier": { + "singleQuote": true, + "jsxSingleQuote": true + }, + "proxy": "http://localhost:3000" +} diff --git a/web/air/public/favicon.ico b/web/air/public/favicon.ico new file mode 100644 index 00000000..c2c8de0c Binary files /dev/null and b/web/air/public/favicon.ico differ diff --git a/web/air/public/index.html b/web/air/public/index.html new file mode 100644 index 00000000..36365c7e --- /dev/null +++ b/web/air/public/index.html @@ -0,0 +1,18 @@ + + + + + + + + + One API + + + +
+ + diff --git a/web/air/public/logo.png b/web/air/public/logo.png new file mode 100644 index 00000000..0f237a22 Binary files /dev/null and b/web/air/public/logo.png differ diff --git a/web/air/public/robots.txt b/web/air/public/robots.txt new file mode 100644 index 00000000..e9e57dc4 --- /dev/null +++ b/web/air/public/robots.txt @@ -0,0 +1,3 @@ +# https://www.robotstxt.org/robotstxt.html +User-agent: * +Disallow: diff --git a/web/air/src/App.js b/web/air/src/App.js new file mode 100644 index 00000000..5a673187 --- /dev/null +++ b/web/air/src/App.js @@ -0,0 +1,242 @@ +import React, { lazy, Suspense, useContext, useEffect } from 'react'; +import { Route, Routes } from 'react-router-dom'; +import Loading from './components/Loading'; +import User from './pages/User'; +import { PrivateRoute } from './components/PrivateRoute'; +import RegisterForm from './components/RegisterForm'; +import LoginForm from './components/LoginForm'; +import NotFound from './pages/NotFound'; +import Setting from './pages/Setting'; +import EditUser from './pages/User/EditUser'; +import { getLogo, getSystemName } from './helpers'; +import PasswordResetForm from './components/PasswordResetForm'; +import GitHubOAuth from './components/GitHubOAuth'; +import PasswordResetConfirm from './components/PasswordResetConfirm'; +import { UserContext } from './context/User'; +import Channel from './pages/Channel'; +import Token from './pages/Token'; +import EditChannel from './pages/Channel/EditChannel'; +import Redemption from './pages/Redemption'; +import TopUp from './pages/TopUp'; +import Log from './pages/Log'; +import Chat from './pages/Chat'; +import { Layout } from '@douyinfe/semi-ui'; +import Midjourney from './pages/Midjourney'; +import Detail from './pages/Detail'; + +const Home = lazy(() => import('./pages/Home')); +const About = lazy(() => import('./pages/About')); + +function App() { + const [userState, userDispatch] = useContext(UserContext); + // const [statusState, statusDispatch] = useContext(StatusContext); + + const loadUser = () => { + let user = localStorage.getItem('user'); + if (user) { + let data = JSON.parse(user); + userDispatch({ type: 'login', payload: data }); + } + }; + + useEffect(() => { + loadUser(); + let systemName = getSystemName(); + if (systemName) { + document.title = systemName; + } + let logo = getLogo(); + if (logo) { + let linkElement = document.querySelector('link[rel~=\'icon\']'); + if (linkElement) { + linkElement.href = logo; + } + } + }, []); + + return ( + + + + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + + } + /> + + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + + }> + + + + } + /> + + }> + + + + } + /> + + + + } + /> + + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + } /> + + + + ); +} + +export default App; diff --git a/web/air/src/components/ChannelsTable.js b/web/air/src/components/ChannelsTable.js new file mode 100644 index 00000000..dee21a01 --- /dev/null +++ b/web/air/src/components/ChannelsTable.js @@ -0,0 +1,738 @@ +import React, { useEffect, useState } from 'react'; +import { API, isMobile, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; + +import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; +import { renderGroup, renderNumberWithPoint, renderQuota } from '../helpers/render'; +import { + Button, + Dropdown, + Form, + InputNumber, + Popconfirm, + Space, + SplitButtonGroup, + Switch, + Table, + Tag, + Tooltip, + Typography +} from '@douyinfe/semi-ui'; +import EditChannel from '../pages/Channel/EditChannel'; +import { IconTreeTriangleDown } from '@douyinfe/semi-icons'; + +function renderTimestamp(timestamp) { + return ( + <> + {timestamp2string(timestamp)} + + ); +} + +let type2label = undefined; + +function renderType(type) { + if (!type2label) { + type2label = new Map(); + for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { + type2label[CHANNEL_OPTIONS[i].value] = CHANNEL_OPTIONS[i]; + } + type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; + } + return {type2label[type]?.text}; +} + +const ChannelsTable = () => { + const columns = [ + // { + // title: '', + // dataIndex: 'checkbox', + // className: 'checkbox', + // }, + { + title: 'ID', + dataIndex: 'id' + }, + { + title: '名称', + dataIndex: 'name' + }, + // { + // title: '分组', + // dataIndex: 'group', + // render: (text, record, index) => { + // return ( + //
+ // + // { + // text.split(',').map((item, index) => { + // return (renderGroup(item)); + // }) + // } + // + //
+ // ); + // } + // }, + { + title: '类型', + dataIndex: 'type', + render: (text, record, index) => { + return ( +
+ {renderType(text)} +
+ ); + } + }, + { + title: '状态', + dataIndex: 'status', + render: (text, record, index) => { + return ( +
+ {renderStatus(text)} +
+ ); + } + }, + { + title: '响应时间', + dataIndex: 'response_time', + render: (text, record, index) => { + return ( +
+ {renderResponseTime(text)} +
+ ); + } + }, + { + title: '已用/剩余', + dataIndex: 'expired_time', + render: (text, record, index) => { + return ( +
+ + + {renderQuota(record.used_quota)} + + + { + updateChannelBalance(record); + }}>${renderNumberWithPoint(record.balance)} + + +
+ ); + } + }, + { + title: '优先级', + dataIndex: 'priority', + render: (text, record, index) => { + return ( +
+ { + manageChannel(record.id, 'priority', record, e.target.value); + }} + keepFocus={true} + innerButtons + defaultValue={record.priority} + min={-999} + /> +
+ ); + } + }, + // { + // title: '权重', + // dataIndex: 'weight', + // render: (text, record, index) => { + // return ( + //
+ // { + // manageChannel(record.id, 'weight', record, e.target.value); + // }} + // keepFocus={true} + // innerButtons + // defaultValue={record.weight} + // min={0} + // /> + //
+ // ); + // } + // }, + { + title: '', + dataIndex: 'operate', + render: (text, record, index) => ( +
+ {/* + + + + + */} + + { + manageChannel(record.id, 'delete', record).then( + () => { + removeRecord(record.id); + } + ); + }} + > + + + { + record.status === 1 ? + : + + } + +
+ ) + } + ]; + + const [channels, setChannels] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [idSort, setIdSort] = useState(false); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searchGroup, setSearchGroup] = useState(''); + const [searchModel, setSearchModel] = useState(''); + const [searching, setSearching] = useState(false); + const [updatingBalance, setUpdatingBalance] = useState(false); + const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); + const [showPrompt, setShowPrompt] = useState(shouldShowPrompt('channel-test')); + const [channelCount, setChannelCount] = useState(pageSize); + const [groupOptions, setGroupOptions] = useState([]); + const [showEdit, setShowEdit] = useState(false); + const [enableBatchDelete, setEnableBatchDelete] = useState(false); + const [editingChannel, setEditingChannel] = useState({ + id: undefined + }); + const [selectedChannels, setSelectedChannels] = useState([]); + + const removeRecord = id => { + let newDataSource = [...channels]; + if (id != null) { + let idx = newDataSource.findIndex(data => data.id === id); + + if (idx > -1) { + newDataSource.splice(idx, 1); + setChannels(newDataSource); + } + } + }; + + const setChannelFormat = (channels) => { + for (let i = 0; i < channels.length; i++) { + channels[i].key = '' + channels[i].id; + let test_models = []; + channels[i].models.split(',').forEach((item, index) => { + test_models.push({ + node: 'item', + name: item, + onClick: () => { + testChannel(channels[i], item); + } + }); + }); + channels[i].test_models = test_models; + } + // data.key = '' + data.id + setChannels(channels); + if (channels.length >= pageSize) { + setChannelCount(channels.length + pageSize); + } else { + setChannelCount(channels.length); + } + }; + + const loadChannels = async (startIdx, pageSize, idSort) => { + setLoading(true); + const res = await API.get(`/api/channel/?p=${startIdx}&page_size=${pageSize}&id_sort=${idSort}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setChannelFormat(data); + } else { + let newChannels = [...channels]; + newChannels.splice(startIdx * pageSize, data.length, ...data); + setChannelFormat(newChannels); + } + } else { + showError(message); + } + setLoading(false); + }; + + const refresh = async () => { + await loadChannels(activePage - 1, pageSize, idSort); + }; + + useEffect(() => { + // console.log('default effect') + const localIdSort = localStorage.getItem('id-sort') === 'true'; + const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE; + setIdSort(localIdSort); + setPageSize(localPageSize); + loadChannels(0, localPageSize, localIdSort) + .then() + .catch((reason) => { + showError(reason); + }); + fetchGroups().then(); + }, []); + + const manageChannel = async (id, action, record, value) => { + let data = { id }; + let res; + switch (action) { + case 'delete': + res = await API.delete(`/api/channel/${id}/`); + break; + case 'enable': + data.status = 1; + res = await API.put('/api/channel/', data); + break; + case 'disable': + data.status = 2; + res = await API.put('/api/channel/', data); + break; + case 'priority': + if (value === '') { + return; + } + data.priority = parseInt(value); + res = await API.put('/api/channel/', data); + break; + case 'weight': + if (value === '') { + return; + } + data.weight = parseInt(value); + if (data.weight < 0) { + data.weight = 0; + } + res = await API.put('/api/channel/', data); + break; + } + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let channel = res.data.data; + let newChannels = [...channels]; + if (action === 'delete') { + + } else { + record.status = channel.status; + } + setChannels(newChannels); + } else { + showError(message); + } + }; + + const renderStatus = (status) => { + switch (status) { + case 1: + return 已启用; + case 2: + return ( + + 已禁用 + + ); + case 3: + return ( + + 自动禁用 + + ); + default: + return ( + + 未知状态 + + ); + } + }; + + const renderResponseTime = (responseTime) => { + let time = responseTime / 1000; + time = time.toFixed(2) + ' 秒'; + if (responseTime === 0) { + return 未测试; + } else if (responseTime <= 1000) { + return {time}; + } else if (responseTime <= 3000) { + return {time}; + } else if (responseTime <= 5000) { + return {time}; + } else { + return {time}; + } + }; + + const searchChannels = async (searchKeyword, searchGroup, searchModel) => { + if (searchKeyword === '' && searchGroup === '' && searchModel === '') { + // if keyword is blank, load files instead. + await loadChannels(0, pageSize, idSort); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}`); + const { success, message, data } = res.data; + if (success) { + setChannels(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const testChannel = async (record, model) => { + const res = await API.get(`/api/channel/test/${record.id}?model=${model}`); + const { success, message, time } = res.data; + if (success) { + record.response_time = time * 1000; + record.test_time = Date.now() / 1000; + showInfo(`通道 ${record.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); + } else { + showError(message); + } + }; + + const testChannels = async (scope) => { + const res = await API.get(`/api/channel/test?scope=${scope}`); + const { success, message } = res.data; + if (success) { + showInfo('已成功开始测试通道,请刷新页面查看结果。'); + } else { + showError(message); + } + }; + + const deleteAllDisabledChannels = async () => { + const res = await API.delete(`/api/channel/disabled`); + const { success, message, data } = res.data; + if (success) { + showSuccess(`已删除所有禁用渠道,共计 ${data} 个`); + await refresh(); + } else { + showError(message); + } + }; + + const updateChannelBalance = async (record) => { + const res = await API.get(`/api/channel/update_balance/${record.id}/`); + const { success, message, balance } = res.data; + if (success) { + record.balance = balance; + record.balance_updated_time = Date.now() / 1000; + showInfo(`通道 ${record.name} 余额更新成功!`); + } else { + showError(message); + } + }; + + const updateAllChannelsBalance = async () => { + setUpdatingBalance(true); + const res = await API.get(`/api/channel/update_balance`); + const { success, message } = res.data; + if (success) { + showInfo('已更新完毕所有已启用通道余额!'); + } else { + showError(message); + } + setUpdatingBalance(false); + }; + + const batchDeleteChannels = async () => { + if (selectedChannels.length === 0) { + showError('请先选择要删除的通道!'); + return; + } + setLoading(true); + let ids = []; + selectedChannels.forEach((channel) => { + ids.push(channel.id); + }); + const res = await API.post(`/api/channel/batch`, { ids: ids }); + const { success, message, data } = res.data; + if (success) { + showSuccess(`已删除 ${data} 个通道!`); + await refresh(); + } else { + showError(message); + } + setLoading(false); + }; + + const fixChannelsAbilities = async () => { + const res = await API.post(`/api/channel/fix`); + const { success, message, data } = res.data; + if (success) { + showSuccess(`已修复 ${data} 个通道!`); + await refresh(); + } else { + showError(message); + } + }; + + let pageData = channels.slice((activePage - 1) * pageSize, activePage * pageSize); + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(channels.length / pageSize) + 1) { + // In this case we have to load more data and then append them. + loadChannels(page - 1, pageSize, idSort).then(r => { + }); + } + }; + + const handlePageSizeChange = async (size) => { + localStorage.setItem('page-size', size + ''); + setPageSize(size); + setActivePage(1); + loadChannels(0, size, idSort) + .then() + .catch((reason) => { + showError(reason); + }); + }; + + const fetchGroups = async () => { + try { + let res = await API.get(`/api/group/`); + // add 'all' option + // res.data.data.unshift('all'); + setGroupOptions(res.data.data.map((group) => ({ + label: group, + value: group + }))); + } catch (error) { + showError(error.message); + } + }; + + const closeEdit = () => { + setShowEdit(false); + }; + + const handleRow = (record, index) => { + if (record.status !== 1) { + return { + style: { + background: 'var(--semi-color-disabled-border)' + } + }; + } else { + return {}; + } + }; + + + return ( + <> + +
+
{ + searchChannels(searchKeyword, searchGroup, searchModel); + }} labelPosition="left"> +
+ + { + setSearchKeyword(v.trim()); + }} + /> + {/* { + setSearchModel(v.trim()); + }} + /> + { + setSearchGroup(v); + searchChannels(searchKeyword, v, searchModel); + }} /> */} + + +
+
+
+ + + { testChannels("all") }} + position={isMobile() ? 'top' : 'left'} + > + + + { testChannels("disabled") }} + position={isMobile() ? 'top' : 'left'} + > + + + {/* + + */} + + + + + + + {/*
*/} + + {/*
*/} +
+ {/*
+ + 开启批量删除 + { + setEnableBatchDelete(v); + }}> + + + + + + + +
+
+ + + 使用ID排序 + { + localStorage.setItem('id-sort', v + ''); + setIdSort(v); + loadChannels(0, pageSize, v) + .then() + .catch((reason) => { + showError(reason); + }); + }}> + + +
*/} +
+ '', + onPageSizeChange: (size) => { + handlePageSizeChange(size).then(); + }, + onPageChange: handlePageChange + }} loading={loading} onRow={handleRow} rowSelection={ + enableBatchDelete ? + { + onChange: (selectedRowKeys, selectedRows) => { + // console.log(`selectedRowKeys: ${selectedRowKeys}`, 'selectedRows: ', selectedRows); + setSelectedChannels(selectedRows); + } + } : null + } /> + + ); +}; + +export default ChannelsTable; diff --git a/web/air/src/components/Footer.js b/web/air/src/components/Footer.js new file mode 100644 index 00000000..6fd0fa54 --- /dev/null +++ b/web/air/src/components/Footer.js @@ -0,0 +1,64 @@ +import React, { useEffect, useState } from 'react'; + +import { Container, Segment } from 'semantic-ui-react'; +import { getFooterHTML, getSystemName } from '../helpers'; + +const Footer = () => { + const systemName = getSystemName(); + const [footer, setFooter] = useState(getFooterHTML()); + let remainCheckTimes = 5; + + const loadFooter = () => { + let footer_html = localStorage.getItem('footer_html'); + if (footer_html) { + setFooter(footer_html); + } + }; + + useEffect(() => { + const timer = setInterval(() => { + if (remainCheckTimes <= 0) { + clearInterval(timer); + return; + } + remainCheckTimes--; + loadFooter(); + }, 200); + return () => clearTimeout(timer); + }, []); + + return ( + + + {footer ? ( +
+ ) : ( +
+ + {systemName} {process.env.REACT_APP_VERSION}{' '} + + 由{' '} + + JustSong + {' '} + 构建,主题 air 来自{' '} + + Calon + {' '},源代码遵循{' '} + + MIT 协议 + +
+ )} +
+
+ ); +}; + +export default Footer; diff --git a/web/air/src/components/GitHubOAuth.js b/web/air/src/components/GitHubOAuth.js new file mode 100644 index 00000000..4e3b93ba --- /dev/null +++ b/web/air/src/components/GitHubOAuth.js @@ -0,0 +1,58 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { Dimmer, Loader, Segment } from 'semantic-ui-react'; +import { useNavigate, useSearchParams } from 'react-router-dom'; +import { API, showError, showSuccess } from '../helpers'; +import { UserContext } from '../context/User'; + +const GitHubOAuth = () => { + const [searchParams, setSearchParams] = useSearchParams(); + + const [userState, userDispatch] = useContext(UserContext); + const [prompt, setPrompt] = useState('处理中...'); + const [processing, setProcessing] = useState(true); + + let navigate = useNavigate(); + + const sendCode = async (code, state, count) => { + const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/setting'); + } else { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } + } else { + showError(message); + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + navigate('/setting'); // in case this is failed to bind GitHub + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, count * 2000)); + await sendCode(code, state, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + let state = searchParams.get('state'); + sendCode(code, state, 0).then(); + }, []); + + return ( + + + {prompt} + + + ); +}; + +export default GitHubOAuth; diff --git a/web/air/src/components/HeaderBar.js b/web/air/src/components/HeaderBar.js new file mode 100644 index 00000000..eaf36c48 --- /dev/null +++ b/web/air/src/components/HeaderBar.js @@ -0,0 +1,161 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { Link, useNavigate } from 'react-router-dom'; +import { UserContext } from '../context/User'; + +import { API, getLogo, getSystemName, showSuccess } from '../helpers'; +import '../index.css'; + +import fireworks from 'react-fireworks'; + +import { IconHelpCircle, IconKey, IconUser } from '@douyinfe/semi-icons'; +import { Avatar, Dropdown, Layout, Nav, Switch } from '@douyinfe/semi-ui'; +import { stringToColor } from '../helpers/render'; + +// HeaderBar Buttons +let headerButtons = [ + { + text: '关于', + itemKey: 'about', + to: '/about', + icon: + } +]; + +if (localStorage.getItem('chat_link')) { + headerButtons.splice(1, 0, { + name: '聊天', + to: '/chat', + icon: 'comments' + }); +} + +const HeaderBar = () => { + const [userState, userDispatch] = useContext(UserContext); + let navigate = useNavigate(); + + const [showSidebar, setShowSidebar] = useState(false); + const [dark, setDark] = useState(false); + const systemName = getSystemName(); + const logo = getLogo(); + var themeMode = localStorage.getItem('theme-mode'); + const currentDate = new Date(); + // enable fireworks on new year(1.1 and 2.9-2.24) + const isNewYear = (currentDate.getMonth() === 0 && currentDate.getDate() === 1) || (currentDate.getMonth() === 1 && currentDate.getDate() >= 9 && currentDate.getDate() <= 24); + + async function logout() { + setShowSidebar(false); + await API.get('/api/user/logout'); + showSuccess('注销成功!'); + userDispatch({ type: 'logout' }); + localStorage.removeItem('user'); + navigate('/login'); + } + + const handleNewYearClick = () => { + fireworks.init('root', {}); + fireworks.start(); + setTimeout(() => { + fireworks.stop(); + setTimeout(() => { + window.location.reload(); + }, 10000); + }, 3000); + }; + + useEffect(() => { + if (themeMode === 'dark') { + switchMode(true); + } + if (isNewYear) { + console.log('Happy New Year!'); + } + }, []); + + const switchMode = (model) => { + const body = document.body; + if (!model) { + body.removeAttribute('theme-mode'); + localStorage.setItem('theme-mode', 'light'); + } else { + body.setAttribute('theme-mode', 'dark'); + localStorage.setItem('theme-mode', 'dark'); + } + setDark(model); + }; + return ( + <> + +
+ +
+
+ + ); +}; + +export default HeaderBar; diff --git a/web/air/src/components/Loading.js b/web/air/src/components/Loading.js new file mode 100644 index 00000000..bacb53b3 --- /dev/null +++ b/web/air/src/components/Loading.js @@ -0,0 +1,14 @@ +import React from 'react'; +import { Dimmer, Loader, Segment } from 'semantic-ui-react'; + +const Loading = ({ prompt: name = 'page' }) => { + return ( + + + 加载{name}中... + + + ); +}; + +export default Loading; diff --git a/web/air/src/components/LoginForm.js b/web/air/src/components/LoginForm.js new file mode 100644 index 00000000..3cbeb52c --- /dev/null +++ b/web/air/src/components/LoginForm.js @@ -0,0 +1,254 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { Link, useNavigate, useSearchParams } from 'react-router-dom'; +import { UserContext } from '../context/User'; +import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; +import { onGitHubOAuthClicked } from './utils'; +import Turnstile from 'react-turnstile'; +import { Button, Card, Divider, Form, Icon, Layout, Modal } from '@douyinfe/semi-ui'; +import Title from '@douyinfe/semi-ui/lib/es/typography/title'; +import Text from '@douyinfe/semi-ui/lib/es/typography/text'; +import TelegramLoginButton from 'react-telegram-login'; + +import { IconGithubLogo } from '@douyinfe/semi-icons'; +import WeChatIcon from './WeChatIcon'; + +const LoginForm = () => { + const [inputs, setInputs] = useState({ + username: '', + password: '', + wechat_verification_code: '' + }); + const [searchParams, setSearchParams] = useSearchParams(); + const [submitted, setSubmitted] = useState(false); + const { username, password } = inputs; + const [userState, userDispatch] = useContext(UserContext); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + let navigate = useNavigate(); + const [status, setStatus] = useState({}); + const logo = getLogo(); + + useEffect(() => { + if (searchParams.get('expired')) { + showError('未登录或登录已过期,请重新登录!'); + } + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + setStatus(status); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + } + }, []); + + const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); + + const onWeChatLoginClicked = () => { + setShowWeChatLoginModal(true); + }; + + const onSubmitWeChatVerificationCode = async () => { + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + const res = await API.get( + `/api/oauth/wechat?code=${inputs.wechat_verification_code}` + ); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + navigate('/'); + showSuccess('登录成功!'); + setShowWeChatLoginModal(false); + } else { + showError(message); + } + }; + + function handleChange(name, value) { + setInputs((inputs) => ({ ...inputs, [name]: value })); + } + + async function handleSubmit(e) { + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setSubmitted(true); + if (username && password) { + const res = await API.post(`/api/user/login?turnstile=${turnstileToken}`, { + username, + password + }); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + if (username === 'root' && password === '123456') { + Modal.error({ title: '您正在使用默认密码!', content: '请立刻修改默认密码!', centered: true }); + } + navigate('/token'); + } else { + showError(message); + } + } else { + showError('请输入用户名和密码!'); + } + } + + // 添加Telegram登录处理函数 + const onTelegramLoginClicked = async (response) => { + const fields = ['id', 'first_name', 'last_name', 'username', 'photo_url', 'auth_date', 'hash', 'lang']; + const params = {}; + fields.forEach((field) => { + if (response[field]) { + params[field] = response[field]; + } + }); + const res = await API.get(`/api/oauth/telegram/login`, { params }); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } else { + showError(message); + } + }; + + return ( +
+ + + + +
+
+ + + 用户登录 + +
+ handleChange('username', value)} + /> + handleChange('password', value)} + /> + + + +
+ + 没有账号请先 注册账号 + + + 忘记密码 点击重置 + +
+ {status.github_oauth || status.wechat_login || status.telegram_oauth ? ( + <> + + 第三方登录 + +
+ {status.github_oauth ? ( +
+ + ) : ( + <> + )} + setShowWeChatLoginModal(false)} + okText={'登录'} + size={'small'} + centered={true} + > +
+ +
+
+

+ 微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) +

+
+
+ handleChange('wechat_verification_code', value)} + /> + +
+
+ {turnstileEnabled ? ( +
+ { + setTurnstileToken(token); + }} + /> +
+ ) : ( + <> + )} +
+
+ +
+
+
+ ); +}; + +export default LoginForm; diff --git a/web/air/src/components/LogsTable.js b/web/air/src/components/LogsTable.js new file mode 100644 index 00000000..004188c3 --- /dev/null +++ b/web/air/src/components/LogsTable.js @@ -0,0 +1,401 @@ +import React, { useEffect, useState } from 'react'; +import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers'; + +import { Avatar, Button, Form, Layout, Modal, Select, Space, Spin, Table, Tag } from '@douyinfe/semi-ui'; +import { ITEMS_PER_PAGE } from '../constants'; +import { renderNumber, renderQuota, stringToColor } from '../helpers/render'; +import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph'; + +const { Header } = Layout; + +function renderTimestamp(timestamp) { + return (<> + {timestamp2string(timestamp)} + ); +} + +const MODE_OPTIONS = [{ key: 'all', text: '全部用户', value: 'all' }, { key: 'self', text: '当前用户', value: 'self' }]; + +const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo', 'light-blue', 'lime', 'orange', 'pink', 'purple', 'red', 'teal', 'violet', 'yellow']; + +function renderType(type) { + switch (type) { + case 1: + return 充值 ; + case 2: + return 消费 ; + case 3: + return 管理 ; + case 4: + return 系统 ; + default: + return 未知 ; + } +} + +function renderIsStream(bool) { + if (bool) { + return ; + } else { + return 非流; + } +} + +function renderUseTime(type) { + const time = parseInt(type); + if (time < 101) { + return {time} s ; + } else if (time < 300) { + return {time} s ; + } else { + return {time} s ; + } +} + +const LogsTable = () => { + const columns = [{ + title: '时间', dataIndex: 'timestamp2string' + }, { + title: '渠道', + dataIndex: 'channel', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return (isAdminUser ? record.type === 0 || record.type === 2 ?
+ { {text} } +
: <> : <>); + } + }, { + title: '用户', + dataIndex: 'username', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return (isAdminUser ?
+ showUserInfo(record.user_id)}> + {typeof text === 'string' && text.slice(0, 1)} + + {text} +
: <>); + } + }, { + title: '令牌', dataIndex: 'token_name', render: (text, record, index) => { + return (record.type === 0 || record.type === 2 ?
+ { + copyText(text); + }}> {text} +
: <>); + } + }, { + title: '类型', dataIndex: 'type', render: (text, record, index) => { + return (
+ {renderType(text)} +
); + } + }, { + title: '模型', dataIndex: 'model_name', render: (text, record, index) => { + return (record.type === 0 || record.type === 2 ?
+ { + copyText(text); + }}> {text} +
: <>); + } + }, + // { + // title: '用时', dataIndex: 'use_time', render: (text, record, index) => { + // return (
+ // + // {renderUseTime(text)} + // {renderIsStream(record.is_stream)} + // + //
); + // } + // }, + { + title: '提示', dataIndex: 'prompt_tokens', render: (text, record, index) => { + return (record.type === 0 || record.type === 2 ?
+ { {text} } +
: <>); + } + }, { + title: '补全', dataIndex: 'completion_tokens', render: (text, record, index) => { + return (parseInt(text) > 0 && (record.type === 0 || record.type === 2) ?
+ { {text} } +
: <>); + } + }, { + title: '花费', dataIndex: 'quota', render: (text, record, index) => { + return (record.type === 0 || record.type === 2 ?
+ {renderQuota(text, 6)} +
: <>); + } + }, { + title: '详情', dataIndex: 'content', render: (text, record, index) => { + return + {text} + ; + } + }]; + + const [logs, setLogs] = useState([]); + const [showStat, setShowStat] = useState(false); + const [loading, setLoading] = useState(false); + const [loadingStat, setLoadingStat] = useState(false); + const [activePage, setActivePage] = useState(1); + const [logCount, setLogCount] = useState(ITEMS_PER_PAGE); + const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + const [logType, setLogType] = useState(0); + const isAdminUser = isAdmin(); + let now = new Date(); + // 初始化start_timestamp为前一天 + const [inputs, setInputs] = useState({ + username: '', + token_name: '', + model_name: '', + start_timestamp: timestamp2string(now.getTime() / 1000 - 86400), + end_timestamp: timestamp2string(now.getTime() / 1000 + 3600), + channel: '' + }); + const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs; + + const [stat, setStat] = useState({ + quota: 0, token: 0 + }); + + const handleInputChange = (value, name) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + const getLogSelfStat = async () => { + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); + const { success, message, data } = res.data; + if (success) { + setStat(data); + } else { + showError(message); + } + }; + + const getLogStat = async () => { + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`); + const { success, message, data } = res.data; + if (success) { + setStat(data); + } else { + showError(message); + } + }; + + const handleEyeClick = async () => { + setLoadingStat(true); + if (isAdminUser) { + await getLogStat(); + } else { + await getLogSelfStat(); + } + setShowStat(true); + setLoadingStat(false); + }; + + const showUserInfo = async (userId) => { + if (!isAdminUser) { + return; + } + const res = await API.get(`/api/user/${userId}`); + const { success, message, data } = res.data; + if (success) { + Modal.info({ + title: '用户信息', content:
+

用户名: {data.username}

+

余额: {renderQuota(data.quota)}

+

已用额度:{renderQuota(data.used_quota)}

+

请求次数:{renderNumber(data.request_count)}

+
, centered: true + }); + } else { + showError(message); + } + }; + + const setLogsFormat = (logs) => { + for (let i = 0; i < logs.length; i++) { + logs[i].timestamp2string = timestamp2string(logs[i].created_at); + logs[i].key = '' + logs[i].id; + } + // data.key = '' + data.id + setLogs(logs); + setLogCount(logs.length + ITEMS_PER_PAGE); + // console.log(logCount); + }; + + const loadLogs = async (startIdx, pageSize, logType = 0) => { + setLoading(true); + + let url = ''; + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + if (isAdminUser) { + url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`; + } else { + url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } + const res = await API.get(url); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setLogsFormat(data); + } else { + let newLogs = [...logs]; + newLogs.splice(startIdx * pageSize, data.length, ...data); + setLogsFormat(newLogs); + } + } else { + showError(message); + } + setLoading(false); + }; + + const pageData = logs.slice((activePage - 1) * pageSize, activePage * pageSize); + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(logs.length / pageSize) + 1) { + // In this case we have to load more data and then append them. + loadLogs(page - 1, pageSize).then(r => { + }); + } + }; + + const handlePageSizeChange = async (size) => { + localStorage.setItem('page-size', size + ''); + setPageSize(size); + setActivePage(1); + loadLogs(0, size) + .then() + .catch((reason) => { + showError(reason); + }); + }; + + const refresh = async (localLogType) => { + // setLoading(true); + setActivePage(1); + await loadLogs(0, pageSize, localLogType); + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制:' + text); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + useEffect(() => { + // console.log('default effect') + const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE; + setPageSize(localPageSize); + loadLogs(0, localPageSize) + .then() + .catch((reason) => { + showError(reason); + }); + }, []); + + const searchLogs = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadLogs(0, pageSize); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/log/self/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setLogs(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + return (<> + +
+ +

使用明细(总消耗额度: + {showStat ? renderQuota(stat.quota) : '点击查看'} + ) +

+
+
+
+ <> + handleInputChange(value, 'token_name')} /> + handleInputChange(value, 'model_name')} /> + handleInputChange(value, 'start_timestamp')} /> + handleInputChange(value, 'end_timestamp')} /> + {isAdminUser && <> + handleInputChange(value, 'channel')} /> + handleInputChange(value, 'username')} /> + } + + + + + +
{ + handlePageSizeChange(size).then(); + }, + onPageChange: handlePageChange + }} /> + + + ); +}; + +export default LogsTable; diff --git a/web/air/src/components/MjLogsTable.js b/web/air/src/components/MjLogsTable.js new file mode 100644 index 00000000..6a6fbd95 --- /dev/null +++ b/web/air/src/components/MjLogsTable.js @@ -0,0 +1,454 @@ +import React, { useEffect, useState } from 'react'; +import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers'; + +import { Banner, Button, Form, ImagePreview, Layout, Modal, Progress, Table, Tag, Typography } from '@douyinfe/semi-ui'; +import { ITEMS_PER_PAGE } from '../constants'; + + +const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo', + 'light-blue', 'lime', 'orange', 'pink', + 'purple', 'red', 'teal', 'violet', 'yellow' +]; + +function renderType(type) { + switch (type) { + case 'IMAGINE': + return 绘图; + case 'UPSCALE': + return 放大; + case 'VARIATION': + return 变换; + case 'HIGH_VARIATION': + return 强变换; + case 'LOW_VARIATION': + return 弱变换; + case 'PAN': + return 平移; + case 'DESCRIBE': + return 图生文; + case 'BLEND': + return 图混合; + case 'SHORTEN': + return 缩词; + case 'REROLL': + return 重绘; + case 'INPAINT': + return 局部重绘-提交; + case 'ZOOM': + return 变焦; + case 'CUSTOM_ZOOM': + return 自定义变焦-提交; + case 'MODAL': + return 窗口处理; + case 'SWAP_FACE': + return 换脸; + default: + return 未知; + } +} + + +function renderCode(code) { + switch (code) { + case 1: + return 已提交; + case 21: + return 等待中; + case 22: + return 重复提交; + case 0: + return 未提交; + default: + return 未知; + } +} + + +function renderStatus(type) { + // Ensure all cases are string literals by adding quotes. + switch (type) { + case 'SUCCESS': + return 成功; + case 'NOT_START': + return 未启动; + case 'SUBMITTED': + return 队列中; + case 'IN_PROGRESS': + return 执行中; + case 'FAILURE': + return 失败; + case 'MODAL': + return 窗口等待; + default: + return 未知; + } +} + +const renderTimestamp = (timestampInSeconds) => { + const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒 + + const year = date.getFullYear(); // 获取年份 + const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份,从0开始需要+1,并保证两位数 + const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数 + const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数 + const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数 + const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数 + + return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出 +}; + + +const LogsTable = () => { + const [isModalOpen, setIsModalOpen] = useState(false); + const [modalContent, setModalContent] = useState(''); + const columns = [ + { + title: '提交时间', + dataIndex: 'submit_time', + render: (text, record, index) => { + return ( +
+ {renderTimestamp(text / 1000)} +
+ ); + } + }, + { + title: '渠道', + dataIndex: 'channel_id', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return ( + +
+ { + copyText(text); // 假设copyText是用于文本复制的函数 + }}> {text} +
+ + ); + } + }, + { + title: '类型', + dataIndex: 'action', + render: (text, record, index) => { + return ( +
+ {renderType(text)} +
+ ); + } + }, + { + title: '任务ID', + dataIndex: 'mj_id', + render: (text, record, index) => { + return ( +
+ {text} +
+ ); + } + }, + { + title: '提交结果', + dataIndex: 'code', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return ( +
+ {renderCode(text)} +
+ ); + } + }, + { + title: '任务状态', + dataIndex: 'status', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return ( +
+ {renderStatus(text)} +
+ ); + } + }, + { + title: '进度', + dataIndex: 'progress', + render: (text, record, index) => { + return ( +
+ { + // 转换例如100%为数字100,如果text未定义,返回0 + + } +
+ ); + } + }, + { + title: '结果图片', + dataIndex: 'image_url', + render: (text, record, index) => { + if (!text) { + return '无'; + } + return ( + + ); + } + }, + { + title: 'Prompt', + dataIndex: 'prompt', + render: (text, record, index) => { + // 如果text未定义,返回替代文本,例如空字符串''或其他 + if (!text) { + return '无'; + } + + return ( + { + setModalContent(text); + setIsModalOpen(true); + }} + > + {text} + + ); + } + }, + { + title: 'PromptEn', + dataIndex: 'prompt_en', + render: (text, record, index) => { + // 如果text未定义,返回替代文本,例如空字符串''或其他 + if (!text) { + return '无'; + } + + return ( + { + setModalContent(text); + setIsModalOpen(true); + }} + > + {text} + + ); + } + }, + { + title: '失败原因', + dataIndex: 'fail_reason', + render: (text, record, index) => { + // 如果text未定义,返回替代文本,例如空字符串''或其他 + if (!text) { + return '无'; + } + + return ( + { + setModalContent(text); + setIsModalOpen(true); + }} + > + {text} + + ); + } + } + + ]; + + const [logs, setLogs] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [logCount, setLogCount] = useState(ITEMS_PER_PAGE); + const [logType, setLogType] = useState(0); + const isAdminUser = isAdmin(); + const [isModalOpenurl, setIsModalOpenurl] = useState(false); + const [showBanner, setShowBanner] = useState(false); + + // 定义模态框图片URL的状态和更新函数 + const [modalImageUrl, setModalImageUrl] = useState(''); + let now = new Date(); + // 初始化start_timestamp为前一天 + const [inputs, setInputs] = useState({ + channel_id: '', + mj_id: '', + start_timestamp: timestamp2string(now.getTime() / 1000 - 2592000), + end_timestamp: timestamp2string(now.getTime() / 1000 + 3600) + }); + const { channel_id, mj_id, start_timestamp, end_timestamp } = inputs; + + const [stat, setStat] = useState({ + quota: 0, + token: 0 + }); + + const handleInputChange = (value, name) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + + const setLogsFormat = (logs) => { + for (let i = 0; i < logs.length; i++) { + logs[i].timestamp2string = timestamp2string(logs[i].created_at); + logs[i].key = '' + logs[i].id; + } + // data.key = '' + data.id + setLogs(logs); + setLogCount(logs.length + ITEMS_PER_PAGE); + // console.log(logCount); + }; + + const loadLogs = async (startIdx) => { + setLoading(true); + + let url = ''; + let localStartTimestamp = Date.parse(start_timestamp); + let localEndTimestamp = Date.parse(end_timestamp); + if (isAdminUser) { + url = `/api/mj/?p=${startIdx}&channel_id=${channel_id}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } else { + url = `/api/mj/self/?p=${startIdx}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } + const res = await API.get(url); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setLogsFormat(data); + } else { + let newLogs = [...logs]; + newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); + setLogsFormat(newLogs); + } + } else { + showError(message); + } + setLoading(false); + }; + + const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE); + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + loadLogs(page - 1).then(r => { + }); + } + }; + + const refresh = async () => { + // setLoading(true); + setActivePage(1); + await loadLogs(0); + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制:' + text); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + useEffect(() => { + refresh().then(); + }, [logType]); + + useEffect(() => { + const mjNotifyEnabled = localStorage.getItem('mj_notify_enabled'); + if (mjNotifyEnabled !== 'true') { + setShowBanner(true); + } + }, []); + + return ( + <> + + + {isAdminUser && showBanner ? : <> + } +
+ <> + handleInputChange(value, 'channel_id')} /> + handleInputChange(value, 'mj_id')} /> + handleInputChange(value, 'start_timestamp')} /> + handleInputChange(value, 'end_timestamp')} /> + + + + + + +
+ setIsModalOpen(false)} + onCancel={() => setIsModalOpen(false)} + closable={null} + bodyStyle={{ height: '400px', overflow: 'auto' }} // 设置模态框内容区域样式 + width={800} // 设置模态框宽度 + > +

{modalContent}

+
+ setIsModalOpenurl(visible)} + /> + + + + ); +}; + +export default LogsTable; diff --git a/web/air/src/components/OperationSetting.js b/web/air/src/components/OperationSetting.js new file mode 100644 index 00000000..b823bb28 --- /dev/null +++ b/web/air/src/components/OperationSetting.js @@ -0,0 +1,389 @@ +import React, { useEffect, useState } from 'react'; +import { Divider, Form, Grid, Header } from 'semantic-ui-react'; +import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers'; + +const OperationSetting = () => { + let now = new Date(); + let [inputs, setInputs] = useState({ + QuotaForNewUser: 0, + QuotaForInviter: 0, + QuotaForInvitee: 0, + QuotaRemindThreshold: 0, + PreConsumedQuota: 0, + ModelRatio: '', + CompletionRatio: '', + GroupRatio: '', + TopUpLink: '', + ChatLink: '', + QuotaPerUnit: 0, + AutomaticDisableChannelEnabled: '', + AutomaticEnableChannelEnabled: '', + ChannelDisableThreshold: 0, + LogConsumeEnabled: '', + DisplayInCurrencyEnabled: '', + DisplayTokenStatEnabled: '', + ApproximateTokenEnabled: '', + RetryTimes: 0 + }); + const [originInputs, setOriginInputs] = useState({}); + let [loading, setLoading] = useState(false); + let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key === 'ModelRatio' || item.key === 'GroupRatio' || item.key === 'CompletionRatio') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + if (item.value === '{}') { + item.value = ''; + } + newInputs[item.key] = item.value; + }); + setInputs(newInputs); + setOriginInputs(newInputs); + } else { + showError(message); + } + }; + + useEffect(() => { + getOptions().then(); + }, []); + + const updateOption = async (key, value) => { + setLoading(true); + if (key.endsWith('Enabled')) { + value = inputs[key] === 'true' ? 'false' : 'true'; + } + const res = await API.put('/api/option/', { + key, + value + }); + const { success, message } = res.data; + if (success) { + setInputs((inputs) => ({ ...inputs, [key]: value })); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, { name, value }) => { + if (name.endsWith('Enabled')) { + await updateOption(name, value); + } else { + setInputs((inputs) => ({ ...inputs, [name]: value })); + } + }; + + const submitConfig = async (group) => { + switch (group) { + case 'monitor': + if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) { + await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold); + } + if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) { + await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold); + } + break; + case 'ratio': + if (originInputs['ModelRatio'] !== inputs.ModelRatio) { + if (!verifyJSON(inputs.ModelRatio)) { + showError('模型倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('ModelRatio', inputs.ModelRatio); + } + if (originInputs['GroupRatio'] !== inputs.GroupRatio) { + if (!verifyJSON(inputs.GroupRatio)) { + showError('分组倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('GroupRatio', inputs.GroupRatio); + } + if (originInputs['CompletionRatio'] !== inputs.CompletionRatio) { + if (!verifyJSON(inputs.CompletionRatio)) { + showError('补全倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('CompletionRatio', inputs.CompletionRatio); + } + break; + case 'quota': + if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) { + await updateOption('QuotaForNewUser', inputs.QuotaForNewUser); + } + if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) { + await updateOption('QuotaForInvitee', inputs.QuotaForInvitee); + } + if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) { + await updateOption('QuotaForInviter', inputs.QuotaForInviter); + } + if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) { + await updateOption('PreConsumedQuota', inputs.PreConsumedQuota); + } + break; + case 'general': + if (originInputs['TopUpLink'] !== inputs.TopUpLink) { + await updateOption('TopUpLink', inputs.TopUpLink); + } + if (originInputs['ChatLink'] !== inputs.ChatLink) { + await updateOption('ChatLink', inputs.ChatLink); + } + if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) { + await updateOption('QuotaPerUnit', inputs.QuotaPerUnit); + } + if (originInputs['RetryTimes'] !== inputs.RetryTimes) { + await updateOption('RetryTimes', inputs.RetryTimes); + } + break; + } + }; + + const deleteHistoryLogs = async () => { + console.log(inputs); + const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`); + const { success, message, data } = res.data; + if (success) { + showSuccess(`${data} 条日志已清理!`); + return; + } + showError('日志清理失败:' + message); + }; + + return ( + + +
+
+ 通用设置 +
+ + + + + + + + + + + + { + submitConfig('general').then(); + }}>保存通用设置 + +
+ 日志设置 +
+ + + + + { + setHistoryTimestamp(value); + }} /> + + { + deleteHistoryLogs().then(); + }}>清理历史日志 + +
+ 监控设置 +
+ + + + + + + + + { + submitConfig('monitor').then(); + }}>保存监控设置 + +
+ 额度设置 +
+ + + + + + + { + submitConfig('quota').then(); + }}>保存额度设置 + +
+ 倍率设置 +
+ + + + + + + + + + { + submitConfig('ratio').then(); + }}>保存倍率设置 + +
+
+ ); +}; + +export default OperationSetting; diff --git a/web/air/src/components/OtherSetting.js b/web/air/src/components/OtherSetting.js new file mode 100644 index 00000000..ae924d9f --- /dev/null +++ b/web/air/src/components/OtherSetting.js @@ -0,0 +1,225 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Divider, Form, Grid, Header, Message, Modal } from 'semantic-ui-react'; +import { API, showError, showSuccess } from '../helpers'; +import { marked } from 'marked'; +import { Link } from 'react-router-dom'; + +const OtherSetting = () => { + let [inputs, setInputs] = useState({ + Footer: '', + Notice: '', + About: '', + SystemName: '', + Logo: '', + HomePageContent: '', + Theme: '' + }); + let [loading, setLoading] = useState(false); + const [showUpdateModal, setShowUpdateModal] = useState(false); + const [updateData, setUpdateData] = useState({ + tag_name: '', + content: '' + }); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key in inputs) { + newInputs[item.key] = item.value; + } + }); + setInputs(newInputs); + } else { + showError(message); + } + }; + + useEffect(() => { + getOptions().then(); + }, []); + + const updateOption = async (key, value) => { + setLoading(true); + const res = await API.put('/api/option/', { + key, + value + }); + const { success, message } = res.data; + if (success) { + setInputs((inputs) => ({ ...inputs, [key]: value })); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, { name, value }) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + const submitNotice = async () => { + await updateOption('Notice', inputs.Notice); + }; + + const submitFooter = async () => { + await updateOption('Footer', inputs.Footer); + }; + + const submitSystemName = async () => { + await updateOption('SystemName', inputs.SystemName); + }; + + const submitTheme = async () => { + await updateOption('Theme', inputs.Theme); + }; + + const submitLogo = async () => { + await updateOption('Logo', inputs.Logo); + }; + + const submitAbout = async () => { + await updateOption('About', inputs.About); + }; + + const submitOption = async (key) => { + await updateOption(key, inputs[key]); + }; + + const openGitHubRelease = () => { + window.location = + 'https://github.com/songquanpeng/one-api/releases/latest'; + }; + + const checkUpdate = async () => { + const res = await API.get( + 'https://api.github.com/repos/songquanpeng/one-api/releases/latest' + ); + const { tag_name, body } = res.data; + if (tag_name === process.env.REACT_APP_VERSION) { + showSuccess(`已是最新版本:${tag_name}`); + } else { + setUpdateData({ + tag_name: tag_name, + content: marked.parse(body) + }); + setShowUpdateModal(true); + } + }; + + return ( + + +
+
通用设置
+ 检查更新 + + + + 保存公告 + +
个性化设置
+ + + + 设置系统名称 + + 主题名称(当前可用主题)} + placeholder='请输入主题名称' + value={inputs.Theme} + name='Theme' + onChange={handleInputChange} + /> + + 设置主题(重启生效) + + + + 设置 Logo + + + + submitOption('HomePageContent')}>保存首页内容 + + + + 保存关于 + 移除 One API + 的版权标识必须首先获得授权,项目维护需要花费大量精力,如果本项目对你有意义,请主动支持本项目。 + + + + 设置页脚 + +
+ setShowUpdateModal(false)} + onOpen={() => setShowUpdateModal(true)} + open={showUpdateModal} + > + 新版本:{updateData.tag_name} + + +
+
+
+ + + + + + +
+ ); +}; + +export default PasswordResetConfirm; diff --git a/web/air/src/components/PasswordResetForm.js b/web/air/src/components/PasswordResetForm.js new file mode 100644 index 00000000..ff3eaadb --- /dev/null +++ b/web/air/src/components/PasswordResetForm.js @@ -0,0 +1,102 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Form, Grid, Header, Image, Segment } from 'semantic-ui-react'; +import { API, showError, showInfo, showSuccess } from '../helpers'; +import Turnstile from 'react-turnstile'; + +const PasswordResetForm = () => { + const [inputs, setInputs] = useState({ + email: '' + }); + const { email } = inputs; + + const [loading, setLoading] = useState(false); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + const [disableButton, setDisableButton] = useState(false); + const [countdown, setCountdown] = useState(30); + + useEffect(() => { + let countdownInterval = null; + if (disableButton && countdown > 0) { + countdownInterval = setInterval(() => { + setCountdown(countdown - 1); + }, 1000); + } else if (countdown === 0) { + setDisableButton(false); + setCountdown(30); + } + return () => clearInterval(countdownInterval); + }, [disableButton, countdown]); + + function handleChange(e) { + const { name, value } = e.target; + setInputs(inputs => ({ ...inputs, [name]: value })); + } + + async function handleSubmit(e) { + setDisableButton(true); + if (!email) return; + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/reset_password?email=${email}&turnstile=${turnstileToken}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('重置邮件发送成功,请检查邮箱!'); + setInputs({ ...inputs, email: '' }); + } else { + showError(message); + } + setLoading(false); + } + + return ( + + +
+ 密码重置 +
+
+ + + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} + + + +
+
+ ); +}; + +export default PasswordResetForm; diff --git a/web/air/src/components/PersonalSetting.js b/web/air/src/components/PersonalSetting.js new file mode 100644 index 00000000..45a5b776 --- /dev/null +++ b/web/air/src/components/PersonalSetting.js @@ -0,0 +1,653 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { useNavigate } from 'react-router-dom'; +import { API, copy, isRoot, showError, showInfo, showSuccess } from '../helpers'; +import Turnstile from 'react-turnstile'; +import { UserContext } from '../context/User'; +import { onGitHubOAuthClicked } from './utils'; +import { + Avatar, + Banner, + Button, + Card, + Descriptions, + Image, + Input, + InputNumber, + Layout, + Modal, + Space, + Tag, + Typography +} from '@douyinfe/semi-ui'; +import { getQuotaPerUnit, renderQuota, renderQuotaWithPrompt, stringToColor } from '../helpers/render'; +import TelegramLoginButton from 'react-telegram-login'; + +const PersonalSetting = () => { + const [userState, userDispatch] = useContext(UserContext); + let navigate = useNavigate(); + + const [inputs, setInputs] = useState({ + wechat_verification_code: '', + email_verification_code: '', + email: '', + self_account_deletion_confirmation: '', + set_new_password: '', + set_new_password_confirmation: '' + }); + const [status, setStatus] = useState({}); + const [showChangePasswordModal, setShowChangePasswordModal] = useState(false); + const [showWeChatBindModal, setShowWeChatBindModal] = useState(false); + const [showEmailBindModal, setShowEmailBindModal] = useState(false); + const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + const [loading, setLoading] = useState(false); + const [disableButton, setDisableButton] = useState(false); + const [countdown, setCountdown] = useState(30); + const [affLink, setAffLink] = useState(''); + const [systemToken, setSystemToken] = useState(''); + // const [models, setModels] = useState([]); + const [openTransfer, setOpenTransfer] = useState(false); + const [transferAmount, setTransferAmount] = useState(0); + + useEffect(() => { + // let user = localStorage.getItem('user'); + // if (user) { + // userDispatch({ type: 'login', payload: user }); + // } + // console.log(localStorage.getItem('user')) + + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + setStatus(status); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + } + getUserData().then( + (res) => { + console.log(userState); + } + ); + // loadModels().then(); + getAffLink().then(); + setTransferAmount(getQuotaPerUnit()); + }, []); + + useEffect(() => { + let countdownInterval = null; + if (disableButton && countdown > 0) { + countdownInterval = setInterval(() => { + setCountdown(countdown - 1); + }, 1000); + } else if (countdown === 0) { + setDisableButton(false); + setCountdown(30); + } + return () => clearInterval(countdownInterval); // Clean up on unmount + }, [disableButton, countdown]); + + const handleInputChange = (name, value) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + const generateAccessToken = async () => { + const res = await API.get('/api/user/token'); + const { success, message, data } = res.data; + if (success) { + setSystemToken(data); + await copy(data); + showSuccess(`令牌已重置并已复制到剪贴板`); + } else { + showError(message); + } + }; + + const getAffLink = async () => { + const res = await API.get('/api/user/aff'); + const { success, message, data } = res.data; + if (success) { + let link = `${window.location.origin}/register?aff=${data}`; + setAffLink(link); + } else { + showError(message); + } + }; + + const getUserData = async () => { + let res = await API.get(`/api/user/self`); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + } else { + showError(message); + } + }; + + // const loadModels = async () => { + // let res = await API.get(`/api/user/models`); + // const { success, message, data } = res.data; + // if (success) { + // setModels(data); + // console.log(data); + // } else { + // showError(message); + // } + // }; + + const handleAffLinkClick = async (e) => { + e.target.select(); + await copy(e.target.value); + showSuccess(`邀请链接已复制到剪切板`); + }; + + const handleSystemTokenClick = async (e) => { + e.target.select(); + await copy(e.target.value); + showSuccess(`系统令牌已复制到剪切板`); + }; + + const deleteAccount = async () => { + if (inputs.self_account_deletion_confirmation !== userState.user.username) { + showError('请输入你的账户名以确认删除!'); + return; + } + + const res = await API.delete('/api/user/self'); + const { success, message } = res.data; + + if (success) { + showSuccess('账户已删除!'); + await API.get('/api/user/logout'); + userDispatch({ type: 'logout' }); + localStorage.removeItem('user'); + navigate('/login'); + } else { + showError(message); + } + }; + + const bindWeChat = async () => { + if (inputs.wechat_verification_code === '') return; + const res = await API.get( + `/api/oauth/wechat/bind?code=${inputs.wechat_verification_code}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('微信账户绑定成功!'); + setShowWeChatBindModal(false); + } else { + showError(message); + } + }; + + const changePassword = async () => { + if (inputs.set_new_password !== inputs.set_new_password_confirmation) { + showError('两次输入的密码不一致!'); + return; + } + const res = await API.put( + `/api/user/self`, + { + password: inputs.set_new_password + } + ); + const { success, message } = res.data; + if (success) { + showSuccess('密码修改成功!'); + setShowWeChatBindModal(false); + } else { + showError(message); + } + setShowChangePasswordModal(false); + }; + + const transfer = async () => { + if (transferAmount < getQuotaPerUnit()) { + showError('划转金额最低为' + renderQuota(getQuotaPerUnit())); + return; + } + const res = await API.post( + `/api/user/aff_transfer`, + { + quota: transferAmount + } + ); + const { success, message } = res.data; + if (success) { + showSuccess(message); + setOpenTransfer(false); + getUserData().then(); + } else { + showError(message); + } + }; + + const sendVerificationCode = async () => { + if (inputs.email === '') { + showError('请输入邮箱!'); + return; + } + setDisableButton(true); + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('验证码发送成功,请检查邮箱!'); + } else { + showError(message); + } + setLoading(false); + }; + + const bindEmail = async () => { + if (inputs.email_verification_code === '') { + showError('请输入邮箱验证码!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/oauth/email/bind?email=${inputs.email}&code=${inputs.email_verification_code}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('邮箱账户绑定成功!'); + setShowEmailBindModal(false); + userState.user.email = inputs.email; + } else { + showError(message); + } + setLoading(false); + }; + + const getUsername = () => { + if (userState.user) { + return userState.user.username; + } else { + return 'null'; + } + }; + + const handleCancel = () => { + setOpenTransfer(false); + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制:' + text); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + return ( +
+ + + +
+ {`可用额度${renderQuotaWithPrompt(userState?.user?.aff_quota)}`} + +
+
+ {`划转额度${renderQuotaWithPrompt(transferAmount)} 最低` + renderQuota(getQuotaPerUnit())} +
+ setTransferAmount(value)} disabled={false}> +
+
+
+
+ + {typeof getUsername() === 'string' && getUsername().slice(0, 1)} + } + title={{getUsername()}} + description={isRoot() ? 管理员 : 普通用户} + > + } + headerExtraContent={ + <> + + {'ID: ' + userState?.user?.id} + {userState?.user?.group} + + + } + footer={ + + {renderQuota(userState?.user?.quota)} + {renderQuota(userState?.user?.used_quota)} + {userState.user?.request_count} + + } + > + 调用信息 + {/* 可用模型 +
+ + {models.map((model) => ( + { + copyText(model); + }}> + {model} + + ))} + +
*/} +
+ {/* + 邀请链接 + +
+ } + > + 邀请信息 +
+ + + + { + renderQuota(userState?.user?.aff_quota) + } + + + + {renderQuota(userState?.user?.aff_history_quota)} + {userState?.user?.aff_count} + +
+ */} + + 邀请链接 + + + + 个人信息 +
+ 邮箱 +
+
+ +
+
+ +
+
+
+
+ 微信 +
+
+ +
+
+ +
+
+
+
+ GitHub +
+
+ +
+
+ +
+
+
+ + {/*
+ Telegram +
+
+ +
+
+ {status.telegram_oauth ? + userState.user.telegram_id !== '' ? + : + : + } +
+
+
*/} + +
+ + + + + + + {systemToken && ( + + )} + { + status.wechat_login && ( + + ) + } + setShowWeChatBindModal(false)} + // onOpen={() => setShowWeChatBindModal(true)} + visible={showWeChatBindModal} + size={'mini'} + > + +
+

+ 微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) +

+
+ handleInputChange('wechat_verification_code', v)} + /> + +
+
+
+ setShowEmailBindModal(false)} + // onOpen={() => setShowEmailBindModal(true)} + onOk={bindEmail} + visible={showEmailBindModal} + size={'small'} + centered={true} + maskClosable={false} + > + 绑定邮箱地址 +
+ handleInputChange('email', value)} + name="email" + type="email" + /> + +
+
+ handleInputChange('email_verification_code', value)} + /> +
+ {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} +
+ setShowAccountDeleteModal(false)} + visible={showAccountDeleteModal} + size={'small'} + centered={true} + onOk={deleteAccount} + > +
+ +
+
+ handleInputChange('self_account_deletion_confirmation', value)} + /> + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} +
+
+ setShowChangePasswordModal(false)} + visible={showChangePasswordModal} + size={'small'} + centered={true} + onOk={changePassword} + > +
+ handleInputChange('set_new_password', value)} + /> + handleInputChange('set_new_password_confirmation', value)} + /> + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} +
+
+
+ + + + + ); +}; + +export default PersonalSetting; diff --git a/web/air/src/components/PrivateRoute.js b/web/air/src/components/PrivateRoute.js new file mode 100644 index 00000000..9ef826c1 --- /dev/null +++ b/web/air/src/components/PrivateRoute.js @@ -0,0 +1,13 @@ +import { Navigate } from 'react-router-dom'; + +import { history } from '../helpers'; + + +function PrivateRoute({ children }) { + if (!localStorage.getItem('user')) { + return ; + } + return children; +} + +export { PrivateRoute }; \ No newline at end of file diff --git a/web/air/src/components/RedemptionsTable.js b/web/air/src/components/RedemptionsTable.js new file mode 100644 index 00000000..89e4ce20 --- /dev/null +++ b/web/air/src/components/RedemptionsTable.js @@ -0,0 +1,406 @@ +import React, { useEffect, useState } from 'react'; +import { API, copy, showError, showSuccess, timestamp2string } from '../helpers'; + +import { ITEMS_PER_PAGE } from '../constants'; +import { renderQuota } from '../helpers/render'; +import { Button, Form, Modal, Popconfirm, Popover, Table, Tag } from '@douyinfe/semi-ui'; +import EditRedemption from '../pages/Redemption/EditRedemption'; + +function renderTimestamp(timestamp) { + return ( + <> + {timestamp2string(timestamp)} + + ); +} + +function renderStatus(status) { + switch (status) { + case 1: + return 未使用; + case 2: + return 已禁用 ; + case 3: + return 已使用 ; + default: + return 未知状态 ; + } +} + +const RedemptionsTable = () => { + const columns = [ + { + title: 'ID', + dataIndex: 'id' + }, + { + title: '名称', + dataIndex: 'name' + }, + { + title: '状态', + dataIndex: 'status', + key: 'status', + render: (text, record, index) => { + return ( +
+ {renderStatus(text)} +
+ ); + } + }, + { + title: '额度', + dataIndex: 'quota', + render: (text, record, index) => { + return ( +
+ {renderQuota(parseInt(text))} +
+ ); + } + }, + { + title: '创建时间', + dataIndex: 'created_time', + render: (text, record, index) => { + return ( +
+ {renderTimestamp(text)} +
+ ); + } + }, + // { + // title: '兑换人ID', + // dataIndex: 'used_user_id', + // render: (text, record, index) => { + // return ( + //
+ // {text === 0 ? '无' : text} + //
+ // ); + // } + // }, + { + title: '', + dataIndex: 'operate', + render: (text, record, index) => ( +
+ + + + + { + manageRedemption(record.id, 'delete', record).then( + () => { + removeRecord(record.key); + } + ); + }} + > + + + { + record.status === 1 ? + : + + } + +
+ ) + } + ]; + + const [redemptions, setRedemptions] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + const [tokenCount, setTokenCount] = useState(ITEMS_PER_PAGE); + const [selectedKeys, setSelectedKeys] = useState([]); + const [editingRedemption, setEditingRedemption] = useState({ + id: undefined + }); + const [showEdit, setShowEdit] = useState(false); + + const closeEdit = () => { + setShowEdit(false); + }; + + // const setCount = (data) => { + // if (data.length >= (activePage) * ITEMS_PER_PAGE) { + // setTokenCount(data.length + 1); + // } else { + // setTokenCount(data.length); + // } + // } + + const setRedemptionFormat = (redeptions) => { + // for (let i = 0; i < redeptions.length; i++) { + // redeptions[i].key = '' + redeptions[i].id; + // } + // data.key = '' + data.id + setRedemptions(redeptions); + if (redeptions.length >= (activePage) * ITEMS_PER_PAGE) { + setTokenCount(redeptions.length + 1); + } else { + setTokenCount(redeptions.length); + } + }; + + const loadRedemptions = async (startIdx) => { + const res = await API.get(`/api/redemption/?p=${startIdx}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setRedemptionFormat(data); + } else { + let newRedemptions = redemptions; + newRedemptions.push(...data); + setRedemptionFormat(newRedemptions); + } + } else { + showError(message); + } + setLoading(false); + }; + + const removeRecord = key => { + let newDataSource = [...redemptions]; + if (key != null) { + let idx = newDataSource.findIndex(data => data.key === key); + + if (idx > -1) { + newDataSource.splice(idx, 1); + setRedemptions(newDataSource); + } + } + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制到剪贴板!'); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(redemptions.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + await loadRedemptions(activePage - 1); + } + setActivePage(activePage); + })(); + }; + + useEffect(() => { + loadRedemptions(0) + .then() + .catch((reason) => { + showError(reason); + }); + }, []); + + const refresh = async () => { + await loadRedemptions(activePage - 1); + }; + + const manageRedemption = async (id, action, record) => { + let data = { id }; + let res; + switch (action) { + case 'delete': + res = await API.delete(`/api/redemption/${id}/`); + break; + case 'enable': + data.status = 1; + res = await API.put('/api/redemption/?status_only=true', data); + break; + case 'disable': + data.status = 2; + res = await API.put('/api/redemption/?status_only=true', data); + break; + } + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let redemption = res.data.data; + let newRedemptions = [...redemptions]; + // let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + if (action === 'delete') { + + } else { + record.status = redemption.status; + } + setRedemptions(newRedemptions); + } else { + showError(message); + } + }; + + const searchRedemptions = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadRedemptions(0); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/redemption/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setRedemptions(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (value) => { + setSearchKeyword(value.trim()); + }; + + const sortRedemption = (key) => { + if (redemptions.length === 0) return; + setLoading(true); + let sortedRedemptions = [...redemptions]; + sortedRedemptions.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + if (sortedRedemptions[0].id === redemptions[0].id) { + sortedRedemptions.reverse(); + } + setRedemptions(sortedRedemptions); + setLoading(false); + }; + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(redemptions.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + loadRedemptions(page - 1).then(r => { + }); + } + }; + + let pageData = redemptions.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE); + const rowSelection = { + onSelect: (record, selected) => { + }, + onSelectAll: (selected, selectedRows) => { + }, + onChange: (selectedRowKeys, selectedRows) => { + setSelectedKeys(selectedRows); + } + }; + + const handleRow = (record, index) => { + if (record.status !== 1) { + return { + style: { + background: 'var(--semi-color-disabled-border)' + } + }; + } else { + return {}; + } + }; + + return ( + <> + +
+ + + +
`第 ${page.currentStart} - ${page.currentEnd} 条,共 ${redemptions.length} 条`, + // onPageSizeChange: (size) => { + // setPageSize(size); + // setActivePage(1); + // }, + onPageChange: handlePageChange + }} loading={loading} rowSelection={rowSelection} onRow={handleRow}> +
+ + + + ); +}; + +export default RedemptionsTable; diff --git a/web/air/src/components/RegisterForm.js b/web/air/src/components/RegisterForm.js new file mode 100644 index 00000000..1f26b63f --- /dev/null +++ b/web/air/src/components/RegisterForm.js @@ -0,0 +1,194 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Form, Grid, Header, Image, Message, Segment } from 'semantic-ui-react'; +import { Link, useNavigate } from 'react-router-dom'; +import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; +import Turnstile from 'react-turnstile'; + +const RegisterForm = () => { + const [inputs, setInputs] = useState({ + username: '', + password: '', + password2: '', + email: '', + verification_code: '' + }); + const { username, password, password2 } = inputs; + const [showEmailVerification, setShowEmailVerification] = useState(false); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + const [loading, setLoading] = useState(false); + const logo = getLogo(); + let affCode = new URLSearchParams(window.location.search).get('aff'); + if (affCode) { + localStorage.setItem('aff', affCode); + } + + useEffect(() => { + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + setShowEmailVerification(status.email_verification); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + } + }); + + let navigate = useNavigate(); + + function handleChange(e) { + const { name, value } = e.target; + console.log(name, value); + setInputs((inputs) => ({ ...inputs, [name]: value })); + } + + async function handleSubmit(e) { + if (password.length < 8) { + showInfo('密码长度不得小于 8 位!'); + return; + } + if (password !== password2) { + showInfo('两次输入的密码不一致'); + return; + } + if (username && password) { + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + if (!affCode) { + affCode = localStorage.getItem('aff'); + } + inputs.aff_code = affCode; + const res = await API.post( + `/api/user/register?turnstile=${turnstileToken}`, + inputs + ); + const { success, message } = res.data; + if (success) { + navigate('/login'); + showSuccess('注册成功!'); + } else { + showError(message); + } + setLoading(false); + } + } + + const sendVerificationCode = async () => { + if (inputs.email === '') return; + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('验证码发送成功,请检查你的邮箱!'); + } else { + showError(message); + } + setLoading(false); + }; + + return ( + + +
+ 新用户注册 +
+
+ + + + + {showEmailVerification ? ( + <> + + 获取验证码 + + } + /> + + + ) : ( + <> + )} + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} + + +
+ + 已有账户? + + 点击登录 + + +
+
+ ); +}; + +export default RegisterForm; diff --git a/web/air/src/components/SiderBar.js b/web/air/src/components/SiderBar.js new file mode 100644 index 00000000..b3da272f --- /dev/null +++ b/web/air/src/components/SiderBar.js @@ -0,0 +1,214 @@ +import React, { useContext, useEffect, useMemo, useState } from 'react'; +import { Link, useNavigate } from 'react-router-dom'; +import { UserContext } from '../context/User'; +import { StatusContext } from '../context/Status'; + +import { API, getLogo, getSystemName, isAdmin, isMobile, showError } from '../helpers'; +import '../index.css'; + +import { + IconCalendarClock, + IconComment, + IconCreditCard, + IconGift, + IconHistogram, + IconHome, + IconImage, + IconKey, + IconLayers, + IconSetting, + IconUser +} from '@douyinfe/semi-icons'; +import { Layout, Nav } from '@douyinfe/semi-ui'; + +// HeaderBar Buttons + +const SiderBar = () => { + const [userState, userDispatch] = useContext(UserContext); + const [statusState, statusDispatch] = useContext(StatusContext); + const defaultIsCollapsed = isMobile() || localStorage.getItem('default_collapse_sidebar') === 'true'; + + let navigate = useNavigate(); + const [selectedKeys, setSelectedKeys] = useState(['home']); + const systemName = getSystemName(); + const logo = getLogo(); + const [isCollapsed, setIsCollapsed] = useState(defaultIsCollapsed); + + const headerButtons = useMemo(() => [ + { + text: '首页', + itemKey: 'home', + to: '/', + icon: + }, + { + text: '渠道', + itemKey: 'channel', + to: '/channel', + icon: , + className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '聊天', + itemKey: 'chat', + to: '/chat', + icon: , + className: localStorage.getItem('chat_link') ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '令牌', + itemKey: 'token', + to: '/token', + icon: + }, + { + text: '兑换', + itemKey: 'redemption', + to: '/redemption', + icon: , + className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '充值', + itemKey: 'topup', + to: '/topup', + icon: + }, + { + text: '用户', + itemKey: 'user', + to: '/user', + icon: , + className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '日志', + itemKey: 'log', + to: '/log', + icon: + }, + { + text: '数据看板', + itemKey: 'detail', + to: '/detail', + icon: , + className: localStorage.getItem('enable_data_export') === 'true' ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '绘图', + itemKey: 'midjourney', + to: '/midjourney', + icon: , + className: localStorage.getItem('enable_drawing') === 'true' ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '设置', + itemKey: 'setting', + to: '/setting', + icon: + } + // { + // text: '关于', + // itemKey: 'about', + // to: '/about', + // icon: + // } + ], [localStorage.getItem('enable_data_export'), localStorage.getItem('enable_drawing'), localStorage.getItem('chat_link'), isAdmin()]); + + const loadStatus = async () => { + const res = await API.get('/api/status'); + const { success, data } = res.data; + if (success) { + localStorage.setItem('status', JSON.stringify(data)); + statusDispatch({ type: 'set', payload: data }); + localStorage.setItem('system_name', data.system_name); + localStorage.setItem('logo', data.logo); + localStorage.setItem('footer_html', data.footer_html); + localStorage.setItem('quota_per_unit', data.quota_per_unit); + localStorage.setItem('display_in_currency', data.display_in_currency); + localStorage.setItem('enable_drawing', data.enable_drawing); + localStorage.setItem('enable_data_export', data.enable_data_export); + localStorage.setItem('data_export_default_time', data.data_export_default_time); + localStorage.setItem('default_collapse_sidebar', data.default_collapse_sidebar); + localStorage.setItem('mj_notify_enabled', data.mj_notify_enabled); + if (data.chat_link) { + localStorage.setItem('chat_link', data.chat_link); + } else { + localStorage.removeItem('chat_link'); + } + if (data.chat_link2) { + localStorage.setItem('chat_link2', data.chat_link2); + } else { + localStorage.removeItem('chat_link2'); + } + } else { + showError('无法正常连接至服务器!'); + } + }; + + useEffect(() => { + loadStatus().then(() => { + setIsCollapsed(isMobile() || localStorage.getItem('default_collapse_sidebar') === 'true'); + }); + }, []); + + return ( + <> + +
+ +
+
+ + ); +}; + +export default SiderBar; diff --git a/web/air/src/components/SystemSetting.js b/web/air/src/components/SystemSetting.js new file mode 100644 index 00000000..09b98665 --- /dev/null +++ b/web/air/src/components/SystemSetting.js @@ -0,0 +1,590 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Divider, Form, Grid, Header, Modal, Message } from 'semantic-ui-react'; +import { API, removeTrailingSlash, showError } from '../helpers'; + +const SystemSetting = () => { + let [inputs, setInputs] = useState({ + PasswordLoginEnabled: '', + PasswordRegisterEnabled: '', + EmailVerificationEnabled: '', + GitHubOAuthEnabled: '', + GitHubClientId: '', + GitHubClientSecret: '', + Notice: '', + SMTPServer: '', + SMTPPort: '', + SMTPAccount: '', + SMTPFrom: '', + SMTPToken: '', + ServerAddress: '', + Footer: '', + WeChatAuthEnabled: '', + WeChatServerAddress: '', + WeChatServerToken: '', + WeChatAccountQRCodeImageURL: '', + MessagePusherAddress: '', + MessagePusherToken: '', + TurnstileCheckEnabled: '', + TurnstileSiteKey: '', + TurnstileSecretKey: '', + RegisterEnabled: '', + EmailDomainRestrictionEnabled: '', + EmailDomainWhitelist: '' + }); + const [originInputs, setOriginInputs] = useState({}); + let [loading, setLoading] = useState(false); + const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]); + const [restrictedDomainInput, setRestrictedDomainInput] = useState(''); + const [showPasswordWarningModal, setShowPasswordWarningModal] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + newInputs[item.key] = item.value; + }); + setInputs({ + ...newInputs, + EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(',') + }); + setOriginInputs(newInputs); + + setEmailDomainWhitelist(newInputs.EmailDomainWhitelist.split(',').map((item) => { + return { key: item, text: item, value: item }; + })); + } else { + showError(message); + } + }; + + useEffect(() => { + getOptions().then(); + }, []); + + const updateOption = async (key, value) => { + setLoading(true); + switch (key) { + case 'PasswordLoginEnabled': + case 'PasswordRegisterEnabled': + case 'EmailVerificationEnabled': + case 'GitHubOAuthEnabled': + case 'WeChatAuthEnabled': + case 'TurnstileCheckEnabled': + case 'EmailDomainRestrictionEnabled': + case 'RegisterEnabled': + value = inputs[key] === 'true' ? 'false' : 'true'; + break; + default: + break; + } + const res = await API.put('/api/option/', { + key, + value + }); + const { success, message } = res.data; + if (success) { + if (key === 'EmailDomainWhitelist') { + value = value.split(','); + } + setInputs((inputs) => ({ + ...inputs, [key]: value + })); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, { name, value }) => { + if (name === 'PasswordLoginEnabled' && inputs[name] === 'true') { + // block disabling password login + setShowPasswordWarningModal(true); + return; + } + if ( + name === 'Notice' || + name.startsWith('SMTP') || + name === 'ServerAddress' || + name === 'GitHubClientId' || + name === 'GitHubClientSecret' || + name === 'WeChatServerAddress' || + name === 'WeChatServerToken' || + name === 'WeChatAccountQRCodeImageURL' || + name === 'TurnstileSiteKey' || + name === 'TurnstileSecretKey' || + name === 'EmailDomainWhitelist' + ) { + setInputs((inputs) => ({ ...inputs, [name]: value })); + } else { + await updateOption(name, value); + } + }; + + const submitServerAddress = async () => { + let ServerAddress = removeTrailingSlash(inputs.ServerAddress); + await updateOption('ServerAddress', ServerAddress); + }; + + const submitSMTP = async () => { + if (originInputs['SMTPServer'] !== inputs.SMTPServer) { + await updateOption('SMTPServer', inputs.SMTPServer); + } + if (originInputs['SMTPAccount'] !== inputs.SMTPAccount) { + await updateOption('SMTPAccount', inputs.SMTPAccount); + } + if (originInputs['SMTPFrom'] !== inputs.SMTPFrom) { + await updateOption('SMTPFrom', inputs.SMTPFrom); + } + if ( + originInputs['SMTPPort'] !== inputs.SMTPPort && + inputs.SMTPPort !== '' + ) { + await updateOption('SMTPPort', inputs.SMTPPort); + } + if ( + originInputs['SMTPToken'] !== inputs.SMTPToken && + inputs.SMTPToken !== '' + ) { + await updateOption('SMTPToken', inputs.SMTPToken); + } + }; + + + const submitEmailDomainWhitelist = async () => { + if ( + originInputs['EmailDomainWhitelist'] !== inputs.EmailDomainWhitelist.join(',') && + inputs.SMTPToken !== '' + ) { + await updateOption('EmailDomainWhitelist', inputs.EmailDomainWhitelist.join(',')); + } + }; + + const submitWeChat = async () => { + if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) { + await updateOption( + 'WeChatServerAddress', + removeTrailingSlash(inputs.WeChatServerAddress) + ); + } + if ( + originInputs['WeChatAccountQRCodeImageURL'] !== + inputs.WeChatAccountQRCodeImageURL + ) { + await updateOption( + 'WeChatAccountQRCodeImageURL', + inputs.WeChatAccountQRCodeImageURL + ); + } + if ( + originInputs['WeChatServerToken'] !== inputs.WeChatServerToken && + inputs.WeChatServerToken !== '' + ) { + await updateOption('WeChatServerToken', inputs.WeChatServerToken); + } + }; + + const submitMessagePusher = async () => { + if (originInputs['MessagePusherAddress'] !== inputs.MessagePusherAddress) { + await updateOption( + 'MessagePusherAddress', + removeTrailingSlash(inputs.MessagePusherAddress) + ); + } + if ( + originInputs['MessagePusherToken'] !== inputs.MessagePusherToken && + inputs.MessagePusherToken !== '' + ) { + await updateOption('MessagePusherToken', inputs.MessagePusherToken); + } + }; + + const submitGitHubOAuth = async () => { + if (originInputs['GitHubClientId'] !== inputs.GitHubClientId) { + await updateOption('GitHubClientId', inputs.GitHubClientId); + } + if ( + originInputs['GitHubClientSecret'] !== inputs.GitHubClientSecret && + inputs.GitHubClientSecret !== '' + ) { + await updateOption('GitHubClientSecret', inputs.GitHubClientSecret); + } + }; + + const submitTurnstile = async () => { + if (originInputs['TurnstileSiteKey'] !== inputs.TurnstileSiteKey) { + await updateOption('TurnstileSiteKey', inputs.TurnstileSiteKey); + } + if ( + originInputs['TurnstileSecretKey'] !== inputs.TurnstileSecretKey && + inputs.TurnstileSecretKey !== '' + ) { + await updateOption('TurnstileSecretKey', inputs.TurnstileSecretKey); + } + }; + + const submitNewRestrictedDomain = () => { + const localDomainList = inputs.EmailDomainWhitelist; + if (restrictedDomainInput !== '' && !localDomainList.includes(restrictedDomainInput)) { + setRestrictedDomainInput(''); + setInputs({ + ...inputs, + EmailDomainWhitelist: [...localDomainList, restrictedDomainInput], + }); + setEmailDomainWhitelist([...EmailDomainWhitelist, { + key: restrictedDomainInput, + text: restrictedDomainInput, + value: restrictedDomainInput, + }]); + } + } + + return ( + + +
+
通用设置
+ + + + + 更新服务器地址 + + +
配置登录注册
+ + + { + showPasswordWarningModal && + setShowPasswordWarningModal(false)} + size={'tiny'} + style={{ maxWidth: '450px' }} + > + 警告 + +

取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?

+
+ + + + +
+ } + + + + +
+ + + + + +
+ 配置邮箱域名白名单 + 用以防止恶意用户利用临时邮箱批量注册 +
+ + + + + + { + submitNewRestrictedDomain(); + }}>填入 + } + onKeyDown={(e) => { + if (e.key === 'Enter') { + submitNewRestrictedDomain(); + } + }} + autoComplete='new-password' + placeholder='输入新的允许的邮箱域名' + value={restrictedDomainInput} + onChange={(e, { value }) => { + setRestrictedDomainInput(value); + }} + /> + + 保存邮箱域名白名单设置 + +
+ 配置 SMTP + 用以支持系统的邮件发送 +
+ + + + + + + + + + 保存 SMTP 设置 + +
+ 配置 GitHub OAuth App + + 用以支持通过 GitHub 进行登录注册, + + 点击此处 + + 管理你的 GitHub OAuth App + +
+ + Homepage URL 填 {inputs.ServerAddress} + ,Authorization callback URL 填{' '} + {`${inputs.ServerAddress}/oauth/github`} + + + + + + + 保存 GitHub OAuth 设置 + + +
+ 配置 WeChat Server + + 用以支持通过微信进行登录注册, + + 点击此处 + + 了解 WeChat Server + +
+ + + + + + + 保存 WeChat Server 设置 + + +
+ 配置 Message Pusher + + 用以推送报警信息, + + 点击此处 + + 了解 Message Pusher + +
+ + + + + + 保存 Message Pusher 设置 + + +
+ 配置 Turnstile + + 用以支持用户校验, + + 点击此处 + + 管理你的 Turnstile Sites,推荐选择 Invisible Widget Type + +
+ + + + + + 保存 Turnstile 设置 + + +
+
+ ); +}; + +export default SystemSetting; diff --git a/web/air/src/components/TokensTable.js b/web/air/src/components/TokensTable.js new file mode 100644 index 00000000..9c4deb6e --- /dev/null +++ b/web/air/src/components/TokensTable.js @@ -0,0 +1,586 @@ +import React, { useEffect, useState } from 'react'; +import { API, copy, showError, showSuccess, timestamp2string } from '../helpers'; + +import { ITEMS_PER_PAGE } from '../constants'; +import { renderQuota } from '../helpers/render'; +import { Button, Dropdown, Form, Modal, Popconfirm, Popover, SplitButtonGroup, Table, Tag } from '@douyinfe/semi-ui'; + +import { IconTreeTriangleDown } from '@douyinfe/semi-icons'; +import EditToken from '../pages/Token/EditToken'; + +const COPY_OPTIONS = [ + { key: 'next', text: 'ChatGPT Next Web', value: 'next' }, + { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, + { key: 'opencat', text: 'OpenCat', value: 'opencat' } +]; + +const OPEN_LINK_OPTIONS = [ + { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, + { key: 'opencat', text: 'OpenCat', value: 'opencat' } +]; + +function renderTimestamp(timestamp) { + return ( + <> + {timestamp2string(timestamp)} + + ); +} + +function renderStatus(status, model_limits_enabled = false) { + switch (status) { + case 1: + if (model_limits_enabled) { + return 已启用:限制模型; + } else { + return 已启用; + } + case 2: + return 已禁用 ; + case 3: + return 已过期 ; + case 4: + return 已耗尽 ; + default: + return 未知状态 ; + } +} + +const TokensTable = () => { + + const link_menu = [ + { + node: 'item', key: 'next', name: 'ChatGPT Next Web', onClick: () => { + onOpenLink('next'); + } + }, + { node: 'item', key: 'ama', name: 'AMA 问天', value: 'ama' }, + { + node: 'item', key: 'next-mj', name: 'ChatGPT Web & Midjourney', value: 'next-mj', onClick: () => { + onOpenLink('next-mj'); + } + }, + { node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' } + ]; + + const columns = [ + { + title: '名称', + dataIndex: 'name' + }, + { + title: '状态', + dataIndex: 'status', + key: 'status', + render: (text, record, index) => { + return ( +
+ {renderStatus(text, record.model_limits_enabled)} +
+ ); + } + }, + { + title: '已用额度', + dataIndex: 'used_quota', + render: (text, record, index) => { + return ( +
+ {renderQuota(parseInt(text))} +
+ ); + } + }, + { + title: '剩余额度', + dataIndex: 'remain_quota', + render: (text, record, index) => { + return ( +
+ {record.unlimited_quota ? 无限制 : + {renderQuota(parseInt(text))}} +
+ ); + } + }, + { + title: '创建时间', + dataIndex: 'created_time', + render: (text, record, index) => { + return ( +
+ {renderTimestamp(text)} +
+ ); + } + }, + { + title: '过期时间', + dataIndex: 'expired_time', + render: (text, record, index) => { + return ( +
+ {record.expired_time === -1 ? '永不过期' : renderTimestamp(text)} +
+ ); + } + }, + { + title: '', + dataIndex: 'operate', + render: (text, record, index) => ( +
+ + + + + + + { + onOpenLink('next', record.key); + } + }, + { + node: 'item', + key: 'next-mj', + disabled: !localStorage.getItem('chat_link2'), + name: 'ChatGPT Web & Midjourney', + onClick: () => { + onOpenLink('next-mj', record.key); + } + }, + { + node: 'item', key: 'ama', name: 'AMA 问天(BotGem)', onClick: () => { + onOpenLink('ama', record.key); + } + }, + { + node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => { + onOpenLink('opencat', record.key); + } + } + ] + } + > + + + + { + manageToken(record.id, 'delete', record).then( + () => { + removeRecord(record.key); + } + ); + }} + > + + + { + record.status === 1 ? + : + + } + +
+ ) + } + ]; + + const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); + const [showEdit, setShowEdit] = useState(false); + const [tokens, setTokens] = useState([]); + const [selectedKeys, setSelectedKeys] = useState([]); + const [tokenCount, setTokenCount] = useState(pageSize); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searchToken, setSearchToken] = useState(''); + const [searching, setSearching] = useState(false); + const [showTopUpModal, setShowTopUpModal] = useState(false); + const [targetTokenIdx, setTargetTokenIdx] = useState(0); + const [editingToken, setEditingToken] = useState({ + id: undefined + }); + + const closeEdit = () => { + setShowEdit(false); + setTimeout(() => { + setEditingToken({ + id: undefined + }); + }, 500); + }; + + const setTokensFormat = (tokens) => { + setTokens(tokens); + if (tokens.length >= pageSize) { + setTokenCount(tokens.length + pageSize); + } else { + setTokenCount(tokens.length); + } + }; + + let pageData = tokens.slice((activePage - 1) * pageSize, activePage * pageSize); + const loadTokens = async (startIdx) => { + setLoading(true); + const res = await API.get(`/api/token/?p=${startIdx}&size=${pageSize}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setTokensFormat(data); + } else { + let newTokens = [...tokens]; + newTokens.splice(startIdx * pageSize, data.length, ...data); + setTokensFormat(newTokens); + } + } else { + showError(message); + } + setLoading(false); + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(tokens.length / pageSize) + 1) { + // In this case we have to load more data and then append them. + await loadTokens(activePage - 1); + } + setActivePage(activePage); + })(); + }; + + const refresh = async () => { + await loadTokens(activePage - 1); + }; + + const onCopy = async (type, key) => { + let status = localStorage.getItem('status'); + let serverAddress = ''; + if (status) { + status = JSON.parse(status); + serverAddress = status.server_address; + } + if (serverAddress === '') { + serverAddress = window.location.origin; + } + let encodedServerAddress = encodeURIComponent(serverAddress); + const nextLink = localStorage.getItem('chat_link'); + const mjLink = localStorage.getItem('chat_link2'); + let nextUrl; + + if (nextLink) { + nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + } else { + nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + } + + let url; + switch (type) { + case 'ama': + url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + break; + case 'opencat': + url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`; + break; + case 'next': + url = nextUrl; + break; + default: + url = `sk-${key}`; + } + // if (await copy(url)) { + // showSuccess('已复制到剪贴板!'); + // } else { + // showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。'); + // setSearchKeyword(url); + // } + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制到剪贴板!'); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + const onOpenLink = async (type, key) => { + let status = localStorage.getItem('status'); + let serverAddress = ''; + if (status) { + status = JSON.parse(status); + serverAddress = status.server_address; + } + if (serverAddress === '') { + serverAddress = window.location.origin; + } + let encodedServerAddress = encodeURIComponent(serverAddress); + const chatLink = localStorage.getItem('chat_link'); + const mjLink = localStorage.getItem('chat_link2'); + let defaultUrl; + + if (chatLink) { + defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + } + let url; + switch (type) { + case 'ama': + url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`; + break; + case 'opencat': + url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`; + break; + case 'next-mj': + url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + break; + default: + if (!chatLink) { + showError('管理员未设置聊天链接'); + return; + } + url = defaultUrl; + } + + window.open(url, '_blank'); + }; + + useEffect(() => { + loadTokens(0) + .then() + .catch((reason) => { + showError(reason); + }); + }, [pageSize]); + + const removeRecord = key => { + let newDataSource = [...tokens]; + if (key != null) { + let idx = newDataSource.findIndex(data => data.key === key); + + if (idx > -1) { + newDataSource.splice(idx, 1); + setTokensFormat(newDataSource); + } + } + }; + + const manageToken = async (id, action, record) => { + setLoading(true); + let data = { id }; + let res; + switch (action) { + case 'delete': + res = await API.delete(`/api/token/${id}/`); + break; + case 'enable': + data.status = 1; + res = await API.put('/api/token/?status_only=true', data); + break; + case 'disable': + data.status = 2; + res = await API.put('/api/token/?status_only=true', data); + break; + } + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let token = res.data.data; + let newTokens = [...tokens]; + // let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + if (action === 'delete') { + + } else { + record.status = token.status; + // newTokens[realIdx].status = token.status; + } + setTokensFormat(newTokens); + } else { + showError(message); + } + setLoading(false); + }; + + const searchTokens = async () => { + if (searchKeyword === '' && searchToken === '') { + // if keyword is blank, load files instead. + await loadTokens(0); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/token/search?keyword=${searchKeyword}&token=${searchToken}`); + const { success, message, data } = res.data; + if (success) { + setTokensFormat(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (value) => { + setSearchKeyword(value.trim()); + }; + + const handleSearchTokenChange = async (value) => { + setSearchToken(value.trim()); + }; + + const sortToken = (key) => { + if (tokens.length === 0) return; + setLoading(true); + let sortedTokens = [...tokens]; + sortedTokens.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + if (sortedTokens[0].id === tokens[0].id) { + sortedTokens.reverse(); + } + setTokens(sortedTokens); + setLoading(false); + }; + + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(tokens.length / pageSize) + 1) { + // In this case we have to load more data and then append them. + loadTokens(page - 1).then(r => { + }); + } + }; + + const rowSelection = { + onSelect: (record, selected) => { + }, + onSelectAll: (selected, selectedRows) => { + }, + onChange: (selectedRowKeys, selectedRows) => { + setSelectedKeys(selectedRows); + } + }; + + const handleRow = (record, index) => { + if (record.status !== 1) { + return { + style: { + background: 'var(--semi-color-disabled-border)' + } + }; + } else { + return {}; + } + }; + + return ( + <> + +
+ + {/* */} + + + + `第 ${page.currentStart} - ${page.currentEnd} 条,共 ${tokens.length} 条`, + onPageSizeChange: (size) => { + setPageSize(size); + setActivePage(1); + }, + onPageChange: handlePageChange + }} loading={loading} rowSelection={rowSelection} onRow={handleRow}> +
+ + + + ); +}; + +export default TokensTable; diff --git a/web/air/src/components/UsersTable.js b/web/air/src/components/UsersTable.js new file mode 100644 index 00000000..f3de46d6 --- /dev/null +++ b/web/air/src/components/UsersTable.js @@ -0,0 +1,338 @@ +import React, { useEffect, useState } from 'react'; +import { API, showError, showSuccess } from '../helpers'; +import { Button, Form, Popconfirm, Space, Table, Tag, Tooltip } from '@douyinfe/semi-ui'; +import { ITEMS_PER_PAGE } from '../constants'; +import { renderGroup, renderNumber, renderQuota } from '../helpers/render'; +import AddUser from '../pages/User/AddUser'; +import EditUser from '../pages/User/EditUser'; + +function renderRole(role) { + switch (role) { + case 1: + return 普通用户; + case 10: + return 管理员; + case 100: + return 超级管理员; + default: + return 未知身份; + } +} + +const UsersTable = () => { + const columns = [{ + title: 'ID', dataIndex: 'id' + }, { + title: '用户名', dataIndex: 'username' + }, { + title: '分组', dataIndex: 'group', render: (text, record, index) => { + return (
+ {renderGroup(text)} +
); + } + }, { + title: '统计信息', dataIndex: 'info', render: (text, record, index) => { + return (
+ + + {renderQuota(record.quota)} + + + {renderQuota(record.used_quota)} + + + {renderNumber(record.request_count)} + + +
); + } + }, + // { + // title: '邀请信息', dataIndex: 'invite', render: (text, record, index) => { + // return (
+ // + // + // {renderNumber(record.aff_count)} + // + // + // {renderQuota(record.aff_history_quota)} + // + // + // {record.inviter_id === 0 ? : + // {record.inviter_id}} + // + // + //
); + // } + // }, + { + title: '角色', dataIndex: 'role', render: (text, record, index) => { + return (
+ {renderRole(text)} +
); + } + }, + { + title: '状态', dataIndex: 'status', render: (text, record, index) => { + return (
+ {renderStatus(text)} +
); + } + }, + { + title: '', dataIndex: 'operate', render: (text, record, index) => (
+ <> + { + manageUser(record.username, 'promote', record); + }} + > + + + { + manageUser(record.username, 'demote', record); + }} + > + + + {record.status === 1 ? + : + } + + + { + manageUser(record.username, 'delete', record).then(() => { + removeRecord(record.id); + }); + }} + > + + +
) + }]; + + const [users, setUsers] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + const [userCount, setUserCount] = useState(ITEMS_PER_PAGE); + const [showAddUser, setShowAddUser] = useState(false); + const [showEditUser, setShowEditUser] = useState(false); + const [editingUser, setEditingUser] = useState({ + id: undefined + }); + + const setCount = (data) => { + if (data.length >= (activePage) * ITEMS_PER_PAGE) { + setUserCount(data.length + 1); + } else { + setUserCount(data.length); + } + }; + + const removeRecord = key => { + console.log(key); + let newDataSource = [...users]; + if (key != null) { + let idx = newDataSource.findIndex(data => data.id === key); + + if (idx > -1) { + newDataSource.splice(idx, 1); + setUsers(newDataSource); + } + } + }; + + const loadUsers = async (startIdx) => { + const res = await API.get(`/api/user/?p=${startIdx}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setUsers(data); + setCount(data); + } else { + let newUsers = users; + newUsers.push(...data); + setUsers(newUsers); + setCount(newUsers); + } + } else { + showError(message); + } + setLoading(false); + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(users.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + await loadUsers(activePage - 1); + } + setActivePage(activePage); + })(); + }; + + useEffect(() => { + loadUsers(0) + .then() + .catch((reason) => { + showError(reason); + }); + }, []); + + const manageUser = async (username, action, record) => { + const res = await API.post('/api/user/manage', { + username, action + }); + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let user = res.data.data; + let newUsers = [...users]; + if (action === 'delete') { + + } else { + record.status = user.status; + record.role = user.role; + } + setUsers(newUsers); + } else { + showError(message); + } + }; + + const renderStatus = (status) => { + switch (status) { + case 1: + return 已激活; + case 2: + return ( + 已封禁 + ); + default: + return ( + 未知状态 + ); + } + }; + + const searchUsers = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadUsers(0); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/user/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setUsers(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (value) => { + setSearchKeyword(value.trim()); + }; + + const sortUser = (key) => { + if (users.length === 0) return; + setLoading(true); + let sortedUsers = [...users]; + sortedUsers.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + if (sortedUsers[0].id === users[0].id) { + sortedUsers.reverse(); + } + setUsers(sortedUsers); + setLoading(false); + }; + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(users.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + loadUsers(page - 1).then(r => { + }); + } + }; + + const pageData = users.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE); + + const closeAddUser = () => { + setShowAddUser(false); + }; + + const closeEditUser = () => { + setShowEditUser(false); + setEditingUser({ + id: undefined + }); + }; + + const refresh = async () => { + if (searchKeyword === '') { + await loadUsers(activePage - 1); + } else { + await searchUsers(); + } + }; + + return ( + <> + + +
+ handleKeywordChange(value)} + /> + + + + + + ); +}; + +export default UsersTable; diff --git a/web/air/src/components/WeChatIcon.js b/web/air/src/components/WeChatIcon.js new file mode 100644 index 00000000..22210d95 --- /dev/null +++ b/web/air/src/components/WeChatIcon.js @@ -0,0 +1,24 @@ +import React from 'react'; +import { Icon } from '@douyinfe/semi-ui'; + +const WeChatIcon = () => { + function CustomIcon() { + return + + + ; + } + + return ( +
+ } /> +
+ ); +}; + +export default WeChatIcon; diff --git a/web/air/src/components/utils.js b/web/air/src/components/utils.js new file mode 100644 index 00000000..5363ba5e --- /dev/null +++ b/web/air/src/components/utils.js @@ -0,0 +1,20 @@ +import { API, showError } from '../helpers'; + +export async function getOAuthState() { + const res = await API.get('/api/oauth/state'); + const { success, message, data } = res.data; + if (success) { + return data; + } else { + showError(message); + return ''; + } +} + +export async function onGitHubOAuthClicked(github_client_id) { + const state = await getOAuthState(); + if (!state) return; + window.open( + `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email` + ); +} \ No newline at end of file diff --git a/web/air/src/constants/channel.constants.js b/web/air/src/constants/channel.constants.js new file mode 100644 index 00000000..4bf035f9 --- /dev/null +++ b/web/air/src/constants/channel.constants.js @@ -0,0 +1,37 @@ +export const CHANNEL_OPTIONS = [ + { key: 1, text: 'OpenAI', value: 1, color: 'green' }, + { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, + { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, + { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, + { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, + { key: 28, text: 'Mistral AI', value: 28, color: 'orange' }, + { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, + { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, + { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, + { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, + { key: 19, text: '360 智脑', value: 19, color: 'blue' }, + { key: 25, text: 'Moonshot AI', value: 25, color: 'black' }, + { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, + { key: 26, text: '百川大模型', value: 26, color: 'orange' }, + { key: 27, text: 'MiniMax', value: 27, color: 'red' }, + { key: 29, text: 'Groq', value: 29, color: 'orange' }, + { key: 30, text: 'Ollama', value: 30, color: 'black' }, + { key: 31, text: '零一万物', value: 31, color: 'green' }, + { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, + { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, + { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, + { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, + { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, + { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, + { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, + { key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' }, + { key: 4, text: '代理:CloseAI', value: 4, color: 'teal' }, + { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' }, + { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, + { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, + { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } +]; + +for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { + CHANNEL_OPTIONS[i].label = CHANNEL_OPTIONS[i].text; +} \ No newline at end of file diff --git a/web/air/src/constants/common.constant.js b/web/air/src/constants/common.constant.js new file mode 100644 index 00000000..1a37d5f6 --- /dev/null +++ b/web/air/src/constants/common.constant.js @@ -0,0 +1 @@ +export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend! diff --git a/web/air/src/constants/index.js b/web/air/src/constants/index.js new file mode 100644 index 00000000..e83152bc --- /dev/null +++ b/web/air/src/constants/index.js @@ -0,0 +1,4 @@ +export * from './toast.constants'; +export * from './user.constants'; +export * from './common.constant'; +export * from './channel.constants'; \ No newline at end of file diff --git a/web/air/src/constants/toast.constants.js b/web/air/src/constants/toast.constants.js new file mode 100644 index 00000000..50684722 --- /dev/null +++ b/web/air/src/constants/toast.constants.js @@ -0,0 +1,7 @@ +export const toastConstants = { + SUCCESS_TIMEOUT: 1500, + INFO_TIMEOUT: 3000, + ERROR_TIMEOUT: 5000, + WARNING_TIMEOUT: 10000, + NOTICE_TIMEOUT: 20000 +}; diff --git a/web/air/src/constants/user.constants.js b/web/air/src/constants/user.constants.js new file mode 100644 index 00000000..2680d8ef --- /dev/null +++ b/web/air/src/constants/user.constants.js @@ -0,0 +1,19 @@ +export const userConstants = { + REGISTER_REQUEST: 'USERS_REGISTER_REQUEST', + REGISTER_SUCCESS: 'USERS_REGISTER_SUCCESS', + REGISTER_FAILURE: 'USERS_REGISTER_FAILURE', + + LOGIN_REQUEST: 'USERS_LOGIN_REQUEST', + LOGIN_SUCCESS: 'USERS_LOGIN_SUCCESS', + LOGIN_FAILURE: 'USERS_LOGIN_FAILURE', + + LOGOUT: 'USERS_LOGOUT', + + GETALL_REQUEST: 'USERS_GETALL_REQUEST', + GETALL_SUCCESS: 'USERS_GETALL_SUCCESS', + GETALL_FAILURE: 'USERS_GETALL_FAILURE', + + DELETE_REQUEST: 'USERS_DELETE_REQUEST', + DELETE_SUCCESS: 'USERS_DELETE_SUCCESS', + DELETE_FAILURE: 'USERS_DELETE_FAILURE' +}; diff --git a/web/air/src/context/Status/index.js b/web/air/src/context/Status/index.js new file mode 100644 index 00000000..71f0682b --- /dev/null +++ b/web/air/src/context/Status/index.js @@ -0,0 +1,19 @@ +// contexts/User/index.jsx + +import React from 'react'; +import { initialState, reducer } from './reducer'; + +export const StatusContext = React.createContext({ + state: initialState, + dispatch: () => null, +}); + +export const StatusProvider = ({ children }) => { + const [state, dispatch] = React.useReducer(reducer, initialState); + + return ( + + {children} + + ); +}; \ No newline at end of file diff --git a/web/air/src/context/Status/reducer.js b/web/air/src/context/Status/reducer.js new file mode 100644 index 00000000..ec9ac6ae --- /dev/null +++ b/web/air/src/context/Status/reducer.js @@ -0,0 +1,20 @@ +export const reducer = (state, action) => { + switch (action.type) { + case 'set': + return { + ...state, + status: action.payload, + }; + case 'unset': + return { + ...state, + status: undefined, + }; + default: + return state; + } +}; + +export const initialState = { + status: undefined, +}; diff --git a/web/air/src/context/User/index.js b/web/air/src/context/User/index.js new file mode 100644 index 00000000..c6671591 --- /dev/null +++ b/web/air/src/context/User/index.js @@ -0,0 +1,19 @@ +// contexts/User/index.jsx + +import React from "react" +import { reducer, initialState } from "./reducer" + +export const UserContext = React.createContext({ + state: initialState, + dispatch: () => null +}) + +export const UserProvider = ({ children }) => { + const [state, dispatch] = React.useReducer(reducer, initialState) + + return ( + + { children } + + ) +} \ No newline at end of file diff --git a/web/air/src/context/User/reducer.js b/web/air/src/context/User/reducer.js new file mode 100644 index 00000000..9ed1d809 --- /dev/null +++ b/web/air/src/context/User/reducer.js @@ -0,0 +1,21 @@ +export const reducer = (state, action) => { + switch (action.type) { + case 'login': + return { + ...state, + user: action.payload + }; + case 'logout': + return { + ...state, + user: undefined + }; + + default: + return state; + } +}; + +export const initialState = { + user: undefined +}; \ No newline at end of file diff --git a/web/air/src/helpers/api.js b/web/air/src/helpers/api.js new file mode 100644 index 00000000..35fdb1e9 --- /dev/null +++ b/web/air/src/helpers/api.js @@ -0,0 +1,13 @@ +import { showError } from './utils'; +import axios from 'axios'; + +export const API = axios.create({ + baseURL: process.env.REACT_APP_SERVER ? process.env.REACT_APP_SERVER : '', +}); + +API.interceptors.response.use( + (response) => response, + (error) => { + showError(error); + } +); diff --git a/web/air/src/helpers/auth-header.js b/web/air/src/helpers/auth-header.js new file mode 100644 index 00000000..a8fe5f5a --- /dev/null +++ b/web/air/src/helpers/auth-header.js @@ -0,0 +1,10 @@ +export function authHeader() { + // return authorization header with jwt token + let user = JSON.parse(localStorage.getItem('user')); + + if (user && user.token) { + return { 'Authorization': 'Bearer ' + user.token }; + } else { + return {}; + } +} \ No newline at end of file diff --git a/web/air/src/helpers/history.js b/web/air/src/helpers/history.js new file mode 100644 index 00000000..629039e5 --- /dev/null +++ b/web/air/src/helpers/history.js @@ -0,0 +1,3 @@ +import { createBrowserHistory } from 'history'; + +export const history = createBrowserHistory(); \ No newline at end of file diff --git a/web/air/src/helpers/index.js b/web/air/src/helpers/index.js new file mode 100644 index 00000000..505a8cf9 --- /dev/null +++ b/web/air/src/helpers/index.js @@ -0,0 +1,4 @@ +export * from './history'; +export * from './auth-header'; +export * from './utils'; +export * from './api'; \ No newline at end of file diff --git a/web/air/src/helpers/render.js b/web/air/src/helpers/render.js new file mode 100644 index 00000000..62fb0dcd --- /dev/null +++ b/web/air/src/helpers/render.js @@ -0,0 +1,170 @@ +import {Label} from 'semantic-ui-react'; +import {Tag} from "@douyinfe/semi-ui"; + +export function renderText(text, limit) { + if (text.length > limit) { + return text.slice(0, limit - 3) + '...'; + } + return text; +} + +export function renderGroup(group) { + if (group === '') { + return default; + } + let groups = group.split(','); + groups.sort(); + return <> + {groups.map((group) => { + if (group === 'vip' || group === 'pro') { + return {group}; + } else if (group === 'svip' || group === 'premium') { + return {group}; + } + if (group === 'default') { + return {group}; + } else { + return {group}; + } + })} + ; +} + +export function renderNumber(num) { + if (num >= 1000000000) { + return (num / 1000000000).toFixed(1) + 'B'; + } else if (num >= 1000000) { + return (num / 1000000).toFixed(1) + 'M'; + } else if (num >= 10000) { + return (num / 1000).toFixed(1) + 'k'; + } else { + return num; + } +} + +export function renderQuotaNumberWithDigit(num, digits = 2) { + let displayInCurrency = localStorage.getItem('display_in_currency'); + num = num.toFixed(digits); + if (displayInCurrency) { + return '$' + num; + } + return num; +} + +export function renderNumberWithPoint(num) { + num = num.toFixed(2); + if (num >= 100000) { + // Convert number to string to manipulate it + let numStr = num.toString(); + // Find the position of the decimal point + let decimalPointIndex = numStr.indexOf('.'); + + let wholePart = numStr; + let decimalPart = ''; + + // If there is a decimal point, split the number into whole and decimal parts + if (decimalPointIndex !== -1) { + wholePart = numStr.slice(0, decimalPointIndex); + decimalPart = numStr.slice(decimalPointIndex); + } + + // Take the first two and last two digits of the whole number part + let shortenedWholePart = wholePart.slice(0, 2) + '..' + wholePart.slice(-2); + + // Return the formatted number + return shortenedWholePart + decimalPart; + } + + // If the number is less than 100,000, return it unmodified + return num; +} + +export function getQuotaPerUnit() { + let quotaPerUnit = localStorage.getItem('quota_per_unit'); + quotaPerUnit = parseFloat(quotaPerUnit); + return quotaPerUnit; +} + +export function getQuotaWithUnit(quota, digits = 6) { + let quotaPerUnit = localStorage.getItem('quota_per_unit'); + quotaPerUnit = parseFloat(quotaPerUnit); + return (quota / quotaPerUnit).toFixed(digits); +} + +export function renderQuota(quota, digits = 2) { + let quotaPerUnit = localStorage.getItem('quota_per_unit'); + let displayInCurrency = localStorage.getItem('display_in_currency'); + quotaPerUnit = parseFloat(quotaPerUnit); + displayInCurrency = displayInCurrency === 'true'; + if (displayInCurrency) { + return '$' + (quota / quotaPerUnit).toFixed(digits); + } + return renderNumber(quota); +} + +export function renderQuotaWithPrompt(quota, digits) { + let displayInCurrency = localStorage.getItem('display_in_currency'); + displayInCurrency = displayInCurrency === 'true'; + if (displayInCurrency) { + return `(等价金额:${renderQuota(quota, digits)})`; + } + return ''; +} + +const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo', + 'light-blue', 'lime', 'orange', 'pink', + 'purple', 'red', 'teal', 'violet', 'yellow' +] + +export const modelColorMap = { + 'dall-e': 'rgb(147,112,219)', // 深紫色 + 'dall-e-2': 'rgb(147,112,219)', // 介于紫色和蓝色之间的色调 + 'dall-e-3': 'rgb(153,50,204)', // 介于紫罗兰和洋红之间的色调 + 'midjourney': 'rgb(136,43,180)', // 介于紫罗兰和洋红之间的色调 + 'gpt-3.5-turbo': 'rgb(184,227,167)', // 浅绿色 + 'gpt-3.5-turbo-0301': 'rgb(131,220,131)', // 亮绿色 + 'gpt-3.5-turbo-0613': 'rgb(60,179,113)', // 海洋绿 + 'gpt-3.5-turbo-1106': 'rgb(32,178,170)', // 浅海洋绿 + 'gpt-3.5-turbo-16k': 'rgb(252,200,149)', // 淡橙色 + 'gpt-3.5-turbo-16k-0613': 'rgb(255,181,119)', // 淡桃色 + 'gpt-3.5-turbo-instruct': 'rgb(175,238,238)', // 粉蓝色 + 'gpt-4': 'rgb(135,206,235)', // 天蓝色 + 'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色 + 'gpt-4-0613': 'rgb(100,149,237)', // 矢车菊蓝 + 'gpt-4-1106-preview': 'rgb(30,144,255)', // 道奇蓝 + 'gpt-4-0125-preview': 'rgb(2,177,236)', // 深天蓝 + 'gpt-4-turbo-preview': 'rgb(2,177,255)', // 深天蓝 + 'gpt-4-32k': 'rgb(104,111,238)', // 中紫色 + 'gpt-4-32k-0314': 'rgb(90,105,205)', // 暗灰蓝色 + 'gpt-4-32k-0613': 'rgb(61,71,139)', // 暗蓝灰色 + 'gpt-4-all': 'rgb(65,105,225)', // 皇家蓝 + 'gpt-4-gizmo-*': 'rgb(0,0,255)', // 纯蓝色 + 'gpt-4-vision-preview': 'rgb(25,25,112)', // 午夜蓝 + 'text-ada-001': 'rgb(255,192,203)', // 粉红色 + 'text-babbage-001': 'rgb(255,160,122)', // 浅珊瑚色 + 'text-curie-001': 'rgb(219,112,147)', // 苍紫罗兰色 + 'text-davinci-002': 'rgb(199,21,133)', // 中紫罗兰红色 + 'text-davinci-003': 'rgb(219,112,147)', // 苍紫罗兰色(与Curie相同,表示同一个系列) + 'text-davinci-edit-001': 'rgb(255,105,180)', // 热粉色 + 'text-embedding-ada-002': 'rgb(255,182,193)', // 浅粉红 + 'text-embedding-v1': 'rgb(255,174,185)', // 浅粉红色(略有区别) + 'text-moderation-latest': 'rgb(255,130,171)', // 强粉色 + 'text-moderation-stable': 'rgb(255,160,122)', // 浅珊瑚色(与Babbage相同,表示同一类功能) + 'tts-1': 'rgb(255,140,0)', // 深橙色 + 'tts-1-1106': 'rgb(255,165,0)', // 橙色 + 'tts-1-hd': 'rgb(255,215,0)', // 金色 + 'tts-1-hd-1106': 'rgb(255,223,0)', // 金黄色(略有区别) + 'whisper-1': 'rgb(245,245,220)' // 米色 +} + +export function stringToColor(str) { + let sum = 0; + // 对字符串中的每个字符进行操作 + for (let i = 0; i < str.length; i++) { + // 将字符的ASCII值加到sum中 + sum += str.charCodeAt(i); + } + // 使用模运算得到个位数 + let i = sum % colors.length; + return colors[i]; +} \ No newline at end of file diff --git a/web/air/src/helpers/utils.js b/web/air/src/helpers/utils.js new file mode 100644 index 00000000..580c77ce --- /dev/null +++ b/web/air/src/helpers/utils.js @@ -0,0 +1,233 @@ +import { Toast } from '@douyinfe/semi-ui'; +import { toastConstants } from '../constants'; +import React from 'react'; +import {toast} from "react-toastify"; + +const HTMLToastContent = ({ htmlContent }) => { + return
; +}; +export default HTMLToastContent; +export function isAdmin() { + let user = localStorage.getItem('user'); + if (!user) return false; + user = JSON.parse(user); + return user.role >= 10; +} + +export function isRoot() { + let user = localStorage.getItem('user'); + if (!user) return false; + user = JSON.parse(user); + return user.role >= 100; +} + +export function getSystemName() { + let system_name = localStorage.getItem('system_name'); + if (!system_name) return 'One API'; + return system_name; +} + +export function getLogo() { + let logo = localStorage.getItem('logo'); + if (!logo) return '/logo.png'; + return logo +} + +export function getFooterHTML() { + return localStorage.getItem('footer_html'); +} + +export async function copy(text) { + let okay = true; + try { + await navigator.clipboard.writeText(text); + } catch (e) { + okay = false; + console.error(e); + } + return okay; +} + +export function isMobile() { + return window.innerWidth <= 600; +} + +let showErrorOptions = { autoClose: toastConstants.ERROR_TIMEOUT }; +let showWarningOptions = { autoClose: toastConstants.WARNING_TIMEOUT }; +let showSuccessOptions = { autoClose: toastConstants.SUCCESS_TIMEOUT }; +let showInfoOptions = { autoClose: toastConstants.INFO_TIMEOUT }; +let showNoticeOptions = { autoClose: false }; + +if (isMobile()) { + showErrorOptions.position = 'top-center'; + // showErrorOptions.transition = 'flip'; + + showSuccessOptions.position = 'top-center'; + // showSuccessOptions.transition = 'flip'; + + showInfoOptions.position = 'top-center'; + // showInfoOptions.transition = 'flip'; + + showNoticeOptions.position = 'top-center'; + // showNoticeOptions.transition = 'flip'; +} + +export function showError(error) { + console.error(error); + if (error.message) { + if (error.name === 'AxiosError') { + switch (error.response.status) { + case 401: + // toast.error('错误:未登录或登录已过期,请重新登录!', showErrorOptions); + window.location.href = '/login?expired=true'; + break; + case 429: + Toast.error('错误:请求次数过多,请稍后再试!'); + break; + case 500: + Toast.error('错误:服务器内部错误,请联系管理员!'); + break; + case 405: + Toast.info('本站仅作演示之用,无服务端!'); + break; + default: + Toast.error('错误:' + error.message); + } + return; + } + Toast.error('错误:' + error.message); + } else { + Toast.error('错误:' + error); + } +} + +export function showWarning(message) { + Toast.warning(message); +} + +export function showSuccess(message) { + Toast.success(message); +} + +export function showInfo(message) { + Toast.info(message); +} + +export function showNotice(message, isHTML = false) { + if (isHTML) { + toast(, showNoticeOptions); + } else { + Toast.info(message); + } +} + +export function openPage(url) { + window.open(url); +} + +export function removeTrailingSlash(url) { + if (url.endsWith('/')) { + return url.slice(0, -1); + } else { + return url; + } +} + +export function timestamp2string(timestamp) { + let date = new Date(timestamp * 1000); + let year = date.getFullYear().toString(); + let month = (date.getMonth() + 1).toString(); + let day = date.getDate().toString(); + let hour = date.getHours().toString(); + let minute = date.getMinutes().toString(); + let second = date.getSeconds().toString(); + if (month.length === 1) { + month = '0' + month; + } + if (day.length === 1) { + day = '0' + day; + } + if (hour.length === 1) { + hour = '0' + hour; + } + if (minute.length === 1) { + minute = '0' + minute; + } + if (second.length === 1) { + second = '0' + second; + } + return ( + year + + '-' + + month + + '-' + + day + + ' ' + + hour + + ':' + + minute + + ':' + + second + ); +} + +export function timestamp2string1(timestamp, dataExportDefaultTime = 'hour') { + let date = new Date(timestamp * 1000); + // let year = date.getFullYear().toString(); + let month = (date.getMonth() + 1).toString(); + let day = date.getDate().toString(); + let hour = date.getHours().toString(); + if (month.length === 1) { + month = '0' + month; + } + if (day.length === 1) { + day = '0' + day; + } + if (hour.length === 1) { + hour = '0' + hour; + } + let str = month + '-' + day + if (dataExportDefaultTime === 'hour') { + str += ' ' + hour + ":00" + } else if (dataExportDefaultTime === 'week') { + let nextWeek = new Date(timestamp * 1000 + 6 * 24 * 60 * 60 * 1000); + let nextMonth = (nextWeek.getMonth() + 1).toString(); + let nextDay = nextWeek.getDate().toString(); + if (nextMonth.length === 1) { + nextMonth = '0' + nextMonth; + } + if (nextDay.length === 1) { + nextDay = '0' + nextDay; + } + str += ' - ' + nextMonth + '-' + nextDay + } + return str; +} + +export function downloadTextAsFile(text, filename) { + let blob = new Blob([text], { type: 'text/plain;charset=utf-8' }); + let url = URL.createObjectURL(blob); + let a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); +} + +export const verifyJSON = (str) => { + try { + JSON.parse(str); + } catch (e) { + return false; + } + return true; +}; + +export function shouldShowPrompt(id) { + let prompt = localStorage.getItem(`prompt-${id}`); + return !prompt; + +} + +export function setPromptShown(id) { + localStorage.setItem(`prompt-${id}`, 'true'); +} \ No newline at end of file diff --git a/web/air/src/index.css b/web/air/src/index.css new file mode 100644 index 00000000..271f14e2 --- /dev/null +++ b/web/air/src/index.css @@ -0,0 +1,116 @@ +body { + margin: 0; + padding-top: 55px; + overflow-y: scroll; + font-family: Lato, 'Helvetica Neue', Arial, Helvetica, "Microsoft YaHei", sans-serif; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; + scrollbar-width: none; + color: var(--semi-color-text-0) !important; + background-color: var( --semi-color-bg-0) !important; + height: 100%; +} + +#root { + height: 100%; +} + +@media only screen and (max-width: 767px) { + .semi-table-tbody, .semi-table-row, .semi-table-row-cell { + display: block!important; + width: auto!important; + padding: 2px!important; + } + .semi-table-row-cell { + border-bottom: 0!important; + } + .semi-table-tbody>.semi-table-row { + border-bottom: 1px solid rgba(0,0,0,.1); + } + .semi-space { + /*display: block!important;*/ + display: flex; + flex-direction: row; + flex-wrap: wrap; + row-gap: 3px; + column-gap: 10px; + } +} + +.semi-table-tbody > .semi-table-row > .semi-table-row-cell { + padding: 16px 14px; +} + +.channel-table { + .semi-table-tbody > .semi-table-row > .semi-table-row-cell { + padding: 16px 8px; + } +} + +.semi-layout { + height: 100%; +} + +.tableShow { + display: revert; +} + +.tableHiddle { + display: none !important; +} + +body::-webkit-scrollbar { + display: none; +} + +code { + font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace; +} + +.semi-navigation-vertical { + /*display: flex;*/ + /*flex-direction: column;*/ +} + +.semi-navigation-item { + margin-bottom: 0; +} + +.semi-navigation-vertical { + /*flex: 0 0 auto;*/ + /*display: flex;*/ + /*flex-direction: column;*/ + /*width: 100%;*/ + height: 100%; + overflow: hidden; +} + +.main-content { + padding: 4px; + height: 100%; +} + +.small-icon .icon { + font-size: 1em !important; +} + +.custom-footer { + font-size: 1.1em; +} + +@media only screen and (max-width: 600px) { + .hide-on-mobile { + display: none !important; + } +} + + +/* 隐藏浏览器默认的滚动条 */ +body { + overflow: hidden; +} + +/* 自定义滚动条样式 */ +body::-webkit-scrollbar { + width: 0; /* 隐藏滚动条的宽度 */ +} \ No newline at end of file diff --git a/web/air/src/index.js b/web/air/src/index.js new file mode 100644 index 00000000..25b1d39e --- /dev/null +++ b/web/air/src/index.js @@ -0,0 +1,54 @@ +import { initVChartSemiTheme } from '@visactor/vchart-semi-theme'; +import React from 'react'; +import ReactDOM from 'react-dom/client'; +import {BrowserRouter} from 'react-router-dom'; +import App from './App'; +import HeaderBar from './components/HeaderBar'; +import Footer from './components/Footer'; +import 'semantic-ui-css/semantic.min.css'; +import './index.css'; +import {UserProvider} from './context/User'; +import {ToastContainer} from 'react-toastify'; +import 'react-toastify/dist/ReactToastify.css'; +import {StatusProvider} from './context/Status'; +import {Layout} from "@douyinfe/semi-ui"; +import SiderBar from "./components/SiderBar"; + +// initialization +initVChartSemiTheme({ + isWatchingThemeSwitch: true, +}); + +const root = ReactDOM.createRoot(document.getElementById('root')); +const {Sider, Content, Header} = Layout; +root.render( + + + + + + + + + +
+ +
+ + + + +
+
+
+ +
+
+
+
+
+); diff --git a/web/air/src/pages/About/index.js b/web/air/src/pages/About/index.js new file mode 100644 index 00000000..ec13f151 --- /dev/null +++ b/web/air/src/pages/About/index.js @@ -0,0 +1,58 @@ +import React, { useEffect, useState } from 'react'; +import { Header, Segment } from 'semantic-ui-react'; +import { API, showError } from '../../helpers'; +import { marked } from 'marked'; + +const About = () => { + const [about, setAbout] = useState(''); + const [aboutLoaded, setAboutLoaded] = useState(false); + + const displayAbout = async () => { + setAbout(localStorage.getItem('about') || ''); + const res = await API.get('/api/about'); + const { success, message, data } = res.data; + if (success) { + let aboutContent = data; + if (!data.startsWith('https://')) { + aboutContent = marked.parse(data); + } + setAbout(aboutContent); + localStorage.setItem('about', aboutContent); + } else { + showError(message); + setAbout('加载关于内容失败...'); + } + setAboutLoaded(true); + }; + + useEffect(() => { + displayAbout().then(); + }, []); + + return ( + <> + { + aboutLoaded && about === '' ? <> + +
关于
+

可在设置页面设置关于内容,支持 HTML & Markdown

+ 项目仓库地址: + + https://github.com/songquanpeng/one-api + +
+ : <> + { + about.startsWith('https://') ?