diff --git a/.github/workflows/docker-image-amd64-en.yml b/.github/workflows/docker-image-amd64-en.yml index 44dc0bc0..af488256 100644 --- a/.github/workflows/docker-image-amd64-en.yml +++ b/.github/workflows/docker-image-amd64-en.yml @@ -20,6 +20,13 @@ jobs: - name: Check out the repo uses: actions/checkout@v3 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi + - name: Save version info run: | git describe --tags > VERSION diff --git a/.github/workflows/docker-image-amd64.yml b/.github/workflows/docker-image-amd64.yml index e3b8439a..2079d31f 100644 --- a/.github/workflows/docker-image-amd64.yml +++ b/.github/workflows/docker-image-amd64.yml @@ -20,6 +20,13 @@ jobs: - name: Check out the repo uses: actions/checkout@v3 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi + - name: Save version info run: | git describe --tags > VERSION diff --git a/.github/workflows/docker-image-arm64.yml b/.github/workflows/docker-image-arm64.yml index d6449eb8..39d1a401 100644 --- a/.github/workflows/docker-image-arm64.yml +++ b/.github/workflows/docker-image-arm64.yml @@ -21,6 +21,13 @@ jobs: - name: Check out the repo uses: actions/checkout@v3 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi + - name: Save version info run: | git describe --tags > VERSION diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml index 04782864..1418917e 100644 --- a/.github/workflows/linux-release.yml +++ b/.github/workflows/linux-release.yml @@ -20,6 +20,12 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi - uses: actions/setup-node@v3 with: node-version: 16 @@ -38,7 +44,7 @@ jobs: - name: Build Backend (amd64) run: | go mod download - go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api + go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api - name: Build Backend (arm64) run: | diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml index 9142609f..359c2c92 100644 --- a/.github/workflows/macos-release.yml +++ b/.github/workflows/macos-release.yml @@ -20,6 +20,12 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi - uses: actions/setup-node@v3 with: node-version: 16 @@ -38,7 +44,7 @@ jobs: - name: Build Backend run: | go mod download - go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos + go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos - name: Release uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml index c058f41d..4e99b75c 100644 --- a/.github/workflows/windows-release.yml +++ b/.github/workflows/windows-release.yml @@ -23,6 +23,12 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi - uses: actions/setup-node@v3 with: node-version: 16 @@ -41,7 +47,7 @@ jobs: - name: Build Backend run: | go mod download - go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe + go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe - name: Release uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') diff --git a/.gitignore b/.gitignore index 974fcf63..2a8ae16e 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ build logs data /web/node_modules +cmd.md \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index ec2f9d43..6743b139 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,6 +12,10 @@ WORKDIR /web/berry RUN npm install RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build +WORKDIR /web/air +RUN npm install +RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build + FROM golang AS builder2 ENV GO111MODULE=on \ diff --git a/README.en.md b/README.en.md index eec0047b..bce47353 100644 --- a/README.en.md +++ b/README.en.md @@ -241,17 +241,19 @@ If the channel ID is not provided, load balancing will be used to distribute the + Example: `SESSION_SECRET=random_string` 3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0. + Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` -4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. +4. `LOG_SQL_DSN`: When set, a separate database will be used for the `logs` table; please use MySQL or PostgreSQL. + + Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs` +5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` -5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. +6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. + Example: `SYNC_FREQUENCY=60` -6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. +7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. + Example: `NODE_TYPE=slave` -7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. +8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. + Example: `CHANNEL_UPDATE_FREQUENCY=1440` -8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. +9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. + Example: `CHANNEL_TEST_FREQUENCY=1440` -9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. +10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. + Example: `POLLING_INTERVAL=5` ### Command Line Parameters diff --git a/README.ja.md b/README.ja.md index e9149d71..c15915ec 100644 --- a/README.ja.md +++ b/README.ja.md @@ -242,17 +242,18 @@ graph LR + 例: `SESSION_SECRET=random_string` 3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。 + 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` -4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 +4. `LOG_SQL_DSN`: を設定すると、`logs`テーブルには独立したデータベースが使用されます。MySQLまたはPostgreSQLを使用してください。 +5. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 + 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` -5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 +6. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 + 例: `SYNC_FREQUENCY=60` -6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 +7. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 + 例: `NODE_TYPE=slave` -7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 +8. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 + 例: `CHANNEL_UPDATE_FREQUENCY=1440` -8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 +9. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 + 例: `CHANNEL_TEST_FREQUENCY=1440` -9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 +10. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 + 例: `POLLING_INTERVAL=5` ### コマンドラインパラメータ diff --git a/README.md b/README.md index ff1fffd2..d42c6b49 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude 系列模型](https://anthropic.com) + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + + [x] [Mistral 系列模型](https://mistral.ai/) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) @@ -74,15 +75,20 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [360 智脑](https://ai.360.cn) + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) + [x] [Moonshot AI](https://platform.moonshot.cn/) + + [x] [百川大模型](https://platform.baichuan-ai.com) + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) - + [ ] [MINIMAX](https://api.minimax.chat/) (WIP) + + [x] [MINIMAX](https://api.minimax.chat/) + + [x] [Groq](https://wow.groq.com/) + + [x] [Ollama](https://github.com/ollama/ollama) + + [x] [零一万物](https://platform.lingyiwanwu.com/) + + [x] [阶跃星辰](https://platform.stepfun.com/) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 5. 支持**多机部署**,[详见此处](#多机部署)。 -6. 支持**令牌管理**,设置令牌的过期时间和额度。 +6. 支持**令牌管理**,设置令牌的过期时间、额度、允许的 IP 范围以及允许的模型访问。 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 -8. 支持**通道管理**,批量创建通道。 +8. 支持**渠道管理**,批量创建渠道。 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 10. 支持渠道**设置模型列表**。 11. 支持**查看额度明细**。 @@ -96,13 +102,15 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 19. 支持丰富的**自定义**设置, 1. 支持自定义系统名称,logo 以及页脚。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 -20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。 +20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。。 21. 支持 Cloudflare Turnstile 用户校验。 22. 支持用户管理,支持**多种用户登录注册方式**: + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 + + 支持使用飞书进行授权登录。 + [GitHub 开放授权](https://github.com/settings/applications/new)。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 +24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 ## 部署 ### 基于 Docker 进行部署 @@ -343,35 +351,41 @@ graph LR + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 -4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 +4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL。 +5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` -5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 +6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 + 例子:`MEMORY_CACHE_ENABLED=true` -6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 +7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 + 例子:`SYNC_FREQUENCY=60` -7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 +8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 + 例子:`NODE_TYPE=slave` -8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 +9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` -9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 - + 例子:`CHANNEL_TEST_FREQUENCY=1440` -10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 +10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 +11. 例子:`CHANNEL_TEST_FREQUENCY=1440` +12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + 例子:`POLLING_INTERVAL=5` -11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 +13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + 例子:`BATCH_UPDATE_ENABLED=true` + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 -12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 +14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + 例子:`BATCH_UPDATE_INTERVAL=5` -13. 请求频率限制: +15. 请求频率限制: + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 -14. 编码器缓存设置: +16. 编码器缓存设置: + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 -15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 -16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 -17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 -18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 +17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 +18. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 +19. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 +20. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 +21. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 +22. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 +23. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 +24. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 +25. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 @@ -410,7 +424,7 @@ https://openai.justsong.cn + 检查你的接口地址和 API Key 有没有填对。 + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 6. 报错:`当前分组负载已饱和,请稍后再试` - + 上游通道 429 了。 + + 上游渠道 429 了。 7. 升级之后我的数据会丢失吗? + 如果使用 MySQL,不会。 + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 @@ -418,8 +432,8 @@ https://openai.justsong.cn + 一般情况下不需要,系统将在初始化的时候自动调整。 + 如果需要的话,我会在更新日志中说明,并给出脚本。 9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`? - + 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。 - + 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。 + + 这是检测到 ability 表里有些记录的渠道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的渠道。 + + 对于每一个渠道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该渠道支持该模型。 ## 相关项目 * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 diff --git a/common/blacklist/main.go b/common/blacklist/main.go new file mode 100644 index 00000000..f84ce6ae --- /dev/null +++ b/common/blacklist/main.go @@ -0,0 +1,29 @@ +package blacklist + +import ( + "fmt" + "sync" +) + +var blackList sync.Map + +func init() { + blackList = sync.Map{} +} + +func userId2Key(id int) string { + return fmt.Sprintf("userid_%d", id) +} + +func BanUser(id int) { + blackList.Store(userId2Key(id), true) +} + +func UnbanUser(id int) { + blackList.Delete(userId2Key(id)) +} + +func IsUserBanned(id int) bool { + _, ok := blackList.Load(userId2Key(id)) + return ok +} diff --git a/common/config/config.go b/common/config/config.go index dd0236b4..4d54f4e5 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -1,7 +1,7 @@ package config import ( - "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/env" "os" "strconv" "sync" @@ -52,6 +52,7 @@ var EmailDomainWhitelist = []string{ } var DebugEnabled = os.Getenv("DEBUG") == "true" +var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true" var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" var LogConsumeEnabled = true @@ -65,21 +66,27 @@ var SMTPToken = "" var GitHubClientId = "" var GitHubClientSecret = "" +var LarkClientId = "" +var LarkClientSecret = "" + var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" +var MessagePusherAddress = "" +var MessagePusherToken = "" + var TurnstileSiteKey = "" var TurnstileSecretKey = "" -var QuotaForNewUser = 0 -var QuotaForInviter = 0 -var QuotaForInvitee = 0 +var QuotaForNewUser int64 = 0 +var QuotaForInviter int64 = 0 +var QuotaForInvitee int64 = 0 var ChannelDisableThreshold = 5.0 var AutomaticDisableChannelEnabled = false var AutomaticEnableChannelEnabled = false -var QuotaRemindThreshold = 1000 -var PreConsumedQuota = 500 +var QuotaRemindThreshold int64 = 1000 +var PreConsumedQuota int64 = 500 var ApproximateTokenEnabled = false var RetryTimes = 0 @@ -90,28 +97,29 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var RequestInterval = time.Duration(requestInterval) * time.Second -var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second +var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second var BatchUpdateEnabled = false -var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) +var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5) -var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second +var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second -var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") +var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE") -var Theme = helper.GetOrDefaultEnvString("THEME", "default") +var Theme = env.String("THEME", "default") var ValidThemes = map[string]bool{ "default": true, "berry": true, + "air": true, } // All duration's unit is seconds // Shouldn't larger then RateLimitKeyExpirationDuration var ( - GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) + GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitDuration int64 = 3 * 60 - GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) + GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration int64 = 3 * 60 UploadRateLimitNum = 10 @@ -125,3 +133,13 @@ var ( ) var RateLimitKeyExpirationDuration = 20 * time.Minute + +var EnableMetric = env.Bool("ENABLE_METRIC", false) +var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10) +var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) +var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024) +var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) + +var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") + +var GeminiVersion = env.String("GEMINI_VERSION", "v1") diff --git a/common/config/key.go b/common/config/key.go new file mode 100644 index 00000000..4b503c2d --- /dev/null +++ b/common/config/key.go @@ -0,0 +1,9 @@ +package config + +const ( + KeyPrefix = "cfg_" + + KeyAPIVersion = KeyPrefix + "api_version" + KeyLibraryID = KeyPrefix + "library_id" + KeyPlugin = KeyPrefix + "plugin" +) diff --git a/common/constants.go b/common/constants.go index ccaa3560..87221b61 100644 --- a/common/constants.go +++ b/common/constants.go @@ -4,101 +4,3 @@ import "time" var StartTime = time.Now().Unix() // unit: second var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change - -const ( - RoleGuestUser = 0 - RoleCommonUser = 1 - RoleAdminUser = 10 - RoleRootUser = 100 -) - -const ( - UserStatusEnabled = 1 // don't use 0, 0 is the default value! - UserStatusDisabled = 2 // also don't use 0 -) - -const ( - TokenStatusEnabled = 1 // don't use 0, 0 is the default value! - TokenStatusDisabled = 2 // also don't use 0 - TokenStatusExpired = 3 - TokenStatusExhausted = 4 -) - -const ( - RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value! - RedemptionCodeStatusDisabled = 2 // also don't use 0 - RedemptionCodeStatusUsed = 3 // also don't use 0 -) - -const ( - ChannelStatusUnknown = 0 - ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! - ChannelStatusManuallyDisabled = 2 // also don't use 0 - ChannelStatusAutoDisabled = 3 -) - -const ( - ChannelTypeUnknown = 0 - ChannelTypeOpenAI = 1 - ChannelTypeAPI2D = 2 - ChannelTypeAzure = 3 - ChannelTypeCloseAI = 4 - ChannelTypeOpenAISB = 5 - ChannelTypeOpenAIMax = 6 - ChannelTypeOhMyGPT = 7 - ChannelTypeCustom = 8 - ChannelTypeAILS = 9 - ChannelTypeAIProxy = 10 - ChannelTypePaLM = 11 - ChannelTypeAPI2GPT = 12 - ChannelTypeAIGC2D = 13 - ChannelTypeAnthropic = 14 - ChannelTypeBaidu = 15 - ChannelTypeZhipu = 16 - ChannelTypeAli = 17 - ChannelTypeXunfei = 18 - ChannelType360 = 19 - ChannelTypeOpenRouter = 20 - ChannelTypeAIProxyLibrary = 21 - ChannelTypeFastGPT = 22 - ChannelTypeTencent = 23 - ChannelTypeGemini = 24 - ChannelTypeMoonshot = 25 -) - -var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "https://generativelanguage.googleapis.com", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://ai.360.cn", // 19 - "https://openrouter.ai/api", // 20 - "https://api.aiproxy.io", // 21 - "https://fastgpt.run/api/openapi", // 22 - "https://hunyuan.cloud.tencent.com", // 23 - "https://generativelanguage.googleapis.com", // 24 - "https://api.moonshot.cn", // 25 -} - -const ( - ConfigKeyPrefix = "cfg_" - - ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version" - ConfigKeyLibraryID = ConfigKeyPrefix + "library_id" - ConfigKeyPlugin = ConfigKeyPrefix + "plugin" -) diff --git a/common/conv/any.go b/common/conv/any.go new file mode 100644 index 00000000..467e8bb7 --- /dev/null +++ b/common/conv/any.go @@ -0,0 +1,6 @@ +package conv + +func AsString(v any) string { + str, _ := v.(string) + return str +} diff --git a/common/database.go b/common/database.go index 9b52a0d5..f2db759f 100644 --- a/common/database.go +++ b/common/database.go @@ -1,9 +1,12 @@ package common -import "github.com/songquanpeng/one-api/common/helper" +import ( + "github.com/songquanpeng/one-api/common/env" +) var UsingSQLite = false var UsingPostgreSQL = false +var UsingMySQL = false var SQLitePath = "one-api.db" -var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) +var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/env/helper.go b/common/env/helper.go new file mode 100644 index 00000000..fdb9f827 --- /dev/null +++ b/common/env/helper.go @@ -0,0 +1,42 @@ +package env + +import ( + "os" + "strconv" +) + +func Bool(env string, defaultValue bool) bool { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) == "true" +} + +func Int(env string, defaultValue int) int { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.Atoi(os.Getenv(env)) + if err != nil { + return defaultValue + } + return num +} + +func Float64(env string, defaultValue float64) float64 { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.ParseFloat(os.Getenv(env), 64) + if err != nil { + return defaultValue + } + return num +} + +func String(env string, defaultValue string) string { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) +} diff --git a/common/helper/helper.go b/common/helper/helper.go index babe422b..cf2e1635 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -2,18 +2,14 @@ package helper import ( "fmt" - "github.com/google/uuid" - "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" "html/template" "log" - "math/rand" "net" - "os" "os/exec" "runtime" "strconv" "strings" - "time" ) func OpenBrowser(url string) { @@ -81,31 +77,6 @@ func Bytes2Size(num int64) string { return numStr + " " + unit } -func Seconds2Time(num int) (time string) { - if num/31104000 > 0 { - time += strconv.Itoa(num/31104000) + " 年 " - num %= 31104000 - } - if num/2592000 > 0 { - time += strconv.Itoa(num/2592000) + " 个月 " - num %= 2592000 - } - if num/86400 > 0 { - time += strconv.Itoa(num/86400) + " 天 " - num %= 86400 - } - if num/3600 > 0 { - time += strconv.Itoa(num/3600) + " 小时 " - num %= 3600 - } - if num/60 > 0 { - time += strconv.Itoa(num/60) + " 分钟 " - num %= 60 - } - time += strconv.Itoa(num) + " 秒" - return -} - func Interface2String(inter interface{}) string { switch inter := inter.(type) { case string: @@ -130,61 +101,8 @@ func IntMax(a int, b int) int { } } -func GetUUID() string { - code := uuid.New().String() - code = strings.Replace(code, "-", "", -1) - return code -} - -const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -const keyNumbers = "0123456789" - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func GenerateKey() string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, 48) - for i := 0; i < 16; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - uuid_ := GetUUID() - for i := 0; i < 32; i++ { - c := uuid_[i] - if i%2 == 0 && c >= 'a' && c <= 'z' { - c = c - 'a' + 'A' - } - key[i+16] = c - } - return string(key) -} - -func GetRandomString(length int) string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, length) - for i := 0; i < length; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - return string(key) -} - -func GetRandomNumberString(length int) string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, length) - for i := 0; i < length; i++ { - key[i] = keyNumbers[rand.Intn(len(keyNumbers))] - } - return string(key) -} - -func GetTimestamp() int64 { - return time.Now().Unix() -} - -func GetTimeString() string { - now := time.Now() - return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) +func GenRequestID() string { + return GetTimeString() + random.GetRandomNumberString(8) } func Max(a int, b int) int { @@ -195,25 +113,6 @@ func Max(a int, b int) int { } } -func GetOrDefaultEnvInt(env string, defaultValue int) int { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - num, err := strconv.Atoi(os.Getenv(env)) - if err != nil { - logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) - return defaultValue - } - return num -} - -func GetOrDefaultEnvString(env string, defaultValue string) string { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - return os.Getenv(env) -} - func AssignOrDefault(value string, defaultValue string) string { if len(value) != 0 { return value diff --git a/common/helper/time.go b/common/helper/time.go new file mode 100644 index 00000000..302746db --- /dev/null +++ b/common/helper/time.go @@ -0,0 +1,15 @@ +package helper + +import ( + "fmt" + "time" +) + +func GetTimestamp() int64 { + return time.Now().Unix() +} + +func GetTimeString() string { + now := time.Now() + return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) +} diff --git a/common/logger/logger.go b/common/logger/logger.go index f970ee61..957d8a11 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" "io" "log" "os" @@ -13,14 +15,12 @@ import ( ) const ( + loggerDEBUG = "DEBUG" loggerINFO = "INFO" loggerWarn = "WARN" loggerError = "ERR" ) -const maxLogCount = 1000000 - -var logCount int var setupLogLock sync.Mutex var setupLogWorking bool @@ -55,6 +55,12 @@ func SysError(s string) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } +func Debug(ctx context.Context, msg string) { + if config.DebugEnabled { + logHelper(ctx, loggerDEBUG, msg) + } +} + func Info(ctx context.Context, msg string) { logHelper(ctx, loggerINFO, msg) } @@ -67,6 +73,10 @@ func Error(ctx context.Context, msg string) { logHelper(ctx, loggerError, msg) } +func Debugf(ctx context.Context, format string, a ...any) { + Debug(ctx, fmt.Sprintf(format, a...)) +} + func Infof(ctx context.Context, format string, a ...any) { Info(ctx, fmt.Sprintf(format, a...)) } @@ -85,11 +95,12 @@ func logHelper(ctx context.Context, level string, msg string) { writer = gin.DefaultWriter } id := ctx.Value(RequestIdKey) + if id == nil { + id = helper.GenRequestID() + } now := time.Now() _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) - logCount++ // we don't need accurate count, so no lock here - if logCount > maxLogCount && !setupLogWorking { - logCount = 0 + if !setupLogWorking { setupLogWorking = true go func() { SetupLogger() diff --git a/common/email.go b/common/message/email.go similarity index 96% rename from common/email.go rename to common/message/email.go index 2689da6a..b06782db 100644 --- a/common/email.go +++ b/common/message/email.go @@ -1,4 +1,4 @@ -package common +package message import ( "crypto/rand" @@ -12,6 +12,9 @@ import ( ) func SendEmail(subject string, receiver string, content string) error { + if receiver == "" { + return fmt.Errorf("receiver is empty") + } if config.SMTPFrom == "" { // for compatibility config.SMTPFrom = config.SMTPAccount } diff --git a/common/message/main.go b/common/message/main.go new file mode 100644 index 00000000..5ce82a64 --- /dev/null +++ b/common/message/main.go @@ -0,0 +1,22 @@ +package message + +import ( + "fmt" + "github.com/songquanpeng/one-api/common/config" +) + +const ( + ByAll = "all" + ByEmail = "email" + ByMessagePusher = "message_pusher" +) + +func Notify(by string, title string, description string, content string) error { + if by == ByEmail { + return SendEmail(title, config.RootUserEmail, content) + } + if by == ByMessagePusher { + return SendMessage(title, description, content) + } + return fmt.Errorf("unknown notify method: %s", by) +} diff --git a/common/message/message-pusher.go b/common/message/message-pusher.go new file mode 100644 index 00000000..69949b4b --- /dev/null +++ b/common/message/message-pusher.go @@ -0,0 +1,53 @@ +package message + +import ( + "bytes" + "encoding/json" + "errors" + "github.com/songquanpeng/one-api/common/config" + "net/http" +) + +type request struct { + Title string `json:"title"` + Description string `json:"description"` + Content string `json:"content"` + URL string `json:"url"` + Channel string `json:"channel"` + Token string `json:"token"` +} + +type response struct { + Success bool `json:"success"` + Message string `json:"message"` +} + +func SendMessage(title string, description string, content string) error { + if config.MessagePusherAddress == "" { + return errors.New("message pusher address is not set") + } + req := request{ + Title: title, + Description: description, + Content: content, + Token: config.MessagePusherToken, + } + data, err := json.Marshal(req) + if err != nil { + return err + } + resp, err := http.Post(config.MessagePusherAddress, + "application/json", bytes.NewBuffer(data)) + if err != nil { + return err + } + var res response + err = json.NewDecoder(resp.Body).Decode(&res) + if err != nil { + return err + } + if !res.Success { + return errors.New(res.Message) + } + return nil +} diff --git a/common/network/ip.go b/common/network/ip.go new file mode 100644 index 00000000..0fbe5e6f --- /dev/null +++ b/common/network/ip.go @@ -0,0 +1,52 @@ +package network + +import ( + "context" + "fmt" + "github.com/songquanpeng/one-api/common/logger" + "net" + "strings" +) + +func splitSubnets(subnets string) []string { + res := strings.Split(subnets, ",") + for i := 0; i < len(res); i++ { + res[i] = strings.TrimSpace(res[i]) + } + return res +} + +func isValidSubnet(subnet string) error { + _, _, err := net.ParseCIDR(subnet) + if err != nil { + return fmt.Errorf("failed to parse subnet: %w", err) + } + return nil +} + +func isIpInSubnet(ctx context.Context, ip string, subnet string) bool { + _, ipNet, err := net.ParseCIDR(subnet) + if err != nil { + logger.Errorf(ctx, "failed to parse subnet: %s", err.Error()) + return false + } + return ipNet.Contains(net.ParseIP(ip)) +} + +func IsValidSubnets(subnets string) error { + for _, subnet := range splitSubnets(subnets) { + if err := isValidSubnet(subnet); err != nil { + return err + } + } + return nil +} + +func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool { + for _, subnet := range splitSubnets(subnets) { + if isIpInSubnet(ctx, ip, subnet) { + return true + } + } + return false +} diff --git a/common/network/ip_test.go b/common/network/ip_test.go new file mode 100644 index 00000000..6c593458 --- /dev/null +++ b/common/network/ip_test.go @@ -0,0 +1,19 @@ +package network + +import ( + "context" + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestIsIpInSubnet(t *testing.T) { + ctx := context.Background() + ip1 := "192.168.0.5" + ip2 := "125.216.250.89" + subnet := "192.168.0.0/24" + Convey("TestIsIpInSubnet", t, func() { + So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue) + So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse) + }) +} diff --git a/common/random/main.go b/common/random/main.go new file mode 100644 index 00000000..dbb772cd --- /dev/null +++ b/common/random/main.go @@ -0,0 +1,61 @@ +package random + +import ( + "github.com/google/uuid" + "math/rand" + "strings" + "time" +) + +func GetUUID() string { + code := uuid.New().String() + code = strings.Replace(code, "-", "", -1) + return code +} + +const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +const keyNumbers = "0123456789" + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func GenerateKey() string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, 48) + for i := 0; i < 16; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + uuid_ := GetUUID() + for i := 0; i < 32; i++ { + c := uuid_[i] + if i%2 == 0 && c >= 'a' && c <= 'z' { + c = c - 'a' + 'A' + } + key[i+16] = c + } + return string(key) +} + +func GetRandomString(length int) string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + return string(key) +} + +func GetRandomNumberString(length int) string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyNumbers[rand.Intn(len(keyNumbers))] + } + return string(key) +} + +// RandRange returns a random number between min and max (max is not included) +func RandRange(min, max int) int { + return min + rand.Intn(max-min) +} diff --git a/common/utils.go b/common/utils.go index 24615225..ecee2c8e 100644 --- a/common/utils.go +++ b/common/utils.go @@ -5,7 +5,7 @@ import ( "github.com/songquanpeng/one-api/common/config" ) -func LogQuota(quota int) string { +func LogQuota(quota int64) string { if config.DisplayInCurrencyEnabled { return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) } else { diff --git a/controller/github.go b/controller/auth/github.go similarity index 94% rename from controller/github.go rename to controller/auth/github.go index 7d7fa106..15542655 100644 --- a/controller/github.go +++ b/controller/auth/github.go @@ -1,4 +1,4 @@ -package controller +package auth import ( "bytes" @@ -7,10 +7,10 @@ import ( "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -133,8 +133,8 @@ func GitHubOAuth(c *gin.Context) { user.DisplayName = "GitHub User" } user.Email = githubUser.Email - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled if err := user.Insert(0); err != nil { c.JSON(http.StatusOK, gin.H{ @@ -152,14 +152,14 @@ func GitHubOAuth(c *gin.Context) { } } - if user.Status != common.UserStatusEnabled { + if user.Status != model.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, }) return } - setupLogin(&user, c) + controller.SetupLogin(&user, c) } func GitHubBind(c *gin.Context) { @@ -219,7 +219,7 @@ func GitHubBind(c *gin.Context) { func GenerateOAuthCode(c *gin.Context) { session := sessions.Default(c) - state := helper.GetRandomString(12) + state := random.GetRandomString(12) session.Set("oauth_state", state) err := session.Save() if err != nil { diff --git a/controller/auth/lark.go b/controller/auth/lark.go new file mode 100644 index 00000000..eb06dde9 --- /dev/null +++ b/controller/auth/lark.go @@ -0,0 +1,200 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" + "time" +) + +type LarkOAuthResponse struct { + AccessToken string `json:"access_token"` +} + +type LarkUser struct { + Name string `json:"name"` + OpenID string `json:"open_id"` +} + +func getLarkUserInfoByCode(code string) (*LarkUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{ + "client_id": config.LarkClientId, + "client_secret": config.LarkClientSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress), + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至飞书服务器,请稍后重试!") + } + defer res.Body.Close() + var oAuthResponse LarkOAuthResponse + err = json.NewDecoder(res.Body).Decode(&oAuthResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至飞书服务器,请稍后重试!") + } + var larkUser LarkUser + err = json.NewDecoder(res2.Body).Decode(&larkUser) + if err != nil { + return nil, err + } + return &larkUser, nil +} + +func LarkOAuth(c *gin.Context) { + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + LarkBind(c) + return + } + code := c.Query("code") + larkUser, err := getLarkUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + LarkId: larkUser.OpenID, + } + if model.IsLarkIdAlreadyTaken(user.LarkId) { + err := user.FillUserByLarkId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1) + if larkUser.Name != "" { + user.DisplayName = larkUser.Name + } else { + user.DisplayName = "Lark User" + } + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled + + if err := user.Insert(0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != model.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func LarkBind(c *gin.Context) { + code := c.Query("code") + larkUser, err := getLarkUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + LarkId: larkUser.OpenID, + } + if model.IsLarkIdAlreadyTaken(user.LarkId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该飞书账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.LarkId = larkUser.OpenID + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/wechat.go b/controller/auth/wechat.go similarity index 93% rename from controller/wechat.go rename to controller/auth/wechat.go index 74be5604..a64746c9 100644 --- a/controller/wechat.go +++ b/controller/auth/wechat.go @@ -1,12 +1,12 @@ -package controller +package auth import ( "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -83,8 +83,8 @@ func WeChatAuth(c *gin.Context) { if config.RegisterEnabled { user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.DisplayName = "WeChat User" - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled if err := user.Insert(0); err != nil { c.JSON(http.StatusOK, gin.H{ @@ -102,14 +102,14 @@ func WeChatAuth(c *gin.Context) { } } - if user.Status != common.UserStatusEnabled { + if user.Status != model.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, }) return } - setupLogin(&user, c) + controller.SetupLogin(&user, c) } func WeChatBind(c *gin.Context) { diff --git a/controller/billing.go b/controller/billing.go index 7317913d..dd518678 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -8,8 +8,8 @@ import ( ) func GetSubscription(c *gin.Context) { - var remainQuota int - var usedQuota int + var remainQuota int64 + var usedQuota int64 var err error var token *model.Token var expiredTime int64 @@ -60,7 +60,7 @@ func GetSubscription(c *gin.Context) { } func GetUsage(c *gin.Context) { - var quota int + var quota int64 var err error var token *model.Token if config.DisplayTokenStatEnabled { diff --git a/controller/channel-billing.go b/controller/channel-billing.go index abeab26a..b7ac61fd 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -4,11 +4,12 @@ import ( "encoding/json" "errors" "fmt" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/monitor" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/client" "io" "net/http" "strconv" @@ -95,7 +96,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He for k := range headers { req.Header.Add(k, headers.Get(k)) } - res, err := util.HTTPClient.Do(req) + res, err := client.HTTPClient.Do(req) if err != nil { return nil, err } @@ -203,28 +204,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { } func updateChannelBalance(channel *model.Channel) (float64, error) { - baseURL := common.ChannelBaseURLs[channel.Type] + baseURL := channeltype.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() == "" { channel.BaseURL = &baseURL } switch channel.Type { - case common.ChannelTypeOpenAI: + case channeltype.OpenAI: if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } - case common.ChannelTypeAzure: + case channeltype.Azure: return 0, errors.New("尚未实现") - case common.ChannelTypeCustom: + case channeltype.Custom: baseURL = channel.GetBaseURL() - case common.ChannelTypeCloseAI: + case channeltype.CloseAI: return updateChannelCloseAIBalance(channel) - case common.ChannelTypeOpenAISB: + case channeltype.OpenAISB: return updateChannelOpenAISBBalance(channel) - case common.ChannelTypeAIProxy: + case channeltype.AIProxy: return updateChannelAIProxyBalance(channel) - case common.ChannelTypeAPI2GPT: + case channeltype.API2GPT: return updateChannelAPI2GPTBalance(channel) - case common.ChannelTypeAIGC2D: + case channeltype.AIGC2D: return updateChannelAIGC2DBalance(channel) default: return 0, errors.New("尚未实现") @@ -295,16 +296,16 @@ func UpdateChannelBalance(c *gin.Context) { } func updateAllChannelsBalance() error { - channels, err := model.GetAllChannels(0, 0, true) + channels, err := model.GetAllChannels(0, 0, "all") if err != nil { return err } for _, channel := range channels { - if channel.Status != common.ChannelStatusEnabled { + if channel.Status != model.ChannelStatusEnabled { continue } // TODO: support Azure - if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { + if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom { continue } balance, err := updateChannelBalance(channel) @@ -313,7 +314,7 @@ func updateAllChannelsBalance() error { } else { // err is nil & balance <= 0 means quota is used up if balance <= 0 { - disableChannel(channel.Id, channel.Name, "余额不足") + monitor.DisableChannel(channel.Id, channel.Name, "余额不足") } } time.Sleep(config.RequestInterval) @@ -322,15 +323,14 @@ func updateAllChannelsBalance() error { } func UpdateAllChannelsBalance(c *gin.Context) { - // TODO: make it async - err := updateAllChannelsBalance() - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } + //err := updateAllChannelsBalance() + //if err != nil { + // c.JSON(http.StatusOK, gin.H{ + // "success": false, + // "message": err.Error(), + // }) + // return + //} c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/controller/channel-test.go b/controller/channel-test.go index b498f4f1..ddbe0b4a 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,19 +5,24 @@ import ( "encoding/json" "errors" "fmt" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" + "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/helper" + "github.com/songquanpeng/one-api/monitor" + relay "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/controller" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "net/http/httptest" "net/url" "strconv" + "strings" "sync" "time" @@ -26,7 +31,7 @@ import ( func buildTestRequest() *relaymodel.GeneralOpenAIRequest { testRequest := &relaymodel.GeneralOpenAIRequest{ - MaxTokens: 1, + MaxTokens: 2, Stream: false, Model: "gpt-3.5-turbo", } @@ -51,18 +56,25 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error c.Request.Header.Set("Content-Type", "application/json") c.Set("channel", channel.Type) c.Set("base_url", channel.GetBaseURL()) - meta := util.GetRelayMeta(c) - apiType := constant.ChannelType2APIType(channel.Type) - adaptor := helper.GetAdaptor(apiType) + middleware.SetupContextForSelectedChannel(c, channel, "") + meta := meta.GetByContext(c) + apiType := channeltype.ToAPIType(channel.Type) + adaptor := relay.GetAdaptor(apiType) if adaptor == nil { return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } adaptor.Init(meta) modelName := adaptor.GetModelList()[0] + if !strings.Contains(channel.Models, modelName) { + modelNames := strings.Split(channel.Models, ",") + if len(modelNames) > 0 { + modelName = modelNames[0] + } + } request := buildTestRequest() request.Model = modelName meta.OriginModelName, meta.ActualModelName = modelName, modelName - convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) + convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) if err != nil { return err, nil } @@ -77,7 +89,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error return err, nil } if resp.StatusCode != http.StatusOK { - err := util.RelayErrorHandler(resp) + err := controller.RelayErrorHandler(resp) return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error } usage, respErr := adaptor.DoResponse(c, resp, meta) @@ -139,33 +151,7 @@ func TestChannel(c *gin.Context) { var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false -func notifyRootUser(subject string, content string) { - if config.RootUserEmail == "" { - config.RootUserEmail = model.GetRootUserEmail() - } - err := common.SendEmail(subject, config.RootUserEmail, content) - if err != nil { - logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) - } -} - -// disable & notify -func disableChannel(channelId int, channelName string, reason string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - notifyRootUser(subject, content) -} - -// enable & notify -func enableChannel(channelId int, channelName string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - notifyRootUser(subject, content) -} - -func testAllChannels(notify bool) error { +func testChannels(notify bool, scope string) error { if config.RootUserEmail == "" { config.RootUserEmail = model.GetRootUserEmail() } @@ -176,7 +162,7 @@ func testAllChannels(notify bool) error { } testAllChannelsRunning = true testAllChannelsLock.Unlock() - channels, err := model.GetAllChannels(0, 0, true) + channels, err := model.GetAllChannels(0, 0, scope) if err != nil { return err } @@ -186,20 +172,24 @@ func testAllChannels(notify bool) error { } go func() { for _, channel := range channels { - isChannelEnabled := channel.Status == common.ChannelStatusEnabled + isChannelEnabled := channel.Status == model.ChannelStatusEnabled tik := time.Now() err, openaiErr := testChannel(channel) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() if isChannelEnabled && milliseconds > disableThreshold { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - disableChannel(channel.Id, channel.Name, err.Error()) + if config.AutomaticDisableChannelEnabled { + monitor.DisableChannel(channel.Id, channel.Name, err.Error()) + } else { + _ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error()) + } } - if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { - disableChannel(channel.Id, channel.Name, err.Error()) + if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) { + monitor.DisableChannel(channel.Id, channel.Name, err.Error()) } - if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { - enableChannel(channel.Id, channel.Name) + if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) { + monitor.EnableChannel(channel.Id, channel.Name) } channel.UpdateResponseTime(milliseconds) time.Sleep(config.RequestInterval) @@ -208,7 +198,7 @@ func testAllChannels(notify bool) error { testAllChannelsRunning = false testAllChannelsLock.Unlock() if notify { - err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") + err := message.Notify(message.ByAll, "渠道测试完成", "", "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常") if err != nil { logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } @@ -217,8 +207,12 @@ func testAllChannels(notify bool) error { return nil } -func TestAllChannels(c *gin.Context) { - err := testAllChannels(true) +func TestChannels(c *gin.Context) { + scope := c.Query("scope") + if scope == "" { + scope = "all" + } + err := testChannels(true, scope) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -237,7 +231,7 @@ func AutomaticallyTestChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) logger.SysLog("testing all channels") - _ = testAllChannels(false) + _ = testChannels(false, "all") logger.SysLog("channel test finished") } } diff --git a/controller/channel.go b/controller/channel.go index bdfa00d9..37bfb99d 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -15,7 +15,7 @@ func GetAllChannels(c *gin.Context) { if p < 0 { p = 0 } - channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false) + channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, "limited") if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/group.go b/controller/group.go index 128a3527..6f02394f 100644 --- a/controller/group.go +++ b/controller/group.go @@ -2,13 +2,13 @@ package controller import ( "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "net/http" ) func GetGroups(c *gin.Context) { groupNames := make([]string, 0) - for groupName := range common.GroupRatio { + for groupName := range billingratio.GroupRatio { groupNames = append(groupNames, groupName) } c.JSON(http.StatusOK, gin.H{ diff --git a/controller/misc.go b/controller/misc.go index 036bdbd1..2928b8fb 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/message" "github.com/songquanpeng/one-api/model" "net/http" "strings" @@ -22,6 +23,7 @@ func GetStatus(c *gin.Context) { "email_verification": config.EmailVerificationEnabled, "github_oauth": config.GitHubOAuthEnabled, "github_client_id": config.GitHubClientId, + "lark_client_id": config.LarkClientId, "system_name": config.SystemName, "logo": config.Logo, "footer_html": config.Footer, @@ -110,7 +112,7 @@ func SendEmailVerification(c *gin.Context) { content := fmt.Sprintf("

您好,你正在进行%s邮箱验证。

"+ "

您的验证码为: %s

"+ "

验证码 %d 分钟内有效,如果不是本人操作,请忽略。

", config.SystemName, code, common.VerificationValidMinutes) - err := common.SendEmail(subject, email, content) + err := message.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -149,7 +151,7 @@ func SendPasswordResetEmail(c *gin.Context) { "

点击 此处 进行密码重置。

"+ "

如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s

"+ "

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

", config.SystemName, link, link, common.VerificationValidMinutes) - err := common.SendEmail(subject, email, content) + err := message.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/model.go b/controller/model.go index f5760901..77e2e94e 100644 --- a/controller/model.go +++ b/controller/model.go @@ -3,11 +3,15 @@ package controller import ( "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel/ai360" - "github.com/songquanpeng/one-api/relay/channel/moonshot" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/helper" + "github.com/songquanpeng/one-api/model" + relay "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/apitype" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" + "net/http" + "strings" ) // https://platform.openai.com/docs/api-reference/models/list @@ -37,8 +41,9 @@ type OpenAIModels struct { Parent *string `json:"parent"` } -var openAIModels []OpenAIModels -var openAIModelsMap map[string]OpenAIModels +var models []OpenAIModels +var modelsMap map[string]OpenAIModels +var channelId2Models map[int][]string func init() { var permission []OpenAIModelPermission @@ -57,15 +62,15 @@ func init() { IsBlocking: false, }) // https://platform.openai.com/docs/models/model-endpoint-compatibility - for i := 0; i < constant.APITypeDummy; i++ { - if i == constant.APITypeAIProxyLibrary { + for i := 0; i < apitype.Dummy; i++ { + if i == apitype.AIProxyLibrary { continue } - adaptor := helper.GetAdaptor(i) + adaptor := relay.GetAdaptor(i) channelName := adaptor.GetChannelName() modelNames := adaptor.GetModelList() for _, modelName := range modelNames { - openAIModels = append(openAIModels, OpenAIModels{ + models = append(models, OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -76,44 +81,95 @@ func init() { }) } } - for _, modelName := range ai360.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "360", - Permission: permission, - Root: modelName, - Parent: nil, - }) + for _, channelType := range openai.CompatibleChannels { + if channelType == channeltype.Azure { + continue + } + channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) + for _, modelName := range channelModelList { + models = append(models, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: channelName, + Permission: permission, + Root: modelName, + Parent: nil, + }) + } } - for _, modelName := range moonshot.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "moonshot", - Permission: permission, - Root: modelName, - Parent: nil, - }) + modelsMap = make(map[string]OpenAIModels) + for _, model := range models { + modelsMap[model.Id] = model } - openAIModelsMap = make(map[string]OpenAIModels) - for _, model := range openAIModels { - openAIModelsMap[model.Id] = model + channelId2Models = make(map[int][]string) + for i := 1; i < channeltype.Dummy; i++ { + adaptor := relay.GetAdaptor(channeltype.ToAPIType(i)) + meta := &meta.Meta{ + ChannelType: i, + } + adaptor.Init(meta) + channelId2Models[i] = adaptor.GetModelList() } } -func ListModels(c *gin.Context) { +func DashboardListModels(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": channelId2Models, + }) +} + +func ListAllModels(c *gin.Context) { c.JSON(200, gin.H{ "object": "list", - "data": openAIModels, + "data": models, + }) +} + +func ListModels(c *gin.Context) { + ctx := c.Request.Context() + var availableModels []string + if c.GetString("available_models") != "" { + availableModels = strings.Split(c.GetString("available_models"), ",") + } else { + userId := c.GetInt("id") + userGroup, _ := model.CacheGetUserGroup(userId) + availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) + } + modelSet := make(map[string]bool) + for _, availableModel := range availableModels { + modelSet[availableModel] = true + } + availableOpenAIModels := make([]OpenAIModels, 0) + for _, model := range models { + if _, ok := modelSet[model.Id]; ok { + modelSet[model.Id] = false + availableOpenAIModels = append(availableOpenAIModels, model) + } + } + for modelName, ok := range modelSet { + if ok { + availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "custom", + Root: modelName, + Parent: nil, + }) + } + } + c.JSON(200, gin.H{ + "object": "list", + "data": availableOpenAIModels, }) } func RetrieveModel(c *gin.Context) { modelId := c.Param("model") - if model, ok := openAIModelsMap[modelId]; ok { + if model, ok := modelsMap[modelId]; ok { c.JSON(200, model) } else { Error := relaymodel.Error{ @@ -127,3 +183,30 @@ func RetrieveModel(c *gin.Context) { }) } } + +func GetUserAvailableModels(c *gin.Context) { + ctx := c.Request.Context() + id := c.GetInt("id") + userGroup, err := model.CacheGetUserGroup(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + models, err := model.CacheGetGroupModels(ctx, userGroup) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": models, + }) + return +} diff --git a/controller/redemption.go b/controller/redemption.go index 31c9348d..8d2b3f38 100644 --- a/controller/redemption.go +++ b/controller/redemption.go @@ -4,6 +4,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -106,7 +107,7 @@ func AddRedemption(c *gin.Context) { } var keys []string for i := 0; i < redemption.Count; i++ { - key := helper.GetUUID() + key := random.GetUUID() cleanRedemption := model.Redemption{ UserId: c.GetInt("id"), Name: redemption.Name, diff --git a/controller/relay.go b/controller/relay.go index 499e8ddc..56359a1c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -11,26 +11,26 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/middleware" dbmodel "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" ) // https://platform.openai.com/docs/api-reference/chat -func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { +func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { var err *model.ErrorWithStatusCode switch relayMode { - case constant.RelayModeImagesGenerations: + case relaymode.ImagesGenerations: err = controller.RelayImageHelper(c, relayMode) - case constant.RelayModeAudioSpeech: + case relaymode.AudioSpeech: fallthrough - case constant.RelayModeAudioTranslation: + case relaymode.AudioTranslation: fallthrough - case constant.RelayModeAudioTranscription: + case relaymode.AudioTranscription: err = controller.RelayAudioHelper(c, relayMode) default: err = controller.RelayTextHelper(c) @@ -40,12 +40,17 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { func Relay(c *gin.Context) { ctx := c.Request.Context() - relayMode := constant.Path2RelayMode(c.Request.URL.Path) - bizErr := relay(c, relayMode) - if bizErr == nil { - return + relayMode := relaymode.GetByPath(c.Request.URL.Path) + if config.DebugEnabled { + requestBody, _ := common.GetRequestBody(c) + logger.Debugf(ctx, "request body: %s", string(requestBody)) } channelId := c.GetInt("channel_id") + bizErr := relayHelper(c, relayMode) + if bizErr == nil { + monitor.Emit(channelId, true) + return + } lastFailedChannelId := channelId channelName := c.GetString("channel_name") group := c.GetString("group") @@ -58,7 +63,7 @@ func Relay(c *gin.Context) { retryTimes = 0 } for i := retryTimes; i > 0; i-- { - channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) + channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes) if err != nil { logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) break @@ -70,7 +75,7 @@ func Relay(c *gin.Context) { middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, err := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - bizErr = relay(c, relayMode) + bizErr = relayHelper(c, relayMode) if bizErr == nil { return } @@ -112,8 +117,10 @@ func shouldRetry(c *gin.Context, statusCode int) bool { func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) { logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors - if util.ShouldDisableChannel(&err.Error, err.StatusCode) { - disableChannel(channelId, channelName, err.Message) + if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { + monitor.DisableChannel(channelId, channelName, err.Message) + } else { + monitor.Emit(channelId, false) } } diff --git a/controller/token.go b/controller/token.go index de0e65eb..557b5ce1 100644 --- a/controller/token.go +++ b/controller/token.go @@ -1,10 +1,12 @@ package controller import ( + "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/network" + "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -16,7 +18,10 @@ func GetAllTokens(c *gin.Context) { if p < 0 { p = 0 } - tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage) + + order := c.Query("order") + tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage, order) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -101,6 +106,19 @@ func GetTokenStatus(c *gin.Context) { }) } +func validateToken(c *gin.Context, token model.Token) error { + if len(token.Name) > 30 { + return fmt.Errorf("令牌名称过长") + } + if token.Subnet != nil && *token.Subnet != "" { + err := network.IsValidSubnets(*token.Subnet) + if err != nil { + return fmt.Errorf("无效的网段:%s", err.Error()) + } + } + return nil +} + func AddToken(c *gin.Context) { token := model.Token{} err := c.ShouldBindJSON(&token) @@ -111,22 +129,26 @@ func AddToken(c *gin.Context) { }) return } - if len(token.Name) > 30 { + err = validateToken(c, token) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "令牌名称过长", + "message": fmt.Sprintf("参数错误:%s", err.Error()), }) return } + cleanToken := model.Token{ UserId: c.GetInt("id"), Name: token.Name, - Key: helper.GenerateKey(), + Key: random.GenerateKey(), CreatedTime: helper.GetTimestamp(), AccessedTime: helper.GetTimestamp(), ExpiredTime: token.ExpiredTime, RemainQuota: token.RemainQuota, UnlimitedQuota: token.UnlimitedQuota, + Models: token.Models, + Subnet: token.Subnet, } err = cleanToken.Insert() if err != nil { @@ -139,6 +161,7 @@ func AddToken(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", + "data": cleanToken, }) return } @@ -173,10 +196,11 @@ func UpdateToken(c *gin.Context) { }) return } - if len(token.Name) > 30 { + err = validateToken(c, token) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "令牌名称过长", + "message": fmt.Sprintf("参数错误:%s", err.Error()), }) return } @@ -188,15 +212,15 @@ func UpdateToken(c *gin.Context) { }) return } - if token.Status == common.TokenStatusEnabled { - if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { + if token.Status == model.TokenStatusEnabled { + if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", }) return } - if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { + if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", @@ -212,6 +236,8 @@ func UpdateToken(c *gin.Context) { cleanToken.ExpiredTime = token.ExpiredTime cleanToken.RemainQuota = token.RemainQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota + cleanToken.Models = token.Models + cleanToken.Subnet = token.Subnet } err = cleanToken.Update() if err != nil { diff --git a/controller/user.go b/controller/user.go index 7cb1a8aa..ba827feb 100644 --- a/controller/user.go +++ b/controller/user.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -58,11 +58,11 @@ func Login(c *gin.Context) { }) return } - setupLogin(&user, c) + SetupLogin(&user, c) } // setup session & cookies and then return user info -func setupLogin(user *model.User, c *gin.Context) { +func SetupLogin(user *model.User, c *gin.Context) { session := sessions.Default(c) session.Set("id", user.Id) session.Set("username", user.Username) @@ -196,7 +196,10 @@ func GetAllUsers(c *gin.Context) { if p < 0 { p = 0 } - users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage) + + order := c.DefaultQuery("order", "") + users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -204,12 +207,12 @@ func GetAllUsers(c *gin.Context) { }) return } + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": users, }) - return } func SearchUsers(c *gin.Context) { @@ -248,7 +251,7 @@ func GetUser(c *gin.Context) { return } myRole := c.GetInt("role") - if myRole <= user.Role && myRole != common.RoleRootUser { + if myRole <= user.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权获取同级或更高等级用户的信息", @@ -296,7 +299,7 @@ func GenerateAccessToken(c *gin.Context) { }) return } - user.AccessToken = helper.GetUUID() + user.AccessToken = random.GetUUID() if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { c.JSON(http.StatusOK, gin.H{ @@ -333,7 +336,7 @@ func GetAffCode(c *gin.Context) { return } if user.AffCode == "" { - user.AffCode = helper.GetRandomString(4) + user.AffCode = random.GetRandomString(4) if err := user.Update(false); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -397,14 +400,14 @@ func UpdateUser(c *gin.Context) { return } myRole := c.GetInt("role") - if myRole <= originUser.Role && myRole != common.RoleRootUser { + if myRole <= originUser.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权更新同权限等级或更高权限等级的用户信息", }) return } - if myRole <= updatedUser.Role && myRole != common.RoleRootUser { + if myRole <= updatedUser.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权将其他用户权限等级提升到大于等于自己的权限等级", @@ -518,7 +521,7 @@ func DeleteSelf(c *gin.Context) { id := c.GetInt("id") user, _ := model.GetUserById(id, false) - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "不能删除超级管理员账户", @@ -620,7 +623,7 @@ func ManageUser(c *gin.Context) { return } myRole := c.GetInt("role") - if myRole <= user.Role && myRole != common.RoleRootUser { + if myRole <= user.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权更新同权限等级或更高权限等级的用户信息", @@ -629,8 +632,8 @@ func ManageUser(c *gin.Context) { } switch req.Action { case "disable": - user.Status = common.UserStatusDisabled - if user.Role == common.RoleRootUser { + user.Status = model.UserStatusDisabled + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法禁用超级管理员用户", @@ -638,9 +641,9 @@ func ManageUser(c *gin.Context) { return } case "enable": - user.Status = common.UserStatusEnabled + user.Status = model.UserStatusEnabled case "delete": - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法删除超级管理员用户", @@ -655,37 +658,37 @@ func ManageUser(c *gin.Context) { return } case "promote": - if myRole != common.RoleRootUser { + if myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "普通管理员用户无法提升其他用户为管理员", }) return } - if user.Role >= common.RoleAdminUser { + if user.Role >= model.RoleAdminUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该用户已经是管理员", }) return } - user.Role = common.RoleAdminUser + user.Role = model.RoleAdminUser case "demote": - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法降级超级管理员用户", }) return } - if user.Role == common.RoleCommonUser { + if user.Role == model.RoleCommonUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该用户已经是普通用户", }) return } - user.Role = common.RoleCommonUser + user.Role = model.RoleCommonUser } if err := user.Update(false); err != nil { @@ -739,7 +742,7 @@ func EmailBind(c *gin.Context) { }) return } - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { config.RootUserEmail = email } c.JSON(http.StatusOK, gin.H{ @@ -779,3 +782,38 @@ func TopUp(c *gin.Context) { }) return } + +type adminTopUpRequest struct { + UserId int `json:"user_id"` + Quota int `json:"quota"` + Remark string `json:"remark"` +} + +func AdminTopUp(c *gin.Context) { + req := adminTopUpRequest{} + err := c.ShouldBindJSON(&req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + err = model.IncreaseUserQuota(req.UserId, int64(req.Quota)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if req.Remark == "" { + req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) + } + model.RecordTopupLog(req.UserId, req.Remark, req.Quota) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/docker-compose.yml b/docker-compose.yml index 30edb281..1325a818 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: '3.4' services: one-api: - image: justsong/one-api:latest + image: "${REGISTRY:-docker.io}/justsong/one-api:latest" container_name: one-api restart: always command: --log-dir /app/logs @@ -29,12 +29,12 @@ services: retries: 3 redis: - image: redis:latest + image: "${REGISTRY:-docker.io}/redis:latest" container_name: redis restart: always db: - image: mysql:8.2.0 + image: "${REGISTRY:-docker.io}/mysql:8.2.0" restart: always container_name: mysql volumes: diff --git a/docs/API.md b/docs/API.md new file mode 100644 index 00000000..0b7ddf5a --- /dev/null +++ b/docs/API.md @@ -0,0 +1,53 @@ +# 使用 API 操控 & 扩展 One API +> 欢迎提交 PR 在此放上你的拓展项目。 + +例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。 + +又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。 + +## 鉴权 +One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取: + +![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/c15281a7-83ed-47cb-a1f6-913cb6bf4a7c) + +之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API: +![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/1273b7ae-cb60-4c0d-93a6-b1cbc039c4f8) + +## 请求格式与响应格式 +One API 使用 JSON 格式进行请求和响应。 + +对于响应体,一般格式如下: +```json +{ + "message": "请求信息", + "success": true, + "data": {} +} +``` + +## API 列表 +> 当前 API 列表不全,请自行通过浏览器抓取前端请求 + +如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。 + +### 获取当前登录用户信息 +**GET** `/api/user/self` + +### 为给定用户充值额度 +**POST** `/api/topup` +```json +{ + "user_id": 1, + "quota": 100000, + "remark": "充值 100000 额度" +} +``` + +## 其他 +### 充值链接上的附加参数 +One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如: +`https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837` + +你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。 + +注意,不是所有主题都支持该功能,欢迎 PR 补齐。 \ No newline at end of file diff --git a/go.mod b/go.mod index 4ab23003..6ace51f2 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 github.com/pkoukk/tiktoken-go v0.1.5 + github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.8.3 golang.org/x/crypto v0.17.0 golang.org/x/image v0.14.0 @@ -37,15 +38,18 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.3.1 // indirect + github.com/jackc/pgx/v5 v5.5.4 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/mattn/go-isatty v0.0.19 // indirect @@ -54,12 +58,14 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/smarty/assertions v1.15.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/net v0.17.0 // indirect + golang.org/x/sync v0.1.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect + google.golang.org/protobuf v1.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 21bcddc6..3ead2711 100644 --- a/go.sum +++ b/go.sum @@ -56,11 +56,13 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= @@ -73,8 +75,10 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= -github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= +github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8= +github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= @@ -83,6 +87,8 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/ github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= @@ -125,6 +131,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -157,6 +167,8 @@ golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -173,12 +185,12 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/i18n/en.json b/i18n/en.json index 54728e2f..b7f1bd3e 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -8,12 +8,12 @@ "确认删除": "Confirm Delete", "确认绑定": "Confirm Binding", "您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.", - "\"通道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", - "通道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", + "\"渠道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", + "渠道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", "测试已在运行中": "Test is already running", "响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs", - "通道测试完成": "Channel test completed", - "通道测试完成,如果没有收到禁用通知,说明所有通道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal", + "渠道测试完成": "Channel test completed", + "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal", "无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!", "返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!", "管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub", @@ -119,11 +119,11 @@ " 个月 ": " M ", " 年 ": " y ", "未测试": "Not tested", - "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", - "已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", - "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", - "通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", - "已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!", + "渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", + "已成功开始测试所有渠道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", + "已成功开始测试所有已启用渠道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", + "渠道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", + "已更新完毕所有已启用渠道余额!": "The balance of all enabled channels has been updated!", "搜索渠道的 ID,名称和密钥 ...": "Search for channel ID, name and key ...", "名称": "Name", "分组": "Group", @@ -141,9 +141,9 @@ "启用": "Enable", "编辑": "Edit", "添加新的渠道": "Add a new channel", - "测试所有通道": "Test all channels", - "测试所有已启用通道": "Test all enabled channels", - "更新所有已启用通道余额": "Update the balance of all enabled channels", + "测试所有渠道": "Test all channels", + "测试所有已启用渠道": "Test all enabled channels", + "更新所有已启用渠道余额": "Update the balance of all enabled channels", "刷新": "Refresh", "处理中...": "Processing...", "绑定成功!": "Binding succeeded!", @@ -207,11 +207,11 @@ "监控设置": "Monitoring Settings", "最长响应时间": "Longest Response Time", "单位秒": "Unit in seconds", - "当运行通道全部测试时": "When all operating channels are tested", - "超过此时间将自动禁用通道": "Channels will be automatically disabled if this time is exceeded", + "当运行渠道全部测试时": "When all operating channels are tested", + "超过此时间将自动禁用渠道": "Channels will be automatically disabled if this time is exceeded", "额度提醒阈值": "Quota reminder threshold", "低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this", - "失败时自动禁用通道": "Automatically disable the channel when it fails", + "失败时自动禁用渠道": "Automatically disable the channel when it fails", "保存监控设置": "Save Monitoring Settings", "额度设置": "Quota Settings", "新用户初始额度": "Initial quota for new users", @@ -405,7 +405,7 @@ "镜像": "Mirror", "请输入镜像站地址,格式为:https://domain.com,可不填,不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used", "模型": "Model", - "请选择该通道所支持的模型": "Please select the model supported by the channel", + "请选择该渠道所支持的模型": "Please select the model supported by the channel", "填入基础模型": "Fill in the basic model", "填入所有模型": "Fill in all models", "清除所有模型": "Clear all models", @@ -515,7 +515,7 @@ "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", "Homepage URL 填": "Fill in the Homepage URL", "Authorization callback URL 填": "Fill in the Authorization callback URL", - "请为通道命名": "Please name the channel", + "请为渠道命名": "Please name the channel", "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", "模型重定向": "Model redirection", "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", diff --git a/main.go b/main.go index 1f43a45f..a0621711 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,7 @@ import ( "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/router" "os" "strconv" @@ -30,11 +30,25 @@ func main() { if config.DebugEnabled { logger.SysLog("running in debug mode") } + var err error // Initialize SQL Database - err := model.InitDB() + model.DB, err = model.InitDB("SQL_DSN") if err != nil { logger.FatalLog("failed to initialize database: " + err.Error()) } + if os.Getenv("LOG_SQL_DSN") != "" { + logger.SysLog("using secondary database for table logs") + model.LOG_DB, err = model.InitDB("LOG_SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize secondary database: " + err.Error()) + } + } else { + model.LOG_DB = model.DB + } + err = model.CreateRootAccountIfNeed() + if err != nil { + logger.FatalLog("database init error: " + err.Error()) + } defer func() { err := model.CloseDB() if err != nil { @@ -64,13 +78,6 @@ func main() { go model.SyncOptions(config.SyncFrequency) go model.SyncChannelCache(config.SyncFrequency) } - if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { - frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) - if err != nil { - logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) - } - go controller.AutomaticallyUpdateChannels(frequency) - } if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) if err != nil { @@ -83,6 +90,9 @@ func main() { logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") model.InitBatchUpdater() } + if config.EnableMetric { + logger.SysLog("metric enabled, will disable channel if too much request failed") + } openai.InitTokenEncoders() // Initialize HTTP server diff --git a/middleware/auth.go b/middleware/auth.go index 9d25f395..64ce6608 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,9 +1,11 @@ package middleware import ( + "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/blacklist" + "github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/model" "net/http" "strings" @@ -42,11 +44,14 @@ func authHelper(c *gin.Context, minRole int) { return } } - if status.(int) == common.UserStatusDisabled { + if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "用户已被封禁", }) + session := sessions.Default(c) + session.Clear() + _ = session.Save() c.Abort() return } @@ -66,24 +71,25 @@ func authHelper(c *gin.Context, minRole int) { func UserAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleCommonUser) + authHelper(c, model.RoleCommonUser) } } func AdminAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleAdminUser) + authHelper(c, model.RoleAdminUser) } } func RootAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleRootUser) + authHelper(c, model.RoleRootUser) } } func TokenAuth() func(c *gin.Context) { return func(c *gin.Context) { + ctx := c.Request.Context() key := c.Request.Header.Get("Authorization") key = strings.TrimPrefix(key, "Bearer ") key = strings.TrimPrefix(key, "sk-") @@ -94,15 +100,34 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusUnauthorized, err.Error()) return } + if token.Subnet != nil && *token.Subnet != "" { + if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) { + abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP())) + return + } + } userEnabled, err := model.CacheIsUserEnabled(token.UserId) if err != nil { abortWithMessage(c, http.StatusInternalServerError, err.Error()) return } - if !userEnabled { + if !userEnabled || blacklist.IsUserBanned(token.UserId) { abortWithMessage(c, http.StatusForbidden, "用户已被封禁") return } + requestModel, err := getRequestModel(c) + if err != nil && shouldCheckModel(c) { + abortWithMessage(c, http.StatusBadRequest, err.Error()) + return + } + c.Set("request_model", requestModel) + if token.Models != nil && *token.Models != "" { + c.Set("available_models", *token.Models) + if requestModel != "" && !isModelInList(requestModel, *token.Models) { + abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) + return + } + } c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_name", token.Name) @@ -117,3 +142,19 @@ func TokenAuth() func(c *gin.Context) { c.Next() } } + +func shouldCheckModel(c *gin.Context) bool { + if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { + return true + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { + return true + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images") { + return true + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + return true + } + return false +} diff --git a/middleware/distributor.go b/middleware/distributor.go index aeb2796a..6e0d2718 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -2,14 +2,13 @@ package middleware import ( "fmt" - "github.com/songquanpeng/one-api/common" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/channeltype" "net/http" "strconv" - "strings" - - "github.com/gin-gonic/gin" ) type ModelRequest struct { @@ -35,42 +34,16 @@ func Distribute() func(c *gin.Context) { abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } - if channel.Status != common.ChannelStatusEnabled { + if channel.Status != model.ChannelStatusEnabled { abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") return } } else { - // Select a channel for the user - var modelRequest ModelRequest - err := common.UnmarshalBodyReusable(c, &modelRequest) + requestModel := c.GetString("request_model") + var err error + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的请求") - return - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - if modelRequest.Model == "" { - modelRequest.Model = "text-moderation-stable" - } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - if modelRequest.Model == "" { - modelRequest.Model = c.Param("model") - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.Model == "" { - modelRequest.Model = "dall-e-2" - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - if modelRequest.Model == "" { - modelRequest.Model = "whisper-1" - } - } - requestModel = modelRequest.Model - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) - if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel) if channel != nil { logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" @@ -94,19 +67,19 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("base_url", channel.GetBaseURL()) // this is for backward compatibility switch channel.Type { - case common.ChannelTypeAzure: - c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeXunfei: - c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeGemini: - c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeAIProxyLibrary: - c.Set(common.ConfigKeyLibraryID, channel.Other) - case common.ChannelTypeAli: - c.Set(common.ConfigKeyPlugin, channel.Other) + case channeltype.Azure: + c.Set(config.KeyAPIVersion, channel.Other) + case channeltype.Xunfei: + c.Set(config.KeyAPIVersion, channel.Other) + case channeltype.Gemini: + c.Set(config.KeyAPIVersion, channel.Other) + case channeltype.AIProxyLibrary: + c.Set(config.KeyLibraryID, channel.Other) + case channeltype.Ali: + c.Set(config.KeyPlugin, channel.Other) } cfg, _ := channel.LoadConfig() for k, v := range cfg { - c.Set(common.ConfigKeyPrefix+k, v) + c.Set(config.KeyPrefix+k, v) } } diff --git a/middleware/recover.go b/middleware/recover.go index 02e3e3bb..cfc3f827 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -3,6 +3,7 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "net/http" "runtime/debug" @@ -12,11 +13,15 @@ func RelayPanicRecover() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - logger.SysError(fmt.Sprintf("panic detected: %v", err)) - logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + ctx := c.Request.Context() + logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err)) + logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path)) + body, _ := common.GetRequestBody(c) + logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ - "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), + "message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err), "type": "one_api_panic", }, }) diff --git a/middleware/request-id.go b/middleware/request-id.go index 234a93d8..a4c49ddb 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -9,7 +9,7 @@ import ( func RequestId() func(c *gin.Context) { return func(c *gin.Context) { - id := helper.GetTimeString() + helper.GetRandomNumberString(8) + id := helper.GenRequestID() c.Set(logger.RequestIdKey, id) ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) c.Request = c.Request.WithContext(ctx) diff --git a/middleware/utils.go b/middleware/utils.go index bc14c367..b65b018b 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -1,9 +1,12 @@ package middleware import ( + "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "strings" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { @@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { c.Abort() logger.Error(c.Request.Context(), message) } + +func getRequestModel(c *gin.Context) (string, error) { + var modelRequest ModelRequest + err := common.UnmarshalBodyReusable(c, &modelRequest) + if err != nil { + return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + if modelRequest.Model == "" { + modelRequest.Model = "text-moderation-stable" + } + } + if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + if modelRequest.Model == "" { + modelRequest.Model = c.Param("model") + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + if modelRequest.Model == "" { + modelRequest.Model = "dall-e-2" + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + if modelRequest.Model == "" { + modelRequest.Model = "whisper-1" + } + } + return modelRequest.Model, nil +} + +func isModelInList(modelName string, models string) bool { + modelList := strings.Split(models, ",") + for _, model := range modelList { + if modelName == model { + return true + } + } + return false +} diff --git a/model/ability.go b/model/ability.go index 7127abc3..2db72518 100644 --- a/model/ability.go +++ b/model/ability.go @@ -1,7 +1,10 @@ package model import ( + "context" "github.com/songquanpeng/one-api/common" + "gorm.io/gorm" + "sort" "strings" ) @@ -13,7 +16,7 @@ type Ability struct { Priority *int64 `json:"priority" gorm:"bigint;default:0;index"` } -func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { +func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { ability := Ability{} groupCol := "`group`" trueVal := "1" @@ -23,8 +26,13 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { } var err error = nil - maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) - channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) + var channelQuery *gorm.DB + if ignoreFirstPriority { + channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) + } else { + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) + channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) + } if common.UsingSQLite || common.UsingPostgreSQL { err = channelQuery.Order("RANDOM()").First(&ability).Error } else { @@ -49,7 +57,7 @@ func (channel *Channel) AddAbilities() error { Group: group, Model: model, ChannelId: channel.Id, - Enabled: channel.Status == common.ChannelStatusEnabled, + Enabled: channel.Status == ChannelStatusEnabled, Priority: channel.Priority, } abilities = append(abilities, ability) @@ -82,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error { func UpdateAbilityStatus(channelId int, status bool) error { return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error } + +func GetGroupModels(ctx context.Context, group string) ([]string, error) { + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + var models []string + err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error + if err != nil { + return nil, err + } + sort.Strings(models) + return models, err +} diff --git a/model/cache.go b/model/cache.go index 04a60348..cfb0f8a4 100644 --- a/model/cache.go +++ b/model/cache.go @@ -1,12 +1,14 @@ package model import ( + "context" "encoding/json" "errors" "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" "math/rand" "sort" "strconv" @@ -20,6 +22,7 @@ var ( UserId2GroupCacheSeconds = config.SyncFrequency UserId2QuotaCacheSeconds = config.SyncFrequency UserId2StatusCacheSeconds = config.SyncFrequency + GroupModelsCacheSeconds = config.SyncFrequency ) func CacheGetTokenByKey(key string) (*Token, error) { @@ -70,31 +73,42 @@ func CacheGetUserGroup(id int) (group string, err error) { return group, err } -func CacheGetUserQuota(id int) (quota int, err error) { +func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) { + quota, err = GetUserQuota(id) + if err != nil { + return 0, err + } + err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) + if err != nil { + logger.Error(ctx, "Redis set user quota error: "+err.Error()) + } + return +} + +func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) { if !common.RedisEnabled { return GetUserQuota(id) } quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) if err != nil { - quota, err = GetUserQuota(id) - if err != nil { - return 0, err - } - err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) - if err != nil { - logger.SysError("Redis set user quota error: " + err.Error()) - } - return quota, err + return fetchAndUpdateUserQuota(ctx, id) } - quota, err = strconv.Atoi(quotaString) - return quota, err + quota, err = strconv.ParseInt(quotaString, 10, 64) + if err != nil { + return 0, nil + } + if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db + logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id) + return fetchAndUpdateUserQuota(ctx, id) + } + return quota, nil } -func CacheUpdateUserQuota(id int) error { +func CacheUpdateUserQuota(ctx context.Context, id int) error { if !common.RedisEnabled { return nil } - quota, err := CacheGetUserQuota(id) + quota, err := CacheGetUserQuota(ctx, id) if err != nil { return err } @@ -102,7 +116,7 @@ func CacheUpdateUserQuota(id int) error { return err } -func CacheDecreaseUserQuota(id int, quota int) error { +func CacheDecreaseUserQuota(id int, quota int64) error { if !common.RedisEnabled { return nil } @@ -134,13 +148,32 @@ func CacheIsUserEnabled(userId int) (bool, error) { return userEnabled, err } +func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { + if !common.RedisEnabled { + return GetGroupModels(ctx, group) + } + modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group)) + if err == nil { + return strings.Split(modelsStr, ","), nil + } + models, err := GetGroupModels(ctx, group) + if err != nil { + return nil, err + } + err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second) + if err != nil { + logger.SysError("Redis set group models error: " + err.Error()) + } + return models, nil +} + var group2model2channels map[string]map[string][]*Channel var channelSyncLock sync.RWMutex func InitChannelCache() { newChannelId2channel := make(map[int]*Channel) var channels []*Channel - DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) + DB.Where("status = ?", ChannelStatusEnabled).Find(&channels) for _, channel := range channels { newChannelId2channel[channel.Id] = channel } @@ -191,9 +224,9 @@ func SyncChannelCache(frequency int) { } } -func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { +func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { if !config.MemoryCacheEnabled { - return GetRandomSatisfiedChannel(group, model) + return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority) } channelSyncLock.RLock() defer channelSyncLock.RUnlock() @@ -213,5 +246,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error } } idx := rand.Intn(endIdx) + if ignoreFirstPriority { + if endIdx < len(channels) { // which means there are more than one priority + idx = random.RandRange(endIdx, len(channels)) + } + } return channels[idx], nil } diff --git a/model/channel.go b/model/channel.go index 19af2263..e667f7e7 100644 --- a/model/channel.go +++ b/model/channel.go @@ -3,17 +3,23 @@ package model import ( "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/gorm" ) +const ( + ChannelStatusUnknown = 0 + ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! + ChannelStatusManuallyDisabled = 2 // also don't use 0 + ChannelStatusAutoDisabled = 3 +) + type Channel struct { Id int `json:"id"` Type int `json:"type" gorm:"default:0"` - Key string `json:"key" gorm:"not null;index"` + Key string `json:"key" gorm:"type:text"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` Weight *uint `json:"weight" gorm:"default:0"` @@ -32,23 +38,22 @@ type Channel struct { Config string `json:"config"` } -func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { +func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { var channels []*Channel var err error - if selectAll { + switch scope { + case "all": err = DB.Order("id desc").Find(&channels).Error - } else { + case "disabled": + err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error + default: err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error } return channels, err } func SearchChannels(keyword string) (channels []*Channel, err error) { - keyCol := "`key`" - if common.UsingPostgreSQL { - keyCol = `"key"` - } - err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error + err = DB.Omit("key").Where("id = ? or name LIKE ?", helper.String2Int(keyword), keyword+"%").Find(&channels).Error return channels, err } @@ -169,7 +174,7 @@ func (channel *Channel) LoadConfig() (map[string]string, error) { } func UpdateChannelStatusById(id int, status int) { - err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) + err := UpdateAbilityStatus(id, status == ChannelStatusEnabled) if err != nil { logger.SysError("failed to update ability status: " + err.Error()) } @@ -179,7 +184,7 @@ func UpdateChannelStatusById(id int, status int) { } } -func UpdateChannelUsedQuota(id int, quota int) { +func UpdateChannelUsedQuota(id int, quota int64) { if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) return @@ -187,7 +192,7 @@ func UpdateChannelUsedQuota(id int, quota int) { updateChannelUsedQuota(id, quota) } -func updateChannelUsedQuota(id int, quota int) { +func updateChannelUsedQuota(id int, quota int64) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { logger.SysError("failed to update channel used quota: " + err.Error()) @@ -200,6 +205,6 @@ func DeleteChannelByStatus(status int64) (int64, error) { } func DeleteDisabledChannel() (int64, error) { - result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) + result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{}) return result.RowsAffected, result.Error } diff --git a/model/log.go b/model/log.go index 9615c237..6fba776a 100644 --- a/model/log.go +++ b/model/log.go @@ -7,7 +7,6 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "gorm.io/gorm" ) @@ -45,13 +44,28 @@ func RecordLog(userId int, logType int, content string) { Type: logType, Content: content, } - err := DB.Create(log).Error + err := LOG_DB.Create(log).Error if err != nil { logger.SysError("failed to record log: " + err.Error()) } } -func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { +func RecordTopupLog(userId int, content string, quota int) { + log := &Log{ + UserId: userId, + Username: GetUsernameById(userId), + CreatedAt: helper.GetTimestamp(), + Type: LogTypeTopup, + Content: content, + Quota: quota, + } + err := LOG_DB.Create(log).Error + if err != nil { + logger.SysError("failed to record log: " + err.Error()) + } +} + +func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !config.LogConsumeEnabled { return @@ -66,10 +80,10 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke CompletionTokens: completionTokens, TokenName: tokenName, ModelName: modelName, - Quota: quota, + Quota: int(quota), ChannelId: channelId, } - err := DB.Create(log).Error + err := LOG_DB.Create(log).Error if err != nil { logger.Error(ctx, "failed to record log: "+err.Error()) } @@ -78,9 +92,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { - tx = DB + tx = LOG_DB } else { - tx = DB.Where("type = ?", logType) + tx = LOG_DB.Where("type = ?", logType) } if modelName != "" { tx = tx.Where("model_name = ?", modelName) @@ -107,9 +121,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { - tx = DB.Where("user_id = ?", userId) + tx = LOG_DB.Where("user_id = ?", userId) } else { - tx = DB.Where("user_id = ? and type = ?", userId, logType) + tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType) } if modelName != "" { tx = tx.Where("model_name = ?", modelName) @@ -128,17 +142,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int } func SearchAllLogs(keyword string) (logs []*Log, err error) { - err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error + err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error return logs, err } func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { - err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error + err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error return logs, err } -func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { - tx := DB.Table("logs").Select("ifnull(sum(quota),0)") +func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { + tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") if username != "" { tx = tx.Where("username = ?", username) } @@ -162,7 +176,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa } func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { - tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") + tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") if username != "" { tx = tx.Where("username = ?", username) } @@ -183,7 +197,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa } func DeleteOldLog(targetTimestamp int64) (int64, error) { - result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) + result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) return result.RowsAffected, result.Error } @@ -207,7 +221,7 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day" } - err = DB.Raw(` + err = LOG_DB.Raw(` SELECT `+groupSelect+`, model_name, count(1) as request_count, sum(quota) as quota, diff --git a/model/main.go b/model/main.go index 18ed01d0..4b5323c4 100644 --- a/model/main.go +++ b/model/main.go @@ -4,8 +4,10 @@ import ( "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/env" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" @@ -16,12 +18,13 @@ import ( ) var DB *gorm.DB +var LOG_DB *gorm.DB -func createRootAccountIfNeed() error { +func CreateRootAccountIfNeed() error { var user User //if user.Status != util.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { - logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") + logger.SysLog("no user exists, creating a root user for you: username is root, password is 123456") hashedPassword, err := common.Password2Hash("123456") if err != nil { return err @@ -29,20 +32,36 @@ func createRootAccountIfNeed() error { rootUser := User{ Username: "root", Password: hashedPassword, - Role: common.RoleRootUser, - Status: common.UserStatusEnabled, + Role: RoleRootUser, + Status: UserStatusEnabled, DisplayName: "Root User", - AccessToken: helper.GetUUID(), - Quota: 100000000, + AccessToken: random.GetUUID(), + Quota: 500000000000000, } DB.Create(&rootUser) + if config.InitialRootToken != "" { + logger.SysLog("creating initial root token as requested") + token := Token{ + Id: 1, + UserId: rootUser.Id, + Key: config.InitialRootToken, + Status: TokenStatusEnabled, + Name: "Initial Root Token", + CreatedTime: helper.GetTimestamp(), + AccessedTime: helper.GetTimestamp(), + ExpiredTime: -1, + RemainQuota: 500000000000000, + UnlimitedQuota: true, + } + DB.Create(&token) + } } return nil } -func chooseDB() (*gorm.DB, error) { - if os.Getenv("SQL_DSN") != "" { - dsn := os.Getenv("SQL_DSN") +func chooseDB(envName string) (*gorm.DB, error) { + if os.Getenv(envName) != "" { + dsn := os.Getenv(envName) if strings.HasPrefix(dsn, "postgres://") { // Use PostgreSQL logger.SysLog("using PostgreSQL as database") @@ -56,6 +75,7 @@ func chooseDB() (*gorm.DB, error) { } // Use MySQL logger.SysLog("using MySQL as database") + common.UsingMySQL = true return gorm.Open(mysql.Open(dsn), &gorm.Config{ PrepareStmt: true, // precompile SQL }) @@ -69,67 +89,78 @@ func chooseDB() (*gorm.DB, error) { }) } -func InitDB() (err error) { - db, err := chooseDB() +func InitDB(envName string) (db *gorm.DB, err error) { + db, err = chooseDB(envName) if err == nil { - if config.DebugEnabled { + if config.DebugSQLEnabled { db = db.Debug() } - DB = db - sqlDB, err := DB.DB() + sqlDB, err := db.DB() if err != nil { - return err + return nil, err } - sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) + sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) if !config.IsMasterNode { - return nil + return db, err + } + if common.UsingMySQL { + _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded } logger.SysLog("database migration started") err = db.AutoMigrate(&Channel{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Token{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&User{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Option{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Redemption{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Ability{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Log{}) if err != nil { - return err + return nil, err } logger.SysLog("database migrated") - err = createRootAccountIfNeed() - return err + return db, err } else { logger.FatalLog(err) } - return err + return db, err } -func CloseDB() error { - sqlDB, err := DB.DB() +func closeDB(db *gorm.DB) error { + sqlDB, err := db.DB() if err != nil { return err } err = sqlDB.Close() return err } + +func CloseDB() error { + if LOG_DB != DB { + err := closeDB(LOG_DB) + if err != nil { + return err + } + } + return closeDB(DB) +} diff --git a/model/option.go b/model/option.go index 6002c795..bed8d4c3 100644 --- a/model/option.go +++ b/model/option.go @@ -1,9 +1,9 @@ package model import ( - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "strconv" "strings" "time" @@ -57,16 +57,18 @@ func InitOptionMap() { config.OptionMap["WeChatServerAddress"] = "" config.OptionMap["WeChatServerToken"] = "" config.OptionMap["WeChatAccountQRCodeImageURL"] = "" + config.OptionMap["MessagePusherAddress"] = "" + config.OptionMap["MessagePusherToken"] = "" config.OptionMap["TurnstileSiteKey"] = "" config.OptionMap["TurnstileSecretKey"] = "" - config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) - config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) - config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) - config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) - config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) - config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() - config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() - config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() + config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10) + config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10) + config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) + config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) + config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) + config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString() + config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString() + config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString() config.OptionMap["TopUpLink"] = config.TopUpLink config.OptionMap["ChatLink"] = config.ChatLink config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) @@ -79,6 +81,9 @@ func InitOptionMap() { func loadOptionsFromDatabase() { options, _ := AllOption() for _, option := range options { + if option.Key == "ModelRatio" { + option.Value = billingratio.AddNewMissingRatio(option.Value) + } err := updateOptionMap(option.Key, option.Value) if err != nil { logger.SysError("failed to update option map: " + err.Error()) @@ -167,6 +172,10 @@ func updateOptionMap(key string, value string) (err error) { config.GitHubClientId = value case "GitHubClientSecret": config.GitHubClientSecret = value + case "LarkClientId": + config.LarkClientId = value + case "LarkClientSecret": + config.LarkClientSecret = value case "Footer": config.Footer = value case "SystemName": @@ -179,28 +188,32 @@ func updateOptionMap(key string, value string) (err error) { config.WeChatServerToken = value case "WeChatAccountQRCodeImageURL": config.WeChatAccountQRCodeImageURL = value + case "MessagePusherAddress": + config.MessagePusherAddress = value + case "MessagePusherToken": + config.MessagePusherToken = value case "TurnstileSiteKey": config.TurnstileSiteKey = value case "TurnstileSecretKey": config.TurnstileSecretKey = value case "QuotaForNewUser": - config.QuotaForNewUser, _ = strconv.Atoi(value) + config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64) case "QuotaForInviter": - config.QuotaForInviter, _ = strconv.Atoi(value) + config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64) case "QuotaForInvitee": - config.QuotaForInvitee, _ = strconv.Atoi(value) + config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64) case "QuotaRemindThreshold": - config.QuotaRemindThreshold, _ = strconv.Atoi(value) + config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64) case "PreConsumedQuota": - config.PreConsumedQuota, _ = strconv.Atoi(value) + config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64) case "RetryTimes": config.RetryTimes, _ = strconv.Atoi(value) case "ModelRatio": - err = common.UpdateModelRatioByJSONString(value) + err = billingratio.UpdateModelRatioByJSONString(value) case "GroupRatio": - err = common.UpdateGroupRatioByJSONString(value) + err = billingratio.UpdateGroupRatioByJSONString(value) case "CompletionRatio": - err = common.UpdateCompletionRatioByJSONString(value) + err = billingratio.UpdateCompletionRatioByJSONString(value) case "TopUpLink": config.TopUpLink = value case "ChatLink": diff --git a/model/redemption.go b/model/redemption.go index 2c5a4141..45871a71 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -8,13 +8,19 @@ import ( "gorm.io/gorm" ) +const ( + RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value! + RedemptionCodeStatusDisabled = 2 // also don't use 0 + RedemptionCodeStatusUsed = 3 // also don't use 0 +) + type Redemption struct { Id int `json:"id"` UserId int `json:"user_id"` Key string `json:"key" gorm:"type:char(32);uniqueIndex"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` - Quota int `json:"quota" gorm:"default:100"` + Quota int64 `json:"quota" gorm:"bigint;default:100"` CreatedTime int64 `json:"created_time" gorm:"bigint"` RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"` Count int `json:"count" gorm:"-:all"` // only for api request @@ -42,7 +48,7 @@ func GetRedemptionById(id int) (*Redemption, error) { return &redemption, err } -func Redeem(key string, userId int) (quota int, err error) { +func Redeem(key string, userId int) (quota int64, err error) { if key == "" { return 0, errors.New("未提供兑换码") } @@ -61,7 +67,7 @@ func Redeem(key string, userId int) (quota int, err error) { if err != nil { return errors.New("无效的兑换码") } - if redemption.Status != common.RedemptionCodeStatusEnabled { + if redemption.Status != RedemptionCodeStatusEnabled { return errors.New("该兑换码已被使用") } err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error @@ -69,7 +75,7 @@ func Redeem(key string, userId int) (quota int, err error) { return err } redemption.RedeemedTime = helper.GetTimestamp() - redemption.Status = common.RedemptionCodeStatusUsed + redemption.Status = RedemptionCodeStatusUsed err = tx.Save(redemption).Error return err }) diff --git a/model/token.go b/model/token.go index d0a0648a..96e6b491 100644 --- a/model/token.go +++ b/model/token.go @@ -7,27 +7,48 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" "gorm.io/gorm" ) +const ( + TokenStatusEnabled = 1 // don't use 0, 0 is the default value! + TokenStatusDisabled = 2 // also don't use 0 + TokenStatusExpired = 3 + TokenStatusExhausted = 4 +) + type Token struct { - Id int `json:"id"` - UserId int `json:"user_id"` - Key string `json:"key" gorm:"type:char(48);uniqueIndex"` - Status int `json:"status" gorm:"default:1"` - Name string `json:"name" gorm:"index" ` - CreatedTime int64 `json:"created_time" gorm:"bigint"` - AccessedTime int64 `json:"accessed_time" gorm:"bigint"` - ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired - RemainQuota int `json:"remain_quota" gorm:"default:0"` - UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` - UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota + Id int `json:"id"` + UserId int `json:"user_id"` + Key string `json:"key" gorm:"type:char(48);uniqueIndex"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index" ` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + AccessedTime int64 `json:"accessed_time" gorm:"bigint"` + ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired + RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` + UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` + UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota + Models *string `json:"models" gorm:"default:''"` // allowed models + Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet } -func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { +func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { var tokens []*Token var err error - err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error + query := DB.Where("user_id = ?", userId) + + switch order { + case "remain_quota": + query = query.Order("unlimited_quota desc, remain_quota desc") + case "used_quota": + query = query.Order("used_quota desc") + default: + query = query.Order("id desc") + } + + err = query.Limit(num).Offset(startIdx).Find(&tokens).Error return tokens, err } @@ -48,17 +69,17 @@ func ValidateUserToken(key string) (token *Token, err error) { } return nil, errors.New("令牌验证失败") } - if token.Status == common.TokenStatusExhausted { - return nil, errors.New("该令牌额度已用尽") - } else if token.Status == common.TokenStatusExpired { + if token.Status == TokenStatusExhausted { + return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id) + } else if token.Status == TokenStatusExpired { return nil, errors.New("该令牌已过期") } - if token.Status != common.TokenStatusEnabled { + if token.Status != TokenStatusEnabled { return nil, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { if !common.RedisEnabled { - token.Status = common.TokenStatusExpired + token.Status = TokenStatusExpired err := token.SelectUpdate() if err != nil { logger.SysError("failed to update token status" + err.Error()) @@ -69,7 +90,7 @@ func ValidateUserToken(key string) (token *Token, err error) { if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !common.RedisEnabled { // in this case, we can make sure the token is exhausted - token.Status = common.TokenStatusExhausted + token.Status = TokenStatusExhausted err := token.SelectUpdate() if err != nil { logger.SysError("failed to update token status" + err.Error()) @@ -109,7 +130,7 @@ func (token *Token) Insert() error { // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error + err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error return err } @@ -137,7 +158,7 @@ func DeleteTokenById(id int, userId int) (err error) { return token.Delete() } -func IncreaseTokenQuota(id int, quota int) (err error) { +func IncreaseTokenQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -148,7 +169,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { return increaseTokenQuota(id, quota) } -func increaseTokenQuota(id int, quota int) (err error) { +func increaseTokenQuota(id int, quota int64) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota + ?", quota), @@ -159,7 +180,7 @@ func increaseTokenQuota(id int, quota int) (err error) { return err } -func DecreaseTokenQuota(id int, quota int) (err error) { +func DecreaseTokenQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -170,7 +191,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { return decreaseTokenQuota(id, quota) } -func decreaseTokenQuota(id int, quota int) (err error) { +func decreaseTokenQuota(id int, quota int64) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota - ?", quota), @@ -181,7 +202,7 @@ func decreaseTokenQuota(id int, quota int) (err error) { return err } -func PreConsumeTokenQuota(tokenId int, quota int) (err error) { +func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -213,7 +234,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { } if email != "" { topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) - err = common.SendEmail(prompt, email, + err = message.SendEmail(prompt, email, fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink)) if err != nil { logger.SysError("failed to send email" + err.Error()) @@ -231,7 +252,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { return err } -func PostConsumeTokenQuota(tokenId int, quota int) (err error) { +func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { token, err := GetTokenById(tokenId) if quota > 0 { err = DecreaseUserQuota(token.UserId, quota) diff --git a/model/user.go b/model/user.go index 6979c70b..1dc633b1 100644 --- a/model/user.go +++ b/model/user.go @@ -4,13 +4,27 @@ import ( "errors" "fmt" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" "gorm.io/gorm" "strings" ) +const ( + RoleGuestUser = 0 + RoleCommonUser = 1 + RoleAdminUser = 10 + RoleRootUser = 100 +) + +const ( + UserStatusEnabled = 1 // don't use 0, 0 is the default value! + UserStatusDisabled = 2 // also don't use 0 + UserStatusDeleted = 3 +) + // User if you add sensitive fields, don't forget to clean them in setupLogin function. // Otherwise, the sensitive information will be saved on local storage in plain text! type User struct { @@ -23,11 +37,12 @@ type User struct { Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` + LarkId string `json:"lark_id" gorm:"column:lark_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management - Quota int `json:"quota" gorm:"type:int;default:0"` - UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota - RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number + Quota int64 `json:"quota" gorm:"bigint;default:0"` + UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0;column:used_quota"` // used quota + RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number Group string `json:"group" gorm:"type:varchar(32);default:'default'"` AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` @@ -39,8 +54,21 @@ func GetMaxUserId() int { return user.Id } -func GetAllUsers(startIdx int, num int) (users []*User, err error) { - err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error +func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) { + query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted) + + switch order { + case "quota": + query = query.Order("quota desc") + case "used_quota": + query = query.Order("used_quota desc") + case "request_count": + query = query.Order("request_count desc") + default: + query = query.Order("id desc") + } + + err = query.Find(&users).Error return users, err } @@ -93,8 +121,8 @@ func (user *User) Insert(inviterId int) error { } } user.Quota = config.QuotaForNewUser - user.AccessToken = helper.GetUUID() - user.AffCode = helper.GetRandomString(4) + user.AccessToken = random.GetUUID() + user.AffCode = random.GetRandomString(4) result := DB.Create(user) if result.Error != nil { return result.Error @@ -123,6 +151,11 @@ func (user *User) Update(updatePassword bool) error { return err } } + if user.Status == UserStatusDisabled { + blacklist.BanUser(user.Id) + } else if user.Status == UserStatusEnabled { + blacklist.UnbanUser(user.Id) + } err = DB.Model(user).Updates(user).Error return err } @@ -131,7 +164,10 @@ func (user *User) Delete() error { if user.Id == 0 { return errors.New("id 为空!") } - err := DB.Delete(user).Error + blacklist.BanUser(user.Id) + user.Username = fmt.Sprintf("deleted_%s", random.GetUUID()) + user.Status = UserStatusDeleted + err := DB.Model(user).Updates(user).Error return err } @@ -154,7 +190,7 @@ func (user *User) ValidateAndFill() (err error) { } } okay := common.ValidatePasswordAndHash(password, user.Password) - if !okay || user.Status != common.UserStatusEnabled { + if !okay || user.Status != UserStatusEnabled { return errors.New("用户名或密码错误,或用户已被封禁") } return nil @@ -184,6 +220,14 @@ func (user *User) FillUserByGitHubId() error { return nil } +func (user *User) FillUserByLarkId() error { + if user.LarkId == "" { + return errors.New("lark id 为空!") + } + DB.Where(User{LarkId: user.LarkId}).First(user) + return nil +} + func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") @@ -212,6 +256,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool { return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 } +func IsLarkIdAlreadyTaken(githubId string) bool { + return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 +} + func IsUsernameAlreadyTaken(username string) bool { return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 } @@ -238,7 +286,7 @@ func IsAdmin(userId int) bool { logger.SysError("no such user " + err.Error()) return false } - return user.Role >= common.RoleAdminUser + return user.Role >= RoleAdminUser } func IsUserEnabled(userId int) (bool, error) { @@ -250,7 +298,7 @@ func IsUserEnabled(userId int) (bool, error) { if err != nil { return false, err } - return user.Status == common.UserStatusEnabled, nil + return user.Status == UserStatusEnabled, nil } func ValidateAccessToken(token string) (user *User) { @@ -265,12 +313,12 @@ func ValidateAccessToken(token string) (user *User) { return nil } -func GetUserQuota(id int) (quota int, err error) { +func GetUserQuota(id int) (quota int64, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error return quota, err } -func GetUserUsedQuota(id int) (quota int, err error) { +func GetUserUsedQuota(id int) (quota int64, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error return quota, err } @@ -290,7 +338,7 @@ func GetUserGroup(id int) (group string, err error) { return group, err } -func IncreaseUserQuota(id int, quota int) (err error) { +func IncreaseUserQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -301,12 +349,12 @@ func IncreaseUserQuota(id int, quota int) (err error) { return increaseUserQuota(id, quota) } -func increaseUserQuota(id int, quota int) (err error) { +func increaseUserQuota(id int, quota int64) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error return err } -func DecreaseUserQuota(id int, quota int) (err error) { +func DecreaseUserQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -317,17 +365,17 @@ func DecreaseUserQuota(id int, quota int) (err error) { return decreaseUserQuota(id, quota) } -func decreaseUserQuota(id int, quota int) (err error) { +func decreaseUserQuota(id int, quota int64) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error return err } func GetRootUserEmail() (email string) { - DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) + DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email) return email } -func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { +func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) { if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUsedQuota, id, quota) addNewRecord(BatchUpdateTypeRequestCount, id, 1) @@ -336,7 +384,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { updateUserUsedQuotaAndRequestCount(id, quota, 1) } -func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { +func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), @@ -348,7 +396,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { } } -func updateUserUsedQuota(id int, quota int) { +func updateUserUsedQuota(id int, quota int64) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), diff --git a/model/utils.go b/model/utils.go index d481973a..a55eb4b6 100644 --- a/model/utils.go +++ b/model/utils.go @@ -16,12 +16,12 @@ const ( BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock ) -var batchUpdateStores []map[int]int +var batchUpdateStores []map[int]int64 var batchUpdateLocks []sync.Mutex func init() { for i := 0; i < BatchUpdateTypeCount; i++ { - batchUpdateStores = append(batchUpdateStores, make(map[int]int)) + batchUpdateStores = append(batchUpdateStores, make(map[int]int64)) batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) } } @@ -35,7 +35,7 @@ func InitBatchUpdater() { }() } -func addNewRecord(type_ int, id int, value int) { +func addNewRecord(type_ int, id int, value int64) { batchUpdateLocks[type_].Lock() defer batchUpdateLocks[type_].Unlock() if _, ok := batchUpdateStores[type_][id]; !ok { @@ -50,7 +50,7 @@ func batchUpdate() { for i := 0; i < BatchUpdateTypeCount; i++ { batchUpdateLocks[i].Lock() store := batchUpdateStores[i] - batchUpdateStores[i] = make(map[int]int) + batchUpdateStores[i] = make(map[int]int64) batchUpdateLocks[i].Unlock() // TODO: maybe we can combine updates with same key? for key, value := range store { @@ -68,7 +68,7 @@ func batchUpdate() { case BatchUpdateTypeUsedQuota: updateUserUsedQuota(key, value) case BatchUpdateTypeRequestCount: - updateUserRequestCount(key, value) + updateUserRequestCount(key, int(value)) case BatchUpdateTypeChannelUsedQuota: updateChannelUsedQuota(key, value) } diff --git a/monitor/channel.go b/monitor/channel.go new file mode 100644 index 00000000..7e5dc58a --- /dev/null +++ b/monitor/channel.go @@ -0,0 +1,54 @@ +package monitor + +import ( + "fmt" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" + "github.com/songquanpeng/one-api/model" +) + +func notifyRootUser(subject string, content string) { + if config.MessagePusherAddress != "" { + err := message.SendMessage(subject, content, content) + if err != nil { + logger.SysError(fmt.Sprintf("failed to send message: %s", err.Error())) + } else { + return + } + } + if config.RootUserEmail == "" { + config.RootUserEmail = model.GetRootUserEmail() + } + err := message.SendEmail(subject, config.RootUserEmail, content) + if err != nil { + logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + } +} + +// DisableChannel disable & notify +func DisableChannel(channelId int, channelName string, reason string) { + model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) + logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) + subject := fmt.Sprintf("渠道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("渠道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + notifyRootUser(subject, content) +} + +func MetricDisableChannel(channelId int, successRate float64) { + model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) + logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) + subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId) + content := fmt.Sprintf("该渠道(#%d)在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", + channelId, config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100) + notifyRootUser(subject, content) +} + +// EnableChannel enable & notify +func EnableChannel(channelId int, channelName string) { + model.UpdateChannelStatusById(channelId, model.ChannelStatusEnabled) + logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) + subject := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) + content := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) + notifyRootUser(subject, content) +} diff --git a/monitor/manage.go b/monitor/manage.go new file mode 100644 index 00000000..946e78af --- /dev/null +++ b/monitor/manage.go @@ -0,0 +1,62 @@ +package monitor + +import ( + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/relay/model" + "net/http" + "strings" +) + +func ShouldDisableChannel(err *model.Error, statusCode int) bool { + if !config.AutomaticDisableChannelEnabled { + return false + } + if err == nil { + return false + } + if statusCode == http.StatusUnauthorized { + return true + } + switch err.Type { + case "insufficient_quota": + return true + // https://docs.anthropic.com/claude/reference/errors + case "authentication_error": + return true + case "permission_error": + return true + case "forbidden": + return true + } + if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + return true + } + if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic + return true + } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { + return true + } + //if strings.Contains(err.Message, "quota") { + // return true + //} + if strings.Contains(err.Message, "credit") { + return true + } + if strings.Contains(err.Message, "balance") { + return true + } + return false +} + +func ShouldEnableChannel(err error, openAIErr *model.Error) bool { + if !config.AutomaticEnableChannelEnabled { + return false + } + if err != nil { + return false + } + if openAIErr != nil { + return false + } + return true +} diff --git a/monitor/metric.go b/monitor/metric.go new file mode 100644 index 00000000..98bc546e --- /dev/null +++ b/monitor/metric.go @@ -0,0 +1,79 @@ +package monitor + +import ( + "github.com/songquanpeng/one-api/common/config" +) + +var store = make(map[int][]bool) +var metricSuccessChan = make(chan int, config.MetricSuccessChanSize) +var metricFailChan = make(chan int, config.MetricFailChanSize) + +func consumeSuccess(channelId int) { + if len(store[channelId]) > config.MetricQueueSize { + store[channelId] = store[channelId][1:] + } + store[channelId] = append(store[channelId], true) +} + +func consumeFail(channelId int) (bool, float64) { + if len(store[channelId]) > config.MetricQueueSize { + store[channelId] = store[channelId][1:] + } + store[channelId] = append(store[channelId], false) + successCount := 0 + for _, success := range store[channelId] { + if success { + successCount++ + } + } + successRate := float64(successCount) / float64(len(store[channelId])) + if len(store[channelId]) < config.MetricQueueSize { + return false, successRate + } + if successRate < config.MetricSuccessRateThreshold { + store[channelId] = make([]bool, 0) + return true, successRate + } + return false, successRate +} + +func metricSuccessConsumer() { + for { + select { + case channelId := <-metricSuccessChan: + consumeSuccess(channelId) + } + } +} + +func metricFailConsumer() { + for { + select { + case channelId := <-metricFailChan: + disable, successRate := consumeFail(channelId) + if disable { + go MetricDisableChannel(channelId, successRate) + } + } + } +} + +func init() { + if config.EnableMetric { + go metricSuccessConsumer() + go metricFailConsumer() + } +} + +func Emit(channelId int, success bool) { + if !config.EnableMetric { + return + } + go func() { + if success { + metricSuccessChan <- channelId + } else { + metricFailChan <- channelId + } + }() +} diff --git a/pull_request_template.md b/pull_request_template.md index a313004f..c6301343 100644 --- a/pull_request_template.md +++ b/pull_request_template.md @@ -1,9 +1,10 @@ [//]: # (请按照以下格式关联 issue) -[//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR,谢谢) +[//]: # (请在提交 PR 前确认所提交的功能可用,需要附上截图,谢谢) [//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解) [//]: # (开发者交流群:910657413) [//]: # (请在提交 PR 之前删除上面的注释) close #issue_number -我已确认该 PR 已自测通过,相关截图如下: \ No newline at end of file +我已确认该 PR 已自测通过,相关截图如下: +(此处放上测试通过的截图,如果不涉及前端改动或从 UI 上无法看出,请放终端启动成功的截图) diff --git a/relay/adaptor.go b/relay/adaptor.go new file mode 100644 index 00000000..c90bd708 --- /dev/null +++ b/relay/adaptor.go @@ -0,0 +1,45 @@ +package relay + +import ( + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/aiproxy" + "github.com/songquanpeng/one-api/relay/adaptor/ali" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/baidu" + "github.com/songquanpeng/one-api/relay/adaptor/gemini" + "github.com/songquanpeng/one-api/relay/adaptor/ollama" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/adaptor/palm" + "github.com/songquanpeng/one-api/relay/adaptor/tencent" + "github.com/songquanpeng/one-api/relay/adaptor/xunfei" + "github.com/songquanpeng/one-api/relay/adaptor/zhipu" + "github.com/songquanpeng/one-api/relay/apitype" +) + +func GetAdaptor(apiType int) adaptor.Adaptor { + switch apiType { + case apitype.AIProxyLibrary: + return &aiproxy.Adaptor{} + case apitype.Ali: + return &ali.Adaptor{} + case apitype.Anthropic: + return &anthropic.Adaptor{} + case apitype.Baidu: + return &baidu.Adaptor{} + case apitype.Gemini: + return &gemini.Adaptor{} + case apitype.OpenAI: + return &openai.Adaptor{} + case apitype.PaLM: + return &palm.Adaptor{} + case apitype.Tencent: + return &tencent.Adaptor{} + case apitype.Xunfei: + return &xunfei.Adaptor{} + case apitype.Zhipu: + return &zhipu.Adaptor{} + case apitype.Ollama: + return &ollama.Adaptor{} + } + return nil +} diff --git a/relay/channel/ai360/constants.go b/relay/adaptor/ai360/constants.go similarity index 100% rename from relay/channel/ai360/constants.go rename to relay/adaptor/ai360/constants.go diff --git a/relay/channel/aiproxy/adaptor.go b/relay/adaptor/aiproxy/adaptor.go similarity index 53% rename from relay/channel/aiproxy/adaptor.go rename to relay/adaptor/aiproxy/adaptor.go index 2b4e3022..7ad6225a 100644 --- a/relay/channel/aiproxy/adaptor.go +++ b/relay/adaptor/aiproxy/adaptor.go @@ -4,10 +4,10 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) @@ -15,16 +15,16 @@ import ( type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) req.Header.Set("Authorization", "Bearer "+meta.APIKey) return nil } @@ -34,15 +34,22 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } aiProxyLibraryRequest := ConvertRequest(*request) - aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID) + aiProxyLibraryRequest.LibraryId = c.GetString(config.KeyLibraryID) return aiProxyLibraryRequest, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = StreamHandler(c, resp) } else { diff --git a/relay/channel/aiproxy/constants.go b/relay/adaptor/aiproxy/constants.go similarity index 60% rename from relay/channel/aiproxy/constants.go rename to relay/adaptor/aiproxy/constants.go index c4df51c4..818d2709 100644 --- a/relay/channel/aiproxy/constants.go +++ b/relay/adaptor/aiproxy/constants.go @@ -1,6 +1,6 @@ package aiproxy -import "github.com/songquanpeng/one-api/relay/channel/openai" +import "github.com/songquanpeng/one-api/relay/adaptor/openai" var ModelList = []string{""} diff --git a/relay/channel/aiproxy/main.go b/relay/adaptor/aiproxy/main.go similarity index 95% rename from relay/channel/aiproxy/main.go rename to relay/adaptor/aiproxy/main.go index 96972407..01a568f6 100644 --- a/relay/channel/aiproxy/main.go +++ b/relay/adaptor/aiproxy/main.go @@ -8,7 +8,8 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "io" @@ -53,7 +54,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon FinishReason: "stop", } fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion", Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, @@ -66,7 +67,7 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion choice.Delta.Content = aiProxyDocuments2Markdown(documents) choice.FinishReason = &constant.StopFinishReason return &openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: "", @@ -78,7 +79,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = response.Content return &openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: response.Model, diff --git a/relay/channel/aiproxy/model.go b/relay/adaptor/aiproxy/model.go similarity index 100% rename from relay/channel/aiproxy/model.go rename to relay/adaptor/aiproxy/model.go diff --git a/relay/adaptor/ali/adaptor.go b/relay/adaptor/ali/adaptor.go new file mode 100644 index 00000000..21b5e8b8 --- /dev/null +++ b/relay/adaptor/ali/adaptor.go @@ -0,0 +1,105 @@ +package ali + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" +) + +// https://help.aliyun.com/zh/dashscope/developer-reference/api-details + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + fullRequestURL := "" + switch meta.Mode { + case relaymode.Embeddings: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL) + case relaymode.ImagesGenerations: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL) + default: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) + } + + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + if meta.IsStream { + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("X-DashScope-SSE", "enable") + } + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + + if meta.Mode == relaymode.ImagesGenerations { + req.Header.Set("X-DashScope-Async", "enable") + } + if c.GetString(config.KeyPlugin) != "" { + req.Header.Set("X-DashScope-Plugin", c.GetString(config.KeyPlugin)) + } + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case relaymode.Embeddings: + aliEmbeddingRequest := ConvertEmbeddingRequest(*request) + return aliEmbeddingRequest, nil + default: + aliRequest := ConvertRequest(*request) + return aliRequest, nil + } +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + aliRequest := ConvertImageRequest(*request) + return aliRequest, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + case relaymode.ImagesGenerations: + err, usage = ImageHandler(c, resp) + default: + err, usage = Handler(c, resp) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "ali" +} diff --git a/relay/channel/ali/constants.go b/relay/adaptor/ali/constants.go similarity index 65% rename from relay/channel/ali/constants.go rename to relay/adaptor/ali/constants.go index 16bcfca4..3f24ce2e 100644 --- a/relay/channel/ali/constants.go +++ b/relay/adaptor/ali/constants.go @@ -3,4 +3,5 @@ package ali var ModelList = []string{ "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", "text-embedding-v1", + "ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", } diff --git a/relay/adaptor/ali/image.go b/relay/adaptor/ali/image.go new file mode 100644 index 00000000..8261803d --- /dev/null +++ b/relay/adaptor/ali/image.go @@ -0,0 +1,192 @@ +package ali + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" + "time" +) + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + responseFormat := c.GetString("response_format") + + var aliTaskResponse TaskResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &aliTaskResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + if aliTaskResponse.Message != "" { + logger.SysError("aliAsyncTask err: " + string(responseBody)) + return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil + } + + aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey) + if err != nil { + return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Output.TaskStatus != "SUCCEEDED" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: aliResponse.Output.Message, + Type: "ali_error", + Param: "", + Code: aliResponse.Output.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, nil +} + +func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) { + url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID) + + var aliResponse TaskResponse + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return &aliResponse, err, nil + } + + req.Header.Set("Authorization", "Bearer "+key) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + logger.SysError("aliAsyncTask client.Do err: " + err.Error()) + return &aliResponse, err, nil + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + + var response TaskResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + logger.SysError("aliAsyncTask NewDecoder err: " + err.Error()) + return &aliResponse, err, nil + } + + return &response, nil, responseBody +} + +func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) { + waitSeconds := 2 + step := 0 + maxStep := 20 + + var taskResponse TaskResponse + var responseBody []byte + + for { + step++ + rsp, err, body := asyncTask(taskID, key) + responseBody = body + if err != nil { + return &taskResponse, responseBody, err + } + + if rsp.Output.TaskStatus == "" { + return &taskResponse, responseBody, nil + } + + switch rsp.Output.TaskStatus { + case "FAILED": + fallthrough + case "CANCELED": + fallthrough + case "SUCCEEDED": + fallthrough + case "UNKNOWN": + return rsp, responseBody, nil + } + if step >= maxStep { + break + } + time.Sleep(time.Duration(waitSeconds) * time.Second) + } + + return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout") +} + +func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse { + imageResponse := openai.ImageResponse{ + Created: helper.GetTimestamp(), + } + + for _, data := range response.Output.Results { + var b64Json string + if responseFormat == "b64_json" { + // 读取 data.Url 的图片数据并转存到 b64Json + imageData, err := getImageData(data.Url) + if err != nil { + // 处理获取图片数据失败的情况 + logger.SysError("getImageData Error getting image data: " + err.Error()) + continue + } + + // 将图片数据转为 Base64 编码的字符串 + b64Json = Base64Encode(imageData) + } else { + // 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image + b64Json = data.B64Image + } + + imageResponse.Data = append(imageResponse.Data, openai.ImageData{ + Url: data.Url, + B64Json: b64Json, + RevisedPrompt: "", + }) + } + return &imageResponse +} + +func getImageData(url string) ([]byte, error) { + response, err := http.Get(url) + if err != nil { + return nil, err + } + defer response.Body.Close() + + imageData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + return imageData, nil +} + +func Base64Encode(data []byte) string { + b64Json := base64.StdEncoding.EncodeToString(data) + return b64Json +} diff --git a/relay/channel/ali/main.go b/relay/adaptor/ali/main.go similarity index 87% rename from relay/channel/ali/main.go rename to relay/adaptor/ali/main.go index b9625584..0462c26b 100644 --- a/relay/channel/ali/main.go +++ b/relay/adaptor/ali/main.go @@ -7,7 +7,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" "io" "net/http" @@ -33,6 +33,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { enableSearch = true aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) } + if request.TopP >= 1 { + request.TopP = 0.9999 + } return &ChatRequest{ Model: aliModel, Input: Input{ @@ -42,6 +45,12 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { EnableSearch: enableSearch, IncrementalOutput: request.Stream, Seed: uint64(request.Seed), + MaxTokens: request.MaxTokens, + Temperature: request.Temperature, + TopP: request.TopP, + TopK: request.TopK, + ResultFormat: "message", + Tools: request.Tools, }, } } @@ -57,6 +66,17 @@ func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingReque } } +func ConvertImageRequest(request model.ImageRequest) *ImageRequest { + var imageRequest ImageRequest + imageRequest.Input.Prompt = request.Prompt + imageRequest.Model = request.Model + imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) + imageRequest.Parameters.N = request.N + imageRequest.ResponseFormat = request.ResponseFormat + + return &imageRequest +} + func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var aliResponse EmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&aliResponse) @@ -111,19 +131,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR } func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { - choice := openai.TextResponseChoice{ - Index: 0, - Message: model.Message{ - Role: "assistant", - Content: response.Output.Text, - }, - FinishReason: response.Output.FinishReason, - } fullTextResponse := openai.TextResponse{ Id: response.RequestId, Object: "chat.completion", Created: helper.GetTimestamp(), - Choices: []openai.TextResponseChoice{choice}, + Choices: response.Output.Choices, Usage: model.Usage{ PromptTokens: response.Usage.InputTokens, CompletionTokens: response.Usage.OutputTokens, @@ -134,10 +146,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { } func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + if len(aliResponse.Output.Choices) == 0 { + return nil + } + aliChoice := aliResponse.Output.Choices[0] var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = aliResponse.Output.Text - if aliResponse.Output.FinishReason != "null" { - finishReason := aliResponse.Output.FinishReason + choice.Delta = aliChoice.Message + if aliChoice.FinishReason != "null" { + finishReason := aliChoice.FinishReason choice.FinishReason = &finishReason } response := openai.ChatCompletionsStreamResponse{ @@ -198,6 +214,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens } response := streamResponseAli2OpenAI(&aliResponse) + if response == nil { + return true + } //response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) //lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) @@ -220,6 +239,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + ctx := c.Request.Context() var aliResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -229,6 +249,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + logger.Debugf(ctx, "response body: %s\n", responseBody) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/adaptor/ali/model.go b/relay/adaptor/ali/model.go new file mode 100644 index 00000000..450b5f52 --- /dev/null +++ b/relay/adaptor/ali/model.go @@ -0,0 +1,154 @@ +package ali + +import ( + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +type Message struct { + Content string `json:"content"` + Role string `json:"role"` +} + +type Input struct { + //Prompt string `json:"prompt"` + Messages []Message `json:"messages"` +} + +type Parameters struct { + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` + IncrementalOutput bool `json:"incremental_output,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + ResultFormat string `json:"result_format,omitempty"` + Tools []model.Tool `json:"tools,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model"` + Input Input `json:"input"` + Parameters Parameters `json:"parameters,omitempty"` +} + +type ImageRequest struct { + Model string `json:"model"` + Input struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + } `json:"input"` + Parameters struct { + Size string `json:"size,omitempty"` + N int `json:"n,omitempty"` + Steps string `json:"steps,omitempty"` + Scale string `json:"scale,omitempty"` + } `json:"parameters,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` +} + +type TaskResponse struct { + StatusCode int `json:"status_code,omitempty"` + RequestId string `json:"request_id,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Output struct { + TaskId string `json:"task_id,omitempty"` + TaskStatus string `json:"task_status,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Results []struct { + B64Image string `json:"b64_image,omitempty"` + Url string `json:"url,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + } `json:"results,omitempty"` + TaskMetrics struct { + Total int `json:"TOTAL,omitempty"` + Succeeded int `json:"SUCCEEDED,omitempty"` + Failed int `json:"FAILED,omitempty"` + } `json:"task_metrics,omitempty"` + } `json:"output,omitempty"` + Usage Usage `json:"usage"` +} + +type Header struct { + Action string `json:"action,omitempty"` + Streaming string `json:"streaming,omitempty"` + TaskID string `json:"task_id,omitempty"` + Event string `json:"event,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` + Attributes any `json:"attributes,omitempty"` +} + +type Payload struct { + Model string `json:"model,omitempty"` + Task string `json:"task,omitempty"` + TaskGroup string `json:"task_group,omitempty"` + Function string `json:"function,omitempty"` + Parameters struct { + SampleRate int `json:"sample_rate,omitempty"` + Rate float64 `json:"rate,omitempty"` + Format string `json:"format,omitempty"` + } `json:"parameters,omitempty"` + Input struct { + Text string `json:"text,omitempty"` + } `json:"input,omitempty"` + Usage struct { + Characters int `json:"characters,omitempty"` + } `json:"usage,omitempty"` +} + +type WSSMessage struct { + Header Header `json:"header,omitempty"` + Payload Payload `json:"payload,omitempty"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type Embedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type EmbeddingResponse struct { + Output struct { + Embeddings []Embedding `json:"embeddings"` + } `json:"output"` + Usage Usage `json:"usage"` + Error +} + +type Error struct { + Code string `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Output struct { + //Text string `json:"text"` + //FinishReason string `json:"finish_reason"` + Choices []openai.TextResponseChoice `json:"choices"` +} + +type ChatResponse struct { + Output Output `json:"output"` + Usage Usage `json:"usage"` + Error +} diff --git a/relay/channel/anthropic/adaptor.go b/relay/adaptor/anthropic/adaptor.go similarity index 53% rename from relay/channel/anthropic/adaptor.go rename to relay/adaptor/anthropic/adaptor.go index 4b873715..b1136e84 100644 --- a/relay/channel/anthropic/adaptor.go +++ b/relay/adaptor/anthropic/adaptor.go @@ -4,10 +4,9 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) @@ -15,22 +14,23 @@ import ( type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) req.Header.Set("x-api-key", meta.APIKey) anthropicVersion := c.Request.Header.Get("anthropic-version") if anthropicVersion == "" { anthropicVersion = "2023-06-01" } req.Header.Set("anthropic-version", anthropicVersion) + req.Header.Set("anthropic-beta", "messages-2023-12-15") return nil } @@ -41,15 +41,20 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { - var responseText string - err, responseText = StreamHandler(c, resp) - usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + err, usage = StreamHandler(c, resp) } else { err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) } @@ -61,5 +66,5 @@ func (a *Adaptor) GetModelList() []string { } func (a *Adaptor) GetChannelName() string { - return "authropic" + return "anthropic" } diff --git a/relay/adaptor/anthropic/constants.go b/relay/adaptor/anthropic/constants.go new file mode 100644 index 00000000..cadcedc8 --- /dev/null +++ b/relay/adaptor/anthropic/constants.go @@ -0,0 +1,8 @@ +package anthropic + +var ModelList = []string{ + "claude-instant-1.2", "claude-2.0", "claude-2.1", + "claude-3-haiku-20240307", + "claude-3-sonnet-20240229", + "claude-3-opus-20240229", +} diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go new file mode 100644 index 00000000..6bb82d01 --- /dev/null +++ b/relay/adaptor/anthropic/main.go @@ -0,0 +1,273 @@ +package anthropic + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/image" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" +) + +func stopReasonClaude2OpenAI(reason *string) string { + if reason == nil { + return "" + } + switch *reason { + case "end_turn": + return "stop" + case "stop_sequence": + return "stop" + case "max_tokens": + return "length" + default: + return *reason + } +} + +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { + claudeRequest := Request{ + Model: textRequest.Model, + MaxTokens: textRequest.MaxTokens, + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + TopK: textRequest.TopK, + Stream: textRequest.Stream, + } + if claudeRequest.MaxTokens == 0 { + claudeRequest.MaxTokens = 4096 + } + // legacy model name mapping + if claudeRequest.Model == "claude-instant-1" { + claudeRequest.Model = "claude-instant-1.1" + } else if claudeRequest.Model == "claude-2" { + claudeRequest.Model = "claude-2.1" + } + for _, message := range textRequest.Messages { + if message.Role == "system" && claudeRequest.System == "" { + claudeRequest.System = message.StringContent() + continue + } + claudeMessage := Message{ + Role: message.Role, + } + var content Content + if message.IsStringContent() { + content.Type = "text" + content.Text = message.StringContent() + claudeMessage.Content = append(claudeMessage.Content, content) + claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) + continue + } + var contents []Content + openaiContent := message.ParseContent() + for _, part := range openaiContent { + var content Content + if part.Type == model.ContentTypeText { + content.Type = "text" + content.Text = part.Text + } else if part.Type == model.ContentTypeImageURL { + content.Type = "image" + content.Source = &ImageSource{ + Type: "base64", + } + mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) + content.Source.MediaType = mimeType + content.Source.Data = data + } + contents = append(contents, content) + } + claudeMessage.Content = contents + claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) + } + return &claudeRequest +} + +// https://docs.anthropic.com/claude/reference/messages-streaming +func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { + var response *Response + var responseText string + var stopReason string + switch claudeResponse.Type { + case "message_start": + return nil, claudeResponse.Message + case "content_block_start": + if claudeResponse.ContentBlock != nil { + responseText = claudeResponse.ContentBlock.Text + } + case "content_block_delta": + if claudeResponse.Delta != nil { + responseText = claudeResponse.Delta.Text + } + case "message_delta": + if claudeResponse.Usage != nil { + response = &Response{ + Usage: *claudeResponse.Usage, + } + } + if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil { + stopReason = *claudeResponse.Delta.StopReason + } + } + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = responseText + choice.Delta.Role = "assistant" + finishReason := stopReasonClaude2OpenAI(&stopReason) + if finishReason != "null" { + choice.FinishReason = &finishReason + } + var openaiResponse openai.ChatCompletionsStreamResponse + openaiResponse.Object = "chat.completion.chunk" + openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + return &openaiResponse, response +} + +func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { + var responseText string + if len(claudeResponse.Content) > 0 { + responseText = claudeResponse.Content[0].Text + } + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: responseText, + Name: nil, + }, + FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id), + Model: claudeResponse.Model, + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + createdTime := helper.GetTimestamp() + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { + continue + } + if !strings.HasPrefix(data, "data: ") { + continue + } + data = strings.TrimPrefix(data, "data: ") + dataChan <- data + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + var usage model.Usage + var modelName string + var id string + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + var claudeResponse StreamResponse + err := json.Unmarshal([]byte(data), &claudeResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + response, meta := streamResponseClaude2OpenAI(&claudeResponse) + if meta != nil { + usage.PromptTokens += meta.Usage.InputTokens + usage.CompletionTokens += meta.Usage.OutputTokens + modelName = meta.Model + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + return true + } + if response == nil { + return true + } + response.Id = id + response.Model = modelName + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + _ = resp.Body.Close() + return nil, &usage +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var claudeResponse Response + err = json.Unmarshal(responseBody, &claudeResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if claudeResponse.Error.Type != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: claudeResponse.Error.Message, + Type: claudeResponse.Error.Type, + Param: "", + Code: claudeResponse.Error.Type, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseClaude2OpenAI(&claudeResponse) + fullTextResponse.Model = modelName + usage := model.Usage{ + PromptTokens: claudeResponse.Usage.InputTokens, + CompletionTokens: claudeResponse.Usage.OutputTokens, + TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/relay/adaptor/anthropic/model.go b/relay/adaptor/anthropic/model.go new file mode 100644 index 00000000..32b187cd --- /dev/null +++ b/relay/adaptor/anthropic/model.go @@ -0,0 +1,75 @@ +package anthropic + +// https://docs.anthropic.com/claude/reference/messages_post + +type Metadata struct { + UserId string `json:"user_id"` +} + +type ImageSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +type Content struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Source *ImageSource `json:"source,omitempty"` +} + +type Message struct { + Role string `json:"role"` + Content []Content `json:"content"` +} + +type Request struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + System string `json:"system,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + //Metadata `json:"metadata,omitempty"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type Error struct { + Type string `json:"type"` + Message string `json:"message"` +} + +type Response struct { + Id string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []Content `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` + Usage Usage `json:"usage"` + Error Error `json:"error"` +} + +type Delta struct { + Type string `json:"type"` + Text string `json:"text"` + StopReason *string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` +} + +type StreamResponse struct { + Type string `json:"type"` + Message *Response `json:"message"` + Index int `json:"index"` + ContentBlock *Content `json:"content_block"` + Delta *Delta `json:"delta"` + Usage *Usage `json:"usage"` +} diff --git a/relay/adaptor/azure/helper.go b/relay/adaptor/azure/helper.go new file mode 100644 index 00000000..dd207f37 --- /dev/null +++ b/relay/adaptor/azure/helper.go @@ -0,0 +1,15 @@ +package azure + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" +) + +func GetAPIVersion(c *gin.Context) string { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString(config.KeyAPIVersion) + } + return apiVersion +} diff --git a/relay/adaptor/baichuan/constants.go b/relay/adaptor/baichuan/constants.go new file mode 100644 index 00000000..cb20a1ff --- /dev/null +++ b/relay/adaptor/baichuan/constants.go @@ -0,0 +1,7 @@ +package baichuan + +var ModelList = []string{ + "Baichuan2-Turbo", + "Baichuan2-Turbo-192k", + "Baichuan-Text-Embedding", +} diff --git a/relay/adaptor/baidu/adaptor.go b/relay/adaptor/baidu/adaptor.go new file mode 100644 index 00000000..15306b95 --- /dev/null +++ b/relay/adaptor/baidu/adaptor.go @@ -0,0 +1,143 @@ +package baidu + +import ( + "errors" + "fmt" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/model" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t + suffix := "chat/" + if strings.HasPrefix(meta.ActualModelName, "Embedding") { + suffix = "embeddings/" + } + if strings.HasPrefix(meta.ActualModelName, "bge-large") { + suffix = "embeddings/" + } + if strings.HasPrefix(meta.ActualModelName, "tao-8k") { + suffix = "embeddings/" + } + switch meta.ActualModelName { + case "ERNIE-4.0": + suffix += "completions_pro" + case "ERNIE-Bot-4": + suffix += "completions_pro" + case "ERNIE-Bot": + suffix += "completions" + case "ERNIE-Bot-turbo": + suffix += "eb-instant" + case "ERNIE-Speed": + suffix += "ernie_speed" + case "ERNIE-4.0-8K": + suffix += "completions_pro" + case "ERNIE-3.5-8K": + suffix += "completions" + case "ERNIE-3.5-8K-0205": + suffix += "ernie-3.5-8k-0205" + case "ERNIE-3.5-8K-1222": + suffix += "ernie-3.5-8k-1222" + case "ERNIE-Bot-8K": + suffix += "ernie_bot_8k" + case "ERNIE-3.5-4K-0205": + suffix += "ernie-3.5-4k-0205" + case "ERNIE-Speed-8K": + suffix += "ernie_speed" + case "ERNIE-Speed-128K": + suffix += "ernie-speed-128k" + case "ERNIE-Lite-8K-0922": + suffix += "eb-instant" + case "ERNIE-Lite-8K-0308": + suffix += "ernie-lite-8k" + case "ERNIE-Tiny-8K": + suffix += "ernie-tiny-8k" + case "BLOOMZ-7B": + suffix += "bloomz_7b1" + case "Embedding-V1": + suffix += "embedding-v1" + case "bge-large-zh": + suffix += "bge_large_zh" + case "bge-large-en": + suffix += "bge_large_en" + case "tao-8k": + suffix += "tao_8k" + default: + suffix += strings.ToLower(meta.ActualModelName) + } + fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) + var accessToken string + var err error + if accessToken, err = GetAccessToken(meta.APIKey); err != nil { + return "", err + } + fullRequestURL += "?access_token=" + accessToken + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case relaymode.Embeddings: + baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) + return baiduEmbeddingRequest, nil + default: + baiduRequest := ConvertRequest(*request) + return baiduRequest, nil + } +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "baidu" +} diff --git a/relay/adaptor/baidu/constants.go b/relay/adaptor/baidu/constants.go new file mode 100644 index 00000000..f952adc6 --- /dev/null +++ b/relay/adaptor/baidu/constants.go @@ -0,0 +1,20 @@ +package baidu + +var ModelList = []string{ + "ERNIE-4.0-8K", + "ERNIE-3.5-8K", + "ERNIE-3.5-8K-0205", + "ERNIE-3.5-8K-1222", + "ERNIE-Bot-8K", + "ERNIE-3.5-4K-0205", + "ERNIE-Speed-8K", + "ERNIE-Speed-128K", + "ERNIE-Lite-8K-0922", + "ERNIE-Lite-8K-0308", + "ERNIE-Tiny-8K", + "BLOOMZ-7B", + "Embedding-V1", + "bge-large-zh", + "bge-large-en", + "tao-8k", +} diff --git a/relay/channel/baidu/main.go b/relay/adaptor/baidu/main.go similarity index 87% rename from relay/channel/baidu/main.go rename to relay/adaptor/baidu/main.go index 4f2b13fc..6df5ce84 100644 --- a/relay/channel/baidu/main.go +++ b/relay/adaptor/baidu/main.go @@ -8,10 +8,10 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/client" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "strings" @@ -32,9 +32,16 @@ type Message struct { } type ChatRequest struct { - Messages []Message `json:"messages"` - Stream bool `json:"stream"` - UserId string `json:"user_id,omitempty"` + Messages []Message `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + PenaltyScore float64 `json:"penalty_score,omitempty"` + Stream bool `json:"stream,omitempty"` + System string `json:"system,omitempty"` + DisableSearch bool `json:"disable_search,omitempty"` + EnableCitation bool `json:"enable_citation,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + UserId string `json:"user_id,omitempty"` } type Error struct { @@ -45,28 +52,28 @@ type Error struct { var baiduTokenStore sync.Map func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { - messages := make([]Message, 0, len(request.Messages)) + baiduRequest := ChatRequest{ + Messages: make([]Message, 0, len(request.Messages)), + Temperature: request.Temperature, + TopP: request.TopP, + PenaltyScore: request.FrequencyPenalty, + Stream: request.Stream, + DisableSearch: false, + EnableCitation: false, + MaxOutputTokens: request.MaxTokens, + UserId: request.User, + } for _, message := range request.Messages { if message.Role == "system" { - messages = append(messages, Message{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, Message{ - Role: "assistant", - Content: "Okay", - }) + baiduRequest.System = message.StringContent() } else { - messages = append(messages, Message{ + baiduRequest.Messages = append(baiduRequest.Messages, Message{ Role: message.Role, Content: message.StringContent(), }) } } - return &ChatRequest{ - Messages: messages, - Stream: request.Stream, - } + return &baiduRequest } func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { @@ -298,7 +305,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) { } req.Header.Add("Content-Type", "application/json") req.Header.Add("Accept", "application/json") - res, err := util.ImpatientHTTPClient.Do(req) + res, err := client.ImpatientHTTPClient.Do(req) if err != nil { return nil, err } diff --git a/relay/channel/baidu/model.go b/relay/adaptor/baidu/model.go similarity index 100% rename from relay/channel/baidu/model.go rename to relay/adaptor/baidu/model.go diff --git a/relay/channel/common.go b/relay/adaptor/common.go similarity index 80% rename from relay/channel/common.go rename to relay/adaptor/common.go index c6e1abf2..82a5160e 100644 --- a/relay/channel/common.go +++ b/relay/adaptor/common.go @@ -1,15 +1,16 @@ -package channel +package adaptor import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/client" + "github.com/songquanpeng/one-api/relay/meta" "io" "net/http" ) -func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) { +func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) { req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) if meta.IsStream && c.Request.Header.Get("Accept") == "" { @@ -17,7 +18,7 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.Rela } } -func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { fullRequestURL, err := a.GetRequestURL(meta) if err != nil { return nil, fmt.Errorf("get request url failed: %w", err) @@ -38,7 +39,7 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBod } func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) { - resp, err := util.HTTPClient.Do(req) + resp, err := client.HTTPClient.Do(req) if err != nil { return nil, err } diff --git a/relay/channel/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go similarity index 63% rename from relay/channel/gemini/adaptor.go rename to relay/adaptor/gemini/adaptor.go index f3305e5d..6a2867e4 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -4,11 +4,12 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" - channelhelper "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/openai" + channelhelper "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) @@ -16,12 +17,12 @@ import ( type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - version := helper.AssignOrDefault(meta.APIVersion, "v1") +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + version := helper.AssignOrDefault(meta.APIVersion, config.GeminiVersion) action := "generateContent" if meta.IsStream { action = "streamGenerateContent" @@ -29,7 +30,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channelhelper.SetupCommonRequestHeader(c, req, meta) req.Header.Set("x-goog-api-key", meta.APIKey) return nil @@ -42,11 +43,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channelhelper.DoRequestHelper(a, c, meta, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string err, responseText = StreamHandler(c, resp) diff --git a/relay/adaptor/gemini/constants.go b/relay/adaptor/gemini/constants.go new file mode 100644 index 00000000..32e7c240 --- /dev/null +++ b/relay/adaptor/gemini/constants.go @@ -0,0 +1,8 @@ +package gemini + +// https://ai.google.dev/models/gemini + +var ModelList = []string{ + "gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro", + "gemini-pro-vision", "gemini-1.0-pro-vision-001", +} diff --git a/relay/channel/gemini/main.go b/relay/adaptor/gemini/main.go similarity index 97% rename from relay/channel/gemini/main.go rename to relay/adaptor/gemini/main.go index c24694c8..6bf0c6d7 100644 --- a/relay/channel/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -9,7 +9,8 @@ import ( "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "io" @@ -155,7 +156,7 @@ type ChatPromptFeedback struct { func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion", Created: helper.GetTimestamp(), Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), @@ -233,7 +234,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = dummy.Content response := openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: "gemini-pro", diff --git a/relay/channel/gemini/model.go b/relay/adaptor/gemini/model.go similarity index 100% rename from relay/channel/gemini/model.go rename to relay/adaptor/gemini/model.go diff --git a/relay/adaptor/groq/constants.go b/relay/adaptor/groq/constants.go new file mode 100644 index 00000000..fc9a9ebd --- /dev/null +++ b/relay/adaptor/groq/constants.go @@ -0,0 +1,10 @@ +package groq + +// https://console.groq.com/docs/models + +var ModelList = []string{ + "gemma-7b-it", + "llama2-7b-2048", + "llama2-70b-4096", + "mixtral-8x7b-32768", +} diff --git a/relay/adaptor/interface.go b/relay/adaptor/interface.go new file mode 100644 index 00000000..01b2e2cb --- /dev/null +++ b/relay/adaptor/interface.go @@ -0,0 +1,21 @@ +package adaptor + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +type Adaptor interface { + Init(meta *meta.Meta) + GetRequestURL(meta *meta.Meta) (string, error) + SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error + ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + ConvertImageRequest(request *model.ImageRequest) (any, error) + DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) + DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) + GetModelList() []string + GetChannelName() string +} diff --git a/relay/adaptor/lingyiwanwu/constants.go b/relay/adaptor/lingyiwanwu/constants.go new file mode 100644 index 00000000..30000e9d --- /dev/null +++ b/relay/adaptor/lingyiwanwu/constants.go @@ -0,0 +1,9 @@ +package lingyiwanwu + +// https://platform.lingyiwanwu.com/docs + +var ModelList = []string{ + "yi-34b-chat-0205", + "yi-34b-chat-200k", + "yi-vl-plus", +} diff --git a/relay/adaptor/minimax/constants.go b/relay/adaptor/minimax/constants.go new file mode 100644 index 00000000..c3da5b2d --- /dev/null +++ b/relay/adaptor/minimax/constants.go @@ -0,0 +1,7 @@ +package minimax + +var ModelList = []string{ + "abab5.5s-chat", + "abab5.5-chat", + "abab6-chat", +} diff --git a/relay/adaptor/minimax/main.go b/relay/adaptor/minimax/main.go new file mode 100644 index 00000000..fc9b5d26 --- /dev/null +++ b/relay/adaptor/minimax/main.go @@ -0,0 +1,14 @@ +package minimax + +import ( + "fmt" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +func GetRequestURL(meta *meta.Meta) (string, error) { + if meta.Mode == relaymode.ChatCompletions { + return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil + } + return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode) +} diff --git a/relay/adaptor/mistral/constants.go b/relay/adaptor/mistral/constants.go new file mode 100644 index 00000000..cdb157f5 --- /dev/null +++ b/relay/adaptor/mistral/constants.go @@ -0,0 +1,10 @@ +package mistral + +var ModelList = []string{ + "open-mistral-7b", + "open-mixtral-8x7b", + "mistral-small-latest", + "mistral-medium-latest", + "mistral-large-latest", + "mistral-embed", +} diff --git a/relay/channel/moonshot/constants.go b/relay/adaptor/moonshot/constants.go similarity index 100% rename from relay/channel/moonshot/constants.go rename to relay/adaptor/moonshot/constants.go diff --git a/relay/adaptor/ollama/adaptor.go b/relay/adaptor/ollama/adaptor.go new file mode 100644 index 00000000..66702c5d --- /dev/null +++ b/relay/adaptor/ollama/adaptor.go @@ -0,0 +1,82 @@ +package ollama + +import ( + "errors" + "fmt" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/model" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + // https://github.com/ollama/ollama/blob/main/docs/api.md + fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) + if meta.Mode == relaymode.Embeddings { + fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) + } + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case relaymode.Embeddings: + ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request) + return ollamaEmbeddingRequest, nil + default: + return ConvertRequest(*request), nil + } +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "ollama" +} diff --git a/relay/adaptor/ollama/constants.go b/relay/adaptor/ollama/constants.go new file mode 100644 index 00000000..32f82b2a --- /dev/null +++ b/relay/adaptor/ollama/constants.go @@ -0,0 +1,5 @@ +package ollama + +var ModelList = []string{ + "qwen:0.5b-chat", +} diff --git a/relay/adaptor/ollama/main.go b/relay/adaptor/ollama/main.go new file mode 100644 index 00000000..a7e4c058 --- /dev/null +++ b/relay/adaptor/ollama/main.go @@ -0,0 +1,238 @@ +package ollama + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/random" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" +) + +func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { + ollamaRequest := ChatRequest{ + Model: request.Model, + Options: &Options{ + Seed: int(request.Seed), + Temperature: request.Temperature, + TopP: request.TopP, + FrequencyPenalty: request.FrequencyPenalty, + PresencePenalty: request.PresencePenalty, + }, + Stream: request.Stream, + } + for _, message := range request.Messages { + ollamaRequest.Messages = append(ollamaRequest.Messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) + } + return &ollamaRequest +} + +func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse { + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: response.Message.Role, + Content: response.Message.Content, + }, + } + if response.Done { + choice.FinishReason = "stop" + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + Usage: model.Usage{ + PromptTokens: response.PromptEvalCount, + CompletionTokens: response.EvalCount, + TotalTokens: response.PromptEvalCount + response.EvalCount, + }, + } + return &fullTextResponse +} + +func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Role = ollamaResponse.Message.Role + choice.Delta.Content = ollamaResponse.Message.Content + if ollamaResponse.Done { + choice.FinishReason = &constant.StopFinishReason + } + response := openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: ollamaResponse.Model, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage model.Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "}\n"); i >= 0 { + return i + 2, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := strings.TrimPrefix(scanner.Text(), "}") + dataChan <- data + "}" + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var ollamaResponse ChatResponse + err := json.Unmarshal([]byte(data), &ollamaResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if ollamaResponse.EvalCount != 0 { + usage.PromptTokens = ollamaResponse.PromptEvalCount + usage.CompletionTokens = ollamaResponse.EvalCount + usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount + } + response := streamResponseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + Model: request.Model, + Prompt: strings.Join(request.ParseInput(), " "), + } +} + +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var ollamaResponse EmbeddingResponse + err := json.NewDecoder(resp.Body).Decode(&ollamaResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if ollamaResponse.Error != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: ollamaResponse.Error, + Type: "ollama_error", + Param: "", + Code: "ollama_error", + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, 1), + Model: "text-embedding-v1", + Usage: model.Usage{TotalTokens: 0}, + } + + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: 0, + Embedding: response.Embedding, + }) + return &openAIEmbeddingResponse +} + +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + ctx := context.TODO() + var ollamaResponse ChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + logger.Debugf(ctx, "ollama response: %s", string(responseBody)) + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &ollamaResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if ollamaResponse.Error != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: ollamaResponse.Error, + Type: "ollama_error", + Param: "", + Code: "ollama_error", + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/relay/adaptor/ollama/model.go b/relay/adaptor/ollama/model.go new file mode 100644 index 00000000..8baf56a0 --- /dev/null +++ b/relay/adaptor/ollama/model.go @@ -0,0 +1,47 @@ +package ollama + +type Options struct { + Seed int `json:"seed,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` +} + +type Message struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Stream bool `json:"stream"` + Options *Options `json:"options,omitempty"` +} + +type ChatResponse struct { + Model string `json:"model,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + Message Message `json:"message,omitempty"` + Response string `json:"response,omitempty"` // for stream response + Done bool `json:"done,omitempty"` + TotalDuration int `json:"total_duration,omitempty"` + LoadDuration int `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration int `json:"eval_duration,omitempty"` + Error string `json:"error,omitempty"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` +} + +type EmbeddingResponse struct { + Error string `json:"error,omitempty"` + Embedding []float64 `json:"embedding,omitempty"` +} diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go new file mode 100644 index 00000000..4bb2384e --- /dev/null +++ b/relay/adaptor/openai/adaptor.go @@ -0,0 +1,111 @@ +package openai + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/minimax" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" + "strings" +) + +type Adaptor struct { + ChannelType int +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.ChannelType = meta.ChannelType +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + switch meta.ChannelType { + case channeltype.Azure: + if meta.Mode == relaymode.ImagesGenerations { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api + // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview + fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion) + return fullRequestURL, nil + } + + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api + requestURL := strings.Split(meta.RequestURLPath, "?")[0] + requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) + task := strings.TrimPrefix(requestURL, "/v1/") + model_ := meta.ActualModelName + model_ = strings.Replace(model_, ".", "", -1) + //https://github.com/songquanpeng/one-api/issues/1191 + // {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version} + requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) + return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil + case channeltype.Minimax: + return minimax.GetRequestURL(meta) + default: + return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil + } +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + if meta.ChannelType == channeltype.Azure { + req.Header.Set("api-key", meta.APIKey) + return nil + } + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + if meta.ChannelType == channeltype.OpenRouter { + req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") + req.Header.Set("X-Title", "One API") + } + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + var responseText string + err, responseText, usage = StreamHandler(c, resp, meta.Mode) + if usage == nil { + usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } + } else { + switch meta.Mode { + case relaymode.ImagesGenerations: + err, _ = ImageHandler(c, resp) + default: + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + _, modelList := GetCompatibleChannelMeta(a.ChannelType) + return modelList +} + +func (a *Adaptor) GetChannelName() string { + channelName, _ := GetCompatibleChannelMeta(a.ChannelType) + return channelName +} diff --git a/relay/adaptor/openai/compatible.go b/relay/adaptor/openai/compatible.go new file mode 100644 index 00000000..200eac44 --- /dev/null +++ b/relay/adaptor/openai/compatible.go @@ -0,0 +1,50 @@ +package openai + +import ( + "github.com/songquanpeng/one-api/relay/adaptor/ai360" + "github.com/songquanpeng/one-api/relay/adaptor/baichuan" + "github.com/songquanpeng/one-api/relay/adaptor/groq" + "github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu" + "github.com/songquanpeng/one-api/relay/adaptor/minimax" + "github.com/songquanpeng/one-api/relay/adaptor/mistral" + "github.com/songquanpeng/one-api/relay/adaptor/moonshot" + "github.com/songquanpeng/one-api/relay/adaptor/stepfun" + "github.com/songquanpeng/one-api/relay/channeltype" +) + +var CompatibleChannels = []int{ + channeltype.Azure, + channeltype.AI360, + channeltype.Moonshot, + channeltype.Baichuan, + channeltype.Minimax, + channeltype.Mistral, + channeltype.Groq, + channeltype.LingYiWanWu, + channeltype.StepFun, +} + +func GetCompatibleChannelMeta(channelType int) (string, []string) { + switch channelType { + case channeltype.Azure: + return "azure", ModelList + case channeltype.AI360: + return "360", ai360.ModelList + case channeltype.Moonshot: + return "moonshot", moonshot.ModelList + case channeltype.Baichuan: + return "baichuan", baichuan.ModelList + case channeltype.Minimax: + return "minimax", minimax.ModelList + case channeltype.Mistral: + return "mistralai", mistral.ModelList + case channeltype.Groq: + return "groq", groq.ModelList + case channeltype.LingYiWanWu: + return "lingyiwanwu", lingyiwanwu.ModelList + case channeltype.StepFun: + return "stepfun", stepfun.ModelList + default: + return "openai", ModelList + } +} diff --git a/relay/channel/openai/constants.go b/relay/adaptor/openai/constants.go similarity index 100% rename from relay/channel/openai/constants.go rename to relay/adaptor/openai/constants.go diff --git a/relay/adaptor/openai/helper.go b/relay/adaptor/openai/helper.go new file mode 100644 index 00000000..7d73303b --- /dev/null +++ b/relay/adaptor/openai/helper.go @@ -0,0 +1,30 @@ +package openai + +import ( + "fmt" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/model" + "strings" +) + +func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { + usage := &model.Usage{} + usage.PromptTokens = promptTokens + usage.CompletionTokens = CountTokenText(responseText, modeName) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + return usage +} + +func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + switch channelType { + case channeltype.OpenAI: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + case channeltype.Azure: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) + } + } + return fullRequestURL +} diff --git a/relay/adaptor/openai/image.go b/relay/adaptor/openai/image.go new file mode 100644 index 00000000..0f89618a --- /dev/null +++ b/relay/adaptor/openai/image.go @@ -0,0 +1,44 @@ +package openai + +import ( + "bytes" + "encoding/json" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var imageResponse ImageResponse + responseBody, err := io.ReadAll(resp.Body) + + if err != nil { + return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &imageResponse) + if err != nil { + return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, nil +} diff --git a/relay/channel/openai/main.go b/relay/adaptor/openai/main.go similarity index 90% rename from relay/channel/openai/main.go rename to relay/adaptor/openai/main.go index fbe55cf9..68d8f48f 100644 --- a/relay/channel/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -6,15 +6,16 @@ import ( "encoding/json" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "strings" ) -func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { responseText := "" scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -31,6 +32,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E }) dataChan := make(chan string) stopChan := make(chan bool) + var usage *model.Usage go func() { for scanner.Scan() { data := scanner.Text() @@ -44,7 +46,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E data = data[6:] if !strings.HasPrefix(data, "[DONE]") { switch relayMode { - case constant.RelayModeChatCompletions: + case relaymode.ChatCompletions: var streamResponse ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) if err != nil { @@ -52,9 +54,12 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E continue // just ignore the error } for _, choice := range streamResponse.Choices { - responseText += choice.Delta.Content + responseText += conv.AsString(choice.Delta.Content) } - case constant.RelayModeCompletions: + if streamResponse.Usage != nil { + usage = streamResponse.Usage + } + case relaymode.Completions: var streamResponse CompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) if err != nil { @@ -86,9 +91,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E }) err := resp.Body.Close() if err != nil { - return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil } - return nil, responseText + return nil, responseText, usage } func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { diff --git a/relay/channel/openai/model.go b/relay/adaptor/openai/model.go similarity index 89% rename from relay/channel/openai/model.go rename to relay/adaptor/openai/model.go index b24485a8..ce252ff6 100644 --- a/relay/channel/openai/model.go +++ b/relay/adaptor/openai/model.go @@ -110,20 +110,22 @@ type EmbeddingResponse struct { model.Usage `json:"usage"` } +type ImageData struct { + Url string `json:"url,omitempty"` + B64Json string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` +} + type ImageResponse struct { - Created int `json:"created"` - Data []struct { - Url string `json:"url"` - } + Created int64 `json:"created"` + Data []ImageData `json:"data"` + //model.Usage `json:"usage"` } type ChatCompletionsStreamResponseChoice struct { - Index int `json:"index"` - Delta struct { - Content string `json:"content"` - Role string `json:"role,omitempty"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` + Index int `json:"index"` + Delta model.Message `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` } type ChatCompletionsStreamResponse struct { @@ -132,6 +134,7 @@ type ChatCompletionsStreamResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionsStreamResponseChoice `json:"choices"` + Usage *model.Usage `json:"usage"` } type CompletionsStreamResponse struct { diff --git a/relay/channel/openai/token.go b/relay/adaptor/openai/token.go similarity index 98% rename from relay/channel/openai/token.go rename to relay/adaptor/openai/token.go index 0720425f..c95a7b5e 100644 --- a/relay/channel/openai/token.go +++ b/relay/adaptor/openai/token.go @@ -4,10 +4,10 @@ import ( "errors" "fmt" "github.com/pkoukk/tiktoken-go" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/model" "math" "strings" @@ -28,7 +28,7 @@ func InitTokenEncoders() { if err != nil { logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) } - for model := range common.ModelRatio { + for model := range billingratio.ModelRatio { if strings.HasPrefix(model, "gpt-3.5") { tokenEncoderMap[model] = gpt35TokenEncoder } else if strings.HasPrefix(model, "gpt-4") { diff --git a/relay/channel/openai/util.go b/relay/adaptor/openai/util.go similarity index 100% rename from relay/channel/openai/util.go rename to relay/adaptor/openai/util.go diff --git a/relay/channel/palm/adaptor.go b/relay/adaptor/palm/adaptor.go similarity index 59% rename from relay/channel/palm/adaptor.go rename to relay/adaptor/palm/adaptor.go index efd0620c..98aa3e18 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/adaptor/palm/adaptor.go @@ -4,10 +4,10 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) @@ -15,16 +15,16 @@ import ( type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) req.Header.Set("x-goog-api-key", meta.APIKey) return nil } @@ -36,11 +36,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string err, responseText = StreamHandler(c, resp) diff --git a/relay/channel/palm/constants.go b/relay/adaptor/palm/constants.go similarity index 100% rename from relay/channel/palm/constants.go rename to relay/adaptor/palm/constants.go diff --git a/relay/channel/palm/model.go b/relay/adaptor/palm/model.go similarity index 100% rename from relay/channel/palm/model.go rename to relay/adaptor/palm/model.go diff --git a/relay/channel/palm/palm.go b/relay/adaptor/palm/palm.go similarity index 97% rename from relay/channel/palm/palm.go rename to relay/adaptor/palm/palm.go index 56738544..1e60e7cd 100644 --- a/relay/channel/palm/palm.go +++ b/relay/adaptor/palm/palm.go @@ -7,7 +7,8 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "io" @@ -74,7 +75,7 @@ func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletio func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) + responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID()) createdTime := helper.GetTimestamp() dataChan := make(chan string) stopChan := make(chan bool) diff --git a/relay/adaptor/stepfun/constants.go b/relay/adaptor/stepfun/constants.go new file mode 100644 index 00000000..a82e562b --- /dev/null +++ b/relay/adaptor/stepfun/constants.go @@ -0,0 +1,7 @@ +package stepfun + +var ModelList = []string{ + "step-1-32k", + "step-1v-32k", + "step-1-200k", +} diff --git a/relay/channel/tencent/adaptor.go b/relay/adaptor/tencent/adaptor.go similarity index 66% rename from relay/channel/tencent/adaptor.go rename to relay/adaptor/tencent/adaptor.go index f348674e..a97476d6 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/adaptor/tencent/adaptor.go @@ -4,10 +4,10 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "strings" @@ -19,16 +19,16 @@ type Adaptor struct { Sign string } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) req.Header.Set("Authorization", a.Sign) req.Header.Set("X-TC-Action", meta.ActualModelName) return nil @@ -52,11 +52,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return tencentRequest, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string err, responseText = StreamHandler(c, resp) diff --git a/relay/channel/tencent/constants.go b/relay/adaptor/tencent/constants.go similarity index 100% rename from relay/channel/tencent/constants.go rename to relay/adaptor/tencent/constants.go diff --git a/relay/channel/tencent/main.go b/relay/adaptor/tencent/main.go similarity index 94% rename from relay/channel/tencent/main.go rename to relay/adaptor/tencent/main.go index 05edac20..2ca5724e 100644 --- a/relay/channel/tencent/main.go +++ b/relay/adaptor/tencent/main.go @@ -10,9 +10,11 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "io" @@ -28,17 +30,6 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] - if message.Role == "system" { - messages = append(messages, Message{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, Message{ - Role: "assistant", - Content: "Okay", - }) - continue - } messages = append(messages, Message{ Content: message.StringContent(), Role: message.Role, @@ -51,7 +42,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { return &ChatRequest{ Timestamp: helper.GetTimestamp(), Expired: helper.GetTimestamp() + 24*60*60, - QueryID: helper.GetUUID(), + QueryID: random.GetUUID(), Temperature: request.Temperature, TopP: request.TopP, Stream: stream, @@ -81,6 +72,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { response := openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: "tencent-hunyuan", @@ -139,7 +131,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } response := streamResponseTencent2OpenAI(&TencentResponse) if len(response.Choices) != 0 { - responseText += response.Choices[0].Delta.Content + responseText += conv.AsString(response.Choices[0].Delta.Content) } jsonResponse, err := json.Marshal(response) if err != nil { diff --git a/relay/channel/tencent/model.go b/relay/adaptor/tencent/model.go similarity index 100% rename from relay/channel/tencent/model.go rename to relay/adaptor/tencent/model.go diff --git a/relay/channel/xunfei/adaptor.go b/relay/adaptor/xunfei/adaptor.go similarity index 67% rename from relay/channel/xunfei/adaptor.go rename to relay/adaptor/xunfei/adaptor.go index 92d9d7d6..edcd719f 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/adaptor/xunfei/adaptor.go @@ -3,10 +3,10 @@ package xunfei import ( "errors" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "strings" @@ -16,16 +16,16 @@ type Adaptor struct { request *model.GeneralOpenAIRequest } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return "", nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) // check DoResponse for auth part return nil } @@ -38,14 +38,21 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} dummyResp.StatusCode = http.StatusOK return dummyResp, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { splits := strings.Split(meta.APIKey, "|") if len(splits) != 3 { return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) diff --git a/relay/channel/xunfei/constants.go b/relay/adaptor/xunfei/constants.go similarity index 100% rename from relay/channel/xunfei/constants.go rename to relay/adaptor/xunfei/constants.go diff --git a/relay/channel/xunfei/main.go b/relay/adaptor/xunfei/main.go similarity index 80% rename from relay/channel/xunfei/main.go rename to relay/adaptor/xunfei/main.go index 620e808f..369e6227 100644 --- a/relay/channel/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -9,9 +9,11 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "io" @@ -26,22 +28,15 @@ import ( func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) + var lastToolCalls []model.Tool for _, message := range request.Messages { - if message.Role == "system" { - messages = append(messages, Message{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, Message{ - Role: "assistant", - Content: "Okay", - }) - } else { - messages = append(messages, Message{ - Role: message.Role, - Content: message.StringContent(), - }) + if message.ToolCalls != nil { + lastToolCalls = message.ToolCalls } + messages = append(messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) } xunfeiRequest := ChatRequest{} xunfeiRequest.Header.AppId = xunfeiAppId @@ -50,9 +45,33 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string xunfeiRequest.Parameter.Chat.TopK = request.N xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Payload.Message.Text = messages + if len(lastToolCalls) != 0 { + for _, toolCall := range lastToolCalls { + xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function) + } + } + return &xunfeiRequest } +func getToolCalls(response *ChatResponse) []model.Tool { + var toolCalls []model.Tool + if len(response.Payload.Choices.Text) == 0 { + return toolCalls + } + item := response.Payload.Choices.Text[0] + if item.FunctionCall == nil { + return toolCalls + } + toolCall := model.Tool{ + Id: fmt.Sprintf("call_%s", random.GetUUID()), + Type: "function", + Function: *item.FunctionCall, + } + toolCalls = append(toolCalls, toolCall) + return toolCalls +} + func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { if len(response.Payload.Choices.Text) == 0 { response.Payload.Choices.Text = []ChatResponseTextItem{ @@ -64,13 +83,14 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { choice := openai.TextResponseChoice{ Index: 0, Message: model.Message{ - Role: "assistant", - Content: response.Payload.Choices.Text[0].Content, + Role: "assistant", + Content: response.Payload.Choices.Text[0].Content, + ToolCalls: getToolCalls(response), }, FinishReason: constant.StopFinishReason, } fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion", Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, @@ -89,11 +109,12 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl } var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content + choice.Delta.ToolCalls = getToolCalls(xunfeiResponse) if xunfeiResponse.Payload.Choices.Status == 2 { choice.FinishReason = &constant.StopFinishReason } response := openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: "SparkDesk", @@ -132,7 +153,7 @@ func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { - return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil } common.SetEventStreamHeaders(c) var usage model.Usage @@ -162,7 +183,7 @@ func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId strin domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { - return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil } var usage model.Usage var content string @@ -182,11 +203,7 @@ func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId strin } } if len(xunfeiResponse.Payload.Choices.Text) == 0 { - xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ - { - Content: "", - }, - } + return openai.ErrorWrapper(err, "xunfei_empty_response_detected", http.StatusInternalServerError), nil } xunfeiResponse.Payload.Choices.Text[0].Content = content @@ -213,15 +230,21 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, if err != nil { return nil, nil, err } + _, msg, err := conn.ReadMessage() + if err != nil { + return nil, nil, err + } dataChan := make(chan ChatResponse) stopChan := make(chan bool) go func() { for { - _, msg, err := conn.ReadMessage() - if err != nil { - logger.SysError("error reading stream response: " + err.Error()) - break + if msg == nil { + _, msg, err = conn.ReadMessage() + if err != nil { + logger.SysError("error reading stream response: " + err.Error()) + break + } } var response ChatResponse err = json.Unmarshal(msg, &response) @@ -229,6 +252,7 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, logger.SysError("error unmarshalling stream response: " + err.Error()) break } + msg = nil dataChan <- response if response.Payload.Choices.Status == 2 { err := conn.Close() @@ -256,7 +280,7 @@ func getAPIVersion(c *gin.Context, modelName string) string { return apiVersion } - apiVersion = c.GetString(common.ConfigKeyAPIVersion) + apiVersion = c.GetString(config.KeyAPIVersion) if apiVersion != "" { return apiVersion } diff --git a/relay/channel/xunfei/model.go b/relay/adaptor/xunfei/model.go similarity index 81% rename from relay/channel/xunfei/model.go rename to relay/adaptor/xunfei/model.go index 1266739d..97a43154 100644 --- a/relay/channel/xunfei/model.go +++ b/relay/adaptor/xunfei/model.go @@ -26,13 +26,18 @@ type ChatRequest struct { Message struct { Text []Message `json:"text"` } `json:"message"` + Functions struct { + Text []model.Function `json:"text,omitempty"` + } `json:"functions,omitempty"` } `json:"payload"` } type ChatResponseTextItem struct { - Content string `json:"content"` - Role string `json:"role"` - Index int `json:"index"` + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` + ContentType string `json:"content_type"` + FunctionCall *model.Function `json:"function_call"` } type ChatResponse struct { diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go new file mode 100644 index 00000000..5ebafbb3 --- /dev/null +++ b/relay/adaptor/zhipu/adaptor.go @@ -0,0 +1,145 @@ +package zhipu + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "math" + "net/http" + "strings" +) + +type Adaptor struct { + APIVersion string +} + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) SetVersionByModeName(modelName string) { + if strings.HasPrefix(modelName, "glm-") { + a.APIVersion = "v4" + } else { + a.APIVersion = "v3" + } +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + switch meta.Mode { + case relaymode.ImagesGenerations: + return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil + case relaymode.Embeddings: + return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil + } + a.SetVersionByModeName(meta.ActualModelName) + if a.APIVersion == "v4" { + return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil + } + method := "invoke" + if meta.IsStream { + method = "sse-invoke" + } + return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + token := GetToken(meta.APIKey) + req.Header.Set("Authorization", token) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case relaymode.Embeddings: + baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) + return baiduEmbeddingRequest, nil + default: + // TopP (0.0, 1.0) + request.TopP = math.Min(0.99, request.TopP) + request.TopP = math.Max(0.01, request.TopP) + + // Temperature (0.0, 1.0) + request.Temperature = math.Min(0.99, request.Temperature) + request.Temperature = math.Max(0.01, request.Temperature) + a.SetVersionByModeName(request.Model) + if a.APIVersion == "v4" { + return request, nil + } + return ConvertRequest(*request), nil + } +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + newRequest := ImageRequest{ + Model: request.Model, + Prompt: request.Prompt, + UserId: request.User, + } + return newRequest, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, _, usage = openai.StreamHandler(c, resp, meta.Mode) + } else { + err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingsHandler(c, resp) + return + case relaymode.ImagesGenerations: + err, usage = openai.ImageHandler(c, resp) + return + } + if a.APIVersion == "v4" { + return a.DoResponseV4(c, resp, meta) + } + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + if meta.Mode == relaymode.Embeddings { + err, usage = EmbeddingsHandler(c, resp) + } else { + err, usage = Handler(c, resp) + } + } + return +} + +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + Model: "embedding-2", + Input: request.Input.(string), + } +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "zhipu" +} diff --git a/relay/channel/zhipu/constants.go b/relay/adaptor/zhipu/constants.go similarity index 62% rename from relay/channel/zhipu/constants.go rename to relay/adaptor/zhipu/constants.go index f0367b82..e1192123 100644 --- a/relay/channel/zhipu/constants.go +++ b/relay/adaptor/zhipu/constants.go @@ -2,4 +2,6 @@ package zhipu var ModelList = []string{ "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", + "glm-4", "glm-4v", "glm-3-turbo", "embedding-2", + "cogview-3", } diff --git a/relay/channel/zhipu/main.go b/relay/adaptor/zhipu/main.go similarity index 79% rename from relay/channel/zhipu/main.go rename to relay/adaptor/zhipu/main.go index 7c3e83f3..74a1a05e 100644 --- a/relay/channel/zhipu/main.go +++ b/relay/adaptor/zhipu/main.go @@ -8,7 +8,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "io" @@ -76,21 +76,10 @@ func GetToken(apikey string) string { func ConvertRequest(request model.GeneralOpenAIRequest) *Request { messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { - if message.Role == "system" { - messages = append(messages, Message{ - Role: "system", - Content: message.StringContent(), - }) - messages = append(messages, Message{ - Role: "user", - Content: "Okay", - }) - } else { - messages = append(messages, Message{ - Role: message.Role, - Content: message.StringContent(), - }) - } + messages = append(messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) } return &Request{ Prompt: messages, @@ -265,3 +254,50 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var zhipuResponse EmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &zhipuResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseZhipu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)), + Model: response.Model, + Usage: model.Usage{ + PromptTokens: response.PromptTokens, + CompletionTokens: response.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + }, + } + + for _, item := range response.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} diff --git a/relay/channel/zhipu/model.go b/relay/adaptor/zhipu/model.go similarity index 65% rename from relay/channel/zhipu/model.go rename to relay/adaptor/zhipu/model.go index b63e1d6f..f91de1dc 100644 --- a/relay/channel/zhipu/model.go +++ b/relay/adaptor/zhipu/model.go @@ -44,3 +44,27 @@ type tokenData struct { Token string ExpiryTime time.Time } + +type EmbeddingRequest struct { + Model string `json:"model"` + Input string `json:"input"` +} + +type EmbeddingResponse struct { + Model string `json:"model"` + Object string `json:"object"` + Embeddings []EmbeddingData `json:"data"` + model.Usage `json:"usage"` +} + +type EmbeddingData struct { + Index int `json:"index"` + Object string `json:"object"` + Embedding []float64 `json:"embedding"` +} + +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + UserId string `json:"user_id,omitempty"` +} diff --git a/relay/apitype/define.go b/relay/apitype/define.go new file mode 100644 index 00000000..82d32a50 --- /dev/null +++ b/relay/apitype/define.go @@ -0,0 +1,17 @@ +package apitype + +const ( + OpenAI = iota + Anthropic + PaLM + Baidu + Zhipu + Ali + Xunfei + AIProxyLibrary + Tencent + Gemini + Ollama + + Dummy // this one is only for count, do not add any channel after this +) diff --git a/relay/billing/billing.go b/relay/billing/billing.go new file mode 100644 index 00000000..a99d37ee --- /dev/null +++ b/relay/billing/billing.go @@ -0,0 +1,42 @@ +package billing + +import ( + "context" + "fmt" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" +) + +func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { + if preConsumedQuota != 0 { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, "error return pre-consumed quota: "+err.Error()) + } + }(ctx) + } +} + +func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + // quotaDelta is remaining quota to be consumed + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + logger.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(ctx, userId) + if err != nil { + logger.SysError("error update user quota cache: " + err.Error()) + } + // totalQuota is total quota consumed + if totalQuota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) + model.UpdateChannelUsedQuota(channelId, totalQuota) + } + if totalQuota <= 0 { + logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) + } +} diff --git a/common/group-ratio.go b/relay/billing/ratio/group.go similarity index 97% rename from common/group-ratio.go rename to relay/billing/ratio/group.go index 2de6e810..8e9c5b73 100644 --- a/common/group-ratio.go +++ b/relay/billing/ratio/group.go @@ -1,4 +1,4 @@ -package common +package ratio import ( "encoding/json" diff --git a/relay/billing/ratio/image.go b/relay/billing/ratio/image.go new file mode 100644 index 00000000..5a29cddc --- /dev/null +++ b/relay/billing/ratio/image.go @@ -0,0 +1,51 @@ +package ratio + +var ImageSizeRatios = map[string]map[string]float64{ + "dall-e-2": { + "256x256": 1, + "512x512": 1.125, + "1024x1024": 1.25, + }, + "dall-e-3": { + "1024x1024": 1, + "1024x1792": 2, + "1792x1024": 2, + }, + "ali-stable-diffusion-xl": { + "512x1024": 1, + "1024x768": 1, + "1024x1024": 1, + "576x1024": 1, + "1024x576": 1, + }, + "ali-stable-diffusion-v1.5": { + "512x1024": 1, + "1024x768": 1, + "1024x1024": 1, + "576x1024": 1, + "1024x576": 1, + }, + "wanx-v1": { + "1024x1024": 1, + "720x1280": 1, + "1280x720": 1, + }, +} + +var ImageGenerationAmounts = map[string][2]int{ + "dall-e-2": {1, 10}, + "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. + "ali-stable-diffusion-xl": {1, 4}, // Ali + "ali-stable-diffusion-v1.5": {1, 4}, // Ali + "wanx-v1": {1, 4}, // Ali + "cogview-3": {1, 1}, +} + +var ImagePromptLengthLimitations = map[string]int{ + "dall-e-2": 1000, + "dall-e-3": 4000, + "ali-stable-diffusion-xl": 4000, + "ali-stable-diffusion-v1.5": 4000, + "wanx-v1": 4000, + "cogview-3": 833, +} diff --git a/common/model-ratio.go b/relay/billing/ratio/model.go similarity index 52% rename from common/model-ratio.go rename to relay/billing/ratio/model.go index 2e7aae71..108924a1 100644 --- a/common/model-ratio.go +++ b/relay/billing/ratio/model.go @@ -1,35 +1,11 @@ -package common +package ratio import ( "encoding/json" "github.com/songquanpeng/one-api/common/logger" "strings" - "time" ) -var DalleSizeRatios = map[string]map[string]float64{ - "dall-e-2": { - "256x256": 1, - "512x512": 1.125, - "1024x1024": 1.25, - }, - "dall-e-3": { - "1024x1024": 1, - "1024x1792": 2, - "1792x1024": 2, - }, -} - -var DalleGenerationImageAmounts = map[string][2]int{ - "dall-e-2": {1, 10}, - "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. -} - -var DalleImagePromptLengthLimitations = map[string]int{ - "dall-e-2": 1000, - "dall-e-3": 4000, -} - const ( USD2RMB = 7 USD = 500 // $0.002 = 1 -> $1 = 500 @@ -40,7 +16,6 @@ const ( // https://platform.openai.com/docs/models/model-endpoint-compatibility // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf // https://openai.com/pricing -// TODO: when a new api is enabled, check the pricing here // 1 === $0.002 / 1K tokens // 1 === ¥0.014 / 1k tokens var ModelRatio = map[string]float64{ @@ -55,7 +30,7 @@ var ModelRatio = map[string]float64{ "gpt-4-0125-preview": 5, // $0.01 / 1K tokens "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens "gpt-4-vision-preview": 5, // $0.01 / 1K tokens - "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens "gpt-3.5-turbo-0301": 0.75, "gpt-3.5-turbo-0613": 0.75, "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens @@ -87,30 +62,58 @@ var ModelRatio = map[string]float64{ "text-search-ada-doc-001": 10, "text-moderation-stable": 0.1, "text-moderation-latest": 0.1, - "dall-e-2": 8, // $0.016 - $0.020 / image - "dall-e-3": 20, // $0.040 - $0.120 / image - "claude-instant-1": 0.815, // $1.63 / 1M tokens - "claude-2": 5.51, // $11.02 / 1M tokens - "claude-2.0": 5.51, // $11.02 / 1M tokens - "claude-2.1": 5.51, // $11.02 / 1M tokens + "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image + "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image + // https://www.anthropic.com/api#pricing + "claude-instant-1.2": 0.8 / 1000 * USD, + "claude-2.0": 8.0 / 1000 * USD, + "claude-2.1": 8.0 / 1000 * USD, + "claude-3-haiku-20240307": 0.25 / 1000 * USD, + "claude-3-sonnet-20240229": 3.0 / 1000 * USD, + "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 - "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens - "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens - "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens - "ERNIE-Bot-8k": 0.024 * RMB, - "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens + "ERNIE-4.0-8K": 0.120 * RMB, + "ERNIE-3.5-8K": 0.012 * RMB, + "ERNIE-3.5-8K-0205": 0.024 * RMB, + "ERNIE-3.5-8K-1222": 0.012 * RMB, + "ERNIE-Bot-8K": 0.024 * RMB, + "ERNIE-3.5-4K-0205": 0.012 * RMB, + "ERNIE-Speed-8K": 0.004 * RMB, + "ERNIE-Speed-128K": 0.004 * RMB, + "ERNIE-Lite-8K-0922": 0.008 * RMB, + "ERNIE-Lite-8K-0308": 0.003 * RMB, + "ERNIE-Tiny-8K": 0.001 * RMB, + "BLOOMZ-7B": 0.004 * RMB, + "Embedding-V1": 0.002 * RMB, + "bge-large-zh": 0.002 * RMB, + "bge-large-en": 0.002 * RMB, + "tao-8k": 0.002 * RMB, + // https://ai.google.dev/pricing "PaLM-2": 1, - "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens - "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens - "chatglm_std": 0.3572, // ¥0.005 / 1k tokens - "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-1.0-pro-vision-001": 1, + "gemini-1.0-pro-001": 1, + "gemini-1.5-pro": 1, + // https://open.bigmodel.cn/pricing + "glm-4": 0.1 * RMB, + "glm-4v": 0.1 * RMB, + "glm-3-turbo": 0.005 * RMB, + "embedding-2": 0.0005 * RMB, + "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens + "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens + "chatglm_std": 0.3572, // ¥0.005 / 1k tokens + "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "cogview-3": 0.25 * RMB, + // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing + "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens "qwen-plus": 1.4286, // ¥0.02 / 1k tokens "qwen-max": 1.4286, // ¥0.02 / 1k tokens "qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens + "ali-stable-diffusion-xl": 8, + "ali-stable-diffusion-v1.5": 8, + "wanx-v1": 8, "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens @@ -127,6 +130,70 @@ var ModelRatio = map[string]float64{ "moonshot-v1-8k": 0.012 * RMB, "moonshot-v1-32k": 0.024 * RMB, "moonshot-v1-128k": 0.06 * RMB, + // https://platform.baichuan-ai.com/price + "Baichuan2-Turbo": 0.008 * RMB, + "Baichuan2-Turbo-192k": 0.016 * RMB, + "Baichuan2-53B": 0.02 * RMB, + // https://api.minimax.chat/document/price + "abab6-chat": 0.1 * RMB, + "abab5.5-chat": 0.015 * RMB, + "abab5.5s-chat": 0.005 * RMB, + // https://docs.mistral.ai/platform/pricing/ + "open-mistral-7b": 0.25 / 1000 * USD, + "open-mixtral-8x7b": 0.7 / 1000 * USD, + "mistral-small-latest": 2.0 / 1000 * USD, + "mistral-medium-latest": 2.7 / 1000 * USD, + "mistral-large-latest": 8.0 / 1000 * USD, + "mistral-embed": 0.1 / 1000 * USD, + // https://wow.groq.com/ + "llama2-70b-4096": 0.7 / 1000 * USD, + "llama2-7b-2048": 0.1 / 1000 * USD, + "mixtral-8x7b-32768": 0.27 / 1000 * USD, + "gemma-7b-it": 0.1 / 1000 * USD, + // https://platform.lingyiwanwu.com/docs#-计费单元 + "yi-34b-chat-0205": 2.5 / 1000 * RMB, + "yi-34b-chat-200k": 12.0 / 1000 * RMB, + "yi-vl-plus": 6.0 / 1000 * RMB, + // stepfun todo + "step-1v-32k": 0.024 * RMB, + "step-1-32k": 0.024 * RMB, + "step-1-200k": 0.15 * RMB, +} + +var CompletionRatio = map[string]float64{} + +var DefaultModelRatio map[string]float64 +var DefaultCompletionRatio map[string]float64 + +func init() { + DefaultModelRatio = make(map[string]float64) + for k, v := range ModelRatio { + DefaultModelRatio[k] = v + } + DefaultCompletionRatio = make(map[string]float64) + for k, v := range CompletionRatio { + DefaultCompletionRatio[k] = v + } +} + +func AddNewMissingRatio(oldRatio string) string { + newRatio := make(map[string]float64) + err := json.Unmarshal([]byte(oldRatio), &newRatio) + if err != nil { + logger.SysError("error unmarshalling old ratio: " + err.Error()) + return oldRatio + } + for k, v := range DefaultModelRatio { + if _, ok := newRatio[k]; !ok { + newRatio[k] = v + } + } + jsonBytes, err := json.Marshal(newRatio) + if err != nil { + logger.SysError("error marshalling new ratio: " + err.Error()) + return oldRatio + } + return string(jsonBytes) } func ModelRatio2JSONString() string { @@ -147,6 +214,9 @@ func GetModelRatio(name string) float64 { name = strings.TrimSuffix(name, "-internet") } ratio, ok := ModelRatio[name] + if !ok { + ratio, ok = DefaultModelRatio[name] + } if !ok { logger.SysError("model ratio not found: " + name) return 30 @@ -154,8 +224,6 @@ func GetModelRatio(name string) float64 { return ratio } -var CompletionRatio = map[string]float64{} - func CompletionRatio2JSONString() string { jsonBytes, err := json.Marshal(CompletionRatio) if err != nil { @@ -173,8 +241,11 @@ func GetCompletionRatio(name string) float64 { if ratio, ok := CompletionRatio[name]; ok { return ratio } + if ratio, ok := DefaultCompletionRatio[name]; ok { + return ratio + } if strings.HasPrefix(name, "gpt-3.5") { - if strings.HasSuffix(name, "0125") { + if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { // https://openai.com/blog/new-embedding-models-and-api-updates // Updated GPT-3.5 Turbo model and lower pricing return 3 @@ -182,16 +253,7 @@ func GetCompletionRatio(name string) float64 { if strings.HasSuffix(name, "1106") { return 2 } - if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" { - // TODO: clear this after 2023-12-11 - now := time.Now() - // https://platform.openai.com/docs/models/continuous-model-upgrades - // if after 2023-12-11, use 2 - if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) { - return 2 - } - } - return 1.333333 + return 4.0 / 3.0 } if strings.HasPrefix(name, "gpt-4") { if strings.HasSuffix(name, "preview") { @@ -199,11 +261,21 @@ func GetCompletionRatio(name string) float64 { } return 2 } - if strings.HasPrefix(name, "claude-instant-1") { - return 3.38 + if strings.HasPrefix(name, "claude-3") { + return 5 } - if strings.HasPrefix(name, "claude-2") { - return 2.965517 + if strings.HasPrefix(name, "claude-") { + return 3 + } + if strings.HasPrefix(name, "mistral-") { + return 3 + } + if strings.HasPrefix(name, "gemini-") { + return 3 + } + switch name { + case "llama2-70b-4096": + return 0.8 / 0.7 } return 1 } diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go deleted file mode 100644 index 6c6f433e..00000000 --- a/relay/channel/ali/adaptor.go +++ /dev/null @@ -1,83 +0,0 @@ -package ali - -import ( - "errors" - "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" -) - -// https://help.aliyun.com/zh/dashscope/developer-reference/api-details - -type Adaptor struct { -} - -func (a *Adaptor) Init(meta *util.RelayMeta) { - -} - -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) - if meta.Mode == constant.RelayModeEmbeddings { - fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL) - } - return fullRequestURL, nil -} - -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) - req.Header.Set("Authorization", "Bearer "+meta.APIKey) - if meta.IsStream { - req.Header.Set("X-DashScope-SSE", "enable") - } - if c.GetString(common.ConfigKeyPlugin) != "" { - req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin)) - } - return nil -} - -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") - } - switch relayMode { - case constant.RelayModeEmbeddings: - baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) - return baiduEmbeddingRequest, nil - default: - baiduRequest := ConvertRequest(*request) - return baiduRequest, nil - } -} - -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) -} - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - if meta.IsStream { - err, usage = StreamHandler(c, resp) - } else { - switch meta.Mode { - case constant.RelayModeEmbeddings: - err, usage = EmbeddingHandler(c, resp) - default: - err, usage = Handler(c, resp) - } - } - return -} - -func (a *Adaptor) GetModelList() []string { - return ModelList -} - -func (a *Adaptor) GetChannelName() string { - return "ali" -} diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go deleted file mode 100644 index 54f13041..00000000 --- a/relay/channel/ali/model.go +++ /dev/null @@ -1,71 +0,0 @@ -package ali - -type Message struct { - Content string `json:"content"` - Role string `json:"role"` -} - -type Input struct { - //Prompt string `json:"prompt"` - Messages []Message `json:"messages"` -} - -type Parameters struct { - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Seed uint64 `json:"seed,omitempty"` - EnableSearch bool `json:"enable_search,omitempty"` - IncrementalOutput bool `json:"incremental_output,omitempty"` -} - -type ChatRequest struct { - Model string `json:"model"` - Input Input `json:"input"` - Parameters Parameters `json:"parameters,omitempty"` -} - -type EmbeddingRequest struct { - Model string `json:"model"` - Input struct { - Texts []string `json:"texts"` - } `json:"input"` - Parameters *struct { - TextType string `json:"text_type,omitempty"` - } `json:"parameters,omitempty"` -} - -type Embedding struct { - Embedding []float64 `json:"embedding"` - TextIndex int `json:"text_index"` -} - -type EmbeddingResponse struct { - Output struct { - Embeddings []Embedding `json:"embeddings"` - } `json:"output"` - Usage Usage `json:"usage"` - Error -} - -type Error struct { - Code string `json:"code"` - Message string `json:"message"` - RequestId string `json:"request_id"` -} - -type Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type Output struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` -} - -type ChatResponse struct { - Output Output `json:"output"` - Usage Usage `json:"usage"` - Error -} diff --git a/relay/channel/anthropic/constants.go b/relay/channel/anthropic/constants.go deleted file mode 100644 index b98c15c2..00000000 --- a/relay/channel/anthropic/constants.go +++ /dev/null @@ -1,5 +0,0 @@ -package anthropic - -var ModelList = []string{ - "claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", -} diff --git a/relay/channel/anthropic/main.go b/relay/channel/anthropic/main.go deleted file mode 100644 index e2c575fa..00000000 --- a/relay/channel/anthropic/main.go +++ /dev/null @@ -1,199 +0,0 @@ -package anthropic - -import ( - "bufio" - "encoding/json" - "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/common/helper" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" -) - -func stopReasonClaude2OpenAI(reason string) string { - switch reason { - case "stop_sequence": - return "stop" - case "max_tokens": - return "length" - default: - return reason - } -} - -func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { - claudeRequest := Request{ - Model: textRequest.Model, - Prompt: "", - MaxTokensToSample: textRequest.MaxTokens, - StopSequences: nil, - Temperature: textRequest.Temperature, - TopP: textRequest.TopP, - Stream: textRequest.Stream, - } - if claudeRequest.MaxTokensToSample == 0 { - claudeRequest.MaxTokensToSample = 1000000 - } - prompt := "" - for _, message := range textRequest.Messages { - if message.Role == "user" { - prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) - } else if message.Role == "assistant" { - prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) - } else if message.Role == "system" { - if prompt == "" { - prompt = message.StringContent() - } - } - } - prompt += "\n\nAssistant:" - claudeRequest.Prompt = prompt - return &claudeRequest -} - -func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse { - var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = claudeResponse.Completion - finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) - if finishReason != "null" { - choice.FinishReason = &finishReason - } - var response openai.ChatCompletionsStreamResponse - response.Object = "chat.completion.chunk" - response.Model = claudeResponse.Model - response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} - return &response -} - -func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { - choice := openai.TextResponseChoice{ - Index: 0, - Message: model.Message{ - Role: "assistant", - Content: strings.TrimPrefix(claudeResponse.Completion, " "), - Name: nil, - }, - FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), - } - fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), - Object: "chat.completion", - Created: helper.GetTimestamp(), - Choices: []openai.TextResponseChoice{choice}, - } - return &fullTextResponse -} - -func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { - responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) - createdTime := helper.GetTimestamp() - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { - return i + 4, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if !strings.HasPrefix(data, "event: completion") { - continue - } - data = strings.TrimPrefix(data, "event: completion\r\ndata: ") - dataChan <- data - } - stopChan <- true - }() - common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var claudeResponse Response - err := json.Unmarshal([]byte(data), &claudeResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - responseText += claudeResponse.Completion - response := streamResponseClaude2OpenAI(&claudeResponse) - response.Id = responseId - response.Created = createdTime - jsonStr, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - err := resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" - } - return nil, responseText -} - -func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - var claudeResponse Response - err = json.Unmarshal(responseBody, &claudeResponse) - if err != nil { - return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if claudeResponse.Error.Type != "" { - return &model.ErrorWithStatusCode{ - Error: model.Error{ - Message: claudeResponse.Error.Message, - Type: claudeResponse.Error.Type, - Param: "", - Code: claudeResponse.Error.Type, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responseClaude2OpenAI(&claudeResponse) - fullTextResponse.Model = modelName - completionTokens := openai.CountTokenText(claudeResponse.Completion, modelName) - usage := model.Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, - } - fullTextResponse.Usage = usage - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &usage -} diff --git a/relay/channel/anthropic/model.go b/relay/channel/anthropic/model.go deleted file mode 100644 index 70fc9430..00000000 --- a/relay/channel/anthropic/model.go +++ /dev/null @@ -1,29 +0,0 @@ -package anthropic - -type Metadata struct { - UserId string `json:"user_id"` -} - -type Request struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - MaxTokensToSample int `json:"max_tokens_to_sample"` - StopSequences []string `json:"stop_sequences,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - //Metadata `json:"metadata,omitempty"` - Stream bool `json:"stream,omitempty"` -} - -type Error struct { - Type string `json:"type"` - Message string `json:"message"` -} - -type Response struct { - Completion string `json:"completion"` - StopReason string `json:"stop_reason"` - Model string `json:"model"` - Error Error `json:"error"` -} diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go deleted file mode 100644 index d2d06ce0..00000000 --- a/relay/channel/baidu/adaptor.go +++ /dev/null @@ -1,93 +0,0 @@ -package baidu - -import ( - "errors" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" -) - -type Adaptor struct { -} - -func (a *Adaptor) Init(meta *util.RelayMeta) { - -} - -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t - var fullRequestURL string - switch meta.ActualModelName { - case "ERNIE-Bot-4": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" - case "ERNIE-Bot-8K": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k" - case "ERNIE-Bot": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" - case "ERNIE-Speed": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed" - case "ERNIE-Bot-turbo": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" - case "BLOOMZ-7B": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" - case "Embedding-V1": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" - } - var accessToken string - var err error - if accessToken, err = GetAccessToken(meta.APIKey); err != nil { - return "", err - } - fullRequestURL += "?access_token=" + accessToken - return fullRequestURL, nil -} - -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) - req.Header.Set("Authorization", "Bearer "+meta.APIKey) - return nil -} - -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") - } - switch relayMode { - case constant.RelayModeEmbeddings: - baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) - return baiduEmbeddingRequest, nil - default: - baiduRequest := ConvertRequest(*request) - return baiduRequest, nil - } -} - -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) -} - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - if meta.IsStream { - err, usage = StreamHandler(c, resp) - } else { - switch meta.Mode { - case constant.RelayModeEmbeddings: - err, usage = EmbeddingHandler(c, resp) - default: - err, usage = Handler(c, resp) - } - } - return -} - -func (a *Adaptor) GetModelList() []string { - return ModelList -} - -func (a *Adaptor) GetChannelName() string { - return "baidu" -} diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go deleted file mode 100644 index 0fa8f2d6..00000000 --- a/relay/channel/baidu/constants.go +++ /dev/null @@ -1,10 +0,0 @@ -package baidu - -var ModelList = []string{ - "ERNIE-Bot-4", - "ERNIE-Bot-8K", - "ERNIE-Bot", - "ERNIE-Speed", - "ERNIE-Bot-turbo", - "Embedding-V1", -} diff --git a/relay/channel/gemini/constants.go b/relay/channel/gemini/constants.go deleted file mode 100644 index 5bb0c168..00000000 --- a/relay/channel/gemini/constants.go +++ /dev/null @@ -1,6 +0,0 @@ -package gemini - -var ModelList = []string{ - "gemini-pro", - "gemini-pro-vision", -} diff --git a/relay/channel/interface.go b/relay/channel/interface.go deleted file mode 100644 index e25db83f..00000000 --- a/relay/channel/interface.go +++ /dev/null @@ -1,20 +0,0 @@ -package channel - -import ( - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" -) - -type Adaptor interface { - Init(meta *util.RelayMeta) - GetRequestURL(meta *util.RelayMeta) (string, error) - SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error - ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) - DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) - DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) - GetModelList() []string - GetChannelName() string -} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go deleted file mode 100644 index 1313e317..00000000 --- a/relay/channel/openai/adaptor.go +++ /dev/null @@ -1,103 +0,0 @@ -package openai - -import ( - "errors" - "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/ai360" - "github.com/songquanpeng/one-api/relay/channel/moonshot" - "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" - "strings" -) - -type Adaptor struct { - ChannelType int -} - -func (a *Adaptor) Init(meta *util.RelayMeta) { - a.ChannelType = meta.ChannelType -} - -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - if meta.ChannelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api - requestURL := strings.Split(meta.RequestURLPath, "?")[0] - requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) - task := strings.TrimPrefix(requestURL, "/v1/") - model_ := meta.ActualModelName - model_ = strings.Replace(model_, ".", "", -1) - // https://github.com/songquanpeng/one-api/issues/67 - model_ = strings.TrimSuffix(model_, "-0301") - model_ = strings.TrimSuffix(model_, "-0314") - model_ = strings.TrimSuffix(model_, "-0613") - - requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) - return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil - } - return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil -} - -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) - if meta.ChannelType == common.ChannelTypeAzure { - req.Header.Set("api-key", meta.APIKey) - return nil - } - req.Header.Set("Authorization", "Bearer "+meta.APIKey) - if meta.ChannelType == common.ChannelTypeOpenRouter { - req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") - req.Header.Set("X-Title", "One API") - } - return nil -} - -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") - } - return request, nil -} - -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) -} - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - if meta.IsStream { - var responseText string - err, responseText = StreamHandler(c, resp, meta.Mode) - usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) - } else { - err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) - } - return -} - -func (a *Adaptor) GetModelList() []string { - switch a.ChannelType { - case common.ChannelType360: - return ai360.ModelList - case common.ChannelTypeMoonshot: - return moonshot.ModelList - default: - return ModelList - } -} - -func (a *Adaptor) GetChannelName() string { - switch a.ChannelType { - case common.ChannelTypeAzure: - return "azure" - case common.ChannelType360: - return "360" - case common.ChannelTypeMoonshot: - return "moonshot" - default: - return "openai" - } -} diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go deleted file mode 100644 index 9bca8cab..00000000 --- a/relay/channel/openai/helper.go +++ /dev/null @@ -1,11 +0,0 @@ -package openai - -import "github.com/songquanpeng/one-api/relay/model" - -func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { - usage := &model.Usage{} - usage.PromptTokens = promptTokens - usage.CompletionTokens = CountTokenText(responseText, modeName) - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - return usage -} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go deleted file mode 100644 index 7a822853..00000000 --- a/relay/channel/zhipu/adaptor.go +++ /dev/null @@ -1,62 +0,0 @@ -package zhipu - -import ( - "errors" - "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" -) - -type Adaptor struct { -} - -func (a *Adaptor) Init(meta *util.RelayMeta) { - -} - -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - method := "invoke" - if meta.IsStream { - method = "sse-invoke" - } - return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil -} - -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) - token := GetToken(meta.APIKey) - req.Header.Set("Authorization", token) - return nil -} - -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") - } - return ConvertRequest(*request), nil -} - -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) -} - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - if meta.IsStream { - err, usage = StreamHandler(c, resp) - } else { - err, usage = Handler(c, resp) - } - return -} - -func (a *Adaptor) GetModelList() []string { - return ModelList -} - -func (a *Adaptor) GetChannelName() string { - return "zhipu" -} diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go new file mode 100644 index 00000000..80027a80 --- /dev/null +++ b/relay/channeltype/define.go @@ -0,0 +1,39 @@ +package channeltype + +const ( + Unknown = iota + OpenAI + API2D + Azure + CloseAI + OpenAISB + OpenAIMax + OhMyGPT + Custom + Ails + AIProxy + PaLM + API2GPT + AIGC2D + Anthropic + Baidu + Zhipu + Ali + Xunfei + AI360 + OpenRouter + AIProxyLibrary + FastGPT + Tencent + Gemini + Moonshot + Baichuan + Minimax + Mistral + Groq + Ollama + LingYiWanWu + StepFun + + Dummy +) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go new file mode 100644 index 00000000..01c2918c --- /dev/null +++ b/relay/channeltype/helper.go @@ -0,0 +1,30 @@ +package channeltype + +import "github.com/songquanpeng/one-api/relay/apitype" + +func ToAPIType(channelType int) int { + apiType := apitype.OpenAI + switch channelType { + case Anthropic: + apiType = apitype.Anthropic + case Baidu: + apiType = apitype.Baidu + case PaLM: + apiType = apitype.PaLM + case Zhipu: + apiType = apitype.Zhipu + case Ali: + apiType = apitype.Ali + case Xunfei: + apiType = apitype.Xunfei + case AIProxyLibrary: + apiType = apitype.AIProxyLibrary + case Tencent: + apiType = apitype.Tencent + case Gemini: + apiType = apitype.Gemini + case Ollama: + apiType = apitype.Ollama + } + return apiType +} diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go new file mode 100644 index 00000000..eec59116 --- /dev/null +++ b/relay/channeltype/url.go @@ -0,0 +1,43 @@ +package channeltype + +var ChannelBaseURLs = []string{ + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "https://generativelanguage.googleapis.com", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 + "https://hunyuan.cloud.tencent.com", // 23 + "https://generativelanguage.googleapis.com", // 24 + "https://api.moonshot.cn", // 25 + "https://api.baichuan-ai.com", // 26 + "https://api.minimax.chat", // 27 + "https://api.mistral.ai", // 28 + "https://api.groq.com/openai", // 29 + "http://localhost:11434", // 30 + "https://api.lingyiwanwu.com", // 31 + "https://api.stepfun.com", // 32 +} + +func init() { + if len(ChannelBaseURLs) != Dummy { + panic("channel base urls length not match") + } +} diff --git a/relay/util/init.go b/relay/client/init.go similarity index 96% rename from relay/util/init.go rename to relay/client/init.go index 03dad31b..4b59cba7 100644 --- a/relay/util/init.go +++ b/relay/client/init.go @@ -1,4 +1,4 @@ -package util +package client import ( "github.com/songquanpeng/one-api/common/config" diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go deleted file mode 100644 index d2184dac..00000000 --- a/relay/constant/api_type.go +++ /dev/null @@ -1,45 +0,0 @@ -package constant - -import ( - "github.com/songquanpeng/one-api/common" -) - -const ( - APITypeOpenAI = iota - APITypeAnthropic - APITypePaLM - APITypeBaidu - APITypeZhipu - APITypeAli - APITypeXunfei - APITypeAIProxyLibrary - APITypeTencent - APITypeGemini - - APITypeDummy // this one is only for count, do not add any channel after this -) - -func ChannelType2APIType(channelType int) int { - apiType := APITypeOpenAI - switch channelType { - case common.ChannelTypeAnthropic: - apiType = APITypeAnthropic - case common.ChannelTypeBaidu: - apiType = APITypeBaidu - case common.ChannelTypePaLM: - apiType = APITypePaLM - case common.ChannelTypeZhipu: - apiType = APITypeZhipu - case common.ChannelTypeAli: - apiType = APITypeAli - case common.ChannelTypeXunfei: - apiType = APITypeXunfei - case common.ChannelTypeAIProxyLibrary: - apiType = APITypeAIProxyLibrary - case common.ChannelTypeTencent: - apiType = APITypeTencent - case common.ChannelTypeGemini: - apiType = APITypeGemini - } - return apiType -} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go deleted file mode 100644 index 5e2fe574..00000000 --- a/relay/constant/relay_mode.go +++ /dev/null @@ -1,42 +0,0 @@ -package constant - -import "strings" - -const ( - RelayModeUnknown = iota - RelayModeChatCompletions - RelayModeCompletions - RelayModeEmbeddings - RelayModeModerations - RelayModeImagesGenerations - RelayModeEdits - RelayModeAudioSpeech - RelayModeAudioTranscription - RelayModeAudioTranslation -) - -func Path2RelayMode(path string) int { - relayMode := RelayModeUnknown - if strings.HasPrefix(path, "/v1/chat/completions") { - relayMode = RelayModeChatCompletions - } else if strings.HasPrefix(path, "/v1/completions") { - relayMode = RelayModeCompletions - } else if strings.HasPrefix(path, "/v1/embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasSuffix(path, "embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasPrefix(path, "/v1/moderations") { - relayMode = RelayModeModerations - } else if strings.HasPrefix(path, "/v1/images/generations") { - relayMode = RelayModeImagesGenerations - } else if strings.HasPrefix(path, "/v1/edits") { - relayMode = RelayModeEdits - } else if strings.HasPrefix(path, "/v1/audio/speech") { - relayMode = RelayModeAudioSpeech - } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { - relayMode = RelayModeAudioTranscription - } else if strings.HasPrefix(path, "/v1/audio/translations") { - relayMode = RelayModeAudioTranslation - } - return relayMode -} diff --git a/relay/controller/audio.go b/relay/controller/audio.go index ee8771c9..9d8cfef5 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -12,16 +12,21 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/adaptor/azure" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/client" relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "strings" ) func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { + ctx := c.Request.Context() audioModel := "whisper-1" tokenId := c.GetInt("token_id") @@ -32,7 +37,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus tokenName := c.GetString("token_name") var ttsRequest openai.TextToSpeechRequest - if relayMode == constant.RelayModeAudioSpeech { + if relayMode == relaymode.AudioSpeech { // Read JSON err := common.UnmarshalBodyReusable(c, &ttsRequest) // Check if JSON is valid @@ -46,19 +51,19 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } } - modelRatio := common.GetModelRatio(audioModel) - groupRatio := common.GetGroupRatio(group) + modelRatio := billingratio.GetModelRatio(audioModel) + groupRatio := billingratio.GetGroupRatio(group) ratio := modelRatio * groupRatio - var quota int - var preConsumedQuota int + var quota int64 + var preConsumedQuota int64 switch relayMode { - case constant.RelayModeAudioSpeech: - preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) + case relaymode.AudioSpeech: + preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota default: - preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio) + preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio) } - userQuota, err := model.CacheGetUserQuota(userId) + userQuota, err := model.CacheGetUserQuota(ctx, userId) if err != nil { return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } @@ -82,6 +87,24 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } + succeed := false + defer func() { + if succeed { + return + } + if preConsumedQuota > 0 { + // we need to roll back the pre-consumed quota + defer func(ctx context.Context) { + go func() { + // negative means add quota back for token & user + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) + } + }() + }(c.Request.Context()) + } + }() // map model name modelMapping := c.GetString("model_mapping") @@ -96,17 +119,22 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } } - baseURL := common.ChannelBaseURLs[channelType] + baseURL := channeltype.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) - if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiVersion := util.GetAzureAPIVersion(c) - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) + fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType) + if channelType == channeltype.Azure { + apiVersion := azure.GetAPIVersion(c) + if relayMode == relaymode.AudioTranscription { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) + } else if relayMode == relaymode.AudioSpeech { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion) + } } requestBody := &bytes.Buffer{} @@ -122,7 +150,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") @@ -134,7 +162,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - resp, err := util.HTTPClient.Do(req) + resp, err := client.HTTPClient.Do(req) if err != nil { return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } @@ -148,7 +176,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode != constant.RelayModeAudioSpeech { + if relayMode != relaymode.AudioSpeech { responseBody, err := io.ReadAll(resp.Body) if err != nil { return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -183,27 +211,16 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) } - quota = openai.CountTokenText(text, audioModel) + quota = int64(openai.CountTokenText(text, audioModel)) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { - if preConsumedQuota > 0 { - // we need to roll back the pre-consumed quota - defer func(ctx context.Context) { - go func() { - // negative means add quota back for token & user - err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) - if err != nil { - logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) - } - }() - }(c.Request.Context()) - } - return util.RelayErrorHandler(resp) + return RelayErrorHandler(resp) } + succeed = true quotaDelta := quota - preConsumedQuota defer func(ctx context.Context) { - go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) }(c.Request.Context()) for k, v := range resp.Header { diff --git a/relay/controller/error.go b/relay/controller/error.go new file mode 100644 index 00000000..69ece3ec --- /dev/null +++ b/relay/controller/error.go @@ -0,0 +1,91 @@ +package controller + +import ( + "encoding/json" + "fmt" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strconv" +) + +type GeneralErrorResponse struct { + Error model.Error `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` + Header struct { + Message string `json:"message"` + } `json:"header"` + Response struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } `json:"response"` +} + +func (e GeneralErrorResponse) ToMessage() string { + if e.Error.Message != "" { + return e.Error.Message + } + if e.Message != "" { + return e.Message + } + if e.Msg != "" { + return e.Msg + } + if e.Err != "" { + return e.Err + } + if e.ErrorMsg != "" { + return e.ErrorMsg + } + if e.Header.Message != "" { + return e.Header.Message + } + if e.Response.Error.Message != "" { + return e.Response.Error.Message + } + return "" +} + +func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *model.ErrorWithStatusCode) { + ErrorWithStatusCode = &model.ErrorWithStatusCode{ + StatusCode: resp.StatusCode, + Error: model.Error{ + Message: "", + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + if config.DebugEnabled { + logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody))) + } + err = resp.Body.Close() + if err != nil { + return + } + var errResponse GeneralErrorResponse + err = json.Unmarshal(responseBody, &errResponse) + if err != nil { + return + } + if errResponse.Error.Message != "" { + // OpenAI format error, so we override the default one + ErrorWithStatusCode.Error = errResponse.Error + } else { + ErrorWithStatusCode.Error.Message = errResponse.ToMessage() + } + if ErrorWithStatusCode.Error.Message == "" { + ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } + return +} diff --git a/relay/controller/helper.go b/relay/controller/helper.go index a06b2768..f1b40bef 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -9,10 +9,13 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/controller/validator" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/relaymode" "math" "net/http" ) @@ -23,43 +26,115 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener if err != nil { return nil, err } - if relayMode == constant.RelayModeModerations && textRequest.Model == "" { + if relayMode == relaymode.Moderations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } - if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" { + if relayMode == relaymode.Embeddings && textRequest.Model == "" { textRequest.Model = c.Param("model") } - err = util.ValidateTextRequest(textRequest, relayMode) + err = validator.ValidateTextRequest(textRequest, relayMode) if err != nil { return nil, err } return textRequest, nil } +func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { + imageRequest := &relaymodel.ImageRequest{} + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err + } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" + } + return imageRequest, nil +} + +func isValidImageSize(model string, size string) bool { + if model == "cogview-3" { + return true + } + _, ok := billingratio.ImageSizeRatios[model][size] + return ok +} + +func getImageSizeRatio(model string, size string) float64 { + ratio, ok := billingratio.ImageSizeRatios[model][size] + if !ok { + return 1 + } + return ratio +} + +func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { + // model validation + hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size) + if !hasValidSize { + return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + } + // check prompt length + if imageRequest.Prompt == "" { + return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) + } + if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] { + return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) + } + // Number of generated images validation + if !isWithinRange(imageRequest.Model, imageRequest.N) { + // channel not azure + if meta.ChannelType != channeltype.Azure { + return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + } + } + return nil +} + +func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { + if imageRequest == nil { + return 0, errors.New("imageRequest is nil") + } + imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) + if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { + if imageRequest.Size == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + return imageCostRatio, nil +} + func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { - case constant.RelayModeChatCompletions: + case relaymode.ChatCompletions: return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) - case constant.RelayModeCompletions: + case relaymode.Completions: return openai.CountTokenInput(textRequest.Prompt, textRequest.Model) - case constant.RelayModeModerations: + case relaymode.Moderations: return openai.CountTokenInput(textRequest.Input, textRequest.Model) } return 0 } -func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int { +func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 { preConsumedTokens := config.PreConsumedQuota if textRequest.MaxTokens != 0 { - preConsumedTokens = promptTokens + textRequest.MaxTokens + preConsumedTokens = int64(promptTokens) + int64(textRequest.MaxTokens) } - return int(float64(preConsumedTokens) * ratio) + return int64(float64(preConsumedTokens) * ratio) } -func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) { +func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *meta.Meta) (int64, *relaymodel.ErrorWithStatusCode) { preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) - userQuota, err := model.CacheGetUserQuota(meta.UserId) + userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) if err != nil { return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } @@ -85,16 +160,16 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR return preConsumedQuota, nil } -func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) { +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return } - quota := 0 - completionRatio := common.GetCompletionRatio(textRequest.Model) + var quota int64 + completionRatio := billingratio.GetCompletionRatio(textRequest.Model) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens - quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) + quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 { quota = 1 } @@ -109,14 +184,23 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R if err != nil { logger.Error(ctx, "error consuming token remain quota: "+err.Error()) } - err = model.CacheUpdateUserQuota(meta.UserId) + err = model.CacheUpdateUserQuota(ctx, meta.UserId) if err != nil { logger.Error(ctx, "error update user quota cache: "+err.Error()) } - if quota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) - model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) - model.UpdateChannelUsedQuota(meta.ChannelId, quota) - } + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) + model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) + model.UpdateChannelUsedQuota(meta.ChannelId, quota) +} + +func getMappedModelName(modelName string, mapping map[string]string) (string, bool) { + if mapping == nil { + return modelName, false + } + mappedModelName := mapping[modelName] + if mappedModelName != "" { + return mappedModelName, true + } + return modelName, false } diff --git a/relay/controller/image.go b/relay/controller/image.go index 6ec368f5..80769845 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -6,221 +6,133 @@ import ( "encoding/json" "errors" "fmt" - "github.com/songquanpeng/one-api/common" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" - "strings" - - "github.com/gin-gonic/gin" ) func isWithinRange(element string, value int) bool { - if _, ok := common.DalleGenerationImageAmounts[element]; !ok { + if _, ok := billingratio.ImageGenerationAmounts[element]; !ok { return false } - min := common.DalleGenerationImageAmounts[element][0] - max := common.DalleGenerationImageAmounts[element][1] - + min := billingratio.ImageGenerationAmounts[element][0] + max := billingratio.ImageGenerationAmounts[element][1] return value >= min && value <= max } func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { - imageModel := "dall-e-2" - imageSize := "1024x1024" - - tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - userId := c.GetInt("id") - group := c.GetString("group") - - var imageRequest openai.ImageRequest - err := common.UnmarshalBodyReusable(c, &imageRequest) + ctx := c.Request.Context() + meta := meta.GetByContext(c) + imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { - return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - if imageRequest.N == 0 { - imageRequest.N = 1 - } - - // Size validation - if imageRequest.Size != "" { - imageSize = imageRequest.Size - } - - // Model validation - if imageRequest.Model != "" { - imageModel = imageRequest.Model - } - - imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] - - // Check if model is supported - if hasValidSize { - if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { - if imageSize == "1024x1024" { - imageCostRatio *= 2 - } else { - imageCostRatio *= 1.5 - } - } - } else { - return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) - } - - // Prompt validation - if imageRequest.Prompt == "" { - return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) - } - - // Check prompt length - if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { - return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) - } - - // Number of generated images validation - if !isWithinRange(imageModel, imageRequest.N) { - // channel not azure - if channelType != common.ChannelTypeAzure { - return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) - } + logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } // map model name - modelMapping := c.GetString("model_mapping") - isModelMapped := false - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[imageModel] != "" { - imageModel = modelMap[imageModel] - isModelMapped = true - } + var isModelMapped bool + meta.OriginModelName = imageRequest.Model + imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping) + meta.ActualModelName = imageRequest.Model + + // model validation + bizErr := validateImageRequest(imageRequest, meta) + if bizErr != nil { + return bizErr } - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api - apiVersion := util.GetAzureAPIVersion(c) - // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) + + imageCostRatio, err := getImageCostRatio(imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) } var requestBody io.Reader - if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body + if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { - return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) } else { requestBody = c.Request.Body } - modelRatio := common.GetModelRatio(imageModel) - groupRatio := common.GetGroupRatio(group) - ratio := modelRatio * groupRatio - userQuota, err := model.CacheGetUserQuota(userId) + adaptor := relay.GetAdaptor(meta.APIType) + if adaptor == nil { + return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + } - quota := int(ratio*imageCostRatio*1000) * imageRequest.N + switch meta.ChannelType { + case channeltype.Ali: + fallthrough + case channeltype.Baidu: + fallthrough + case channeltype.Zhipu: + finalRequest, err := adaptor.ConvertImageRequest(imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) + } + jsonStr, err := json.Marshal(finalRequest) + if err != nil { + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + } + + modelRatio := billingratio.GetModelRatio(imageRequest.Model) + groupRatio := billingratio.GetGroupRatio(meta.Group) + ratio := modelRatio * groupRatio + userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) + + quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) if userQuota-quota < 0 { return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - if err != nil { - return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - token := c.Request.Header.Get("Authorization") - if channelType == common.ChannelTypeAzure { // Azure authentication - token = strings.TrimPrefix(token, "Bearer ") - req.Header.Set("api-key", token) - } else { - req.Header.Set("Authorization", token) - } - - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - - resp, err := util.HTTPClient.Do(req) + // do request + resp, err := adaptor.DoRequest(c, meta, requestBody) if err != nil { + logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - err = req.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - var textResponse openai.ImageResponse - defer func(ctx context.Context) { if resp.StatusCode != http.StatusOK { return } - err := model.PostConsumeTokenQuota(tokenId, quota) + err := model.PostConsumeTokenQuota(meta.TokenId, quota) if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) } - err = model.CacheUpdateUserQuota(userId) + err = model.CacheUpdateUserQuota(ctx, meta.UserId) if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } }(c.Request.Context()) - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + // do response + _, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + return respErr } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } return nil } diff --git a/relay/controller/text.go b/relay/controller/text.go index cc460511..0332a23f 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -5,13 +5,15 @@ import ( "encoding/json" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/helper" + "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/apitype" + "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "strings" @@ -19,7 +21,7 @@ import ( func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { ctx := c.Request.Context() - meta := util.GetRelayMeta(c) + meta := meta.GetByContext(c) // get & validate textRequest textRequest, err := getAndValidateTextRequest(c, meta.Mode) if err != nil { @@ -31,11 +33,11 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { // map model name var isModelMapped bool meta.OriginModelName = textRequest.Model - textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping) + textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model // get model ratio & group ratio - modelRatio := common.GetModelRatio(textRequest.Model) - groupRatio := common.GetGroupRatio(meta.Group) + modelRatio := billingratio.GetModelRatio(textRequest.Model) + groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio // pre-consume quota promptTokens := getPromptTokens(textRequest, meta.Mode) @@ -46,16 +48,17 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { return bizErr } - adaptor := helper.GetAdaptor(meta.APIType) + adaptor := relay.GetAdaptor(meta.APIType) if adaptor == nil { return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) } // get request body var requestBody io.Reader - if meta.APIType == constant.APITypeOpenAI { + if meta.APIType == apitype.OpenAI { // no need to convert request for openai - if isModelMapped { + shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan + if shouldResetRequestBody { jsonStr, err := json.Marshal(textRequest) if err != nil { return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) @@ -73,6 +76,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { if err != nil { return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) } + logger.Debugf(ctx, "converted request: \n%s", string(jsonData)) requestBody = bytes.NewBuffer(jsonData) } @@ -82,17 +86,18 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - if resp.StatusCode != http.StatusOK { - util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) - return util.RelayErrorHandler(resp) + errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json") + if errorHappened { + billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + return RelayErrorHandler(resp) } + meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") // do response usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { logger.Errorf(ctx, "respErr is not nil: %+v", respErr) - util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) return respErr } // post-consume quota diff --git a/relay/util/validation.go b/relay/controller/validator/validation.go similarity index 76% rename from relay/util/validation.go rename to relay/controller/validator/validation.go index ef8d840c..8ff520b8 100644 --- a/relay/util/validation.go +++ b/relay/controller/validator/validation.go @@ -1,9 +1,9 @@ -package util +package validator import ( "errors" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "math" ) @@ -15,20 +15,20 @@ func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) return errors.New("model is required") } switch relayMode { - case constant.RelayModeCompletions: + case relaymode.Completions: if textRequest.Prompt == "" { return errors.New("field prompt is required") } - case constant.RelayModeChatCompletions: + case relaymode.ChatCompletions: if textRequest.Messages == nil || len(textRequest.Messages) == 0 { return errors.New("field messages is required") } - case constant.RelayModeEmbeddings: - case constant.RelayModeModerations: + case relaymode.Embeddings: + case relaymode.Moderations: if textRequest.Input == "" { return errors.New("field input is required") } - case constant.RelayModeEdits: + case relaymode.Edits: if textRequest.Instruction == "" { return errors.New("field instruction is required") } diff --git a/relay/helper/main.go b/relay/helper/main.go deleted file mode 100644 index c2b6e6af..00000000 --- a/relay/helper/main.go +++ /dev/null @@ -1,42 +0,0 @@ -package helper - -import ( - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/aiproxy" - "github.com/songquanpeng/one-api/relay/channel/ali" - "github.com/songquanpeng/one-api/relay/channel/anthropic" - "github.com/songquanpeng/one-api/relay/channel/baidu" - "github.com/songquanpeng/one-api/relay/channel/gemini" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/channel/palm" - "github.com/songquanpeng/one-api/relay/channel/tencent" - "github.com/songquanpeng/one-api/relay/channel/xunfei" - "github.com/songquanpeng/one-api/relay/channel/zhipu" - "github.com/songquanpeng/one-api/relay/constant" -) - -func GetAdaptor(apiType int) channel.Adaptor { - switch apiType { - case constant.APITypeAIProxyLibrary: - return &aiproxy.Adaptor{} - case constant.APITypeAli: - return &ali.Adaptor{} - case constant.APITypeAnthropic: - return &anthropic.Adaptor{} - case constant.APITypeBaidu: - return &baidu.Adaptor{} - case constant.APITypeGemini: - return &gemini.Adaptor{} - case constant.APITypeOpenAI: - return &openai.Adaptor{} - case constant.APITypePaLM: - return &palm.Adaptor{} - case constant.APITypeTencent: - return &tencent.Adaptor{} - case constant.APITypeXunfei: - return &xunfei.Adaptor{} - case constant.APITypeZhipu: - return &zhipu.Adaptor{} - } - return nil -} diff --git a/relay/util/relay_meta.go b/relay/meta/relay_meta.go similarity index 63% rename from relay/util/relay_meta.go rename to relay/meta/relay_meta.go index 31b9d2b4..22ef1567 100644 --- a/relay/util/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -1,13 +1,15 @@ -package util +package meta import ( "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/relay/adaptor/azure" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/relaymode" "strings" ) -type RelayMeta struct { +type Meta struct { Mode int ChannelType int ChannelId int @@ -28,9 +30,9 @@ type RelayMeta struct { PromptTokens int // only for DoResponse } -func GetRelayMeta(c *gin.Context) *RelayMeta { - meta := RelayMeta{ - Mode: constant.Path2RelayMode(c.Request.URL.Path), +func GetByContext(c *gin.Context) *Meta { + meta := Meta{ + Mode: relaymode.GetByPath(c.Request.URL.Path), ChannelType: c.GetInt("channel"), ChannelId: c.GetInt("channel_id"), TokenId: c.GetInt("token_id"), @@ -39,17 +41,17 @@ func GetRelayMeta(c *gin.Context) *RelayMeta { Group: c.GetString("group"), ModelMapping: c.GetStringMapString("model_mapping"), BaseURL: c.GetString("base_url"), - APIVersion: c.GetString(common.ConfigKeyAPIVersion), + APIVersion: c.GetString(config.KeyAPIVersion), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), Config: nil, RequestURLPath: c.Request.URL.String(), } - if meta.ChannelType == common.ChannelTypeAzure { - meta.APIVersion = GetAzureAPIVersion(c) + if meta.ChannelType == channeltype.Azure { + meta.APIVersion = azure.GetAPIVersion(c) } if meta.BaseURL == "" { - meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType] + meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType] } - meta.APIType = constant.ChannelType2APIType(meta.ChannelType) + meta.APIType = channeltype.ToAPIType(meta.ChannelType) return &meta } diff --git a/relay/model/general.go b/relay/model/general.go index fbcc04e8..30772894 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -5,25 +5,29 @@ type ResponseFormat struct { } type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` + Model string `json:"model,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"` Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` + FunctionCall any `json:"function_call,omitempty"` + Functions any `json:"functions,omitempty"` User string `json:"user,omitempty"` + Prompt any `json:"prompt,omitempty"` + Input any `json:"input,omitempty"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string { diff --git a/relay/model/image.go b/relay/model/image.go new file mode 100644 index 00000000..bab84256 --- /dev/null +++ b/relay/model/image.go @@ -0,0 +1,12 @@ +package model + +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + User string `json:"user,omitempty"` +} diff --git a/relay/model/message.go b/relay/model/message.go index c6c8a271..32a1055b 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -1,9 +1,10 @@ package model type Message struct { - Role string `json:"role"` - Content any `json:"content"` - Name *string `json:"name,omitempty"` + Role string `json:"role,omitempty"` + Content any `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []Tool `json:"tool_calls,omitempty"` } func (m Message) IsStringContent() bool { diff --git a/relay/model/tool.go b/relay/model/tool.go new file mode 100644 index 00000000..253dca35 --- /dev/null +++ b/relay/model/tool.go @@ -0,0 +1,14 @@ +package model + +type Tool struct { + Id string `json:"id,omitempty"` + Type string `json:"type"` + Function Function `json:"function"` +} + +type Function struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Parameters any `json:"parameters,omitempty"` // request + Arguments any `json:"arguments,omitempty"` // response +} diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go new file mode 100644 index 00000000..96d09438 --- /dev/null +++ b/relay/relaymode/define.go @@ -0,0 +1,14 @@ +package relaymode + +const ( + Unknown = iota + ChatCompletions + Completions + Embeddings + Moderations + ImagesGenerations + Edits + AudioSpeech + AudioTranscription + AudioTranslation +) diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go new file mode 100644 index 00000000..926dd42e --- /dev/null +++ b/relay/relaymode/helper.go @@ -0,0 +1,29 @@ +package relaymode + +import "strings" + +func GetByPath(path string) int { + relayMode := Unknown + if strings.HasPrefix(path, "/v1/chat/completions") { + relayMode = ChatCompletions + } else if strings.HasPrefix(path, "/v1/completions") { + relayMode = Completions + } else if strings.HasPrefix(path, "/v1/embeddings") { + relayMode = Embeddings + } else if strings.HasSuffix(path, "embeddings") { + relayMode = Embeddings + } else if strings.HasPrefix(path, "/v1/moderations") { + relayMode = Moderations + } else if strings.HasPrefix(path, "/v1/images/generations") { + relayMode = ImagesGenerations + } else if strings.HasPrefix(path, "/v1/edits") { + relayMode = Edits + } else if strings.HasPrefix(path, "/v1/audio/speech") { + relayMode = AudioSpeech + } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { + relayMode = AudioTranscription + } else if strings.HasPrefix(path, "/v1/audio/translations") { + relayMode = AudioTranslation + } + return relayMode +} diff --git a/relay/util/billing.go b/relay/util/billing.go deleted file mode 100644 index 1e2b09ea..00000000 --- a/relay/util/billing.go +++ /dev/null @@ -1,19 +0,0 @@ -package util - -import ( - "context" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/model" -) - -func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) { - if preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) - if err != nil { - logger.Error(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(ctx) - } -} diff --git a/relay/util/common.go b/relay/util/common.go deleted file mode 100644 index 6d993378..00000000 --- a/relay/util/common.go +++ /dev/null @@ -1,168 +0,0 @@ -package util - -import ( - "context" - "encoding/json" - "fmt" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/model" - relaymodel "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strconv" - "strings" - - "github.com/gin-gonic/gin" -) - -func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { - if !config.AutomaticDisableChannelEnabled { - return false - } - if err == nil { - return false - } - if statusCode == http.StatusUnauthorized { - return true - } - if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { - return true - } - return false -} - -func ShouldEnableChannel(err error, openAIErr *relaymodel.Error) bool { - if !config.AutomaticEnableChannelEnabled { - return false - } - if err != nil { - return false - } - if openAIErr != nil { - return false - } - return true -} - -type GeneralErrorResponse struct { - Error relaymodel.Error `json:"error"` - Message string `json:"message"` - Msg string `json:"msg"` - Err string `json:"err"` - ErrorMsg string `json:"error_msg"` - Header struct { - Message string `json:"message"` - } `json:"header"` - Response struct { - Error struct { - Message string `json:"message"` - } `json:"error"` - } `json:"response"` -} - -func (e GeneralErrorResponse) ToMessage() string { - if e.Error.Message != "" { - return e.Error.Message - } - if e.Message != "" { - return e.Message - } - if e.Msg != "" { - return e.Msg - } - if e.Err != "" { - return e.Err - } - if e.ErrorMsg != "" { - return e.ErrorMsg - } - if e.Header.Message != "" { - return e.Header.Message - } - if e.Response.Error.Message != "" { - return e.Response.Error.Message - } - return "" -} - -func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.ErrorWithStatusCode) { - ErrorWithStatusCode = &relaymodel.ErrorWithStatusCode{ - StatusCode: resp.StatusCode, - Error: relaymodel.Error{ - Message: "", - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, - } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return - } - err = resp.Body.Close() - if err != nil { - return - } - var errResponse GeneralErrorResponse - err = json.Unmarshal(responseBody, &errResponse) - if err != nil { - return - } - if errResponse.Error.Message != "" { - // OpenAI format error, so we override the default one - ErrorWithStatusCode.Error = errResponse.Error - } else { - ErrorWithStatusCode.Error.Message = errResponse.ToMessage() - } - if ErrorWithStatusCode.Error.Message == "" { - ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) - } - return -} - -func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { - switch channelType { - case common.ChannelTypeOpenAI: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - case common.ChannelTypeAzure: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) - } - } - return fullRequestURL -} - -func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { - // quotaDelta is remaining quota to be consumed - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - logger.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - logger.SysError("error update user quota cache: " + err.Error()) - } - // totalQuota is total quota consumed - if totalQuota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) - model.UpdateChannelUsedQuota(channelId, totalQuota) - } - if totalQuota <= 0 { - logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) - } -} - -func GetAzureAPIVersion(c *gin.Context) string { - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString(common.ConfigKeyAPIVersion) - } - return apiVersion -} diff --git a/relay/util/model_mapping.go b/relay/util/model_mapping.go deleted file mode 100644 index 39e062a1..00000000 --- a/relay/util/model_mapping.go +++ /dev/null @@ -1,12 +0,0 @@ -package util - -func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) { - if mapping == nil { - return modelName, false - } - mappedModelName := mapping[modelName] - if mappedModelName != "" { - return mappedModelName, true - } - return modelName, false -} diff --git a/router/api-router.go b/router/api.go similarity index 89% rename from router/api-router.go rename to router/api.go index 6d143da7..d2ada4eb 100644 --- a/router/api-router.go +++ b/router/api.go @@ -2,6 +2,7 @@ package router import ( "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/controller/auth" "github.com/songquanpeng/one-api/middleware" "github.com/gin-contrib/gzip" @@ -14,17 +15,20 @@ func SetApiRouter(router *gin.Engine) { apiRouter.Use(middleware.GlobalAPIRateLimit()) { apiRouter.GET("/status", controller.GetStatus) + apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels) apiRouter.GET("/notice", controller.GetNotice) apiRouter.GET("/about", controller.GetAbout) apiRouter.GET("/home_page_content", controller.GetHomePageContent) apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) - apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) - apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) - apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) - apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) + apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) + apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) + apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) + apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) + apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), auth.WeChatBind) apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) + apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp) userRoute := apiRouter.Group("/user") { @@ -42,6 +46,7 @@ func SetApiRouter(router *gin.Engine) { selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/aff", controller.GetAffCode) selfRoute.POST("/topup", controller.TopUp) + selfRoute.GET("/available_models", controller.GetUserAvailableModels) } adminRoute := userRoute.Group("/") @@ -67,9 +72,9 @@ func SetApiRouter(router *gin.Engine) { { channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/search", controller.SearchChannels) - channelRoute.GET("/models", controller.ListModels) + channelRoute.GET("/models", controller.ListAllModels) channelRoute.GET("/:id", controller.GetChannel) - channelRoute.GET("/test", controller.TestAllChannels) + channelRoute.GET("/test", controller.TestChannels) channelRoute.GET("/test/:id", controller.TestChannel) channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) diff --git a/router/dashboard.go b/router/dashboard.go index 0b539d44..5952d698 100644 --- a/router/dashboard.go +++ b/router/dashboard.go @@ -9,6 +9,7 @@ import ( func SetDashboardRouter(router *gin.Engine) { apiRouter := router.Group("/") + apiRouter.Use(middleware.CORS()) apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) apiRouter.Use(middleware.GlobalAPIRateLimit()) apiRouter.Use(middleware.TokenAuth()) diff --git a/router/relay-router.go b/router/relay.go similarity index 100% rename from router/relay-router.go rename to router/relay.go diff --git a/router/web-router.go b/router/web.go similarity index 100% rename from router/web-router.go rename to router/web.go diff --git a/web/README.md b/web/README.md index 86486085..829271e2 100644 --- a/web/README.md +++ b/web/README.md @@ -2,6 +2,9 @@ > 每个文件夹代表一个主题,欢迎提交你的主题 +> [!WARNING] +> 不是每一个主题都及时同步了所有功能,由于精力有限,优先更新默认主题,其他主题欢迎 & 期待 PR + ## 提交新的主题 > 欢迎在页面底部保留你和 One API 的版权信息以及指向链接 @@ -9,7 +12,7 @@ 1. 在 `web` 文件夹下新建一个文件夹,文件夹名为主题名。 2. 把你的主题文件放到这个文件夹下。 3. 修改你的 `package.json` 文件,把 `build` 命令改为:`"build": "react-scripts build && mv -f build ../build/default"`,其中 `default` 为你的主题名。 -4. 修改 `common/constants.go` 中的 `ValidThemes`,把你的主题名称注册进去。 +4. 修改 `common/config/config.go` 中的 `ValidThemes`,把你的主题名称注册进去。 5. 修改 `web/THEMES` 文件,这里也需要同步修改。 ## 主题列表 @@ -33,6 +36,12 @@ |![image](https://github.com/songquanpeng/one-api/assets/42402987/fb2b1c64-ef24-4027-9b80-0cd9d945a47f)|![image](https://github.com/songquanpeng/one-api/assets/42402987/b6b649ec-2888-4324-8b2d-d5e11554eed6)| |![image](https://github.com/songquanpeng/one-api/assets/42402987/6d3b22e0-436b-4e26-8911-bcc993c6a2bd)|![image](https://github.com/songquanpeng/one-api/assets/42402987/eef1e224-7245-44d7-804e-9d1c8fa3f29c)| +### 主题:air +由 [Calon](https://github.com/Calcium-Ion) 开发。 +|![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/1ddb274b-a715-4e81-858b-857d520b6ff4)|![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/163b0b8e-1f73-49cb-b632-3dcb986b56d5)| +|:---:|:---:| + + #### 开发说明 请查看 [web/berry/README.md](https://github.com/songquanpeng/one-api/tree/main/web/berry/README.md) diff --git a/web/THEMES b/web/THEMES index 6b0157cb..149e8698 100644 --- a/web/THEMES +++ b/web/THEMES @@ -1,2 +1,3 @@ default berry +air diff --git a/web/air/.gitignore b/web/air/.gitignore new file mode 100644 index 00000000..2b5bba76 --- /dev/null +++ b/web/air/.gitignore @@ -0,0 +1,26 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.js + +# testing +/coverage + +# production +/build + +# misc +.DS_Store +.env.local +.env.development.local +.env.test.local +.env.production.local + +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.idea +package-lock.json +yarn.lock \ No newline at end of file diff --git a/web/air/README.md b/web/air/README.md new file mode 100644 index 00000000..1b1031a3 --- /dev/null +++ b/web/air/README.md @@ -0,0 +1,21 @@ +# React Template + +## Basic Usages + +```shell +# Runs the app in the development mode +npm start + +# Builds the app for production to the `build` folder +npm run build +``` + +If you want to change the default server, please set `REACT_APP_SERVER` environment variables before build, +for example: `REACT_APP_SERVER=http://your.domain.com`. + +Before you start editing, make sure your `Actions on Save` options have `Optimize imports` & `Run Prettier` enabled. + +## Reference + +1. https://github.com/OIerDb-ng/OIerDb +2. https://github.com/cornflourblue/react-hooks-redux-registration-login-example \ No newline at end of file diff --git a/web/air/package.json b/web/air/package.json new file mode 100644 index 00000000..3bdf3952 --- /dev/null +++ b/web/air/package.json @@ -0,0 +1,60 @@ +{ + "name": "react-template", + "version": "0.1.0", + "private": true, + "dependencies": { + "@douyinfe/semi-icons": "^2.46.1", + "@douyinfe/semi-ui": "^2.46.1", + "@visactor/react-vchart": "~1.8.8", + "@visactor/vchart": "~1.8.8", + "@visactor/vchart-semi-theme": "~1.8.8", + "axios": "^0.27.2", + "history": "^5.3.0", + "marked": "^4.1.1", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "react-dropzone": "^14.2.3", + "react-fireworks": "^1.0.4", + "react-router-dom": "^6.3.0", + "react-scripts": "5.0.1", + "react-telegram-login": "^1.1.2", + "react-toastify": "^9.0.8", + "react-turnstile": "^1.0.5", + "semantic-ui-css": "^2.5.0", + "semantic-ui-react": "^2.1.3", + "usehooks-ts": "^2.9.1" + }, + "scripts": { + "start": "react-scripts start", + "build": "react-scripts build && mv -f build ../build/air", + "test": "react-scripts test", + "eject": "react-scripts eject" + }, + "eslintConfig": { + "extends": [ + "react-app", + "react-app/jest" + ] + }, + "browserslist": { + "production": [ + ">0.2%", + "not dead", + "not op_mini all" + ], + "development": [ + "last 1 chrome version", + "last 1 firefox version", + "last 1 safari version" + ] + }, + "devDependencies": { + "prettier": "2.8.8", + "typescript": "4.4.2" + }, + "prettier": { + "singleQuote": true, + "jsxSingleQuote": true + }, + "proxy": "http://localhost:3000" +} diff --git a/web/air/public/favicon.ico b/web/air/public/favicon.ico new file mode 100644 index 00000000..c2c8de0c Binary files /dev/null and b/web/air/public/favicon.ico differ diff --git a/web/air/public/index.html b/web/air/public/index.html new file mode 100644 index 00000000..36365c7e --- /dev/null +++ b/web/air/public/index.html @@ -0,0 +1,18 @@ + + + + + + + + + One API + + + +
+ + diff --git a/web/air/public/logo.png b/web/air/public/logo.png new file mode 100644 index 00000000..0f237a22 Binary files /dev/null and b/web/air/public/logo.png differ diff --git a/web/air/public/robots.txt b/web/air/public/robots.txt new file mode 100644 index 00000000..e9e57dc4 --- /dev/null +++ b/web/air/public/robots.txt @@ -0,0 +1,3 @@ +# https://www.robotstxt.org/robotstxt.html +User-agent: * +Disallow: diff --git a/web/air/src/App.js b/web/air/src/App.js new file mode 100644 index 00000000..5a673187 --- /dev/null +++ b/web/air/src/App.js @@ -0,0 +1,242 @@ +import React, { lazy, Suspense, useContext, useEffect } from 'react'; +import { Route, Routes } from 'react-router-dom'; +import Loading from './components/Loading'; +import User from './pages/User'; +import { PrivateRoute } from './components/PrivateRoute'; +import RegisterForm from './components/RegisterForm'; +import LoginForm from './components/LoginForm'; +import NotFound from './pages/NotFound'; +import Setting from './pages/Setting'; +import EditUser from './pages/User/EditUser'; +import { getLogo, getSystemName } from './helpers'; +import PasswordResetForm from './components/PasswordResetForm'; +import GitHubOAuth from './components/GitHubOAuth'; +import PasswordResetConfirm from './components/PasswordResetConfirm'; +import { UserContext } from './context/User'; +import Channel from './pages/Channel'; +import Token from './pages/Token'; +import EditChannel from './pages/Channel/EditChannel'; +import Redemption from './pages/Redemption'; +import TopUp from './pages/TopUp'; +import Log from './pages/Log'; +import Chat from './pages/Chat'; +import { Layout } from '@douyinfe/semi-ui'; +import Midjourney from './pages/Midjourney'; +import Detail from './pages/Detail'; + +const Home = lazy(() => import('./pages/Home')); +const About = lazy(() => import('./pages/About')); + +function App() { + const [userState, userDispatch] = useContext(UserContext); + // const [statusState, statusDispatch] = useContext(StatusContext); + + const loadUser = () => { + let user = localStorage.getItem('user'); + if (user) { + let data = JSON.parse(user); + userDispatch({ type: 'login', payload: data }); + } + }; + + useEffect(() => { + loadUser(); + let systemName = getSystemName(); + if (systemName) { + document.title = systemName; + } + let logo = getLogo(); + if (logo) { + let linkElement = document.querySelector('link[rel~=\'icon\']'); + if (linkElement) { + linkElement.href = logo; + } + } + }, []); + + return ( + + + + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + + } + /> + + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + + }> + + + + } + /> + + }> + + + + } + /> + + + + } + /> + + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + } /> + + + + ); +} + +export default App; diff --git a/web/air/src/components/ChannelsTable.js b/web/air/src/components/ChannelsTable.js new file mode 100644 index 00000000..c384d50c --- /dev/null +++ b/web/air/src/components/ChannelsTable.js @@ -0,0 +1,738 @@ +import React, { useEffect, useState } from 'react'; +import { API, isMobile, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; + +import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; +import { renderGroup, renderNumberWithPoint, renderQuota } from '../helpers/render'; +import { + Button, + Dropdown, + Form, + InputNumber, + Popconfirm, + Space, + SplitButtonGroup, + Switch, + Table, + Tag, + Tooltip, + Typography +} from '@douyinfe/semi-ui'; +import EditChannel from '../pages/Channel/EditChannel'; +import { IconTreeTriangleDown } from '@douyinfe/semi-icons'; + +function renderTimestamp(timestamp) { + return ( + <> + {timestamp2string(timestamp)} + + ); +} + +let type2label = undefined; + +function renderType(type) { + if (!type2label) { + type2label = new Map(); + for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { + type2label[CHANNEL_OPTIONS[i].value] = CHANNEL_OPTIONS[i]; + } + type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; + } + return {type2label[type]?.text}; +} + +const ChannelsTable = () => { + const columns = [ + // { + // title: '', + // dataIndex: 'checkbox', + // className: 'checkbox', + // }, + { + title: 'ID', + dataIndex: 'id' + }, + { + title: '名称', + dataIndex: 'name' + }, + // { + // title: '分组', + // dataIndex: 'group', + // render: (text, record, index) => { + // return ( + //
+ // + // { + // text.split(',').map((item, index) => { + // return (renderGroup(item)); + // }) + // } + // + //
+ // ); + // } + // }, + { + title: '类型', + dataIndex: 'type', + render: (text, record, index) => { + return ( +
+ {renderType(text)} +
+ ); + } + }, + { + title: '状态', + dataIndex: 'status', + render: (text, record, index) => { + return ( +
+ {renderStatus(text)} +
+ ); + } + }, + { + title: '响应时间', + dataIndex: 'response_time', + render: (text, record, index) => { + return ( +
+ {renderResponseTime(text)} +
+ ); + } + }, + { + title: '已用/剩余', + dataIndex: 'expired_time', + render: (text, record, index) => { + return ( +
+ + + {renderQuota(record.used_quota)} + + + { + updateChannelBalance(record); + }}>${renderNumberWithPoint(record.balance)} + + +
+ ); + } + }, + { + title: '优先级', + dataIndex: 'priority', + render: (text, record, index) => { + return ( +
+ { + manageChannel(record.id, 'priority', record, e.target.value); + }} + keepFocus={true} + innerButtons + defaultValue={record.priority} + min={-999} + /> +
+ ); + } + }, + // { + // title: '权重', + // dataIndex: 'weight', + // render: (text, record, index) => { + // return ( + //
+ // { + // manageChannel(record.id, 'weight', record, e.target.value); + // }} + // keepFocus={true} + // innerButtons + // defaultValue={record.weight} + // min={0} + // /> + //
+ // ); + // } + // }, + { + title: '', + dataIndex: 'operate', + render: (text, record, index) => ( +
+ {/* + + + + + */} + + { + manageChannel(record.id, 'delete', record).then( + () => { + removeRecord(record.id); + } + ); + }} + > + + + { + record.status === 1 ? + : + + } + +
+ ) + } + ]; + + const [channels, setChannels] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [idSort, setIdSort] = useState(false); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searchGroup, setSearchGroup] = useState(''); + const [searchModel, setSearchModel] = useState(''); + const [searching, setSearching] = useState(false); + const [updatingBalance, setUpdatingBalance] = useState(false); + const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); + const [showPrompt, setShowPrompt] = useState(shouldShowPrompt('channel-test')); + const [channelCount, setChannelCount] = useState(pageSize); + const [groupOptions, setGroupOptions] = useState([]); + const [showEdit, setShowEdit] = useState(false); + const [enableBatchDelete, setEnableBatchDelete] = useState(false); + const [editingChannel, setEditingChannel] = useState({ + id: undefined + }); + const [selectedChannels, setSelectedChannels] = useState([]); + + const removeRecord = id => { + let newDataSource = [...channels]; + if (id != null) { + let idx = newDataSource.findIndex(data => data.id === id); + + if (idx > -1) { + newDataSource.splice(idx, 1); + setChannels(newDataSource); + } + } + }; + + const setChannelFormat = (channels) => { + for (let i = 0; i < channels.length; i++) { + channels[i].key = '' + channels[i].id; + let test_models = []; + channels[i].models.split(',').forEach((item, index) => { + test_models.push({ + node: 'item', + name: item, + onClick: () => { + testChannel(channels[i], item); + } + }); + }); + channels[i].test_models = test_models; + } + // data.key = '' + data.id + setChannels(channels); + if (channels.length >= pageSize) { + setChannelCount(channels.length + pageSize); + } else { + setChannelCount(channels.length); + } + }; + + const loadChannels = async (startIdx, pageSize, idSort) => { + setLoading(true); + const res = await API.get(`/api/channel/?p=${startIdx}&page_size=${pageSize}&id_sort=${idSort}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setChannelFormat(data); + } else { + let newChannels = [...channels]; + newChannels.splice(startIdx * pageSize, data.length, ...data); + setChannelFormat(newChannels); + } + } else { + showError(message); + } + setLoading(false); + }; + + const refresh = async () => { + await loadChannels(activePage - 1, pageSize, idSort); + }; + + useEffect(() => { + // console.log('default effect') + const localIdSort = localStorage.getItem('id-sort') === 'true'; + const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE; + setIdSort(localIdSort); + setPageSize(localPageSize); + loadChannels(0, localPageSize, localIdSort) + .then() + .catch((reason) => { + showError(reason); + }); + fetchGroups().then(); + }, []); + + const manageChannel = async (id, action, record, value) => { + let data = { id }; + let res; + switch (action) { + case 'delete': + res = await API.delete(`/api/channel/${id}/`); + break; + case 'enable': + data.status = 1; + res = await API.put('/api/channel/', data); + break; + case 'disable': + data.status = 2; + res = await API.put('/api/channel/', data); + break; + case 'priority': + if (value === '') { + return; + } + data.priority = parseInt(value); + res = await API.put('/api/channel/', data); + break; + case 'weight': + if (value === '') { + return; + } + data.weight = parseInt(value); + if (data.weight < 0) { + data.weight = 0; + } + res = await API.put('/api/channel/', data); + break; + } + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let channel = res.data.data; + let newChannels = [...channels]; + if (action === 'delete') { + + } else { + record.status = channel.status; + } + setChannels(newChannels); + } else { + showError(message); + } + }; + + const renderStatus = (status) => { + switch (status) { + case 1: + return 已启用; + case 2: + return ( + + 已禁用 + + ); + case 3: + return ( + + 自动禁用 + + ); + default: + return ( + + 未知状态 + + ); + } + }; + + const renderResponseTime = (responseTime) => { + let time = responseTime / 1000; + time = time.toFixed(2) + ' 秒'; + if (responseTime === 0) { + return 未测试; + } else if (responseTime <= 1000) { + return {time}; + } else if (responseTime <= 3000) { + return {time}; + } else if (responseTime <= 5000) { + return {time}; + } else { + return {time}; + } + }; + + const searchChannels = async (searchKeyword, searchGroup, searchModel) => { + if (searchKeyword === '' && searchGroup === '' && searchModel === '') { + // if keyword is blank, load files instead. + await loadChannels(0, pageSize, idSort); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}`); + const { success, message, data } = res.data; + if (success) { + setChannels(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const testChannel = async (record, model) => { + const res = await API.get(`/api/channel/test/${record.id}?model=${model}`); + const { success, message, time } = res.data; + if (success) { + record.response_time = time * 1000; + record.test_time = Date.now() / 1000; + showInfo(`渠道 ${record.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); + } else { + showError(message); + } + }; + + const testChannels = async (scope) => { + const res = await API.get(`/api/channel/test?scope=${scope}`); + const { success, message } = res.data; + if (success) { + showInfo('已成功开始测试渠道,请刷新页面查看结果。'); + } else { + showError(message); + } + }; + + const deleteAllDisabledChannels = async () => { + const res = await API.delete(`/api/channel/disabled`); + const { success, message, data } = res.data; + if (success) { + showSuccess(`已删除所有禁用渠道,共计 ${data} 个`); + await refresh(); + } else { + showError(message); + } + }; + + const updateChannelBalance = async (record) => { + const res = await API.get(`/api/channel/update_balance/${record.id}/`); + const { success, message, balance } = res.data; + if (success) { + record.balance = balance; + record.balance_updated_time = Date.now() / 1000; + showInfo(`渠道 ${record.name} 余额更新成功!`); + } else { + showError(message); + } + }; + + const updateAllChannelsBalance = async () => { + setUpdatingBalance(true); + const res = await API.get(`/api/channel/update_balance`); + const { success, message } = res.data; + if (success) { + showInfo('已更新完毕所有已启用渠道余额!'); + } else { + showError(message); + } + setUpdatingBalance(false); + }; + + const batchDeleteChannels = async () => { + if (selectedChannels.length === 0) { + showError('请先选择要删除的渠道!'); + return; + } + setLoading(true); + let ids = []; + selectedChannels.forEach((channel) => { + ids.push(channel.id); + }); + const res = await API.post(`/api/channel/batch`, { ids: ids }); + const { success, message, data } = res.data; + if (success) { + showSuccess(`已删除 ${data} 个渠道!`); + await refresh(); + } else { + showError(message); + } + setLoading(false); + }; + + const fixChannelsAbilities = async () => { + const res = await API.post(`/api/channel/fix`); + const { success, message, data } = res.data; + if (success) { + showSuccess(`已修复 ${data} 个渠道!`); + await refresh(); + } else { + showError(message); + } + }; + + let pageData = channels.slice((activePage - 1) * pageSize, activePage * pageSize); + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(channels.length / pageSize) + 1) { + // In this case we have to load more data and then append them. + loadChannels(page - 1, pageSize, idSort).then(r => { + }); + } + }; + + const handlePageSizeChange = async (size) => { + localStorage.setItem('page-size', size + ''); + setPageSize(size); + setActivePage(1); + loadChannels(0, size, idSort) + .then() + .catch((reason) => { + showError(reason); + }); + }; + + const fetchGroups = async () => { + try { + let res = await API.get(`/api/group/`); + // add 'all' option + // res.data.data.unshift('all'); + setGroupOptions(res.data.data.map((group) => ({ + label: group, + value: group + }))); + } catch (error) { + showError(error.message); + } + }; + + const closeEdit = () => { + setShowEdit(false); + }; + + const handleRow = (record, index) => { + if (record.status !== 1) { + return { + style: { + background: 'var(--semi-color-disabled-border)' + } + }; + } else { + return {}; + } + }; + + + return ( + <> + +
+
{ + searchChannels(searchKeyword, searchGroup, searchModel); + }} labelPosition="left"> +
+ + { + setSearchKeyword(v.trim()); + }} + /> + {/* { + setSearchModel(v.trim()); + }} + /> + { + setSearchGroup(v); + searchChannels(searchKeyword, v, searchModel); + }} /> */} + + +
+
+
+ + + { testChannels("all") }} + position={isMobile() ? 'top' : 'left'} + > + + + { testChannels("disabled") }} + position={isMobile() ? 'top' : 'left'} + > + + + {/* + + */} + + + + + + + {/*
*/} + + {/*
*/} +
+ {/*
+ + 开启批量删除 + { + setEnableBatchDelete(v); + }}> + + + + + + + +
+
+ + + 使用ID排序 + { + localStorage.setItem('id-sort', v + ''); + setIdSort(v); + loadChannels(0, pageSize, v) + .then() + .catch((reason) => { + showError(reason); + }); + }}> + + +
*/} +
+ '', + onPageSizeChange: (size) => { + handlePageSizeChange(size).then(); + }, + onPageChange: handlePageChange + }} loading={loading} onRow={handleRow} rowSelection={ + enableBatchDelete ? + { + onChange: (selectedRowKeys, selectedRows) => { + // console.log(`selectedRowKeys: ${selectedRowKeys}`, 'selectedRows: ', selectedRows); + setSelectedChannels(selectedRows); + } + } : null + } /> + + ); +}; + +export default ChannelsTable; diff --git a/web/air/src/components/Footer.js b/web/air/src/components/Footer.js new file mode 100644 index 00000000..6fd0fa54 --- /dev/null +++ b/web/air/src/components/Footer.js @@ -0,0 +1,64 @@ +import React, { useEffect, useState } from 'react'; + +import { Container, Segment } from 'semantic-ui-react'; +import { getFooterHTML, getSystemName } from '../helpers'; + +const Footer = () => { + const systemName = getSystemName(); + const [footer, setFooter] = useState(getFooterHTML()); + let remainCheckTimes = 5; + + const loadFooter = () => { + let footer_html = localStorage.getItem('footer_html'); + if (footer_html) { + setFooter(footer_html); + } + }; + + useEffect(() => { + const timer = setInterval(() => { + if (remainCheckTimes <= 0) { + clearInterval(timer); + return; + } + remainCheckTimes--; + loadFooter(); + }, 200); + return () => clearTimeout(timer); + }, []); + + return ( + + + {footer ? ( +
+ ) : ( +
+ + {systemName} {process.env.REACT_APP_VERSION}{' '} + + 由{' '} + + JustSong + {' '} + 构建,主题 air 来自{' '} + + Calon + {' '},源代码遵循{' '} + + MIT 协议 + +
+ )} +
+
+ ); +}; + +export default Footer; diff --git a/web/air/src/components/GitHubOAuth.js b/web/air/src/components/GitHubOAuth.js new file mode 100644 index 00000000..4e3b93ba --- /dev/null +++ b/web/air/src/components/GitHubOAuth.js @@ -0,0 +1,58 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { Dimmer, Loader, Segment } from 'semantic-ui-react'; +import { useNavigate, useSearchParams } from 'react-router-dom'; +import { API, showError, showSuccess } from '../helpers'; +import { UserContext } from '../context/User'; + +const GitHubOAuth = () => { + const [searchParams, setSearchParams] = useSearchParams(); + + const [userState, userDispatch] = useContext(UserContext); + const [prompt, setPrompt] = useState('处理中...'); + const [processing, setProcessing] = useState(true); + + let navigate = useNavigate(); + + const sendCode = async (code, state, count) => { + const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/setting'); + } else { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } + } else { + showError(message); + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + navigate('/setting'); // in case this is failed to bind GitHub + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, count * 2000)); + await sendCode(code, state, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + let state = searchParams.get('state'); + sendCode(code, state, 0).then(); + }, []); + + return ( + + + {prompt} + + + ); +}; + +export default GitHubOAuth; diff --git a/web/air/src/components/HeaderBar.js b/web/air/src/components/HeaderBar.js new file mode 100644 index 00000000..eaf36c48 --- /dev/null +++ b/web/air/src/components/HeaderBar.js @@ -0,0 +1,161 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { Link, useNavigate } from 'react-router-dom'; +import { UserContext } from '../context/User'; + +import { API, getLogo, getSystemName, showSuccess } from '../helpers'; +import '../index.css'; + +import fireworks from 'react-fireworks'; + +import { IconHelpCircle, IconKey, IconUser } from '@douyinfe/semi-icons'; +import { Avatar, Dropdown, Layout, Nav, Switch } from '@douyinfe/semi-ui'; +import { stringToColor } from '../helpers/render'; + +// HeaderBar Buttons +let headerButtons = [ + { + text: '关于', + itemKey: 'about', + to: '/about', + icon: + } +]; + +if (localStorage.getItem('chat_link')) { + headerButtons.splice(1, 0, { + name: '聊天', + to: '/chat', + icon: 'comments' + }); +} + +const HeaderBar = () => { + const [userState, userDispatch] = useContext(UserContext); + let navigate = useNavigate(); + + const [showSidebar, setShowSidebar] = useState(false); + const [dark, setDark] = useState(false); + const systemName = getSystemName(); + const logo = getLogo(); + var themeMode = localStorage.getItem('theme-mode'); + const currentDate = new Date(); + // enable fireworks on new year(1.1 and 2.9-2.24) + const isNewYear = (currentDate.getMonth() === 0 && currentDate.getDate() === 1) || (currentDate.getMonth() === 1 && currentDate.getDate() >= 9 && currentDate.getDate() <= 24); + + async function logout() { + setShowSidebar(false); + await API.get('/api/user/logout'); + showSuccess('注销成功!'); + userDispatch({ type: 'logout' }); + localStorage.removeItem('user'); + navigate('/login'); + } + + const handleNewYearClick = () => { + fireworks.init('root', {}); + fireworks.start(); + setTimeout(() => { + fireworks.stop(); + setTimeout(() => { + window.location.reload(); + }, 10000); + }, 3000); + }; + + useEffect(() => { + if (themeMode === 'dark') { + switchMode(true); + } + if (isNewYear) { + console.log('Happy New Year!'); + } + }, []); + + const switchMode = (model) => { + const body = document.body; + if (!model) { + body.removeAttribute('theme-mode'); + localStorage.setItem('theme-mode', 'light'); + } else { + body.setAttribute('theme-mode', 'dark'); + localStorage.setItem('theme-mode', 'dark'); + } + setDark(model); + }; + return ( + <> + +
+ +
+
+ + ); +}; + +export default HeaderBar; diff --git a/web/air/src/components/Loading.js b/web/air/src/components/Loading.js new file mode 100644 index 00000000..bacb53b3 --- /dev/null +++ b/web/air/src/components/Loading.js @@ -0,0 +1,14 @@ +import React from 'react'; +import { Dimmer, Loader, Segment } from 'semantic-ui-react'; + +const Loading = ({ prompt: name = 'page' }) => { + return ( + + + 加载{name}中... + + + ); +}; + +export default Loading; diff --git a/web/air/src/components/LoginForm.js b/web/air/src/components/LoginForm.js new file mode 100644 index 00000000..3cbeb52c --- /dev/null +++ b/web/air/src/components/LoginForm.js @@ -0,0 +1,254 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { Link, useNavigate, useSearchParams } from 'react-router-dom'; +import { UserContext } from '../context/User'; +import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; +import { onGitHubOAuthClicked } from './utils'; +import Turnstile from 'react-turnstile'; +import { Button, Card, Divider, Form, Icon, Layout, Modal } from '@douyinfe/semi-ui'; +import Title from '@douyinfe/semi-ui/lib/es/typography/title'; +import Text from '@douyinfe/semi-ui/lib/es/typography/text'; +import TelegramLoginButton from 'react-telegram-login'; + +import { IconGithubLogo } from '@douyinfe/semi-icons'; +import WeChatIcon from './WeChatIcon'; + +const LoginForm = () => { + const [inputs, setInputs] = useState({ + username: '', + password: '', + wechat_verification_code: '' + }); + const [searchParams, setSearchParams] = useSearchParams(); + const [submitted, setSubmitted] = useState(false); + const { username, password } = inputs; + const [userState, userDispatch] = useContext(UserContext); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + let navigate = useNavigate(); + const [status, setStatus] = useState({}); + const logo = getLogo(); + + useEffect(() => { + if (searchParams.get('expired')) { + showError('未登录或登录已过期,请重新登录!'); + } + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + setStatus(status); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + } + }, []); + + const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); + + const onWeChatLoginClicked = () => { + setShowWeChatLoginModal(true); + }; + + const onSubmitWeChatVerificationCode = async () => { + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + const res = await API.get( + `/api/oauth/wechat?code=${inputs.wechat_verification_code}` + ); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + navigate('/'); + showSuccess('登录成功!'); + setShowWeChatLoginModal(false); + } else { + showError(message); + } + }; + + function handleChange(name, value) { + setInputs((inputs) => ({ ...inputs, [name]: value })); + } + + async function handleSubmit(e) { + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setSubmitted(true); + if (username && password) { + const res = await API.post(`/api/user/login?turnstile=${turnstileToken}`, { + username, + password + }); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + if (username === 'root' && password === '123456') { + Modal.error({ title: '您正在使用默认密码!', content: '请立刻修改默认密码!', centered: true }); + } + navigate('/token'); + } else { + showError(message); + } + } else { + showError('请输入用户名和密码!'); + } + } + + // 添加Telegram登录处理函数 + const onTelegramLoginClicked = async (response) => { + const fields = ['id', 'first_name', 'last_name', 'username', 'photo_url', 'auth_date', 'hash', 'lang']; + const params = {}; + fields.forEach((field) => { + if (response[field]) { + params[field] = response[field]; + } + }); + const res = await API.get(`/api/oauth/telegram/login`, { params }); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } else { + showError(message); + } + }; + + return ( +
+ + + + +
+
+ + + 用户登录 + +
+ handleChange('username', value)} + /> + handleChange('password', value)} + /> + + + +
+ + 没有账号请先 注册账号 + + + 忘记密码 点击重置 + +
+ {status.github_oauth || status.wechat_login || status.telegram_oauth ? ( + <> + + 第三方登录 + +
+ {status.github_oauth ? ( +
+ + ) : ( + <> + )} + setShowWeChatLoginModal(false)} + okText={'登录'} + size={'small'} + centered={true} + > +
+ +
+
+

+ 微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) +

+
+
+ handleChange('wechat_verification_code', value)} + /> + +
+
+ {turnstileEnabled ? ( +
+ { + setTurnstileToken(token); + }} + /> +
+ ) : ( + <> + )} +
+
+ +
+
+
+ ); +}; + +export default LoginForm; diff --git a/web/air/src/components/LogsTable.js b/web/air/src/components/LogsTable.js new file mode 100644 index 00000000..004188c3 --- /dev/null +++ b/web/air/src/components/LogsTable.js @@ -0,0 +1,401 @@ +import React, { useEffect, useState } from 'react'; +import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers'; + +import { Avatar, Button, Form, Layout, Modal, Select, Space, Spin, Table, Tag } from '@douyinfe/semi-ui'; +import { ITEMS_PER_PAGE } from '../constants'; +import { renderNumber, renderQuota, stringToColor } from '../helpers/render'; +import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph'; + +const { Header } = Layout; + +function renderTimestamp(timestamp) { + return (<> + {timestamp2string(timestamp)} + ); +} + +const MODE_OPTIONS = [{ key: 'all', text: '全部用户', value: 'all' }, { key: 'self', text: '当前用户', value: 'self' }]; + +const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo', 'light-blue', 'lime', 'orange', 'pink', 'purple', 'red', 'teal', 'violet', 'yellow']; + +function renderType(type) { + switch (type) { + case 1: + return 充值 ; + case 2: + return 消费 ; + case 3: + return 管理 ; + case 4: + return 系统 ; + default: + return 未知 ; + } +} + +function renderIsStream(bool) { + if (bool) { + return ; + } else { + return 非流; + } +} + +function renderUseTime(type) { + const time = parseInt(type); + if (time < 101) { + return {time} s ; + } else if (time < 300) { + return {time} s ; + } else { + return {time} s ; + } +} + +const LogsTable = () => { + const columns = [{ + title: '时间', dataIndex: 'timestamp2string' + }, { + title: '渠道', + dataIndex: 'channel', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return (isAdminUser ? record.type === 0 || record.type === 2 ?
+ { {text} } +
: <> : <>); + } + }, { + title: '用户', + dataIndex: 'username', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return (isAdminUser ?
+ showUserInfo(record.user_id)}> + {typeof text === 'string' && text.slice(0, 1)} + + {text} +
: <>); + } + }, { + title: '令牌', dataIndex: 'token_name', render: (text, record, index) => { + return (record.type === 0 || record.type === 2 ?
+ { + copyText(text); + }}> {text} +
: <>); + } + }, { + title: '类型', dataIndex: 'type', render: (text, record, index) => { + return (
+ {renderType(text)} +
); + } + }, { + title: '模型', dataIndex: 'model_name', render: (text, record, index) => { + return (record.type === 0 || record.type === 2 ?
+ { + copyText(text); + }}> {text} +
: <>); + } + }, + // { + // title: '用时', dataIndex: 'use_time', render: (text, record, index) => { + // return (
+ // + // {renderUseTime(text)} + // {renderIsStream(record.is_stream)} + // + //
); + // } + // }, + { + title: '提示', dataIndex: 'prompt_tokens', render: (text, record, index) => { + return (record.type === 0 || record.type === 2 ?
+ { {text} } +
: <>); + } + }, { + title: '补全', dataIndex: 'completion_tokens', render: (text, record, index) => { + return (parseInt(text) > 0 && (record.type === 0 || record.type === 2) ?
+ { {text} } +
: <>); + } + }, { + title: '花费', dataIndex: 'quota', render: (text, record, index) => { + return (record.type === 0 || record.type === 2 ?
+ {renderQuota(text, 6)} +
: <>); + } + }, { + title: '详情', dataIndex: 'content', render: (text, record, index) => { + return + {text} + ; + } + }]; + + const [logs, setLogs] = useState([]); + const [showStat, setShowStat] = useState(false); + const [loading, setLoading] = useState(false); + const [loadingStat, setLoadingStat] = useState(false); + const [activePage, setActivePage] = useState(1); + const [logCount, setLogCount] = useState(ITEMS_PER_PAGE); + const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + const [logType, setLogType] = useState(0); + const isAdminUser = isAdmin(); + let now = new Date(); + // 初始化start_timestamp为前一天 + const [inputs, setInputs] = useState({ + username: '', + token_name: '', + model_name: '', + start_timestamp: timestamp2string(now.getTime() / 1000 - 86400), + end_timestamp: timestamp2string(now.getTime() / 1000 + 3600), + channel: '' + }); + const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs; + + const [stat, setStat] = useState({ + quota: 0, token: 0 + }); + + const handleInputChange = (value, name) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + const getLogSelfStat = async () => { + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); + const { success, message, data } = res.data; + if (success) { + setStat(data); + } else { + showError(message); + } + }; + + const getLogStat = async () => { + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`); + const { success, message, data } = res.data; + if (success) { + setStat(data); + } else { + showError(message); + } + }; + + const handleEyeClick = async () => { + setLoadingStat(true); + if (isAdminUser) { + await getLogStat(); + } else { + await getLogSelfStat(); + } + setShowStat(true); + setLoadingStat(false); + }; + + const showUserInfo = async (userId) => { + if (!isAdminUser) { + return; + } + const res = await API.get(`/api/user/${userId}`); + const { success, message, data } = res.data; + if (success) { + Modal.info({ + title: '用户信息', content:
+

用户名: {data.username}

+

余额: {renderQuota(data.quota)}

+

已用额度:{renderQuota(data.used_quota)}

+

请求次数:{renderNumber(data.request_count)}

+
, centered: true + }); + } else { + showError(message); + } + }; + + const setLogsFormat = (logs) => { + for (let i = 0; i < logs.length; i++) { + logs[i].timestamp2string = timestamp2string(logs[i].created_at); + logs[i].key = '' + logs[i].id; + } + // data.key = '' + data.id + setLogs(logs); + setLogCount(logs.length + ITEMS_PER_PAGE); + // console.log(logCount); + }; + + const loadLogs = async (startIdx, pageSize, logType = 0) => { + setLoading(true); + + let url = ''; + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + if (isAdminUser) { + url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`; + } else { + url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } + const res = await API.get(url); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setLogsFormat(data); + } else { + let newLogs = [...logs]; + newLogs.splice(startIdx * pageSize, data.length, ...data); + setLogsFormat(newLogs); + } + } else { + showError(message); + } + setLoading(false); + }; + + const pageData = logs.slice((activePage - 1) * pageSize, activePage * pageSize); + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(logs.length / pageSize) + 1) { + // In this case we have to load more data and then append them. + loadLogs(page - 1, pageSize).then(r => { + }); + } + }; + + const handlePageSizeChange = async (size) => { + localStorage.setItem('page-size', size + ''); + setPageSize(size); + setActivePage(1); + loadLogs(0, size) + .then() + .catch((reason) => { + showError(reason); + }); + }; + + const refresh = async (localLogType) => { + // setLoading(true); + setActivePage(1); + await loadLogs(0, pageSize, localLogType); + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制:' + text); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + useEffect(() => { + // console.log('default effect') + const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE; + setPageSize(localPageSize); + loadLogs(0, localPageSize) + .then() + .catch((reason) => { + showError(reason); + }); + }, []); + + const searchLogs = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadLogs(0, pageSize); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/log/self/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setLogs(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + return (<> + +
+ +

使用明细(总消耗额度: + {showStat ? renderQuota(stat.quota) : '点击查看'} + ) +

+
+
+
+ <> + handleInputChange(value, 'token_name')} /> + handleInputChange(value, 'model_name')} /> + handleInputChange(value, 'start_timestamp')} /> + handleInputChange(value, 'end_timestamp')} /> + {isAdminUser && <> + handleInputChange(value, 'channel')} /> + handleInputChange(value, 'username')} /> + } + + + + + +
{ + handlePageSizeChange(size).then(); + }, + onPageChange: handlePageChange + }} /> + + + ); +}; + +export default LogsTable; diff --git a/web/air/src/components/MjLogsTable.js b/web/air/src/components/MjLogsTable.js new file mode 100644 index 00000000..6a6fbd95 --- /dev/null +++ b/web/air/src/components/MjLogsTable.js @@ -0,0 +1,454 @@ +import React, { useEffect, useState } from 'react'; +import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers'; + +import { Banner, Button, Form, ImagePreview, Layout, Modal, Progress, Table, Tag, Typography } from '@douyinfe/semi-ui'; +import { ITEMS_PER_PAGE } from '../constants'; + + +const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo', + 'light-blue', 'lime', 'orange', 'pink', + 'purple', 'red', 'teal', 'violet', 'yellow' +]; + +function renderType(type) { + switch (type) { + case 'IMAGINE': + return 绘图; + case 'UPSCALE': + return 放大; + case 'VARIATION': + return 变换; + case 'HIGH_VARIATION': + return 强变换; + case 'LOW_VARIATION': + return 弱变换; + case 'PAN': + return 平移; + case 'DESCRIBE': + return 图生文; + case 'BLEND': + return 图混合; + case 'SHORTEN': + return 缩词; + case 'REROLL': + return 重绘; + case 'INPAINT': + return 局部重绘-提交; + case 'ZOOM': + return 变焦; + case 'CUSTOM_ZOOM': + return 自定义变焦-提交; + case 'MODAL': + return 窗口处理; + case 'SWAP_FACE': + return 换脸; + default: + return 未知; + } +} + + +function renderCode(code) { + switch (code) { + case 1: + return 已提交; + case 21: + return 等待中; + case 22: + return 重复提交; + case 0: + return 未提交; + default: + return 未知; + } +} + + +function renderStatus(type) { + // Ensure all cases are string literals by adding quotes. + switch (type) { + case 'SUCCESS': + return 成功; + case 'NOT_START': + return 未启动; + case 'SUBMITTED': + return 队列中; + case 'IN_PROGRESS': + return 执行中; + case 'FAILURE': + return 失败; + case 'MODAL': + return 窗口等待; + default: + return 未知; + } +} + +const renderTimestamp = (timestampInSeconds) => { + const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒 + + const year = date.getFullYear(); // 获取年份 + const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份,从0开始需要+1,并保证两位数 + const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数 + const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数 + const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数 + const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数 + + return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出 +}; + + +const LogsTable = () => { + const [isModalOpen, setIsModalOpen] = useState(false); + const [modalContent, setModalContent] = useState(''); + const columns = [ + { + title: '提交时间', + dataIndex: 'submit_time', + render: (text, record, index) => { + return ( +
+ {renderTimestamp(text / 1000)} +
+ ); + } + }, + { + title: '渠道', + dataIndex: 'channel_id', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return ( + +
+ { + copyText(text); // 假设copyText是用于文本复制的函数 + }}> {text} +
+ + ); + } + }, + { + title: '类型', + dataIndex: 'action', + render: (text, record, index) => { + return ( +
+ {renderType(text)} +
+ ); + } + }, + { + title: '任务ID', + dataIndex: 'mj_id', + render: (text, record, index) => { + return ( +
+ {text} +
+ ); + } + }, + { + title: '提交结果', + dataIndex: 'code', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return ( +
+ {renderCode(text)} +
+ ); + } + }, + { + title: '任务状态', + dataIndex: 'status', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return ( +
+ {renderStatus(text)} +
+ ); + } + }, + { + title: '进度', + dataIndex: 'progress', + render: (text, record, index) => { + return ( +
+ { + // 转换例如100%为数字100,如果text未定义,返回0 + + } +
+ ); + } + }, + { + title: '结果图片', + dataIndex: 'image_url', + render: (text, record, index) => { + if (!text) { + return '无'; + } + return ( + + ); + } + }, + { + title: 'Prompt', + dataIndex: 'prompt', + render: (text, record, index) => { + // 如果text未定义,返回替代文本,例如空字符串''或其他 + if (!text) { + return '无'; + } + + return ( + { + setModalContent(text); + setIsModalOpen(true); + }} + > + {text} + + ); + } + }, + { + title: 'PromptEn', + dataIndex: 'prompt_en', + render: (text, record, index) => { + // 如果text未定义,返回替代文本,例如空字符串''或其他 + if (!text) { + return '无'; + } + + return ( + { + setModalContent(text); + setIsModalOpen(true); + }} + > + {text} + + ); + } + }, + { + title: '失败原因', + dataIndex: 'fail_reason', + render: (text, record, index) => { + // 如果text未定义,返回替代文本,例如空字符串''或其他 + if (!text) { + return '无'; + } + + return ( + { + setModalContent(text); + setIsModalOpen(true); + }} + > + {text} + + ); + } + } + + ]; + + const [logs, setLogs] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [logCount, setLogCount] = useState(ITEMS_PER_PAGE); + const [logType, setLogType] = useState(0); + const isAdminUser = isAdmin(); + const [isModalOpenurl, setIsModalOpenurl] = useState(false); + const [showBanner, setShowBanner] = useState(false); + + // 定义模态框图片URL的状态和更新函数 + const [modalImageUrl, setModalImageUrl] = useState(''); + let now = new Date(); + // 初始化start_timestamp为前一天 + const [inputs, setInputs] = useState({ + channel_id: '', + mj_id: '', + start_timestamp: timestamp2string(now.getTime() / 1000 - 2592000), + end_timestamp: timestamp2string(now.getTime() / 1000 + 3600) + }); + const { channel_id, mj_id, start_timestamp, end_timestamp } = inputs; + + const [stat, setStat] = useState({ + quota: 0, + token: 0 + }); + + const handleInputChange = (value, name) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + + const setLogsFormat = (logs) => { + for (let i = 0; i < logs.length; i++) { + logs[i].timestamp2string = timestamp2string(logs[i].created_at); + logs[i].key = '' + logs[i].id; + } + // data.key = '' + data.id + setLogs(logs); + setLogCount(logs.length + ITEMS_PER_PAGE); + // console.log(logCount); + }; + + const loadLogs = async (startIdx) => { + setLoading(true); + + let url = ''; + let localStartTimestamp = Date.parse(start_timestamp); + let localEndTimestamp = Date.parse(end_timestamp); + if (isAdminUser) { + url = `/api/mj/?p=${startIdx}&channel_id=${channel_id}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } else { + url = `/api/mj/self/?p=${startIdx}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } + const res = await API.get(url); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setLogsFormat(data); + } else { + let newLogs = [...logs]; + newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); + setLogsFormat(newLogs); + } + } else { + showError(message); + } + setLoading(false); + }; + + const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE); + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + loadLogs(page - 1).then(r => { + }); + } + }; + + const refresh = async () => { + // setLoading(true); + setActivePage(1); + await loadLogs(0); + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制:' + text); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + useEffect(() => { + refresh().then(); + }, [logType]); + + useEffect(() => { + const mjNotifyEnabled = localStorage.getItem('mj_notify_enabled'); + if (mjNotifyEnabled !== 'true') { + setShowBanner(true); + } + }, []); + + return ( + <> + + + {isAdminUser && showBanner ? : <> + } +
+ <> + handleInputChange(value, 'channel_id')} /> + handleInputChange(value, 'mj_id')} /> + handleInputChange(value, 'start_timestamp')} /> + handleInputChange(value, 'end_timestamp')} /> + + + + + + +
+ setIsModalOpen(false)} + onCancel={() => setIsModalOpen(false)} + closable={null} + bodyStyle={{ height: '400px', overflow: 'auto' }} // 设置模态框内容区域样式 + width={800} // 设置模态框宽度 + > +

{modalContent}

+
+ setIsModalOpenurl(visible)} + /> + + + + ); +}; + +export default LogsTable; diff --git a/web/air/src/components/OperationSetting.js b/web/air/src/components/OperationSetting.js new file mode 100644 index 00000000..6356ac66 --- /dev/null +++ b/web/air/src/components/OperationSetting.js @@ -0,0 +1,389 @@ +import React, { useEffect, useState } from 'react'; +import { Divider, Form, Grid, Header } from 'semantic-ui-react'; +import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers'; + +const OperationSetting = () => { + let now = new Date(); + let [inputs, setInputs] = useState({ + QuotaForNewUser: 0, + QuotaForInviter: 0, + QuotaForInvitee: 0, + QuotaRemindThreshold: 0, + PreConsumedQuota: 0, + ModelRatio: '', + CompletionRatio: '', + GroupRatio: '', + TopUpLink: '', + ChatLink: '', + QuotaPerUnit: 0, + AutomaticDisableChannelEnabled: '', + AutomaticEnableChannelEnabled: '', + ChannelDisableThreshold: 0, + LogConsumeEnabled: '', + DisplayInCurrencyEnabled: '', + DisplayTokenStatEnabled: '', + ApproximateTokenEnabled: '', + RetryTimes: 0 + }); + const [originInputs, setOriginInputs] = useState({}); + let [loading, setLoading] = useState(false); + let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key === 'ModelRatio' || item.key === 'GroupRatio' || item.key === 'CompletionRatio') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + if (item.value === '{}') { + item.value = ''; + } + newInputs[item.key] = item.value; + }); + setInputs(newInputs); + setOriginInputs(newInputs); + } else { + showError(message); + } + }; + + useEffect(() => { + getOptions().then(); + }, []); + + const updateOption = async (key, value) => { + setLoading(true); + if (key.endsWith('Enabled')) { + value = inputs[key] === 'true' ? 'false' : 'true'; + } + const res = await API.put('/api/option/', { + key, + value + }); + const { success, message } = res.data; + if (success) { + setInputs((inputs) => ({ ...inputs, [key]: value })); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, { name, value }) => { + if (name.endsWith('Enabled')) { + await updateOption(name, value); + } else { + setInputs((inputs) => ({ ...inputs, [name]: value })); + } + }; + + const submitConfig = async (group) => { + switch (group) { + case 'monitor': + if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) { + await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold); + } + if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) { + await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold); + } + break; + case 'ratio': + if (originInputs['ModelRatio'] !== inputs.ModelRatio) { + if (!verifyJSON(inputs.ModelRatio)) { + showError('模型倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('ModelRatio', inputs.ModelRatio); + } + if (originInputs['GroupRatio'] !== inputs.GroupRatio) { + if (!verifyJSON(inputs.GroupRatio)) { + showError('分组倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('GroupRatio', inputs.GroupRatio); + } + if (originInputs['CompletionRatio'] !== inputs.CompletionRatio) { + if (!verifyJSON(inputs.CompletionRatio)) { + showError('补全倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('CompletionRatio', inputs.CompletionRatio); + } + break; + case 'quota': + if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) { + await updateOption('QuotaForNewUser', inputs.QuotaForNewUser); + } + if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) { + await updateOption('QuotaForInvitee', inputs.QuotaForInvitee); + } + if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) { + await updateOption('QuotaForInviter', inputs.QuotaForInviter); + } + if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) { + await updateOption('PreConsumedQuota', inputs.PreConsumedQuota); + } + break; + case 'general': + if (originInputs['TopUpLink'] !== inputs.TopUpLink) { + await updateOption('TopUpLink', inputs.TopUpLink); + } + if (originInputs['ChatLink'] !== inputs.ChatLink) { + await updateOption('ChatLink', inputs.ChatLink); + } + if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) { + await updateOption('QuotaPerUnit', inputs.QuotaPerUnit); + } + if (originInputs['RetryTimes'] !== inputs.RetryTimes) { + await updateOption('RetryTimes', inputs.RetryTimes); + } + break; + } + }; + + const deleteHistoryLogs = async () => { + console.log(inputs); + const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`); + const { success, message, data } = res.data; + if (success) { + showSuccess(`${data} 条日志已清理!`); + return; + } + showError('日志清理失败:' + message); + }; + + return ( + + +
+
+ 通用设置 +
+ + + + + + + + + + + + { + submitConfig('general').then(); + }}>保存通用设置 + +
+ 日志设置 +
+ + + + + { + setHistoryTimestamp(value); + }} /> + + { + deleteHistoryLogs().then(); + }}>清理历史日志 + +
+ 监控设置 +
+ + + + + + + + + { + submitConfig('monitor').then(); + }}>保存监控设置 + +
+ 额度设置 +
+ + + + + + + { + submitConfig('quota').then(); + }}>保存额度设置 + +
+ 倍率设置 +
+ + + + + + + + + + { + submitConfig('ratio').then(); + }}>保存倍率设置 + +
+
+ ); +}; + +export default OperationSetting; diff --git a/web/air/src/components/OtherSetting.js b/web/air/src/components/OtherSetting.js new file mode 100644 index 00000000..ae924d9f --- /dev/null +++ b/web/air/src/components/OtherSetting.js @@ -0,0 +1,225 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Divider, Form, Grid, Header, Message, Modal } from 'semantic-ui-react'; +import { API, showError, showSuccess } from '../helpers'; +import { marked } from 'marked'; +import { Link } from 'react-router-dom'; + +const OtherSetting = () => { + let [inputs, setInputs] = useState({ + Footer: '', + Notice: '', + About: '', + SystemName: '', + Logo: '', + HomePageContent: '', + Theme: '' + }); + let [loading, setLoading] = useState(false); + const [showUpdateModal, setShowUpdateModal] = useState(false); + const [updateData, setUpdateData] = useState({ + tag_name: '', + content: '' + }); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key in inputs) { + newInputs[item.key] = item.value; + } + }); + setInputs(newInputs); + } else { + showError(message); + } + }; + + useEffect(() => { + getOptions().then(); + }, []); + + const updateOption = async (key, value) => { + setLoading(true); + const res = await API.put('/api/option/', { + key, + value + }); + const { success, message } = res.data; + if (success) { + setInputs((inputs) => ({ ...inputs, [key]: value })); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, { name, value }) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + const submitNotice = async () => { + await updateOption('Notice', inputs.Notice); + }; + + const submitFooter = async () => { + await updateOption('Footer', inputs.Footer); + }; + + const submitSystemName = async () => { + await updateOption('SystemName', inputs.SystemName); + }; + + const submitTheme = async () => { + await updateOption('Theme', inputs.Theme); + }; + + const submitLogo = async () => { + await updateOption('Logo', inputs.Logo); + }; + + const submitAbout = async () => { + await updateOption('About', inputs.About); + }; + + const submitOption = async (key) => { + await updateOption(key, inputs[key]); + }; + + const openGitHubRelease = () => { + window.location = + 'https://github.com/songquanpeng/one-api/releases/latest'; + }; + + const checkUpdate = async () => { + const res = await API.get( + 'https://api.github.com/repos/songquanpeng/one-api/releases/latest' + ); + const { tag_name, body } = res.data; + if (tag_name === process.env.REACT_APP_VERSION) { + showSuccess(`已是最新版本:${tag_name}`); + } else { + setUpdateData({ + tag_name: tag_name, + content: marked.parse(body) + }); + setShowUpdateModal(true); + } + }; + + return ( + + +
+
通用设置
+ 检查更新 + + + + 保存公告 + +
个性化设置
+ + + + 设置系统名称 + + 主题名称(当前可用主题)} + placeholder='请输入主题名称' + value={inputs.Theme} + name='Theme' + onChange={handleInputChange} + /> + + 设置主题(重启生效) + + + + 设置 Logo + + + + submitOption('HomePageContent')}>保存首页内容 + + + + 保存关于 + 移除 One API + 的版权标识必须首先获得授权,项目维护需要花费大量精力,如果本项目对你有意义,请主动支持本项目。 + + + + 设置页脚 + +
+ setShowUpdateModal(false)} + onOpen={() => setShowUpdateModal(true)} + open={showUpdateModal} + > + 新版本:{updateData.tag_name} + + +
+
+
+ + + + + + +
+ ); +}; + +export default PasswordResetConfirm; diff --git a/web/air/src/components/PasswordResetForm.js b/web/air/src/components/PasswordResetForm.js new file mode 100644 index 00000000..ff3eaadb --- /dev/null +++ b/web/air/src/components/PasswordResetForm.js @@ -0,0 +1,102 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Form, Grid, Header, Image, Segment } from 'semantic-ui-react'; +import { API, showError, showInfo, showSuccess } from '../helpers'; +import Turnstile from 'react-turnstile'; + +const PasswordResetForm = () => { + const [inputs, setInputs] = useState({ + email: '' + }); + const { email } = inputs; + + const [loading, setLoading] = useState(false); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + const [disableButton, setDisableButton] = useState(false); + const [countdown, setCountdown] = useState(30); + + useEffect(() => { + let countdownInterval = null; + if (disableButton && countdown > 0) { + countdownInterval = setInterval(() => { + setCountdown(countdown - 1); + }, 1000); + } else if (countdown === 0) { + setDisableButton(false); + setCountdown(30); + } + return () => clearInterval(countdownInterval); + }, [disableButton, countdown]); + + function handleChange(e) { + const { name, value } = e.target; + setInputs(inputs => ({ ...inputs, [name]: value })); + } + + async function handleSubmit(e) { + setDisableButton(true); + if (!email) return; + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/reset_password?email=${email}&turnstile=${turnstileToken}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('重置邮件发送成功,请检查邮箱!'); + setInputs({ ...inputs, email: '' }); + } else { + showError(message); + } + setLoading(false); + } + + return ( + + +
+ 密码重置 +
+
+ + + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} + + + +
+
+ ); +}; + +export default PasswordResetForm; diff --git a/web/air/src/components/PersonalSetting.js b/web/air/src/components/PersonalSetting.js new file mode 100644 index 00000000..45a5b776 --- /dev/null +++ b/web/air/src/components/PersonalSetting.js @@ -0,0 +1,653 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { useNavigate } from 'react-router-dom'; +import { API, copy, isRoot, showError, showInfo, showSuccess } from '../helpers'; +import Turnstile from 'react-turnstile'; +import { UserContext } from '../context/User'; +import { onGitHubOAuthClicked } from './utils'; +import { + Avatar, + Banner, + Button, + Card, + Descriptions, + Image, + Input, + InputNumber, + Layout, + Modal, + Space, + Tag, + Typography +} from '@douyinfe/semi-ui'; +import { getQuotaPerUnit, renderQuota, renderQuotaWithPrompt, stringToColor } from '../helpers/render'; +import TelegramLoginButton from 'react-telegram-login'; + +const PersonalSetting = () => { + const [userState, userDispatch] = useContext(UserContext); + let navigate = useNavigate(); + + const [inputs, setInputs] = useState({ + wechat_verification_code: '', + email_verification_code: '', + email: '', + self_account_deletion_confirmation: '', + set_new_password: '', + set_new_password_confirmation: '' + }); + const [status, setStatus] = useState({}); + const [showChangePasswordModal, setShowChangePasswordModal] = useState(false); + const [showWeChatBindModal, setShowWeChatBindModal] = useState(false); + const [showEmailBindModal, setShowEmailBindModal] = useState(false); + const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + const [loading, setLoading] = useState(false); + const [disableButton, setDisableButton] = useState(false); + const [countdown, setCountdown] = useState(30); + const [affLink, setAffLink] = useState(''); + const [systemToken, setSystemToken] = useState(''); + // const [models, setModels] = useState([]); + const [openTransfer, setOpenTransfer] = useState(false); + const [transferAmount, setTransferAmount] = useState(0); + + useEffect(() => { + // let user = localStorage.getItem('user'); + // if (user) { + // userDispatch({ type: 'login', payload: user }); + // } + // console.log(localStorage.getItem('user')) + + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + setStatus(status); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + } + getUserData().then( + (res) => { + console.log(userState); + } + ); + // loadModels().then(); + getAffLink().then(); + setTransferAmount(getQuotaPerUnit()); + }, []); + + useEffect(() => { + let countdownInterval = null; + if (disableButton && countdown > 0) { + countdownInterval = setInterval(() => { + setCountdown(countdown - 1); + }, 1000); + } else if (countdown === 0) { + setDisableButton(false); + setCountdown(30); + } + return () => clearInterval(countdownInterval); // Clean up on unmount + }, [disableButton, countdown]); + + const handleInputChange = (name, value) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + const generateAccessToken = async () => { + const res = await API.get('/api/user/token'); + const { success, message, data } = res.data; + if (success) { + setSystemToken(data); + await copy(data); + showSuccess(`令牌已重置并已复制到剪贴板`); + } else { + showError(message); + } + }; + + const getAffLink = async () => { + const res = await API.get('/api/user/aff'); + const { success, message, data } = res.data; + if (success) { + let link = `${window.location.origin}/register?aff=${data}`; + setAffLink(link); + } else { + showError(message); + } + }; + + const getUserData = async () => { + let res = await API.get(`/api/user/self`); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + } else { + showError(message); + } + }; + + // const loadModels = async () => { + // let res = await API.get(`/api/user/models`); + // const { success, message, data } = res.data; + // if (success) { + // setModels(data); + // console.log(data); + // } else { + // showError(message); + // } + // }; + + const handleAffLinkClick = async (e) => { + e.target.select(); + await copy(e.target.value); + showSuccess(`邀请链接已复制到剪切板`); + }; + + const handleSystemTokenClick = async (e) => { + e.target.select(); + await copy(e.target.value); + showSuccess(`系统令牌已复制到剪切板`); + }; + + const deleteAccount = async () => { + if (inputs.self_account_deletion_confirmation !== userState.user.username) { + showError('请输入你的账户名以确认删除!'); + return; + } + + const res = await API.delete('/api/user/self'); + const { success, message } = res.data; + + if (success) { + showSuccess('账户已删除!'); + await API.get('/api/user/logout'); + userDispatch({ type: 'logout' }); + localStorage.removeItem('user'); + navigate('/login'); + } else { + showError(message); + } + }; + + const bindWeChat = async () => { + if (inputs.wechat_verification_code === '') return; + const res = await API.get( + `/api/oauth/wechat/bind?code=${inputs.wechat_verification_code}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('微信账户绑定成功!'); + setShowWeChatBindModal(false); + } else { + showError(message); + } + }; + + const changePassword = async () => { + if (inputs.set_new_password !== inputs.set_new_password_confirmation) { + showError('两次输入的密码不一致!'); + return; + } + const res = await API.put( + `/api/user/self`, + { + password: inputs.set_new_password + } + ); + const { success, message } = res.data; + if (success) { + showSuccess('密码修改成功!'); + setShowWeChatBindModal(false); + } else { + showError(message); + } + setShowChangePasswordModal(false); + }; + + const transfer = async () => { + if (transferAmount < getQuotaPerUnit()) { + showError('划转金额最低为' + renderQuota(getQuotaPerUnit())); + return; + } + const res = await API.post( + `/api/user/aff_transfer`, + { + quota: transferAmount + } + ); + const { success, message } = res.data; + if (success) { + showSuccess(message); + setOpenTransfer(false); + getUserData().then(); + } else { + showError(message); + } + }; + + const sendVerificationCode = async () => { + if (inputs.email === '') { + showError('请输入邮箱!'); + return; + } + setDisableButton(true); + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('验证码发送成功,请检查邮箱!'); + } else { + showError(message); + } + setLoading(false); + }; + + const bindEmail = async () => { + if (inputs.email_verification_code === '') { + showError('请输入邮箱验证码!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/oauth/email/bind?email=${inputs.email}&code=${inputs.email_verification_code}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('邮箱账户绑定成功!'); + setShowEmailBindModal(false); + userState.user.email = inputs.email; + } else { + showError(message); + } + setLoading(false); + }; + + const getUsername = () => { + if (userState.user) { + return userState.user.username; + } else { + return 'null'; + } + }; + + const handleCancel = () => { + setOpenTransfer(false); + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制:' + text); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + return ( +
+ + + +
+ {`可用额度${renderQuotaWithPrompt(userState?.user?.aff_quota)}`} + +
+
+ {`划转额度${renderQuotaWithPrompt(transferAmount)} 最低` + renderQuota(getQuotaPerUnit())} +
+ setTransferAmount(value)} disabled={false}> +
+
+
+
+ + {typeof getUsername() === 'string' && getUsername().slice(0, 1)} + } + title={{getUsername()}} + description={isRoot() ? 管理员 : 普通用户} + > + } + headerExtraContent={ + <> + + {'ID: ' + userState?.user?.id} + {userState?.user?.group} + + + } + footer={ + + {renderQuota(userState?.user?.quota)} + {renderQuota(userState?.user?.used_quota)} + {userState.user?.request_count} + + } + > + 调用信息 + {/* 可用模型 +
+ + {models.map((model) => ( + { + copyText(model); + }}> + {model} + + ))} + +
*/} +
+ {/* + 邀请链接 + +
+ } + > + 邀请信息 +
+ + + + { + renderQuota(userState?.user?.aff_quota) + } + + + + {renderQuota(userState?.user?.aff_history_quota)} + {userState?.user?.aff_count} + +
+ */} + + 邀请链接 + + + + 个人信息 +
+ 邮箱 +
+
+ +
+
+ +
+
+
+
+ 微信 +
+
+ +
+
+ +
+
+
+
+ GitHub +
+
+ +
+
+ +
+
+
+ + {/*
+ Telegram +
+
+ +
+
+ {status.telegram_oauth ? + userState.user.telegram_id !== '' ? + : + : + } +
+
+
*/} + +
+ + + + + + + {systemToken && ( + + )} + { + status.wechat_login && ( + + ) + } + setShowWeChatBindModal(false)} + // onOpen={() => setShowWeChatBindModal(true)} + visible={showWeChatBindModal} + size={'mini'} + > + +
+

+ 微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) +

+
+ handleInputChange('wechat_verification_code', v)} + /> + +
+
+
+ setShowEmailBindModal(false)} + // onOpen={() => setShowEmailBindModal(true)} + onOk={bindEmail} + visible={showEmailBindModal} + size={'small'} + centered={true} + maskClosable={false} + > + 绑定邮箱地址 +
+ handleInputChange('email', value)} + name="email" + type="email" + /> + +
+
+ handleInputChange('email_verification_code', value)} + /> +
+ {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} +
+ setShowAccountDeleteModal(false)} + visible={showAccountDeleteModal} + size={'small'} + centered={true} + onOk={deleteAccount} + > +
+ +
+
+ handleInputChange('self_account_deletion_confirmation', value)} + /> + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} +
+
+ setShowChangePasswordModal(false)} + visible={showChangePasswordModal} + size={'small'} + centered={true} + onOk={changePassword} + > +
+ handleInputChange('set_new_password', value)} + /> + handleInputChange('set_new_password_confirmation', value)} + /> + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} +
+
+
+ + + + + ); +}; + +export default PersonalSetting; diff --git a/web/air/src/components/PrivateRoute.js b/web/air/src/components/PrivateRoute.js new file mode 100644 index 00000000..9ef826c1 --- /dev/null +++ b/web/air/src/components/PrivateRoute.js @@ -0,0 +1,13 @@ +import { Navigate } from 'react-router-dom'; + +import { history } from '../helpers'; + + +function PrivateRoute({ children }) { + if (!localStorage.getItem('user')) { + return ; + } + return children; +} + +export { PrivateRoute }; \ No newline at end of file diff --git a/web/air/src/components/RedemptionsTable.js b/web/air/src/components/RedemptionsTable.js new file mode 100644 index 00000000..89e4ce20 --- /dev/null +++ b/web/air/src/components/RedemptionsTable.js @@ -0,0 +1,406 @@ +import React, { useEffect, useState } from 'react'; +import { API, copy, showError, showSuccess, timestamp2string } from '../helpers'; + +import { ITEMS_PER_PAGE } from '../constants'; +import { renderQuota } from '../helpers/render'; +import { Button, Form, Modal, Popconfirm, Popover, Table, Tag } from '@douyinfe/semi-ui'; +import EditRedemption from '../pages/Redemption/EditRedemption'; + +function renderTimestamp(timestamp) { + return ( + <> + {timestamp2string(timestamp)} + + ); +} + +function renderStatus(status) { + switch (status) { + case 1: + return 未使用; + case 2: + return 已禁用 ; + case 3: + return 已使用 ; + default: + return 未知状态 ; + } +} + +const RedemptionsTable = () => { + const columns = [ + { + title: 'ID', + dataIndex: 'id' + }, + { + title: '名称', + dataIndex: 'name' + }, + { + title: '状态', + dataIndex: 'status', + key: 'status', + render: (text, record, index) => { + return ( +
+ {renderStatus(text)} +
+ ); + } + }, + { + title: '额度', + dataIndex: 'quota', + render: (text, record, index) => { + return ( +
+ {renderQuota(parseInt(text))} +
+ ); + } + }, + { + title: '创建时间', + dataIndex: 'created_time', + render: (text, record, index) => { + return ( +
+ {renderTimestamp(text)} +
+ ); + } + }, + // { + // title: '兑换人ID', + // dataIndex: 'used_user_id', + // render: (text, record, index) => { + // return ( + //
+ // {text === 0 ? '无' : text} + //
+ // ); + // } + // }, + { + title: '', + dataIndex: 'operate', + render: (text, record, index) => ( +
+ + + + + { + manageRedemption(record.id, 'delete', record).then( + () => { + removeRecord(record.key); + } + ); + }} + > + + + { + record.status === 1 ? + : + + } + +
+ ) + } + ]; + + const [redemptions, setRedemptions] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + const [tokenCount, setTokenCount] = useState(ITEMS_PER_PAGE); + const [selectedKeys, setSelectedKeys] = useState([]); + const [editingRedemption, setEditingRedemption] = useState({ + id: undefined + }); + const [showEdit, setShowEdit] = useState(false); + + const closeEdit = () => { + setShowEdit(false); + }; + + // const setCount = (data) => { + // if (data.length >= (activePage) * ITEMS_PER_PAGE) { + // setTokenCount(data.length + 1); + // } else { + // setTokenCount(data.length); + // } + // } + + const setRedemptionFormat = (redeptions) => { + // for (let i = 0; i < redeptions.length; i++) { + // redeptions[i].key = '' + redeptions[i].id; + // } + // data.key = '' + data.id + setRedemptions(redeptions); + if (redeptions.length >= (activePage) * ITEMS_PER_PAGE) { + setTokenCount(redeptions.length + 1); + } else { + setTokenCount(redeptions.length); + } + }; + + const loadRedemptions = async (startIdx) => { + const res = await API.get(`/api/redemption/?p=${startIdx}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setRedemptionFormat(data); + } else { + let newRedemptions = redemptions; + newRedemptions.push(...data); + setRedemptionFormat(newRedemptions); + } + } else { + showError(message); + } + setLoading(false); + }; + + const removeRecord = key => { + let newDataSource = [...redemptions]; + if (key != null) { + let idx = newDataSource.findIndex(data => data.key === key); + + if (idx > -1) { + newDataSource.splice(idx, 1); + setRedemptions(newDataSource); + } + } + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制到剪贴板!'); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(redemptions.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + await loadRedemptions(activePage - 1); + } + setActivePage(activePage); + })(); + }; + + useEffect(() => { + loadRedemptions(0) + .then() + .catch((reason) => { + showError(reason); + }); + }, []); + + const refresh = async () => { + await loadRedemptions(activePage - 1); + }; + + const manageRedemption = async (id, action, record) => { + let data = { id }; + let res; + switch (action) { + case 'delete': + res = await API.delete(`/api/redemption/${id}/`); + break; + case 'enable': + data.status = 1; + res = await API.put('/api/redemption/?status_only=true', data); + break; + case 'disable': + data.status = 2; + res = await API.put('/api/redemption/?status_only=true', data); + break; + } + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let redemption = res.data.data; + let newRedemptions = [...redemptions]; + // let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + if (action === 'delete') { + + } else { + record.status = redemption.status; + } + setRedemptions(newRedemptions); + } else { + showError(message); + } + }; + + const searchRedemptions = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadRedemptions(0); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/redemption/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setRedemptions(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (value) => { + setSearchKeyword(value.trim()); + }; + + const sortRedemption = (key) => { + if (redemptions.length === 0) return; + setLoading(true); + let sortedRedemptions = [...redemptions]; + sortedRedemptions.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + if (sortedRedemptions[0].id === redemptions[0].id) { + sortedRedemptions.reverse(); + } + setRedemptions(sortedRedemptions); + setLoading(false); + }; + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(redemptions.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + loadRedemptions(page - 1).then(r => { + }); + } + }; + + let pageData = redemptions.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE); + const rowSelection = { + onSelect: (record, selected) => { + }, + onSelectAll: (selected, selectedRows) => { + }, + onChange: (selectedRowKeys, selectedRows) => { + setSelectedKeys(selectedRows); + } + }; + + const handleRow = (record, index) => { + if (record.status !== 1) { + return { + style: { + background: 'var(--semi-color-disabled-border)' + } + }; + } else { + return {}; + } + }; + + return ( + <> + +
+ + + +
`第 ${page.currentStart} - ${page.currentEnd} 条,共 ${redemptions.length} 条`, + // onPageSizeChange: (size) => { + // setPageSize(size); + // setActivePage(1); + // }, + onPageChange: handlePageChange + }} loading={loading} rowSelection={rowSelection} onRow={handleRow}> +
+ + + + ); +}; + +export default RedemptionsTable; diff --git a/web/air/src/components/RegisterForm.js b/web/air/src/components/RegisterForm.js new file mode 100644 index 00000000..1f26b63f --- /dev/null +++ b/web/air/src/components/RegisterForm.js @@ -0,0 +1,194 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Form, Grid, Header, Image, Message, Segment } from 'semantic-ui-react'; +import { Link, useNavigate } from 'react-router-dom'; +import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; +import Turnstile from 'react-turnstile'; + +const RegisterForm = () => { + const [inputs, setInputs] = useState({ + username: '', + password: '', + password2: '', + email: '', + verification_code: '' + }); + const { username, password, password2 } = inputs; + const [showEmailVerification, setShowEmailVerification] = useState(false); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + const [loading, setLoading] = useState(false); + const logo = getLogo(); + let affCode = new URLSearchParams(window.location.search).get('aff'); + if (affCode) { + localStorage.setItem('aff', affCode); + } + + useEffect(() => { + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + setShowEmailVerification(status.email_verification); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + } + }); + + let navigate = useNavigate(); + + function handleChange(e) { + const { name, value } = e.target; + console.log(name, value); + setInputs((inputs) => ({ ...inputs, [name]: value })); + } + + async function handleSubmit(e) { + if (password.length < 8) { + showInfo('密码长度不得小于 8 位!'); + return; + } + if (password !== password2) { + showInfo('两次输入的密码不一致'); + return; + } + if (username && password) { + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + if (!affCode) { + affCode = localStorage.getItem('aff'); + } + inputs.aff_code = affCode; + const res = await API.post( + `/api/user/register?turnstile=${turnstileToken}`, + inputs + ); + const { success, message } = res.data; + if (success) { + navigate('/login'); + showSuccess('注册成功!'); + } else { + showError(message); + } + setLoading(false); + } + } + + const sendVerificationCode = async () => { + if (inputs.email === '') return; + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('验证码发送成功,请检查你的邮箱!'); + } else { + showError(message); + } + setLoading(false); + }; + + return ( + + +
+ 新用户注册 +
+
+ + + + + {showEmailVerification ? ( + <> + + 获取验证码 + + } + /> + + + ) : ( + <> + )} + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} + + +
+ + 已有账户? + + 点击登录 + + +
+
+ ); +}; + +export default RegisterForm; diff --git a/web/air/src/components/SiderBar.js b/web/air/src/components/SiderBar.js new file mode 100644 index 00000000..b3da272f --- /dev/null +++ b/web/air/src/components/SiderBar.js @@ -0,0 +1,214 @@ +import React, { useContext, useEffect, useMemo, useState } from 'react'; +import { Link, useNavigate } from 'react-router-dom'; +import { UserContext } from '../context/User'; +import { StatusContext } from '../context/Status'; + +import { API, getLogo, getSystemName, isAdmin, isMobile, showError } from '../helpers'; +import '../index.css'; + +import { + IconCalendarClock, + IconComment, + IconCreditCard, + IconGift, + IconHistogram, + IconHome, + IconImage, + IconKey, + IconLayers, + IconSetting, + IconUser +} from '@douyinfe/semi-icons'; +import { Layout, Nav } from '@douyinfe/semi-ui'; + +// HeaderBar Buttons + +const SiderBar = () => { + const [userState, userDispatch] = useContext(UserContext); + const [statusState, statusDispatch] = useContext(StatusContext); + const defaultIsCollapsed = isMobile() || localStorage.getItem('default_collapse_sidebar') === 'true'; + + let navigate = useNavigate(); + const [selectedKeys, setSelectedKeys] = useState(['home']); + const systemName = getSystemName(); + const logo = getLogo(); + const [isCollapsed, setIsCollapsed] = useState(defaultIsCollapsed); + + const headerButtons = useMemo(() => [ + { + text: '首页', + itemKey: 'home', + to: '/', + icon: + }, + { + text: '渠道', + itemKey: 'channel', + to: '/channel', + icon: , + className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '聊天', + itemKey: 'chat', + to: '/chat', + icon: , + className: localStorage.getItem('chat_link') ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '令牌', + itemKey: 'token', + to: '/token', + icon: + }, + { + text: '兑换', + itemKey: 'redemption', + to: '/redemption', + icon: , + className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '充值', + itemKey: 'topup', + to: '/topup', + icon: + }, + { + text: '用户', + itemKey: 'user', + to: '/user', + icon: , + className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '日志', + itemKey: 'log', + to: '/log', + icon: + }, + { + text: '数据看板', + itemKey: 'detail', + to: '/detail', + icon: , + className: localStorage.getItem('enable_data_export') === 'true' ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '绘图', + itemKey: 'midjourney', + to: '/midjourney', + icon: , + className: localStorage.getItem('enable_drawing') === 'true' ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '设置', + itemKey: 'setting', + to: '/setting', + icon: + } + // { + // text: '关于', + // itemKey: 'about', + // to: '/about', + // icon: + // } + ], [localStorage.getItem('enable_data_export'), localStorage.getItem('enable_drawing'), localStorage.getItem('chat_link'), isAdmin()]); + + const loadStatus = async () => { + const res = await API.get('/api/status'); + const { success, data } = res.data; + if (success) { + localStorage.setItem('status', JSON.stringify(data)); + statusDispatch({ type: 'set', payload: data }); + localStorage.setItem('system_name', data.system_name); + localStorage.setItem('logo', data.logo); + localStorage.setItem('footer_html', data.footer_html); + localStorage.setItem('quota_per_unit', data.quota_per_unit); + localStorage.setItem('display_in_currency', data.display_in_currency); + localStorage.setItem('enable_drawing', data.enable_drawing); + localStorage.setItem('enable_data_export', data.enable_data_export); + localStorage.setItem('data_export_default_time', data.data_export_default_time); + localStorage.setItem('default_collapse_sidebar', data.default_collapse_sidebar); + localStorage.setItem('mj_notify_enabled', data.mj_notify_enabled); + if (data.chat_link) { + localStorage.setItem('chat_link', data.chat_link); + } else { + localStorage.removeItem('chat_link'); + } + if (data.chat_link2) { + localStorage.setItem('chat_link2', data.chat_link2); + } else { + localStorage.removeItem('chat_link2'); + } + } else { + showError('无法正常连接至服务器!'); + } + }; + + useEffect(() => { + loadStatus().then(() => { + setIsCollapsed(isMobile() || localStorage.getItem('default_collapse_sidebar') === 'true'); + }); + }, []); + + return ( + <> + +
+ +
+
+ + ); +}; + +export default SiderBar; diff --git a/web/air/src/components/SystemSetting.js b/web/air/src/components/SystemSetting.js new file mode 100644 index 00000000..09b98665 --- /dev/null +++ b/web/air/src/components/SystemSetting.js @@ -0,0 +1,590 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Divider, Form, Grid, Header, Modal, Message } from 'semantic-ui-react'; +import { API, removeTrailingSlash, showError } from '../helpers'; + +const SystemSetting = () => { + let [inputs, setInputs] = useState({ + PasswordLoginEnabled: '', + PasswordRegisterEnabled: '', + EmailVerificationEnabled: '', + GitHubOAuthEnabled: '', + GitHubClientId: '', + GitHubClientSecret: '', + Notice: '', + SMTPServer: '', + SMTPPort: '', + SMTPAccount: '', + SMTPFrom: '', + SMTPToken: '', + ServerAddress: '', + Footer: '', + WeChatAuthEnabled: '', + WeChatServerAddress: '', + WeChatServerToken: '', + WeChatAccountQRCodeImageURL: '', + MessagePusherAddress: '', + MessagePusherToken: '', + TurnstileCheckEnabled: '', + TurnstileSiteKey: '', + TurnstileSecretKey: '', + RegisterEnabled: '', + EmailDomainRestrictionEnabled: '', + EmailDomainWhitelist: '' + }); + const [originInputs, setOriginInputs] = useState({}); + let [loading, setLoading] = useState(false); + const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]); + const [restrictedDomainInput, setRestrictedDomainInput] = useState(''); + const [showPasswordWarningModal, setShowPasswordWarningModal] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + newInputs[item.key] = item.value; + }); + setInputs({ + ...newInputs, + EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(',') + }); + setOriginInputs(newInputs); + + setEmailDomainWhitelist(newInputs.EmailDomainWhitelist.split(',').map((item) => { + return { key: item, text: item, value: item }; + })); + } else { + showError(message); + } + }; + + useEffect(() => { + getOptions().then(); + }, []); + + const updateOption = async (key, value) => { + setLoading(true); + switch (key) { + case 'PasswordLoginEnabled': + case 'PasswordRegisterEnabled': + case 'EmailVerificationEnabled': + case 'GitHubOAuthEnabled': + case 'WeChatAuthEnabled': + case 'TurnstileCheckEnabled': + case 'EmailDomainRestrictionEnabled': + case 'RegisterEnabled': + value = inputs[key] === 'true' ? 'false' : 'true'; + break; + default: + break; + } + const res = await API.put('/api/option/', { + key, + value + }); + const { success, message } = res.data; + if (success) { + if (key === 'EmailDomainWhitelist') { + value = value.split(','); + } + setInputs((inputs) => ({ + ...inputs, [key]: value + })); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, { name, value }) => { + if (name === 'PasswordLoginEnabled' && inputs[name] === 'true') { + // block disabling password login + setShowPasswordWarningModal(true); + return; + } + if ( + name === 'Notice' || + name.startsWith('SMTP') || + name === 'ServerAddress' || + name === 'GitHubClientId' || + name === 'GitHubClientSecret' || + name === 'WeChatServerAddress' || + name === 'WeChatServerToken' || + name === 'WeChatAccountQRCodeImageURL' || + name === 'TurnstileSiteKey' || + name === 'TurnstileSecretKey' || + name === 'EmailDomainWhitelist' + ) { + setInputs((inputs) => ({ ...inputs, [name]: value })); + } else { + await updateOption(name, value); + } + }; + + const submitServerAddress = async () => { + let ServerAddress = removeTrailingSlash(inputs.ServerAddress); + await updateOption('ServerAddress', ServerAddress); + }; + + const submitSMTP = async () => { + if (originInputs['SMTPServer'] !== inputs.SMTPServer) { + await updateOption('SMTPServer', inputs.SMTPServer); + } + if (originInputs['SMTPAccount'] !== inputs.SMTPAccount) { + await updateOption('SMTPAccount', inputs.SMTPAccount); + } + if (originInputs['SMTPFrom'] !== inputs.SMTPFrom) { + await updateOption('SMTPFrom', inputs.SMTPFrom); + } + if ( + originInputs['SMTPPort'] !== inputs.SMTPPort && + inputs.SMTPPort !== '' + ) { + await updateOption('SMTPPort', inputs.SMTPPort); + } + if ( + originInputs['SMTPToken'] !== inputs.SMTPToken && + inputs.SMTPToken !== '' + ) { + await updateOption('SMTPToken', inputs.SMTPToken); + } + }; + + + const submitEmailDomainWhitelist = async () => { + if ( + originInputs['EmailDomainWhitelist'] !== inputs.EmailDomainWhitelist.join(',') && + inputs.SMTPToken !== '' + ) { + await updateOption('EmailDomainWhitelist', inputs.EmailDomainWhitelist.join(',')); + } + }; + + const submitWeChat = async () => { + if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) { + await updateOption( + 'WeChatServerAddress', + removeTrailingSlash(inputs.WeChatServerAddress) + ); + } + if ( + originInputs['WeChatAccountQRCodeImageURL'] !== + inputs.WeChatAccountQRCodeImageURL + ) { + await updateOption( + 'WeChatAccountQRCodeImageURL', + inputs.WeChatAccountQRCodeImageURL + ); + } + if ( + originInputs['WeChatServerToken'] !== inputs.WeChatServerToken && + inputs.WeChatServerToken !== '' + ) { + await updateOption('WeChatServerToken', inputs.WeChatServerToken); + } + }; + + const submitMessagePusher = async () => { + if (originInputs['MessagePusherAddress'] !== inputs.MessagePusherAddress) { + await updateOption( + 'MessagePusherAddress', + removeTrailingSlash(inputs.MessagePusherAddress) + ); + } + if ( + originInputs['MessagePusherToken'] !== inputs.MessagePusherToken && + inputs.MessagePusherToken !== '' + ) { + await updateOption('MessagePusherToken', inputs.MessagePusherToken); + } + }; + + const submitGitHubOAuth = async () => { + if (originInputs['GitHubClientId'] !== inputs.GitHubClientId) { + await updateOption('GitHubClientId', inputs.GitHubClientId); + } + if ( + originInputs['GitHubClientSecret'] !== inputs.GitHubClientSecret && + inputs.GitHubClientSecret !== '' + ) { + await updateOption('GitHubClientSecret', inputs.GitHubClientSecret); + } + }; + + const submitTurnstile = async () => { + if (originInputs['TurnstileSiteKey'] !== inputs.TurnstileSiteKey) { + await updateOption('TurnstileSiteKey', inputs.TurnstileSiteKey); + } + if ( + originInputs['TurnstileSecretKey'] !== inputs.TurnstileSecretKey && + inputs.TurnstileSecretKey !== '' + ) { + await updateOption('TurnstileSecretKey', inputs.TurnstileSecretKey); + } + }; + + const submitNewRestrictedDomain = () => { + const localDomainList = inputs.EmailDomainWhitelist; + if (restrictedDomainInput !== '' && !localDomainList.includes(restrictedDomainInput)) { + setRestrictedDomainInput(''); + setInputs({ + ...inputs, + EmailDomainWhitelist: [...localDomainList, restrictedDomainInput], + }); + setEmailDomainWhitelist([...EmailDomainWhitelist, { + key: restrictedDomainInput, + text: restrictedDomainInput, + value: restrictedDomainInput, + }]); + } + } + + return ( + + +
+
通用设置
+ + + + + 更新服务器地址 + + +
配置登录注册
+ + + { + showPasswordWarningModal && + setShowPasswordWarningModal(false)} + size={'tiny'} + style={{ maxWidth: '450px' }} + > + 警告 + +

取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?

+
+ + + + +
+ } + + + + +
+ + + + + +
+ 配置邮箱域名白名单 + 用以防止恶意用户利用临时邮箱批量注册 +
+ + + + + + { + submitNewRestrictedDomain(); + }}>填入 + } + onKeyDown={(e) => { + if (e.key === 'Enter') { + submitNewRestrictedDomain(); + } + }} + autoComplete='new-password' + placeholder='输入新的允许的邮箱域名' + value={restrictedDomainInput} + onChange={(e, { value }) => { + setRestrictedDomainInput(value); + }} + /> + + 保存邮箱域名白名单设置 + +
+ 配置 SMTP + 用以支持系统的邮件发送 +
+ + + + + + + + + + 保存 SMTP 设置 + +
+ 配置 GitHub OAuth App + + 用以支持通过 GitHub 进行登录注册, + + 点击此处 + + 管理你的 GitHub OAuth App + +
+ + Homepage URL 填 {inputs.ServerAddress} + ,Authorization callback URL 填{' '} + {`${inputs.ServerAddress}/oauth/github`} + + + + + + + 保存 GitHub OAuth 设置 + + +
+ 配置 WeChat Server + + 用以支持通过微信进行登录注册, + + 点击此处 + + 了解 WeChat Server + +
+ + + + + + + 保存 WeChat Server 设置 + + +
+ 配置 Message Pusher + + 用以推送报警信息, + + 点击此处 + + 了解 Message Pusher + +
+ + + + + + 保存 Message Pusher 设置 + + +
+ 配置 Turnstile + + 用以支持用户校验, + + 点击此处 + + 管理你的 Turnstile Sites,推荐选择 Invisible Widget Type + +
+ + + + + + 保存 Turnstile 设置 + + +
+
+ ); +}; + +export default SystemSetting; diff --git a/web/air/src/components/TokensTable.js b/web/air/src/components/TokensTable.js new file mode 100644 index 00000000..0853ddfb --- /dev/null +++ b/web/air/src/components/TokensTable.js @@ -0,0 +1,621 @@ +import React, { useEffect, useState } from 'react'; +import { API, copy, showError, showSuccess, timestamp2string } from '../helpers'; + +import { ITEMS_PER_PAGE } from '../constants'; +import { renderQuota } from '../helpers/render'; +import { Button, Dropdown, Form, Modal, Popconfirm, Popover, SplitButtonGroup, Table, Tag } from '@douyinfe/semi-ui'; + +import { IconTreeTriangleDown } from '@douyinfe/semi-icons'; +import EditToken from '../pages/Token/EditToken'; + +const COPY_OPTIONS = [ + { key: 'next', text: 'ChatGPT Next Web', value: 'next' }, + { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, + { key: 'opencat', text: 'OpenCat', value: 'opencat' } +]; + +const OPEN_LINK_OPTIONS = [ + { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, + { key: 'opencat', text: 'OpenCat', value: 'opencat' } +]; + +function renderTimestamp(timestamp) { + return ( + <> + {timestamp2string(timestamp)} + + ); +} + +function renderStatus(status, model_limits_enabled = false) { + switch (status) { + case 1: + if (model_limits_enabled) { + return 已启用:限制模型; + } else { + return 已启用; + } + case 2: + return 已禁用 ; + case 3: + return 已过期 ; + case 4: + return 已耗尽 ; + default: + return 未知状态 ; + } +} + +const TokensTable = () => { + + const link_menu = [ + { + node: 'item', key: 'next', name: 'ChatGPT Next Web', onClick: () => { + onOpenLink('next'); + } + }, + { node: 'item', key: 'ama', name: 'AMA 问天', value: 'ama' }, + { + node: 'item', key: 'next-mj', name: 'ChatGPT Web & Midjourney', value: 'next-mj', onClick: () => { + onOpenLink('next-mj'); + } + }, + { node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' } + ]; + + const columns = [ + { + title: '名称', + dataIndex: 'name' + }, + { + title: '状态', + dataIndex: 'status', + key: 'status', + render: (text, record, index) => { + return ( +
+ {renderStatus(text, record.model_limits_enabled)} +
+ ); + } + }, + { + title: '已用额度', + dataIndex: 'used_quota', + render: (text, record, index) => { + return ( +
+ {renderQuota(parseInt(text))} +
+ ); + } + }, + { + title: '剩余额度', + dataIndex: 'remain_quota', + render: (text, record, index) => { + return ( +
+ {record.unlimited_quota ? 无限制 : + {renderQuota(parseInt(text))}} +
+ ); + } + }, + { + title: '创建时间', + dataIndex: 'created_time', + render: (text, record, index) => { + return ( +
+ {renderTimestamp(text)} +
+ ); + } + }, + { + title: '过期时间', + dataIndex: 'expired_time', + render: (text, record, index) => { + return ( +
+ {record.expired_time === -1 ? '永不过期' : renderTimestamp(text)} +
+ ); + } + }, + { + title: '', + dataIndex: 'operate', + render: (text, record, index) => ( +
+ + + + + + + { + onOpenLink('next', record.key); + } + }, + { + node: 'item', + key: 'next-mj', + disabled: !localStorage.getItem('chat_link2'), + name: 'ChatGPT Web & Midjourney', + onClick: () => { + onOpenLink('next-mj', record.key); + } + }, + { + node: 'item', key: 'ama', name: 'AMA 问天(BotGem)', onClick: () => { + onOpenLink('ama', record.key); + } + }, + { + node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => { + onOpenLink('opencat', record.key); + } + } + ] + } + > + + + + { + manageToken(record.id, 'delete', record).then( + () => { + removeRecord(record.key); + } + ); + }} + > + + + { + record.status === 1 ? + : + + } + +
+ ) + } + ]; + + const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); + const [showEdit, setShowEdit] = useState(false); + const [tokens, setTokens] = useState([]); + const [selectedKeys, setSelectedKeys] = useState([]); + const [tokenCount, setTokenCount] = useState(pageSize); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searchToken, setSearchToken] = useState(''); + const [searching, setSearching] = useState(false); + const [showTopUpModal, setShowTopUpModal] = useState(false); + const [targetTokenIdx, setTargetTokenIdx] = useState(0); + const [editingToken, setEditingToken] = useState({ + id: undefined + }); + const [orderBy, setOrderBy] = useState(''); + const [dropdownVisible, setDropdownVisible] = useState(false); + + const closeEdit = () => { + setShowEdit(false); + setTimeout(() => { + setEditingToken({ + id: undefined + }); + }, 500); + }; + + const setTokensFormat = (tokens) => { + setTokens(tokens); + if (tokens.length >= pageSize) { + setTokenCount(tokens.length + pageSize); + } else { + setTokenCount(tokens.length); + } + }; + + let pageData = tokens.slice((activePage - 1) * pageSize, activePage * pageSize); + const loadTokens = async (startIdx) => { + setLoading(true); + const res = await API.get(`/api/token/?p=${startIdx}&size=${pageSize}&order=${orderBy}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setTokensFormat(data); + } else { + let newTokens = [...tokens]; + newTokens.splice(startIdx * pageSize, data.length, ...data); + setTokensFormat(newTokens); + } + } else { + showError(message); + } + setLoading(false); + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(tokens.length / pageSize) + 1) { + // In this case we have to load more data and then append them. + await loadTokens(activePage - 1, orderBy); + } + setActivePage(activePage); + })(); + }; + + const refresh = async () => { + await loadTokens(activePage - 1); + }; + + const onCopy = async (type, key) => { + let status = localStorage.getItem('status'); + let serverAddress = ''; + if (status) { + status = JSON.parse(status); + serverAddress = status.server_address; + } + if (serverAddress === '') { + serverAddress = window.location.origin; + } + let encodedServerAddress = encodeURIComponent(serverAddress); + const nextLink = localStorage.getItem('chat_link'); + const mjLink = localStorage.getItem('chat_link2'); + let nextUrl; + + if (nextLink) { + nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + } else { + nextUrl = `https://app.nextchat.dev/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + } + + let url; + switch (type) { + case 'ama': + url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + break; + case 'opencat': + url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`; + break; + case 'next': + url = nextUrl; + break; + default: + url = `sk-${key}`; + } + // if (await copy(url)) { + // showSuccess('已复制到剪贴板!'); + // } else { + // showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。'); + // setSearchKeyword(url); + // } + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制到剪贴板!'); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + const onOpenLink = async (type, key) => { + let status = localStorage.getItem('status'); + let serverAddress = ''; + if (status) { + status = JSON.parse(status); + serverAddress = status.server_address; + } + if (serverAddress === '') { + serverAddress = window.location.origin; + } + let encodedServerAddress = encodeURIComponent(serverAddress); + const chatLink = localStorage.getItem('chat_link'); + const mjLink = localStorage.getItem('chat_link2'); + let defaultUrl; + + if (chatLink) { + defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + } + let url; + switch (type) { + case 'ama': + url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`; + break; + case 'opencat': + url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`; + break; + case 'next-mj': + url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + break; + default: + if (!chatLink) { + showError('管理员未设置聊天链接'); + return; + } + url = defaultUrl; + } + + window.open(url, '_blank'); + }; + + useEffect(() => { + loadTokens(0, orderBy) + .then() + .catch((reason) => { + showError(reason); + }); + }, [pageSize, orderBy]); + + const removeRecord = key => { + let newDataSource = [...tokens]; + if (key != null) { + let idx = newDataSource.findIndex(data => data.key === key); + + if (idx > -1) { + newDataSource.splice(idx, 1); + setTokensFormat(newDataSource); + } + } + }; + + const manageToken = async (id, action, record) => { + setLoading(true); + let data = { id }; + let res; + switch (action) { + case 'delete': + res = await API.delete(`/api/token/${id}/`); + break; + case 'enable': + data.status = 1; + res = await API.put('/api/token/?status_only=true', data); + break; + case 'disable': + data.status = 2; + res = await API.put('/api/token/?status_only=true', data); + break; + } + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let token = res.data.data; + let newTokens = [...tokens]; + // let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + if (action === 'delete') { + + } else { + record.status = token.status; + // newTokens[realIdx].status = token.status; + } + setTokensFormat(newTokens); + } else { + showError(message); + } + setLoading(false); + }; + + const searchTokens = async () => { + if (searchKeyword === '' && searchToken === '') { + // if keyword is blank, load files instead. + await loadTokens(0); + setActivePage(1); + setOrderBy(''); + return; + } + setSearching(true); + const res = await API.get(`/api/token/search?keyword=${searchKeyword}&token=${searchToken}`); + const { success, message, data } = res.data; + if (success) { + setTokensFormat(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (value) => { + setSearchKeyword(value.trim()); + }; + + const handleSearchTokenChange = async (value) => { + setSearchToken(value.trim()); + }; + + const sortToken = (key) => { + if (tokens.length === 0) return; + setLoading(true); + let sortedTokens = [...tokens]; + sortedTokens.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + if (sortedTokens[0].id === tokens[0].id) { + sortedTokens.reverse(); + } + setTokens(sortedTokens); + setLoading(false); + }; + + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(tokens.length / pageSize) + 1) { + // In this case we have to load more data and then append them. + loadTokens(page - 1).then(r => { + }); + } + }; + + const rowSelection = { + onSelect: (record, selected) => { + }, + onSelectAll: (selected, selectedRows) => { + }, + onChange: (selectedRowKeys, selectedRows) => { + setSelectedKeys(selectedRows); + } + }; + + const handleRow = (record, index) => { + if (record.status !== 1) { + return { + style: { + background: 'var(--semi-color-disabled-border)' + } + }; + } else { + return {}; + } + }; + + const handleOrderByChange = (e, { value }) => { + setOrderBy(value); + setActivePage(1); + setDropdownVisible(false); + }; + + const renderSelectedOption = (orderBy) => { + switch (orderBy) { + case 'remain_quota': + return '按剩余额度排序'; + case 'used_quota': + return '按已用额度排序'; + default: + return '默认排序'; + } + }; + + return ( + <> + +
+ + {/* */} + + + + `第 ${page.currentStart} - ${page.currentEnd} 条,共 ${tokens.length} 条`, + onPageSizeChange: (size) => { + setPageSize(size); + setActivePage(1); + }, + onPageChange: handlePageChange + }} loading={loading} rowSelection={rowSelection} onRow={handleRow}> +
+ + + setDropdownVisible(visible)} + render={ + + handleOrderByChange('', { value: '' })}>默认排序 + handleOrderByChange('', { value: 'remain_quota' })}>按剩余额度排序 + handleOrderByChange('', { value: 'used_quota' })}>按已用额度排序 + + } + > + + + + ); +}; + +export default TokensTable; diff --git a/web/air/src/components/UsersTable.js b/web/air/src/components/UsersTable.js new file mode 100644 index 00000000..4fc16ba5 --- /dev/null +++ b/web/air/src/components/UsersTable.js @@ -0,0 +1,376 @@ +import React, { useEffect, useState } from 'react'; +import { API, showError, showSuccess } from '../helpers'; +import { Button, Form, Popconfirm, Space, Table, Tag, Tooltip, Dropdown } from '@douyinfe/semi-ui'; +import { ITEMS_PER_PAGE } from '../constants'; +import { renderGroup, renderNumber, renderQuota } from '../helpers/render'; +import AddUser from '../pages/User/AddUser'; +import EditUser from '../pages/User/EditUser'; + +function renderRole(role) { + switch (role) { + case 1: + return 普通用户; + case 10: + return 管理员; + case 100: + return 超级管理员; + default: + return 未知身份; + } +} + +const UsersTable = () => { + const columns = [{ + title: 'ID', dataIndex: 'id' + }, { + title: '用户名', dataIndex: 'username' + }, { + title: '分组', dataIndex: 'group', render: (text, record, index) => { + return (
+ {renderGroup(text)} +
); + } + }, { + title: '统计信息', dataIndex: 'info', render: (text, record, index) => { + return (
+ + + {renderQuota(record.quota)} + + + {renderQuota(record.used_quota)} + + + {renderNumber(record.request_count)} + + +
); + } + }, + // { + // title: '邀请信息', dataIndex: 'invite', render: (text, record, index) => { + // return (
+ // + // + // {renderNumber(record.aff_count)} + // + // + // {renderQuota(record.aff_history_quota)} + // + // + // {record.inviter_id === 0 ? : + // {record.inviter_id}} + // + // + //
); + // } + // }, + { + title: '角色', dataIndex: 'role', render: (text, record, index) => { + return (
+ {renderRole(text)} +
); + } + }, + { + title: '状态', dataIndex: 'status', render: (text, record, index) => { + return (
+ {renderStatus(text)} +
); + } + }, + { + title: '', dataIndex: 'operate', render: (text, record, index) => (
+ <> + { + manageUser(record.username, 'promote', record); + }} + > + + + { + manageUser(record.username, 'demote', record); + }} + > + + + {record.status === 1 ? + : + } + + + { + manageUser(record.username, 'delete', record).then(() => { + removeRecord(record.id); + }); + }} + > + + +
) + }]; + + const [users, setUsers] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + const [userCount, setUserCount] = useState(ITEMS_PER_PAGE); + const [showAddUser, setShowAddUser] = useState(false); + const [showEditUser, setShowEditUser] = useState(false); + const [editingUser, setEditingUser] = useState({ + id: undefined + }); + const [orderBy, setOrderBy] = useState(''); + const [dropdownVisible, setDropdownVisible] = useState(false); + + const setCount = (data) => { + if (data.length >= (activePage) * ITEMS_PER_PAGE) { + setUserCount(data.length + 1); + } else { + setUserCount(data.length); + } + }; + + const removeRecord = key => { + console.log(key); + let newDataSource = [...users]; + if (key != null) { + let idx = newDataSource.findIndex(data => data.id === key); + + if (idx > -1) { + newDataSource.splice(idx, 1); + setUsers(newDataSource); + } + } + }; + + const loadUsers = async (startIdx) => { + const res = await API.get(`/api/user/?p=${startIdx}&order=${orderBy}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setUsers(data); + setCount(data); + } else { + let newUsers = users; + newUsers.push(...data); + setUsers(newUsers); + setCount(newUsers); + } + } else { + showError(message); + } + setLoading(false); + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(users.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + await loadUsers(activePage - 1, orderBy); + } + setActivePage(activePage); + })(); + }; + + useEffect(() => { + loadUsers(0, orderBy) + .then() + .catch((reason) => { + showError(reason); + }); + }, [orderBy]); + + const manageUser = async (username, action, record) => { + const res = await API.post('/api/user/manage', { + username, action + }); + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let user = res.data.data; + let newUsers = [...users]; + if (action === 'delete') { + + } else { + record.status = user.status; + record.role = user.role; + } + setUsers(newUsers); + } else { + showError(message); + } + }; + + const renderStatus = (status) => { + switch (status) { + case 1: + return 已激活; + case 2: + return ( + 已封禁 + ); + default: + return ( + 未知状态 + ); + } + }; + + const searchUsers = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadUsers(0); + setActivePage(1); + setOrderBy(''); + return; + } + setSearching(true); + const res = await API.get(`/api/user/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setUsers(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (value) => { + setSearchKeyword(value.trim()); + }; + + const sortUser = (key) => { + if (users.length === 0) return; + setLoading(true); + let sortedUsers = [...users]; + sortedUsers.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + if (sortedUsers[0].id === users[0].id) { + sortedUsers.reverse(); + } + setUsers(sortedUsers); + setLoading(false); + }; + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(users.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + loadUsers(page - 1).then(r => { + }); + } + }; + + const pageData = users.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE); + + const closeAddUser = () => { + setShowAddUser(false); + }; + + const closeEditUser = () => { + setShowEditUser(false); + setEditingUser({ + id: undefined + }); + }; + + const refresh = async () => { + if (searchKeyword === '') { + await loadUsers(activePage - 1); + } else { + await searchUsers(); + } + }; + + const handleOrderByChange = (e, { value }) => { + setOrderBy(value); + setActivePage(1); + setDropdownVisible(false); + }; + + const renderSelectedOption = (orderBy) => { + switch (orderBy) { + case 'quota': + return '按剩余额度排序'; + case 'used_quota': + return '按已用额度排序'; + case 'request_count': + return '按请求次数排序'; + default: + return '默认排序'; + } + }; + + return ( + <> + + +
+ handleKeywordChange(value)} + /> + + + + + setDropdownVisible(visible)} + render={ + + handleOrderByChange('', { value: '' })}>默认排序 + handleOrderByChange('', { value: 'quota' })}>按剩余额度排序 + handleOrderByChange('', { value: 'used_quota' })}>按已用额度排序 + handleOrderByChange('', { value: 'request_count' })}>按请求次数排序 + + } + > + + + + ); +}; + +export default UsersTable; diff --git a/web/air/src/components/WeChatIcon.js b/web/air/src/components/WeChatIcon.js new file mode 100644 index 00000000..22210d95 --- /dev/null +++ b/web/air/src/components/WeChatIcon.js @@ -0,0 +1,24 @@ +import React from 'react'; +import { Icon } from '@douyinfe/semi-ui'; + +const WeChatIcon = () => { + function CustomIcon() { + return + + + ; + } + + return ( +
+ } /> +
+ ); +}; + +export default WeChatIcon; diff --git a/web/air/src/components/utils.js b/web/air/src/components/utils.js new file mode 100644 index 00000000..5363ba5e --- /dev/null +++ b/web/air/src/components/utils.js @@ -0,0 +1,20 @@ +import { API, showError } from '../helpers'; + +export async function getOAuthState() { + const res = await API.get('/api/oauth/state'); + const { success, message, data } = res.data; + if (success) { + return data; + } else { + showError(message); + return ''; + } +} + +export async function onGitHubOAuthClicked(github_client_id) { + const state = await getOAuthState(); + if (!state) return; + window.open( + `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email` + ); +} \ No newline at end of file diff --git a/web/air/src/constants/channel.constants.js b/web/air/src/constants/channel.constants.js new file mode 100644 index 00000000..4bf035f9 --- /dev/null +++ b/web/air/src/constants/channel.constants.js @@ -0,0 +1,37 @@ +export const CHANNEL_OPTIONS = [ + { key: 1, text: 'OpenAI', value: 1, color: 'green' }, + { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, + { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, + { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, + { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, + { key: 28, text: 'Mistral AI', value: 28, color: 'orange' }, + { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, + { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, + { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, + { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, + { key: 19, text: '360 智脑', value: 19, color: 'blue' }, + { key: 25, text: 'Moonshot AI', value: 25, color: 'black' }, + { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, + { key: 26, text: '百川大模型', value: 26, color: 'orange' }, + { key: 27, text: 'MiniMax', value: 27, color: 'red' }, + { key: 29, text: 'Groq', value: 29, color: 'orange' }, + { key: 30, text: 'Ollama', value: 30, color: 'black' }, + { key: 31, text: '零一万物', value: 31, color: 'green' }, + { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, + { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, + { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, + { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, + { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, + { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, + { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, + { key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' }, + { key: 4, text: '代理:CloseAI', value: 4, color: 'teal' }, + { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' }, + { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, + { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, + { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } +]; + +for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { + CHANNEL_OPTIONS[i].label = CHANNEL_OPTIONS[i].text; +} \ No newline at end of file diff --git a/web/air/src/constants/common.constant.js b/web/air/src/constants/common.constant.js new file mode 100644 index 00000000..1a37d5f6 --- /dev/null +++ b/web/air/src/constants/common.constant.js @@ -0,0 +1 @@ +export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend! diff --git a/web/air/src/constants/index.js b/web/air/src/constants/index.js new file mode 100644 index 00000000..e83152bc --- /dev/null +++ b/web/air/src/constants/index.js @@ -0,0 +1,4 @@ +export * from './toast.constants'; +export * from './user.constants'; +export * from './common.constant'; +export * from './channel.constants'; \ No newline at end of file diff --git a/web/air/src/constants/toast.constants.js b/web/air/src/constants/toast.constants.js new file mode 100644 index 00000000..50684722 --- /dev/null +++ b/web/air/src/constants/toast.constants.js @@ -0,0 +1,7 @@ +export const toastConstants = { + SUCCESS_TIMEOUT: 1500, + INFO_TIMEOUT: 3000, + ERROR_TIMEOUT: 5000, + WARNING_TIMEOUT: 10000, + NOTICE_TIMEOUT: 20000 +}; diff --git a/web/air/src/constants/user.constants.js b/web/air/src/constants/user.constants.js new file mode 100644 index 00000000..2680d8ef --- /dev/null +++ b/web/air/src/constants/user.constants.js @@ -0,0 +1,19 @@ +export const userConstants = { + REGISTER_REQUEST: 'USERS_REGISTER_REQUEST', + REGISTER_SUCCESS: 'USERS_REGISTER_SUCCESS', + REGISTER_FAILURE: 'USERS_REGISTER_FAILURE', + + LOGIN_REQUEST: 'USERS_LOGIN_REQUEST', + LOGIN_SUCCESS: 'USERS_LOGIN_SUCCESS', + LOGIN_FAILURE: 'USERS_LOGIN_FAILURE', + + LOGOUT: 'USERS_LOGOUT', + + GETALL_REQUEST: 'USERS_GETALL_REQUEST', + GETALL_SUCCESS: 'USERS_GETALL_SUCCESS', + GETALL_FAILURE: 'USERS_GETALL_FAILURE', + + DELETE_REQUEST: 'USERS_DELETE_REQUEST', + DELETE_SUCCESS: 'USERS_DELETE_SUCCESS', + DELETE_FAILURE: 'USERS_DELETE_FAILURE' +}; diff --git a/web/air/src/context/Status/index.js b/web/air/src/context/Status/index.js new file mode 100644 index 00000000..71f0682b --- /dev/null +++ b/web/air/src/context/Status/index.js @@ -0,0 +1,19 @@ +// contexts/User/index.jsx + +import React from 'react'; +import { initialState, reducer } from './reducer'; + +export const StatusContext = React.createContext({ + state: initialState, + dispatch: () => null, +}); + +export const StatusProvider = ({ children }) => { + const [state, dispatch] = React.useReducer(reducer, initialState); + + return ( + + {children} + + ); +}; \ No newline at end of file diff --git a/web/air/src/context/Status/reducer.js b/web/air/src/context/Status/reducer.js new file mode 100644 index 00000000..ec9ac6ae --- /dev/null +++ b/web/air/src/context/Status/reducer.js @@ -0,0 +1,20 @@ +export const reducer = (state, action) => { + switch (action.type) { + case 'set': + return { + ...state, + status: action.payload, + }; + case 'unset': + return { + ...state, + status: undefined, + }; + default: + return state; + } +}; + +export const initialState = { + status: undefined, +}; diff --git a/web/air/src/context/User/index.js b/web/air/src/context/User/index.js new file mode 100644 index 00000000..c6671591 --- /dev/null +++ b/web/air/src/context/User/index.js @@ -0,0 +1,19 @@ +// contexts/User/index.jsx + +import React from "react" +import { reducer, initialState } from "./reducer" + +export const UserContext = React.createContext({ + state: initialState, + dispatch: () => null +}) + +export const UserProvider = ({ children }) => { + const [state, dispatch] = React.useReducer(reducer, initialState) + + return ( + + { children } + + ) +} \ No newline at end of file diff --git a/web/air/src/context/User/reducer.js b/web/air/src/context/User/reducer.js new file mode 100644 index 00000000..9ed1d809 --- /dev/null +++ b/web/air/src/context/User/reducer.js @@ -0,0 +1,21 @@ +export const reducer = (state, action) => { + switch (action.type) { + case 'login': + return { + ...state, + user: action.payload + }; + case 'logout': + return { + ...state, + user: undefined + }; + + default: + return state; + } +}; + +export const initialState = { + user: undefined +}; \ No newline at end of file diff --git a/web/air/src/helpers/api.js b/web/air/src/helpers/api.js new file mode 100644 index 00000000..35fdb1e9 --- /dev/null +++ b/web/air/src/helpers/api.js @@ -0,0 +1,13 @@ +import { showError } from './utils'; +import axios from 'axios'; + +export const API = axios.create({ + baseURL: process.env.REACT_APP_SERVER ? process.env.REACT_APP_SERVER : '', +}); + +API.interceptors.response.use( + (response) => response, + (error) => { + showError(error); + } +); diff --git a/web/air/src/helpers/auth-header.js b/web/air/src/helpers/auth-header.js new file mode 100644 index 00000000..a8fe5f5a --- /dev/null +++ b/web/air/src/helpers/auth-header.js @@ -0,0 +1,10 @@ +export function authHeader() { + // return authorization header with jwt token + let user = JSON.parse(localStorage.getItem('user')); + + if (user && user.token) { + return { 'Authorization': 'Bearer ' + user.token }; + } else { + return {}; + } +} \ No newline at end of file diff --git a/web/air/src/helpers/history.js b/web/air/src/helpers/history.js new file mode 100644 index 00000000..629039e5 --- /dev/null +++ b/web/air/src/helpers/history.js @@ -0,0 +1,3 @@ +import { createBrowserHistory } from 'history'; + +export const history = createBrowserHistory(); \ No newline at end of file diff --git a/web/air/src/helpers/index.js b/web/air/src/helpers/index.js new file mode 100644 index 00000000..505a8cf9 --- /dev/null +++ b/web/air/src/helpers/index.js @@ -0,0 +1,4 @@ +export * from './history'; +export * from './auth-header'; +export * from './utils'; +export * from './api'; \ No newline at end of file diff --git a/web/air/src/helpers/render.js b/web/air/src/helpers/render.js new file mode 100644 index 00000000..62fb0dcd --- /dev/null +++ b/web/air/src/helpers/render.js @@ -0,0 +1,170 @@ +import {Label} from 'semantic-ui-react'; +import {Tag} from "@douyinfe/semi-ui"; + +export function renderText(text, limit) { + if (text.length > limit) { + return text.slice(0, limit - 3) + '...'; + } + return text; +} + +export function renderGroup(group) { + if (group === '') { + return default; + } + let groups = group.split(','); + groups.sort(); + return <> + {groups.map((group) => { + if (group === 'vip' || group === 'pro') { + return {group}; + } else if (group === 'svip' || group === 'premium') { + return {group}; + } + if (group === 'default') { + return {group}; + } else { + return {group}; + } + })} + ; +} + +export function renderNumber(num) { + if (num >= 1000000000) { + return (num / 1000000000).toFixed(1) + 'B'; + } else if (num >= 1000000) { + return (num / 1000000).toFixed(1) + 'M'; + } else if (num >= 10000) { + return (num / 1000).toFixed(1) + 'k'; + } else { + return num; + } +} + +export function renderQuotaNumberWithDigit(num, digits = 2) { + let displayInCurrency = localStorage.getItem('display_in_currency'); + num = num.toFixed(digits); + if (displayInCurrency) { + return '$' + num; + } + return num; +} + +export function renderNumberWithPoint(num) { + num = num.toFixed(2); + if (num >= 100000) { + // Convert number to string to manipulate it + let numStr = num.toString(); + // Find the position of the decimal point + let decimalPointIndex = numStr.indexOf('.'); + + let wholePart = numStr; + let decimalPart = ''; + + // If there is a decimal point, split the number into whole and decimal parts + if (decimalPointIndex !== -1) { + wholePart = numStr.slice(0, decimalPointIndex); + decimalPart = numStr.slice(decimalPointIndex); + } + + // Take the first two and last two digits of the whole number part + let shortenedWholePart = wholePart.slice(0, 2) + '..' + wholePart.slice(-2); + + // Return the formatted number + return shortenedWholePart + decimalPart; + } + + // If the number is less than 100,000, return it unmodified + return num; +} + +export function getQuotaPerUnit() { + let quotaPerUnit = localStorage.getItem('quota_per_unit'); + quotaPerUnit = parseFloat(quotaPerUnit); + return quotaPerUnit; +} + +export function getQuotaWithUnit(quota, digits = 6) { + let quotaPerUnit = localStorage.getItem('quota_per_unit'); + quotaPerUnit = parseFloat(quotaPerUnit); + return (quota / quotaPerUnit).toFixed(digits); +} + +export function renderQuota(quota, digits = 2) { + let quotaPerUnit = localStorage.getItem('quota_per_unit'); + let displayInCurrency = localStorage.getItem('display_in_currency'); + quotaPerUnit = parseFloat(quotaPerUnit); + displayInCurrency = displayInCurrency === 'true'; + if (displayInCurrency) { + return '$' + (quota / quotaPerUnit).toFixed(digits); + } + return renderNumber(quota); +} + +export function renderQuotaWithPrompt(quota, digits) { + let displayInCurrency = localStorage.getItem('display_in_currency'); + displayInCurrency = displayInCurrency === 'true'; + if (displayInCurrency) { + return `(等价金额:${renderQuota(quota, digits)})`; + } + return ''; +} + +const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo', + 'light-blue', 'lime', 'orange', 'pink', + 'purple', 'red', 'teal', 'violet', 'yellow' +] + +export const modelColorMap = { + 'dall-e': 'rgb(147,112,219)', // 深紫色 + 'dall-e-2': 'rgb(147,112,219)', // 介于紫色和蓝色之间的色调 + 'dall-e-3': 'rgb(153,50,204)', // 介于紫罗兰和洋红之间的色调 + 'midjourney': 'rgb(136,43,180)', // 介于紫罗兰和洋红之间的色调 + 'gpt-3.5-turbo': 'rgb(184,227,167)', // 浅绿色 + 'gpt-3.5-turbo-0301': 'rgb(131,220,131)', // 亮绿色 + 'gpt-3.5-turbo-0613': 'rgb(60,179,113)', // 海洋绿 + 'gpt-3.5-turbo-1106': 'rgb(32,178,170)', // 浅海洋绿 + 'gpt-3.5-turbo-16k': 'rgb(252,200,149)', // 淡橙色 + 'gpt-3.5-turbo-16k-0613': 'rgb(255,181,119)', // 淡桃色 + 'gpt-3.5-turbo-instruct': 'rgb(175,238,238)', // 粉蓝色 + 'gpt-4': 'rgb(135,206,235)', // 天蓝色 + 'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色 + 'gpt-4-0613': 'rgb(100,149,237)', // 矢车菊蓝 + 'gpt-4-1106-preview': 'rgb(30,144,255)', // 道奇蓝 + 'gpt-4-0125-preview': 'rgb(2,177,236)', // 深天蓝 + 'gpt-4-turbo-preview': 'rgb(2,177,255)', // 深天蓝 + 'gpt-4-32k': 'rgb(104,111,238)', // 中紫色 + 'gpt-4-32k-0314': 'rgb(90,105,205)', // 暗灰蓝色 + 'gpt-4-32k-0613': 'rgb(61,71,139)', // 暗蓝灰色 + 'gpt-4-all': 'rgb(65,105,225)', // 皇家蓝 + 'gpt-4-gizmo-*': 'rgb(0,0,255)', // 纯蓝色 + 'gpt-4-vision-preview': 'rgb(25,25,112)', // 午夜蓝 + 'text-ada-001': 'rgb(255,192,203)', // 粉红色 + 'text-babbage-001': 'rgb(255,160,122)', // 浅珊瑚色 + 'text-curie-001': 'rgb(219,112,147)', // 苍紫罗兰色 + 'text-davinci-002': 'rgb(199,21,133)', // 中紫罗兰红色 + 'text-davinci-003': 'rgb(219,112,147)', // 苍紫罗兰色(与Curie相同,表示同一个系列) + 'text-davinci-edit-001': 'rgb(255,105,180)', // 热粉色 + 'text-embedding-ada-002': 'rgb(255,182,193)', // 浅粉红 + 'text-embedding-v1': 'rgb(255,174,185)', // 浅粉红色(略有区别) + 'text-moderation-latest': 'rgb(255,130,171)', // 强粉色 + 'text-moderation-stable': 'rgb(255,160,122)', // 浅珊瑚色(与Babbage相同,表示同一类功能) + 'tts-1': 'rgb(255,140,0)', // 深橙色 + 'tts-1-1106': 'rgb(255,165,0)', // 橙色 + 'tts-1-hd': 'rgb(255,215,0)', // 金色 + 'tts-1-hd-1106': 'rgb(255,223,0)', // 金黄色(略有区别) + 'whisper-1': 'rgb(245,245,220)' // 米色 +} + +export function stringToColor(str) { + let sum = 0; + // 对字符串中的每个字符进行操作 + for (let i = 0; i < str.length; i++) { + // 将字符的ASCII值加到sum中 + sum += str.charCodeAt(i); + } + // 使用模运算得到个位数 + let i = sum % colors.length; + return colors[i]; +} \ No newline at end of file diff --git a/web/air/src/helpers/utils.js b/web/air/src/helpers/utils.js new file mode 100644 index 00000000..580c77ce --- /dev/null +++ b/web/air/src/helpers/utils.js @@ -0,0 +1,233 @@ +import { Toast } from '@douyinfe/semi-ui'; +import { toastConstants } from '../constants'; +import React from 'react'; +import {toast} from "react-toastify"; + +const HTMLToastContent = ({ htmlContent }) => { + return
; +}; +export default HTMLToastContent; +export function isAdmin() { + let user = localStorage.getItem('user'); + if (!user) return false; + user = JSON.parse(user); + return user.role >= 10; +} + +export function isRoot() { + let user = localStorage.getItem('user'); + if (!user) return false; + user = JSON.parse(user); + return user.role >= 100; +} + +export function getSystemName() { + let system_name = localStorage.getItem('system_name'); + if (!system_name) return 'One API'; + return system_name; +} + +export function getLogo() { + let logo = localStorage.getItem('logo'); + if (!logo) return '/logo.png'; + return logo +} + +export function getFooterHTML() { + return localStorage.getItem('footer_html'); +} + +export async function copy(text) { + let okay = true; + try { + await navigator.clipboard.writeText(text); + } catch (e) { + okay = false; + console.error(e); + } + return okay; +} + +export function isMobile() { + return window.innerWidth <= 600; +} + +let showErrorOptions = { autoClose: toastConstants.ERROR_TIMEOUT }; +let showWarningOptions = { autoClose: toastConstants.WARNING_TIMEOUT }; +let showSuccessOptions = { autoClose: toastConstants.SUCCESS_TIMEOUT }; +let showInfoOptions = { autoClose: toastConstants.INFO_TIMEOUT }; +let showNoticeOptions = { autoClose: false }; + +if (isMobile()) { + showErrorOptions.position = 'top-center'; + // showErrorOptions.transition = 'flip'; + + showSuccessOptions.position = 'top-center'; + // showSuccessOptions.transition = 'flip'; + + showInfoOptions.position = 'top-center'; + // showInfoOptions.transition = 'flip'; + + showNoticeOptions.position = 'top-center'; + // showNoticeOptions.transition = 'flip'; +} + +export function showError(error) { + console.error(error); + if (error.message) { + if (error.name === 'AxiosError') { + switch (error.response.status) { + case 401: + // toast.error('错误:未登录或登录已过期,请重新登录!', showErrorOptions); + window.location.href = '/login?expired=true'; + break; + case 429: + Toast.error('错误:请求次数过多,请稍后再试!'); + break; + case 500: + Toast.error('错误:服务器内部错误,请联系管理员!'); + break; + case 405: + Toast.info('本站仅作演示之用,无服务端!'); + break; + default: + Toast.error('错误:' + error.message); + } + return; + } + Toast.error('错误:' + error.message); + } else { + Toast.error('错误:' + error); + } +} + +export function showWarning(message) { + Toast.warning(message); +} + +export function showSuccess(message) { + Toast.success(message); +} + +export function showInfo(message) { + Toast.info(message); +} + +export function showNotice(message, isHTML = false) { + if (isHTML) { + toast(, showNoticeOptions); + } else { + Toast.info(message); + } +} + +export function openPage(url) { + window.open(url); +} + +export function removeTrailingSlash(url) { + if (url.endsWith('/')) { + return url.slice(0, -1); + } else { + return url; + } +} + +export function timestamp2string(timestamp) { + let date = new Date(timestamp * 1000); + let year = date.getFullYear().toString(); + let month = (date.getMonth() + 1).toString(); + let day = date.getDate().toString(); + let hour = date.getHours().toString(); + let minute = date.getMinutes().toString(); + let second = date.getSeconds().toString(); + if (month.length === 1) { + month = '0' + month; + } + if (day.length === 1) { + day = '0' + day; + } + if (hour.length === 1) { + hour = '0' + hour; + } + if (minute.length === 1) { + minute = '0' + minute; + } + if (second.length === 1) { + second = '0' + second; + } + return ( + year + + '-' + + month + + '-' + + day + + ' ' + + hour + + ':' + + minute + + ':' + + second + ); +} + +export function timestamp2string1(timestamp, dataExportDefaultTime = 'hour') { + let date = new Date(timestamp * 1000); + // let year = date.getFullYear().toString(); + let month = (date.getMonth() + 1).toString(); + let day = date.getDate().toString(); + let hour = date.getHours().toString(); + if (month.length === 1) { + month = '0' + month; + } + if (day.length === 1) { + day = '0' + day; + } + if (hour.length === 1) { + hour = '0' + hour; + } + let str = month + '-' + day + if (dataExportDefaultTime === 'hour') { + str += ' ' + hour + ":00" + } else if (dataExportDefaultTime === 'week') { + let nextWeek = new Date(timestamp * 1000 + 6 * 24 * 60 * 60 * 1000); + let nextMonth = (nextWeek.getMonth() + 1).toString(); + let nextDay = nextWeek.getDate().toString(); + if (nextMonth.length === 1) { + nextMonth = '0' + nextMonth; + } + if (nextDay.length === 1) { + nextDay = '0' + nextDay; + } + str += ' - ' + nextMonth + '-' + nextDay + } + return str; +} + +export function downloadTextAsFile(text, filename) { + let blob = new Blob([text], { type: 'text/plain;charset=utf-8' }); + let url = URL.createObjectURL(blob); + let a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); +} + +export const verifyJSON = (str) => { + try { + JSON.parse(str); + } catch (e) { + return false; + } + return true; +}; + +export function shouldShowPrompt(id) { + let prompt = localStorage.getItem(`prompt-${id}`); + return !prompt; + +} + +export function setPromptShown(id) { + localStorage.setItem(`prompt-${id}`, 'true'); +} \ No newline at end of file diff --git a/web/air/src/index.css b/web/air/src/index.css new file mode 100644 index 00000000..271f14e2 --- /dev/null +++ b/web/air/src/index.css @@ -0,0 +1,116 @@ +body { + margin: 0; + padding-top: 55px; + overflow-y: scroll; + font-family: Lato, 'Helvetica Neue', Arial, Helvetica, "Microsoft YaHei", sans-serif; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; + scrollbar-width: none; + color: var(--semi-color-text-0) !important; + background-color: var( --semi-color-bg-0) !important; + height: 100%; +} + +#root { + height: 100%; +} + +@media only screen and (max-width: 767px) { + .semi-table-tbody, .semi-table-row, .semi-table-row-cell { + display: block!important; + width: auto!important; + padding: 2px!important; + } + .semi-table-row-cell { + border-bottom: 0!important; + } + .semi-table-tbody>.semi-table-row { + border-bottom: 1px solid rgba(0,0,0,.1); + } + .semi-space { + /*display: block!important;*/ + display: flex; + flex-direction: row; + flex-wrap: wrap; + row-gap: 3px; + column-gap: 10px; + } +} + +.semi-table-tbody > .semi-table-row > .semi-table-row-cell { + padding: 16px 14px; +} + +.channel-table { + .semi-table-tbody > .semi-table-row > .semi-table-row-cell { + padding: 16px 8px; + } +} + +.semi-layout { + height: 100%; +} + +.tableShow { + display: revert; +} + +.tableHiddle { + display: none !important; +} + +body::-webkit-scrollbar { + display: none; +} + +code { + font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace; +} + +.semi-navigation-vertical { + /*display: flex;*/ + /*flex-direction: column;*/ +} + +.semi-navigation-item { + margin-bottom: 0; +} + +.semi-navigation-vertical { + /*flex: 0 0 auto;*/ + /*display: flex;*/ + /*flex-direction: column;*/ + /*width: 100%;*/ + height: 100%; + overflow: hidden; +} + +.main-content { + padding: 4px; + height: 100%; +} + +.small-icon .icon { + font-size: 1em !important; +} + +.custom-footer { + font-size: 1.1em; +} + +@media only screen and (max-width: 600px) { + .hide-on-mobile { + display: none !important; + } +} + + +/* 隐藏浏览器默认的滚动条 */ +body { + overflow: hidden; +} + +/* 自定义滚动条样式 */ +body::-webkit-scrollbar { + width: 0; /* 隐藏滚动条的宽度 */ +} \ No newline at end of file diff --git a/web/air/src/index.js b/web/air/src/index.js new file mode 100644 index 00000000..25b1d39e --- /dev/null +++ b/web/air/src/index.js @@ -0,0 +1,54 @@ +import { initVChartSemiTheme } from '@visactor/vchart-semi-theme'; +import React from 'react'; +import ReactDOM from 'react-dom/client'; +import {BrowserRouter} from 'react-router-dom'; +import App from './App'; +import HeaderBar from './components/HeaderBar'; +import Footer from './components/Footer'; +import 'semantic-ui-css/semantic.min.css'; +import './index.css'; +import {UserProvider} from './context/User'; +import {ToastContainer} from 'react-toastify'; +import 'react-toastify/dist/ReactToastify.css'; +import {StatusProvider} from './context/Status'; +import {Layout} from "@douyinfe/semi-ui"; +import SiderBar from "./components/SiderBar"; + +// initialization +initVChartSemiTheme({ + isWatchingThemeSwitch: true, +}); + +const root = ReactDOM.createRoot(document.getElementById('root')); +const {Sider, Content, Header} = Layout; +root.render( + + + + + + + + + +
+ +
+ + + + +
+
+
+ +
+
+
+
+
+); diff --git a/web/air/src/pages/About/index.js b/web/air/src/pages/About/index.js new file mode 100644 index 00000000..ec13f151 --- /dev/null +++ b/web/air/src/pages/About/index.js @@ -0,0 +1,58 @@ +import React, { useEffect, useState } from 'react'; +import { Header, Segment } from 'semantic-ui-react'; +import { API, showError } from '../../helpers'; +import { marked } from 'marked'; + +const About = () => { + const [about, setAbout] = useState(''); + const [aboutLoaded, setAboutLoaded] = useState(false); + + const displayAbout = async () => { + setAbout(localStorage.getItem('about') || ''); + const res = await API.get('/api/about'); + const { success, message, data } = res.data; + if (success) { + let aboutContent = data; + if (!data.startsWith('https://')) { + aboutContent = marked.parse(data); + } + setAbout(aboutContent); + localStorage.setItem('about', aboutContent); + } else { + showError(message); + setAbout('加载关于内容失败...'); + } + setAboutLoaded(true); + }; + + useEffect(() => { + displayAbout().then(); + }, []); + + return ( + <> + { + aboutLoaded && about === '' ? <> + +
关于
+

可在设置页面设置关于内容,支持 HTML & Markdown

+ 项目仓库地址: + + https://github.com/songquanpeng/one-api + +
+ : <> + { + about.startsWith('https://') ?