diff --git a/.github/workflows/docker-image-amd64-en.yml b/.github/workflows/docker-image-amd64-en.yml index af488256..31c01e80 100644 --- a/.github/workflows/docker-image-amd64-en.yml +++ b/.github/workflows/docker-image-amd64-en.yml @@ -3,7 +3,7 @@ name: Publish Docker image (amd64, English) on: push: tags: - - '*' + - 'v*.*.*' workflow_dispatch: inputs: name: diff --git a/.github/workflows/docker-image-amd64.yml b/.github/workflows/docker-image-amd64.yml index 2079d31f..1b9983c6 100644 --- a/.github/workflows/docker-image-amd64.yml +++ b/.github/workflows/docker-image-amd64.yml @@ -3,7 +3,7 @@ name: Publish Docker image (amd64) on: push: tags: - - '*' + - 'v*.*.*' workflow_dispatch: inputs: name: diff --git a/.github/workflows/docker-image-arm64.yml b/.github/workflows/docker-image-arm64.yml index 39d1a401..dc2b4b97 100644 --- a/.github/workflows/docker-image-arm64.yml +++ b/.github/workflows/docker-image-arm64.yml @@ -3,7 +3,7 @@ name: Publish Docker image (arm64) on: push: tags: - - '*' + - 'v*.*.*' - '!*-alpha*' workflow_dispatch: inputs: diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml index 6f30a1d5..161c41e3 100644 --- a/.github/workflows/linux-release.yml +++ b/.github/workflows/linux-release.yml @@ -5,7 +5,7 @@ permissions: on: push: tags: - - '*' + - 'v*.*.*' - '!*-alpha*' workflow_dispatch: inputs: diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml index 359c2c92..94b3e47b 100644 --- a/.github/workflows/macos-release.yml +++ b/.github/workflows/macos-release.yml @@ -5,7 +5,7 @@ permissions: on: push: tags: - - '*' + - 'v*.*.*' - '!*-alpha*' workflow_dispatch: inputs: diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml index 4e99b75c..18641ae8 100644 --- a/.github/workflows/windows-release.yml +++ b/.github/workflows/windows-release.yml @@ -5,7 +5,7 @@ permissions: on: push: tags: - - '*' + - 'v*.*.*' - '!*-alpha*' workflow_dispatch: inputs: diff --git a/.gitignore b/.gitignore index ae288bfa..5a433f55 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ logs data /web/node_modules /.history +cmd.md diff --git a/README.en.md b/README.en.md index eec0047b..bce47353 100644 --- a/README.en.md +++ b/README.en.md @@ -241,17 +241,19 @@ If the channel ID is not provided, load balancing will be used to distribute the + Example: `SESSION_SECRET=random_string` 3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0. + Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` -4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. +4. `LOG_SQL_DSN`: When set, a separate database will be used for the `logs` table; please use MySQL or PostgreSQL. + + Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs` +5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` -5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. +6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. + Example: `SYNC_FREQUENCY=60` -6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. +7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. + Example: `NODE_TYPE=slave` -7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. +8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. + Example: `CHANNEL_UPDATE_FREQUENCY=1440` -8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. +9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. + Example: `CHANNEL_TEST_FREQUENCY=1440` -9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. +10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. + Example: `POLLING_INTERVAL=5` ### Command Line Parameters diff --git a/README.ja.md b/README.ja.md index e9149d71..c15915ec 100644 --- a/README.ja.md +++ b/README.ja.md @@ -242,17 +242,18 @@ graph LR + 例: `SESSION_SECRET=random_string` 3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。 + 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` -4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 +4. `LOG_SQL_DSN`: を設定すると、`logs`テーブルには独立したデータベースが使用されます。MySQLまたはPostgreSQLを使用してください。 +5. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 + 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` -5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 +6. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 + 例: `SYNC_FREQUENCY=60` -6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 +7. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 + 例: `NODE_TYPE=slave` -7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 +8. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 + 例: `CHANNEL_UPDATE_FREQUENCY=1440` -8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 +9. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 + 例: `CHANNEL_TEST_FREQUENCY=1440` -9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 +10. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 + 例: `POLLING_INTERVAL=5` ### コマンドラインパラメータ diff --git a/README.md b/README.md index 0ba659c4..40f6e4e0 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 ## 功能 1. 支持多种大模型: + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) - + [x] [Anthropic Claude 系列模型](https://anthropic.com) + + [x] [Anthropic Claude 系列模型](https://anthropic.com) (支持 AWS Claude) + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + [x] [Mistral 系列模型](https://mistral.ai/) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) @@ -81,13 +81,20 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [Groq](https://wow.groq.com/) + [x] [Ollama](https://github.com/ollama/ollama) + [x] [零一万物](https://platform.lingyiwanwu.com/) + + [x] [阶跃星辰](https://platform.stepfun.com/) + + [x] [Coze](https://www.coze.com/) + + [x] [Cohere](https://cohere.com/) + + [x] [DeepSeek](https://www.deepseek.com/) + + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) + + [x] [DeepL](https://www.deepl.com/) + + [x] [together.ai](https://www.together.ai/) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 5. 支持**多机部署**,[详见此处](#多机部署)。 -6. 支持**令牌管理**,设置令牌的过期时间和额度。 +6. 支持**令牌管理**,设置令牌的过期时间、额度、允许的 IP 范围以及允许的模型访问。 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 -8. 支持**通道管理**,批量创建通道。 +8. 支持**渠道管理**,批量创建渠道。 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 10. 支持渠道**设置模型列表**。 11. 支持**查看额度明细**。 @@ -101,10 +108,11 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 19. 支持丰富的**自定义**设置, 1. 支持自定义系统名称,logo 以及页脚。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 -20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。 +20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。。 21. 支持 Cloudflare Turnstile 用户校验。 22. 支持用户管理,支持**多种用户登录注册方式**: + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 + + 支持使用飞书进行授权登录。 + [GitHub 开放授权](https://github.com/settings/applications/new)。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 @@ -349,38 +357,41 @@ graph LR + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 -4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 +4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL。 +5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` -5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 +6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 + 例子:`MEMORY_CACHE_ENABLED=true` -6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 +7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 + 例子:`SYNC_FREQUENCY=60` -7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 +8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 + 例子:`NODE_TYPE=slave` -8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 +9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` -9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 - + 例子:`CHANNEL_TEST_FREQUENCY=1440` -10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 +10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 +11. 例子:`CHANNEL_TEST_FREQUENCY=1440` +12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + 例子:`POLLING_INTERVAL=5` -11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 +13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + 例子:`BATCH_UPDATE_ENABLED=true` + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 -12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 +14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + 例子:`BATCH_UPDATE_INTERVAL=5` -13. 请求频率限制: +15. 请求频率限制: + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 -14. 编码器缓存设置: +16. 编码器缓存设置: + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 -15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 -16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 -17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 -18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 -19. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 -20. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 -21. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 +17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 +18. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 +19. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 +20. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 +21. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 +22. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 +23. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 +24. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 +25. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 @@ -419,7 +430,7 @@ https://openai.justsong.cn + 检查你的接口地址和 API Key 有没有填对。 + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 6. 报错:`当前分组负载已饱和,请稍后再试` - + 上游通道 429 了。 + + 上游渠道 429 了。 7. 升级之后我的数据会丢失吗? + 如果使用 MySQL,不会。 + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 @@ -427,8 +438,8 @@ https://openai.justsong.cn + 一般情况下不需要,系统将在初始化的时候自动调整。 + 如果需要的话,我会在更新日志中说明,并给出脚本。 9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`? - + 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。 - + 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。 + + 这是检测到 ability 表里有些记录的渠道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的渠道。 + + 对于每一个渠道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该渠道支持该模型。 ## 相关项目 * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 diff --git a/common/config/config.go b/common/config/config.go index a261523d..0864d844 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -4,6 +4,7 @@ import ( "github.com/songquanpeng/one-api/common/env" "os" "strconv" + "strings" "sync" "time" @@ -51,9 +52,9 @@ var EmailDomainWhitelist = []string{ "foxmail.com", } -var DebugEnabled = os.Getenv("DEBUG") == "true" -var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true" -var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" +var DebugEnabled = strings.ToLower(os.Getenv("DEBUG")) == "true" +var DebugSQLEnabled = strings.ToLower(os.Getenv("DEBUG_SQL")) == "true" +var MemoryCacheEnabled = strings.ToLower(os.Getenv("MEMORY_CACHE_ENABLED")) == "true" var LogConsumeEnabled = true @@ -66,6 +67,9 @@ var SMTPToken = "" var GitHubClientId = "" var GitHubClientSecret = "" +var LarkClientId = "" +var LarkClientSecret = "" + var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" @@ -136,3 +140,7 @@ var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10) var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024) var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) + +var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") + +var GeminiVersion = env.String("GEMINI_VERSION", "v1") diff --git a/common/constants.go b/common/constants.go index 849bdce7..87221b61 100644 --- a/common/constants.go +++ b/common/constants.go @@ -4,116 +4,3 @@ import "time" var StartTime = time.Now().Unix() // unit: second var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change - -const ( - RoleGuestUser = 0 - RoleCommonUser = 1 - RoleAdminUser = 10 - RoleRootUser = 100 -) - -const ( - UserStatusEnabled = 1 // don't use 0, 0 is the default value! - UserStatusDisabled = 2 // also don't use 0 - UserStatusDeleted = 3 -) - -const ( - TokenStatusEnabled = 1 // don't use 0, 0 is the default value! - TokenStatusDisabled = 2 // also don't use 0 - TokenStatusExpired = 3 - TokenStatusExhausted = 4 -) - -const ( - RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value! - RedemptionCodeStatusDisabled = 2 // also don't use 0 - RedemptionCodeStatusUsed = 3 // also don't use 0 -) - -const ( - ChannelStatusUnknown = 0 - ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! - ChannelStatusManuallyDisabled = 2 // also don't use 0 - ChannelStatusAutoDisabled = 3 -) - -const ( - ChannelTypeUnknown = iota - ChannelTypeOpenAI - ChannelTypeAPI2D - ChannelTypeAzure - ChannelTypeCloseAI - ChannelTypeOpenAISB - ChannelTypeOpenAIMax - ChannelTypeOhMyGPT - ChannelTypeCustom - ChannelTypeAILS - ChannelTypeAIProxy - ChannelTypePaLM - ChannelTypeAPI2GPT - ChannelTypeAIGC2D - ChannelTypeAnthropic - ChannelTypeBaidu - ChannelTypeZhipu - ChannelTypeAli - ChannelTypeXunfei - ChannelType360 - ChannelTypeOpenRouter - ChannelTypeAIProxyLibrary - ChannelTypeFastGPT - ChannelTypeTencent - ChannelTypeGemini - ChannelTypeMoonshot - ChannelTypeBaichuan - ChannelTypeMinimax - ChannelTypeMistral - ChannelTypeGroq - ChannelTypeOllama - ChannelTypeLingYiWanWu - - ChannelTypeDummy -) - -var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "https://generativelanguage.googleapis.com", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://ai.360.cn", // 19 - "https://openrouter.ai/api", // 20 - "https://api.aiproxy.io", // 21 - "https://fastgpt.run/api/openapi", // 22 - "https://hunyuan.cloud.tencent.com", // 23 - "https://generativelanguage.googleapis.com", // 24 - "https://api.moonshot.cn", // 25 - "https://api.baichuan-ai.com", // 26 - "https://api.minimax.chat", // 27 - "https://api.mistral.ai", // 28 - "https://api.groq.com/openai", // 29 - "http://localhost:11434", // 30 - "https://api.lingyiwanwu.com", // 31 -} - -const ( - ConfigKeyPrefix = "cfg_" - - ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version" - ConfigKeyLibraryID = ConfigKeyPrefix + "library_id" - ConfigKeyPlugin = ConfigKeyPrefix + "plugin" -) diff --git a/common/conv/any.go b/common/conv/any.go new file mode 100644 index 00000000..467e8bb7 --- /dev/null +++ b/common/conv/any.go @@ -0,0 +1,6 @@ +package conv + +func AsString(v any) string { + str, _ := v.(string) + return str +} diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go new file mode 100644 index 00000000..6c640870 --- /dev/null +++ b/common/ctxkey/key.go @@ -0,0 +1,22 @@ +package ctxkey + +const ( + Config = "config" + Id = "id" + Username = "username" + Role = "role" + Status = "status" + Channel = "channel" + ChannelId = "channel_id" + SpecificChannelId = "specific_channel_id" + RequestModel = "request_model" + ConvertedRequest = "converted_request" + OriginalModel = "original_model" + Group = "group" + ModelMapping = "model_mapping" + ChannelName = "channel_name" + TokenId = "token_id" + TokenName = "token_name" + BaseURL = "base_url" + AvailableModels = "available_models" +) diff --git a/common/helper/helper.go b/common/helper/helper.go index db41ac74..e06dfb6e 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -2,16 +2,15 @@ package helper import ( "fmt" - "github.com/google/uuid" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/random" "html/template" "log" - "math/rand" "net" "os/exec" "runtime" "strconv" "strings" - "time" ) func OpenBrowser(url string) { @@ -79,31 +78,6 @@ func Bytes2Size(num int64) string { return numStr + " " + unit } -func Seconds2Time(num int) (time string) { - if num/31104000 > 0 { - time += strconv.Itoa(num/31104000) + " 年 " - num %= 31104000 - } - if num/2592000 > 0 { - time += strconv.Itoa(num/2592000) + " 个月 " - num %= 2592000 - } - if num/86400 > 0 { - time += strconv.Itoa(num/86400) + " 天 " - num %= 86400 - } - if num/3600 > 0 { - time += strconv.Itoa(num/3600) + " 小时 " - num %= 3600 - } - if num/60 > 0 { - time += strconv.Itoa(num/60) + " 分钟 " - num %= 60 - } - time += strconv.Itoa(num) + " 秒" - return -} - func Interface2String(inter interface{}) string { switch inter := inter.(type) { case string: @@ -128,65 +102,13 @@ func IntMax(a int, b int) int { } } -func GetUUID() string { - code := uuid.New().String() - code = strings.Replace(code, "-", "", -1) - return code -} - -const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -const keyNumbers = "0123456789" - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func GenerateKey() string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, 48) - for i := 0; i < 16; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - uuid_ := GetUUID() - for i := 0; i < 32; i++ { - c := uuid_[i] - if i%2 == 0 && c >= 'a' && c <= 'z' { - c = c - 'a' + 'A' - } - key[i+16] = c - } - return string(key) -} - -func GetRandomString(length int) string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, length) - for i := 0; i < length; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - return string(key) -} - -func GetRandomNumberString(length int) string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, length) - for i := 0; i < length; i++ { - key[i] = keyNumbers[rand.Intn(len(keyNumbers))] - } - return string(key) -} - -func GetTimestamp() int64 { - return time.Now().Unix() -} - -func GetTimeString() string { - now := time.Now() - return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) -} - func GenRequestID() string { - return GetTimeString() + GetRandomNumberString(8) + return GetTimeString() + random.GetRandomNumberString(8) +} + +func GetResponseID(c *gin.Context) string { + logID := c.GetString(RequestIdKey) + return fmt.Sprintf("chatcmpl-%s", logID) } func Max(a int, b int) int { diff --git a/common/helper/key.go b/common/helper/key.go new file mode 100644 index 00000000..17aee2e0 --- /dev/null +++ b/common/helper/key.go @@ -0,0 +1,5 @@ +package helper + +const ( + RequestIdKey = "X-Oneapi-Request-Id" +) diff --git a/common/helper/time.go b/common/helper/time.go new file mode 100644 index 00000000..302746db --- /dev/null +++ b/common/helper/time.go @@ -0,0 +1,15 @@ +package helper + +import ( + "fmt" + "time" +) + +func GetTimestamp() int64 { + return time.Now().Unix() +} + +func GetTimeString() string { + now := time.Now() + return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) +} diff --git a/common/image/image.go b/common/image/image.go index de8fefd3..12f0adff 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -16,7 +16,7 @@ import ( ) // Regex to match data URL pattern -var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) +var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) func IsImageUrl(url string) (bool, error) { resp, err := http.Head(url) diff --git a/common/logger/constants.go b/common/logger/constants.go index 78d32062..49df31ec 100644 --- a/common/logger/constants.go +++ b/common/logger/constants.go @@ -1,7 +1,3 @@ package logger -const ( - RequestIdKey = "X-Oneapi-Request-Id" -) - var LogDir string diff --git a/common/logger/logger.go b/common/logger/logger.go index 957d8a11..c3dcd89d 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -3,15 +3,16 @@ package logger import ( "context" "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/helper" "io" "log" "os" "path/filepath" "sync" "time" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" ) const ( @@ -21,28 +22,20 @@ const ( loggerError = "ERR" ) -var setupLogLock sync.Mutex -var setupLogWorking bool +var setupLogOnce sync.Once func SetupLogger() { - if LogDir != "" { - ok := setupLogLock.TryLock() - if !ok { - log.Println("setup log is already working") - return + setupLogOnce.Do(func() { + if LogDir != "" { + logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Fatal("failed to open log file") + } + gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) + gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) } - defer func() { - setupLogLock.Unlock() - setupLogWorking = false - }() - logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) - fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Fatal("failed to open log file") - } - gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) - gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) - } + }) } func SysLog(s string) { @@ -94,18 +87,13 @@ func logHelper(ctx context.Context, level string, msg string) { if level == loggerINFO { writer = gin.DefaultWriter } - id := ctx.Value(RequestIdKey) + id := ctx.Value(helper.RequestIdKey) if id == nil { id = helper.GenRequestID() } now := time.Now() _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) - if !setupLogWorking { - setupLogWorking = true - go func() { - SetupLogger() - }() - } + SetupLogger() } func FatalLog(v ...any) { diff --git a/common/network/ip.go b/common/network/ip.go new file mode 100644 index 00000000..0fbe5e6f --- /dev/null +++ b/common/network/ip.go @@ -0,0 +1,52 @@ +package network + +import ( + "context" + "fmt" + "github.com/songquanpeng/one-api/common/logger" + "net" + "strings" +) + +func splitSubnets(subnets string) []string { + res := strings.Split(subnets, ",") + for i := 0; i < len(res); i++ { + res[i] = strings.TrimSpace(res[i]) + } + return res +} + +func isValidSubnet(subnet string) error { + _, _, err := net.ParseCIDR(subnet) + if err != nil { + return fmt.Errorf("failed to parse subnet: %w", err) + } + return nil +} + +func isIpInSubnet(ctx context.Context, ip string, subnet string) bool { + _, ipNet, err := net.ParseCIDR(subnet) + if err != nil { + logger.Errorf(ctx, "failed to parse subnet: %s", err.Error()) + return false + } + return ipNet.Contains(net.ParseIP(ip)) +} + +func IsValidSubnets(subnets string) error { + for _, subnet := range splitSubnets(subnets) { + if err := isValidSubnet(subnet); err != nil { + return err + } + } + return nil +} + +func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool { + for _, subnet := range splitSubnets(subnets) { + if isIpInSubnet(ctx, ip, subnet) { + return true + } + } + return false +} diff --git a/common/network/ip_test.go b/common/network/ip_test.go new file mode 100644 index 00000000..6c593458 --- /dev/null +++ b/common/network/ip_test.go @@ -0,0 +1,19 @@ +package network + +import ( + "context" + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestIsIpInSubnet(t *testing.T) { + ctx := context.Background() + ip1 := "192.168.0.5" + ip2 := "125.216.250.89" + subnet := "192.168.0.0/24" + Convey("TestIsIpInSubnet", t, func() { + So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue) + So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse) + }) +} diff --git a/common/random.go b/common/random.go deleted file mode 100644 index 44bd2856..00000000 --- a/common/random.go +++ /dev/null @@ -1,8 +0,0 @@ -package common - -import "math/rand" - -// RandRange returns a random number between min and max (max is not included) -func RandRange(min, max int) int { - return min + rand.Intn(max-min) -} diff --git a/common/random/main.go b/common/random/main.go new file mode 100644 index 00000000..dbb772cd --- /dev/null +++ b/common/random/main.go @@ -0,0 +1,61 @@ +package random + +import ( + "github.com/google/uuid" + "math/rand" + "strings" + "time" +) + +func GetUUID() string { + code := uuid.New().String() + code = strings.Replace(code, "-", "", -1) + return code +} + +const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +const keyNumbers = "0123456789" + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func GenerateKey() string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, 48) + for i := 0; i < 16; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + uuid_ := GetUUID() + for i := 0; i < 32; i++ { + c := uuid_[i] + if i%2 == 0 && c >= 'a' && c <= 'z' { + c = c - 'a' + 'A' + } + key[i+16] = c + } + return string(key) +} + +func GetRandomString(length int) string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + return string(key) +} + +func GetRandomNumberString(length int) string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyNumbers[rand.Intn(len(keyNumbers))] + } + return string(key) +} + +// RandRange returns a random number between min and max (max is not included) +func RandRange(min, max int) int { + return min + rand.Intn(max-min) +} diff --git a/controller/github.go b/controller/auth/github.go similarity index 94% rename from controller/github.go rename to controller/auth/github.go index 7d7fa106..15542655 100644 --- a/controller/github.go +++ b/controller/auth/github.go @@ -1,4 +1,4 @@ -package controller +package auth import ( "bytes" @@ -7,10 +7,10 @@ import ( "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -133,8 +133,8 @@ func GitHubOAuth(c *gin.Context) { user.DisplayName = "GitHub User" } user.Email = githubUser.Email - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled if err := user.Insert(0); err != nil { c.JSON(http.StatusOK, gin.H{ @@ -152,14 +152,14 @@ func GitHubOAuth(c *gin.Context) { } } - if user.Status != common.UserStatusEnabled { + if user.Status != model.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, }) return } - setupLogin(&user, c) + controller.SetupLogin(&user, c) } func GitHubBind(c *gin.Context) { @@ -219,7 +219,7 @@ func GitHubBind(c *gin.Context) { func GenerateOAuthCode(c *gin.Context) { session := sessions.Default(c) - state := helper.GetRandomString(12) + state := random.GetRandomString(12) session.Set("oauth_state", state) err := session.Save() if err != nil { diff --git a/controller/auth/lark.go b/controller/auth/lark.go new file mode 100644 index 00000000..eb06dde9 --- /dev/null +++ b/controller/auth/lark.go @@ -0,0 +1,200 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" + "time" +) + +type LarkOAuthResponse struct { + AccessToken string `json:"access_token"` +} + +type LarkUser struct { + Name string `json:"name"` + OpenID string `json:"open_id"` +} + +func getLarkUserInfoByCode(code string) (*LarkUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{ + "client_id": config.LarkClientId, + "client_secret": config.LarkClientSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress), + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至飞书服务器,请稍后重试!") + } + defer res.Body.Close() + var oAuthResponse LarkOAuthResponse + err = json.NewDecoder(res.Body).Decode(&oAuthResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至飞书服务器,请稍后重试!") + } + var larkUser LarkUser + err = json.NewDecoder(res2.Body).Decode(&larkUser) + if err != nil { + return nil, err + } + return &larkUser, nil +} + +func LarkOAuth(c *gin.Context) { + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + LarkBind(c) + return + } + code := c.Query("code") + larkUser, err := getLarkUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + LarkId: larkUser.OpenID, + } + if model.IsLarkIdAlreadyTaken(user.LarkId) { + err := user.FillUserByLarkId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1) + if larkUser.Name != "" { + user.DisplayName = larkUser.Name + } else { + user.DisplayName = "Lark User" + } + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled + + if err := user.Insert(0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != model.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func LarkBind(c *gin.Context) { + code := c.Query("code") + larkUser, err := getLarkUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + LarkId: larkUser.OpenID, + } + if model.IsLarkIdAlreadyTaken(user.LarkId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该飞书账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.LarkId = larkUser.OpenID + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/wechat.go b/controller/auth/wechat.go similarity index 91% rename from controller/wechat.go rename to controller/auth/wechat.go index 74be5604..a561aec0 100644 --- a/controller/wechat.go +++ b/controller/auth/wechat.go @@ -1,12 +1,13 @@ -package controller +package auth import ( "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -83,8 +84,8 @@ func WeChatAuth(c *gin.Context) { if config.RegisterEnabled { user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.DisplayName = "WeChat User" - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled if err := user.Insert(0); err != nil { c.JSON(http.StatusOK, gin.H{ @@ -102,14 +103,14 @@ func WeChatAuth(c *gin.Context) { } } - if user.Status != common.UserStatusEnabled { + if user.Status != model.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, }) return } - setupLogin(&user, c) + controller.SetupLogin(&user, c) } func WeChatBind(c *gin.Context) { @@ -136,7 +137,7 @@ func WeChatBind(c *gin.Context) { }) return } - id := c.GetInt("id") + id := c.GetInt(ctxkey.Id) user := model.User{ Id: id, } diff --git a/controller/billing.go b/controller/billing.go index dd518678..0d03e4c1 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -3,6 +3,7 @@ package controller import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/model" relaymodel "github.com/songquanpeng/one-api/relay/model" ) @@ -14,13 +15,13 @@ func GetSubscription(c *gin.Context) { var token *model.Token var expiredTime int64 if config.DisplayTokenStatEnabled { - tokenId := c.GetInt("token_id") + tokenId := c.GetInt(ctxkey.TokenId) token, err = model.GetTokenById(tokenId) expiredTime = token.ExpiredTime remainQuota = token.RemainQuota usedQuota = token.UsedQuota } else { - userId := c.GetInt("id") + userId := c.GetInt(ctxkey.Id) remainQuota, err = model.GetUserQuota(userId) if err != nil { usedQuota, err = model.GetUserUsedQuota(userId) @@ -64,11 +65,11 @@ func GetUsage(c *gin.Context) { var err error var token *model.Token if config.DisplayTokenStatEnabled { - tokenId := c.GetInt("token_id") + tokenId := c.GetInt(ctxkey.TokenId) token, err = model.GetTokenById(tokenId) quota = token.UsedQuota } else { - userId := c.GetInt("id") + userId := c.GetInt(ctxkey.Id) quota, err = model.GetUserUsedQuota(userId) } if err != nil { diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 03c97349..b7ac61fd 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -4,12 +4,12 @@ import ( "encoding/json" "errors" "fmt" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/monitor" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/client" "io" "net/http" "strconv" @@ -96,7 +96,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He for k := range headers { req.Header.Add(k, headers.Get(k)) } - res, err := util.HTTPClient.Do(req) + res, err := client.HTTPClient.Do(req) if err != nil { return nil, err } @@ -204,28 +204,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { } func updateChannelBalance(channel *model.Channel) (float64, error) { - baseURL := common.ChannelBaseURLs[channel.Type] + baseURL := channeltype.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() == "" { channel.BaseURL = &baseURL } switch channel.Type { - case common.ChannelTypeOpenAI: + case channeltype.OpenAI: if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } - case common.ChannelTypeAzure: + case channeltype.Azure: return 0, errors.New("尚未实现") - case common.ChannelTypeCustom: + case channeltype.Custom: baseURL = channel.GetBaseURL() - case common.ChannelTypeCloseAI: + case channeltype.CloseAI: return updateChannelCloseAIBalance(channel) - case common.ChannelTypeOpenAISB: + case channeltype.OpenAISB: return updateChannelOpenAISBBalance(channel) - case common.ChannelTypeAIProxy: + case channeltype.AIProxy: return updateChannelAIProxyBalance(channel) - case common.ChannelTypeAPI2GPT: + case channeltype.API2GPT: return updateChannelAPI2GPTBalance(channel) - case common.ChannelTypeAIGC2D: + case channeltype.AIGC2D: return updateChannelAIGC2DBalance(channel) default: return 0, errors.New("尚未实现") @@ -301,11 +301,11 @@ func updateAllChannelsBalance() error { return err } for _, channel := range channels { - if channel.Status != common.ChannelStatusEnabled { + if channel.Status != model.ChannelStatusEnabled { continue } // TODO: support Azure - if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { + if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom { continue } balance, err := updateChannelBalance(channel) diff --git a/controller/channel-test.go b/controller/channel-test.go index 67ac91d0..b8c41819 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,17 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/common/message" - "github.com/songquanpeng/one-api/middleware" - "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/monitor" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/helper" - relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "net/http/httptest" @@ -25,6 +14,20 @@ import ( "sync" "time" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" + "github.com/songquanpeng/one-api/middleware" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" + relay "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/controller" + "github.com/songquanpeng/one-api/relay/meta" + relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" + "github.com/gin-gonic/gin" ) @@ -53,27 +56,37 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error } c.Request.Header.Set("Authorization", "Bearer "+channel.Key) c.Request.Header.Set("Content-Type", "application/json") - c.Set("channel", channel.Type) - c.Set("base_url", channel.GetBaseURL()) + c.Set(ctxkey.Channel, channel.Type) + c.Set(ctxkey.BaseURL, channel.GetBaseURL()) + cfg, _ := channel.LoadConfig() + c.Set(ctxkey.Config, cfg) middleware.SetupContextForSelectedChannel(c, channel, "") - meta := util.GetRelayMeta(c) - apiType := constant.ChannelType2APIType(channel.Type) - adaptor := helper.GetAdaptor(apiType) + meta := meta.GetByContext(c) + apiType := channeltype.ToAPIType(channel.Type) + adaptor := relay.GetAdaptor(apiType) if adaptor == nil { return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } adaptor.Init(meta) - modelName := adaptor.GetModelList()[0] - if !strings.Contains(channel.Models, modelName) { + var modelName string + modelList := adaptor.GetModelList() + modelMap := channel.GetModelMapping() + if len(modelList) != 0 { + modelName = modelList[0] + } + if modelName == "" || !strings.Contains(channel.Models, modelName) { modelNames := strings.Split(channel.Models, ",") if len(modelNames) > 0 { modelName = modelNames[0] } + if modelMap != nil && modelMap[modelName] != "" { + modelName = modelMap[modelName] + } } request := buildTestRequest() request.Model = modelName meta.OriginModelName, meta.ActualModelName = modelName, modelName - convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) + convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) if err != nil { return err, nil } @@ -81,14 +94,15 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error if err != nil { return err, nil } + logger.SysLog(string(jsonData)) requestBody := bytes.NewBuffer(jsonData) c.Request.Body = io.NopCloser(requestBody) resp, err := adaptor.DoRequest(c, meta, requestBody) if err != nil { return err, nil } - if resp.StatusCode != http.StatusOK { - err := util.RelayErrorHandler(resp) + if resp != nil && resp.StatusCode != http.StatusOK { + err := controller.RelayErrorHandler(resp) return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error } usage, respErr := adaptor.DoResponse(c, resp, meta) @@ -171,7 +185,7 @@ func testChannels(notify bool, scope string) error { } go func() { for _, channel := range channels { - isChannelEnabled := channel.Status == common.ChannelStatusEnabled + isChannelEnabled := channel.Status == model.ChannelStatusEnabled tik := time.Now() err, openaiErr := testChannel(channel) tok := time.Now() @@ -184,10 +198,10 @@ func testChannels(notify bool, scope string) error { _ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error()) } } - if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { + if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) { monitor.DisableChannel(channel.Id, channel.Name, err.Error()) } - if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { + if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) { monitor.EnableChannel(channel.Id, channel.Name) } channel.UpdateResponseTime(milliseconds) @@ -197,7 +211,7 @@ func testChannels(notify bool, scope string) error { testAllChannelsRunning = false testAllChannelsLock.Unlock() if notify { - err := message.Notify(message.ByAll, "通道测试完成", "", "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") + err := message.Notify(message.ByAll, "渠道测试完成", "", "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常") if err != nil { logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } diff --git a/controller/group.go b/controller/group.go index 128a3527..6f02394f 100644 --- a/controller/group.go +++ b/controller/group.go @@ -2,13 +2,13 @@ package controller import ( "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "net/http" ) func GetGroups(c *gin.Context) { groupNames := make([]string, 0) - for groupName := range common.GroupRatio { + for groupName := range billingratio.GroupRatio { groupNames = append(groupNames, groupName) } c.JSON(http.StatusOK, gin.H{ diff --git a/controller/log.go b/controller/log.go index 4e582982..8cfe090f 100644 --- a/controller/log.go +++ b/controller/log.go @@ -4,6 +4,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -72,7 +73,7 @@ func GetUserLogs(c *gin.Context) { if p < 0 { p = 0 } - userId := c.GetInt("id") + userId := c.GetInt(ctxkey.Id) logType, _ := strconv.Atoi(c.Query("type")) startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) @@ -114,7 +115,7 @@ func SearchAllLogs(c *gin.Context) { func SearchUserLogs(c *gin.Context) { keyword := c.Query("keyword") - userId := c.GetInt("id") + userId := c.GetInt(ctxkey.Id) logs, err := model.SearchUserLogs(userId, keyword) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -153,7 +154,7 @@ func GetLogsStat(c *gin.Context) { } func GetLogsSelfStat(c *gin.Context) { - username := c.GetString("username") + username := c.GetString(ctxkey.Username) logType, _ := strconv.Atoi(c.Query("type")) startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) diff --git a/controller/misc.go b/controller/misc.go index f27fdb12..2928b8fb 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -23,6 +23,7 @@ func GetStatus(c *gin.Context) { "email_verification": config.EmailVerificationEnabled, "github_oauth": config.GitHubOAuthEnabled, "github_client_id": config.GitHubClientId, + "lark_client_id": config.LarkClientId, "system_name": config.SystemName, "logo": config.Logo, "footer_html": config.Footer, diff --git a/controller/model.go b/controller/model.go index 4c5476b4..dcbe709e 100644 --- a/controller/model.go +++ b/controller/model.go @@ -3,13 +3,16 @@ package controller import ( "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/helper" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/model" + relay "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/apitype" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "net/http" + "strings" ) // https://platform.openai.com/docs/api-reference/models/list @@ -39,8 +42,8 @@ type OpenAIModels struct { Parent *string `json:"parent"` } -var openAIModels []OpenAIModels -var openAIModelsMap map[string]OpenAIModels +var models []OpenAIModels +var modelsMap map[string]OpenAIModels var channelId2Models map[int][]string func init() { @@ -60,15 +63,15 @@ func init() { IsBlocking: false, }) // https://platform.openai.com/docs/models/model-endpoint-compatibility - for i := 0; i < constant.APITypeDummy; i++ { - if i == constant.APITypeAIProxyLibrary { + for i := 0; i < apitype.Dummy; i++ { + if i == apitype.AIProxyLibrary { continue } - adaptor := helper.GetAdaptor(i) + adaptor := relay.GetAdaptor(i) channelName := adaptor.GetChannelName() modelNames := adaptor.GetModelList() for _, modelName := range modelNames { - openAIModels = append(openAIModels, OpenAIModels{ + models = append(models, OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -80,12 +83,12 @@ func init() { } } for _, channelType := range openai.CompatibleChannels { - if channelType == common.ChannelTypeAzure { + if channelType == channeltype.Azure { continue } channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) for _, modelName := range channelModelList { - openAIModels = append(openAIModels, OpenAIModels{ + models = append(models, OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -96,14 +99,14 @@ func init() { }) } } - openAIModelsMap = make(map[string]OpenAIModels) - for _, model := range openAIModels { - openAIModelsMap[model.Id] = model + modelsMap = make(map[string]OpenAIModels) + for _, model := range models { + modelsMap[model.Id] = model } channelId2Models = make(map[int][]string) - for i := 1; i < common.ChannelTypeDummy; i++ { - adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i)) - meta := &util.RelayMeta{ + for i := 1; i < channeltype.Dummy; i++ { + adaptor := relay.GetAdaptor(channeltype.ToAPIType(i)) + meta := &meta.Meta{ ChannelType: i, } adaptor.Init(meta) @@ -119,16 +122,55 @@ func DashboardListModels(c *gin.Context) { }) } -func ListModels(c *gin.Context) { +func ListAllModels(c *gin.Context) { c.JSON(200, gin.H{ "object": "list", - "data": openAIModels, + "data": models, + }) +} + +func ListModels(c *gin.Context) { + ctx := c.Request.Context() + var availableModels []string + if c.GetString(ctxkey.AvailableModels) != "" { + availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",") + } else { + userId := c.GetInt(ctxkey.Id) + userGroup, _ := model.CacheGetUserGroup(userId) + availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) + } + modelSet := make(map[string]bool) + for _, availableModel := range availableModels { + modelSet[availableModel] = true + } + availableOpenAIModels := make([]OpenAIModels, 0) + for _, model := range models { + if _, ok := modelSet[model.Id]; ok { + modelSet[model.Id] = false + availableOpenAIModels = append(availableOpenAIModels, model) + } + } + for modelName, ok := range modelSet { + if ok { + availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "custom", + Root: modelName, + Parent: nil, + }) + } + } + c.JSON(200, gin.H{ + "object": "list", + "data": availableOpenAIModels, }) } func RetrieveModel(c *gin.Context) { modelId := c.Param("model") - if model, ok := openAIModelsMap[modelId]; ok { + if model, ok := modelsMap[modelId]; ok { c.JSON(200, model) } else { Error := relaymodel.Error{ @@ -142,3 +184,30 @@ func RetrieveModel(c *gin.Context) { }) } } + +func GetUserAvailableModels(c *gin.Context) { + ctx := c.Request.Context() + id := c.GetInt(ctxkey.Id) + userGroup, err := model.CacheGetUserGroup(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + models, err := model.CacheGetGroupModels(ctx, userGroup) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": models, + }) + return +} diff --git a/controller/redemption.go b/controller/redemption.go index 31c9348d..1d0ffbad 100644 --- a/controller/redemption.go +++ b/controller/redemption.go @@ -3,7 +3,9 @@ package controller import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -106,9 +108,9 @@ func AddRedemption(c *gin.Context) { } var keys []string for i := 0; i < redemption.Count; i++ { - key := helper.GetUUID() + key := random.GetUUID() cleanRedemption := model.Redemption{ - UserId: c.GetInt("id"), + UserId: c.GetInt(ctxkey.Id), Name: redemption.Name, Key: key, CreatedTime: helper.GetTimestamp(), diff --git a/controller/relay.go b/controller/relay.go index 98db0f98..eae71421 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -9,31 +9,30 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/middleware" dbmodel "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/monitor" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" - + "github.com/songquanpeng/one-api/relay/relaymode" "io" ) // https://platform.openai.com/docs/api-reference/chat -func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { +func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { var err *model.ErrorWithStatusCode switch relayMode { - case constant.RelayModeImagesGenerations: + case relaymode.ImagesGenerations: err = controller.RelayImageHelper(c, relayMode) - case constant.RelayModeAudioSpeech: + case relaymode.AudioSpeech: fallthrough - case constant.RelayModeAudioTranslation: + case relaymode.AudioTranslation: fallthrough - case constant.RelayModeAudioTranscription: + case relaymode.AudioTranscription: err = controller.RelayAudioHelper(c, relayMode) default: err = controller.RelayTextHelper(c) @@ -43,23 +42,23 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { func Relay(c *gin.Context) { ctx := c.Request.Context() - relayMode := constant.Path2RelayMode(c.Request.URL.Path) + relayMode := relaymode.GetByPath(c.Request.URL.Path) if config.DebugEnabled { requestBody, _ := common.GetRequestBody(c) logger.Debugf(ctx, "request body: %s", string(requestBody)) } - channelId := c.GetInt("channel_id") - bizErr := relay(c, relayMode) + channelId := c.GetInt(ctxkey.ChannelId) + bizErr := relayHelper(c, relayMode) if bizErr == nil { monitor.Emit(channelId, true) return } lastFailedChannelId := channelId - channelName := c.GetString("channel_name") - group := c.GetString("group") - originalModel := c.GetString("original_model") + channelName := c.GetString(ctxkey.ChannelName) + group := c.GetString(ctxkey.Group) + originalModel := c.GetString(ctxkey.OriginalModel) go processChannelRelayError(ctx, channelId, channelName, bizErr) - requestId := c.GetString(logger.RequestIdKey) + requestId := c.GetString(helper.RequestIdKey) retryTimes := config.RetryTimes if !shouldRetry(c, bizErr.StatusCode) { logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) @@ -68,7 +67,7 @@ func Relay(c *gin.Context) { for i := retryTimes; i > 0; i-- { channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes) if err != nil { - logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) + logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %+v", err) break } @@ -79,13 +78,13 @@ func Relay(c *gin.Context) { middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, err := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - bizErr = relay(c, relayMode) + bizErr = relayHelper(c, relayMode) if bizErr == nil { return } - channelId := c.GetInt("channel_id") + channelId := c.GetInt(ctxkey.ChannelId) lastFailedChannelId = channelId - channelName := c.GetString("channel_name") + channelName := c.GetString(ctxkey.ChannelName) go processChannelRelayError(ctx, channelId, channelName, bizErr) } if bizErr != nil { @@ -100,7 +99,7 @@ func Relay(c *gin.Context) { } func shouldRetry(c *gin.Context, statusCode int) bool { - if _, ok := c.Get("specific_channel_id"); ok { + if _, ok := c.Get(ctxkey.SpecificChannelId); ok { return false } if statusCode == http.StatusTooManyRequests { @@ -122,7 +121,7 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors - if util.ShouldDisableChannel(&err.Error, err.StatusCode) { + if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { monitor.DisableChannel(channelId, channelName, err.Message) } else { monitor.Emit(channelId, false) diff --git a/controller/token.go b/controller/token.go index 6012c482..97a1b313 100644 --- a/controller/token.go +++ b/controller/token.go @@ -1,22 +1,28 @@ package controller import ( + "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/network" + "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/model" "net/http" "strconv" ) func GetAllTokens(c *gin.Context) { - userId := c.GetInt("id") + userId := c.GetInt(ctxkey.Id) p, _ := strconv.Atoi(c.Query("p")) if p < 0 { p = 0 } - tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage) + + order := c.Query("order") + tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage, order) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -49,7 +55,7 @@ func GetNameByToken(c *gin.Context) { return } func SearchTokens(c *gin.Context) { - userId := c.GetInt("id") + userId := c.GetInt(ctxkey.Id) keyword := c.Query("keyword") tokens, err := model.SearchUserTokens(userId, keyword) if err != nil { @@ -69,7 +75,7 @@ func SearchTokens(c *gin.Context) { func GetToken(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) - userId := c.GetInt("id") + userId := c.GetInt(ctxkey.Id) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -94,8 +100,8 @@ func GetToken(c *gin.Context) { } func GetTokenStatus(c *gin.Context) { - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") + tokenId := c.GetInt(ctxkey.TokenId) + userId := c.GetInt(ctxkey.Id) token, err := model.GetTokenByIds(tokenId, userId) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -117,6 +123,19 @@ func GetTokenStatus(c *gin.Context) { }) } +func validateToken(c *gin.Context, token model.Token) error { + if len(token.Name) > 30 { + return fmt.Errorf("令牌名称过长") + } + if token.Subnet != nil && *token.Subnet != "" { + err := network.IsValidSubnets(*token.Subnet) + if err != nil { + return fmt.Errorf("无效的网段:%s", err.Error()) + } + } + return nil +} + func AddToken(c *gin.Context) { token := model.Token{} err := c.ShouldBindJSON(&token) @@ -127,22 +146,26 @@ func AddToken(c *gin.Context) { }) return } - if len(token.Name) > 30 { + err = validateToken(c, token) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "令牌名称过长", + "message": fmt.Sprintf("参数错误:%s", err.Error()), }) return } + cleanToken := model.Token{ - UserId: c.GetInt("id"), + UserId: c.GetInt(ctxkey.Id), Name: token.Name, - Key: helper.GenerateKey(), + Key: random.GenerateKey(), CreatedTime: helper.GetTimestamp(), AccessedTime: helper.GetTimestamp(), ExpiredTime: token.ExpiredTime, RemainQuota: token.RemainQuota, UnlimitedQuota: token.UnlimitedQuota, + Models: token.Models, + Subnet: token.Subnet, } err = cleanToken.Insert() if err != nil { @@ -155,13 +178,14 @@ func AddToken(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", + "data": cleanToken, }) return } func DeleteToken(c *gin.Context) { id, _ := strconv.Atoi(c.Param("id")) - userId := c.GetInt("id") + userId := c.GetInt(ctxkey.Id) err := model.DeleteTokenById(id, userId) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -178,7 +202,7 @@ func DeleteToken(c *gin.Context) { } func UpdateToken(c *gin.Context) { - userId := c.GetInt("id") + userId := c.GetInt(ctxkey.Id) statusOnly := c.Query("status_only") token := model.Token{} err := c.ShouldBindJSON(&token) @@ -189,10 +213,11 @@ func UpdateToken(c *gin.Context) { }) return } - if len(token.Name) > 30 { + err = validateToken(c, token) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "令牌名称过长", + "message": fmt.Sprintf("参数错误:%s", err.Error()), }) return } @@ -204,15 +229,15 @@ func UpdateToken(c *gin.Context) { }) return } - if token.Status == common.TokenStatusEnabled { - if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { + if token.Status == model.TokenStatusEnabled { + if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", }) return } - if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { + if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", @@ -228,6 +253,8 @@ func UpdateToken(c *gin.Context) { cleanToken.ExpiredTime = token.ExpiredTime cleanToken.RemainQuota = token.RemainQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota + cleanToken.Models = token.Models + cleanToken.Subnet = token.Subnet } err = cleanToken.Update() if err != nil { diff --git a/controller/user.go b/controller/user.go index c11b940e..af90acf6 100644 --- a/controller/user.go +++ b/controller/user.go @@ -5,7 +5,8 @@ import ( "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -58,11 +59,11 @@ func Login(c *gin.Context) { }) return } - setupLogin(&user, c) + SetupLogin(&user, c) } // setup session & cookies and then return user info -func setupLogin(user *model.User, c *gin.Context) { +func SetupLogin(user *model.User, c *gin.Context) { session := sessions.Default(c) session.Set("id", user.Id) session.Set("username", user.Username) @@ -184,7 +185,10 @@ func GetAllUsers(c *gin.Context) { if p < 0 { p = 0 } - users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage) + + order := c.DefaultQuery("order", "") + users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -192,12 +196,12 @@ func GetAllUsers(c *gin.Context) { }) return } + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": users, }) - return } func SearchUsers(c *gin.Context) { @@ -235,8 +239,8 @@ func GetUser(c *gin.Context) { }) return } - myRole := c.GetInt("role") - if myRole <= user.Role && myRole != common.RoleRootUser { + myRole := c.GetInt(ctxkey.Role) + if myRole <= user.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权获取同级或更高等级用户的信息", @@ -252,7 +256,7 @@ func GetUser(c *gin.Context) { } func GetUserDashboard(c *gin.Context) { - id := c.GetInt("id") + id := c.GetInt(ctxkey.Id) now := time.Now() startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix() endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix() @@ -275,7 +279,7 @@ func GetUserDashboard(c *gin.Context) { } func GenerateAccessToken(c *gin.Context) { - id := c.GetInt("id") + id := c.GetInt(ctxkey.Id) user, err := model.GetUserById(id, true) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -284,7 +288,7 @@ func GenerateAccessToken(c *gin.Context) { }) return } - user.AccessToken = helper.GetUUID() + user.AccessToken = random.GetUUID() if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { c.JSON(http.StatusOK, gin.H{ @@ -311,7 +315,7 @@ func GenerateAccessToken(c *gin.Context) { } func GetAffCode(c *gin.Context) { - id := c.GetInt("id") + id := c.GetInt(ctxkey.Id) user, err := model.GetUserById(id, true) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -321,7 +325,7 @@ func GetAffCode(c *gin.Context) { return } if user.AffCode == "" { - user.AffCode = helper.GetRandomString(4) + user.AffCode = random.GetRandomString(4) if err := user.Update(false); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -339,7 +343,7 @@ func GetAffCode(c *gin.Context) { } func GetSelf(c *gin.Context) { - id := c.GetInt("id") + id := c.GetInt(ctxkey.Id) user, err := model.GetUserById(id, false) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -384,15 +388,15 @@ func UpdateUser(c *gin.Context) { }) return } - myRole := c.GetInt("role") - if myRole <= originUser.Role && myRole != common.RoleRootUser { + myRole := c.GetInt(ctxkey.Role) + if myRole <= originUser.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权更新同权限等级或更高权限等级的用户信息", }) return } - if myRole <= updatedUser.Role && myRole != common.RoleRootUser { + if myRole <= updatedUser.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权将其他用户权限等级提升到大于等于自己的权限等级", @@ -442,7 +446,7 @@ func UpdateSelf(c *gin.Context) { } cleanUser := model.User{ - Id: c.GetInt("id"), + Id: c.GetInt(ctxkey.Id), Username: user.Username, Password: user.Password, DisplayName: user.DisplayName, @@ -506,7 +510,7 @@ func DeleteSelf(c *gin.Context) { id := c.GetInt("id") user, _ := model.GetUserById(id, false) - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "不能删除超级管理员账户", @@ -608,7 +612,7 @@ func ManageUser(c *gin.Context) { return } myRole := c.GetInt("role") - if myRole <= user.Role && myRole != common.RoleRootUser { + if myRole <= user.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权更新同权限等级或更高权限等级的用户信息", @@ -617,8 +621,8 @@ func ManageUser(c *gin.Context) { } switch req.Action { case "disable": - user.Status = common.UserStatusDisabled - if user.Role == common.RoleRootUser { + user.Status = model.UserStatusDisabled + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法禁用超级管理员用户", @@ -626,9 +630,9 @@ func ManageUser(c *gin.Context) { return } case "enable": - user.Status = common.UserStatusEnabled + user.Status = model.UserStatusEnabled case "delete": - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法删除超级管理员用户", @@ -643,37 +647,37 @@ func ManageUser(c *gin.Context) { return } case "promote": - if myRole != common.RoleRootUser { + if myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "普通管理员用户无法提升其他用户为管理员", }) return } - if user.Role >= common.RoleAdminUser { + if user.Role >= model.RoleAdminUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该用户已经是管理员", }) return } - user.Role = common.RoleAdminUser + user.Role = model.RoleAdminUser case "demote": - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法降级超级管理员用户", }) return } - if user.Role == common.RoleCommonUser { + if user.Role == model.RoleCommonUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该用户已经是普通用户", }) return } - user.Role = common.RoleCommonUser + user.Role = model.RoleCommonUser } if err := user.Update(false); err != nil { @@ -727,7 +731,7 @@ func EmailBind(c *gin.Context) { }) return } - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { config.RootUserEmail = email } c.JSON(http.StatusOK, gin.H{ @@ -767,3 +771,38 @@ func TopUp(c *gin.Context) { }) return } + +type adminTopUpRequest struct { + UserId int `json:"user_id"` + Quota int `json:"quota"` + Remark string `json:"remark"` +} + +func AdminTopUp(c *gin.Context) { + req := adminTopUpRequest{} + err := c.ShouldBindJSON(&req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + err = model.IncreaseUserQuota(req.UserId, int64(req.Quota)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if req.Remark == "" { + req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) + } + model.RecordTopupLog(req.UserId, req.Remark, req.Quota) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/docs/API.md b/docs/API.md new file mode 100644 index 00000000..0b7ddf5a --- /dev/null +++ b/docs/API.md @@ -0,0 +1,53 @@ +# 使用 API 操控 & 扩展 One API +> 欢迎提交 PR 在此放上你的拓展项目。 + +例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。 + +又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。 + +## 鉴权 +One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取: + +![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/c15281a7-83ed-47cb-a1f6-913cb6bf4a7c) + +之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API: +![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/1273b7ae-cb60-4c0d-93a6-b1cbc039c4f8) + +## 请求格式与响应格式 +One API 使用 JSON 格式进行请求和响应。 + +对于响应体,一般格式如下: +```json +{ + "message": "请求信息", + "success": true, + "data": {} +} +``` + +## API 列表 +> 当前 API 列表不全,请自行通过浏览器抓取前端请求 + +如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。 + +### 获取当前登录用户信息 +**GET** `/api/user/self` + +### 为给定用户充值额度 +**POST** `/api/topup` +```json +{ + "user_id": 1, + "quota": 100000, + "remark": "充值 100000 额度" +} +``` + +## 其他 +### 充值链接上的附加参数 +One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如: +`https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837` + +你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。 + +注意,不是所有主题都支持该功能,欢迎 PR 补齐。 \ No newline at end of file diff --git a/go.mod b/go.mod index f9ed96d3..1754ea58 100644 --- a/go.mod +++ b/go.mod @@ -1,66 +1,84 @@ module github.com/songquanpeng/one-api // +heroku goVersion go1.18 -go 1.18 +go 1.20 require ( - github.com/gin-contrib/cors v1.4.0 - github.com/gin-contrib/gzip v0.0.6 - github.com/gin-contrib/sessions v0.0.5 - github.com/gin-contrib/static v0.0.1 + github.com/aws/aws-sdk-go-v2 v1.26.1 + github.com/aws/aws-sdk-go-v2/credentials v1.17.11 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 + github.com/gin-contrib/cors v1.7.1 + github.com/gin-contrib/gzip v1.0.0 + github.com/gin-contrib/sessions v1.0.0 + github.com/gin-contrib/static v1.1.1 github.com/gin-gonic/gin v1.9.1 - github.com/go-playground/validator/v10 v10.14.0 + github.com/go-playground/validator/v10 v10.19.0 github.com/go-redis/redis/v8 v8.11.5 github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/google/uuid v1.3.0 - github.com/gorilla/websocket v1.5.0 - github.com/pkoukk/tiktoken-go v0.1.5 - github.com/stretchr/testify v1.8.3 - golang.org/x/crypto v0.17.0 - golang.org/x/image v0.14.0 - gorm.io/driver/mysql v1.4.3 - gorm.io/driver/postgres v1.5.2 - gorm.io/driver/sqlite v1.4.3 - gorm.io/gorm v1.25.0 + github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.1 + github.com/jinzhu/copier v0.4.0 + github.com/pkg/errors v0.9.1 + github.com/pkoukk/tiktoken-go v0.1.6 + github.com/smartystreets/goconvey v1.8.1 + github.com/stretchr/testify v1.9.0 + golang.org/x/crypto v0.22.0 + golang.org/x/image v0.15.0 + gorm.io/driver/mysql v1.5.6 + gorm.io/driver/postgres v1.5.7 + gorm.io/driver/sqlite v1.5.5 + gorm.io/gorm v1.25.9 ) require ( - github.com/bytedance/sonic v1.9.1 // indirect - github.com/cespare/xxhash/v2 v2.1.2 // indirect - github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + filippo.io/edwards25519 v1.1.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect + github.com/aws/smithy-go v1.20.2 // indirect + github.com/bytedance/sonic v1.11.5 // indirect + github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudwego/base64x v0.1.3 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/dlclark/regexp2 v1.10.0 // indirect - github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/dlclark/regexp2 v1.11.0 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-sql-driver/mysql v1.6.0 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/goccy/go-json v0.10.2 // indirect - github.com/gorilla/context v1.1.1 // indirect - github.com/gorilla/securecookie v1.1.1 // indirect - github.com/gorilla/sessions v1.2.1 // indirect + github.com/gopherjs/gopherjs v1.17.2 // indirect + github.com/gorilla/context v1.1.2 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect + github.com/gorilla/sessions v1.2.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.4 // indirect + github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect + github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/cpuid/v2 v2.2.4 // indirect - github.com/leodido/go-urn v1.2.4 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pelletier/go-toml/v2 v2.2.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/smarty/assertions v1.15.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - github.com/ugorji/go/codec v1.2.11 // indirect - golang.org/x/arch v0.3.0 // indirect - golang.org/x/net v0.17.0 // indirect - golang.org/x/sync v0.1.0 // indirect - golang.org/x/sys v0.15.0 // indirect + github.com/ugorji/go/codec v1.2.12 // indirect + golang.org/x/arch v0.7.0 // indirect + golang.org/x/net v0.24.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/sys v0.19.0 // indirect golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 9cf056e5..b98b377a 100644 --- a/go.sum +++ b/go.sum @@ -1,210 +1,188 @@ -github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= -github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= -github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= -github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= -github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= +github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= +github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= +github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/bytedance/sonic v1.11.5 h1:G00FYjjqll5iQ1PYXynbg/hyzqBqavH8Mo9/oTopd9k= +github.com/bytedance/sonic v1.11.5/go.mod h1:X2PC2giUdj/Cv2lliWFLk6c/DUQok5rViJSemeB0wDw= +github.com/bytedance/sonic/loader v0.1.0/go.mod h1:UmRT+IRTGKz/DAkzcEGzyVqQFJ7H9BqwBO3pm9H/+HY= +github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudwego/base64x v0.1.3 h1:b5J/l8xolB7dyDTTmhJP2oTs5LdrjyrUFuNxdfq5hAg= +github.com/cloudwego/base64x v0.1.3/go.mod h1:1+1K5BUHIQzyapgpF7LwvOGAEDicKtt1umPV+aN8pi8= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= -github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= -github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= -github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= -github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g= -github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs= -github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4= -github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk= -github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfmF+0jE= -github.com/gin-contrib/sessions v0.0.5/go.mod h1:vYAuaUPqie3WUSsft6HUlCjlwwoJQs97miaG2+7neKY= +github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= +github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/gin-contrib/cors v1.7.1 h1:s9SIppU/rk8enVvkzwiC2VK3UZ/0NNGsWfUKvV55rqs= +github.com/gin-contrib/cors v1.7.1/go.mod h1:n/Zj7B4xyrgk/cX1WCX2dkzFfaNm/xJb6oIUk7WTtps= +github.com/gin-contrib/gzip v1.0.0 h1:UKN586Po/92IDX6ie5CWLgMI81obiIp5nSP85T3wlTk= +github.com/gin-contrib/gzip v1.0.0/go.mod h1:CtG7tQrPB3vIBo6Gat9FVUsis+1emjvQqd66ME5TdnE= +github.com/gin-contrib/sessions v1.0.0 h1:r5GLta4Oy5xo9rAwMHx8B4wLpeRGHMdz9NafzJAdP8Y= +github.com/gin-contrib/sessions v1.0.0/go.mod h1:DN0f4bvpqMQElDdi+gNGScrP2QEI04IErRyMFyorUOI= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-contrib/static v0.0.1 h1:JVxuvHPuUfkoul12N7dtQw7KRn/pSMq7Ue1Va9Swm1U= -github.com/gin-contrib/static v0.0.1/go.mod h1:CSxeF+wep05e0kCOsqWdAWbSszmc31zTIbD8TvWl7Hs= -github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= -github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= +github.com/gin-contrib/static v1.1.1 h1:XEvBd4DDLG1HBlyPBQU1XO8NlTpw6mgdqcPteetYA5k= +github.com/gin-contrib/static v1.1.1/go.mod h1:yRGmar7+JYvbMLRPIi4H5TVVSBwULfT9vetnVD0IO74= github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= -github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= -github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= -github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= -github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= -github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= -github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4= +github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= -github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= -github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= -github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= -github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= -github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= -github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= +github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o= +github.com/gorilla/context v1.1.2/go.mod h1:KDPwT9i/MeWHiLl90fuTgrt4/wPcv75vFAZLaOOcbxM= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= +github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY= +github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8= -github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= +github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= +github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= -github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= -github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= -github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= -github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= -github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= -github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= -github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4= -github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= +github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg= +github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= +github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= -github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= -github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= -github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= -github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= -github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= -github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= -golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= -golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc= +golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= +golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= +golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= +golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k= -gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= -gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= -gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= -gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= -gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= -gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= -gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= -gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= -gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= +gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= +gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= +gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E= +gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.9 h1:wct0gxZIELDk8+ZqF/MVnHLkA1rvYlBWUMv2EdsK1g8= +gorm.io/gorm v1.25.9/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/i18n/en.json b/i18n/en.json index 54728e2f..b7f1bd3e 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -8,12 +8,12 @@ "确认删除": "Confirm Delete", "确认绑定": "Confirm Binding", "您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.", - "\"通道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", - "通道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", + "\"渠道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", + "渠道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", "测试已在运行中": "Test is already running", "响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs", - "通道测试完成": "Channel test completed", - "通道测试完成,如果没有收到禁用通知,说明所有通道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal", + "渠道测试完成": "Channel test completed", + "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal", "无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!", "返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!", "管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub", @@ -119,11 +119,11 @@ " 个月 ": " M ", " 年 ": " y ", "未测试": "Not tested", - "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", - "已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", - "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", - "通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", - "已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!", + "渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", + "已成功开始测试所有渠道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", + "已成功开始测试所有已启用渠道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", + "渠道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", + "已更新完毕所有已启用渠道余额!": "The balance of all enabled channels has been updated!", "搜索渠道的 ID,名称和密钥 ...": "Search for channel ID, name and key ...", "名称": "Name", "分组": "Group", @@ -141,9 +141,9 @@ "启用": "Enable", "编辑": "Edit", "添加新的渠道": "Add a new channel", - "测试所有通道": "Test all channels", - "测试所有已启用通道": "Test all enabled channels", - "更新所有已启用通道余额": "Update the balance of all enabled channels", + "测试所有渠道": "Test all channels", + "测试所有已启用渠道": "Test all enabled channels", + "更新所有已启用渠道余额": "Update the balance of all enabled channels", "刷新": "Refresh", "处理中...": "Processing...", "绑定成功!": "Binding succeeded!", @@ -207,11 +207,11 @@ "监控设置": "Monitoring Settings", "最长响应时间": "Longest Response Time", "单位秒": "Unit in seconds", - "当运行通道全部测试时": "When all operating channels are tested", - "超过此时间将自动禁用通道": "Channels will be automatically disabled if this time is exceeded", + "当运行渠道全部测试时": "When all operating channels are tested", + "超过此时间将自动禁用渠道": "Channels will be automatically disabled if this time is exceeded", "额度提醒阈值": "Quota reminder threshold", "低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this", - "失败时自动禁用通道": "Automatically disable the channel when it fails", + "失败时自动禁用渠道": "Automatically disable the channel when it fails", "保存监控设置": "Save Monitoring Settings", "额度设置": "Quota Settings", "新用户初始额度": "Initial quota for new users", @@ -405,7 +405,7 @@ "镜像": "Mirror", "请输入镜像站地址,格式为:https://domain.com,可不填,不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used", "模型": "Model", - "请选择该通道所支持的模型": "Please select the model supported by the channel", + "请选择该渠道所支持的模型": "Please select the model supported by the channel", "填入基础模型": "Fill in the basic model", "填入所有模型": "Fill in all models", "清除所有模型": "Clear all models", @@ -515,7 +515,7 @@ "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", "Homepage URL 填": "Fill in the Homepage URL", "Authorization callback URL 填": "Fill in the Authorization callback URL", - "请为通道命名": "Please name the channel", + "请为渠道命名": "Please name the channel", "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", "模型重定向": "Model redirection", "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", diff --git a/main.go b/main.go index b20c6daf..bdcdcd61 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,7 @@ import ( "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/router" "os" "strconv" @@ -71,7 +71,7 @@ func main() { } if config.MemoryCacheEnabled { logger.SysLog("memory cache enabled") - logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency)) + logger.SysLog(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency)) model.InitChannelCache() } if config.MemoryCacheEnabled { diff --git a/middleware/auth.go b/middleware/auth.go index 30997efd..5cba490a 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,10 +1,12 @@ package middleware import ( + "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/blacklist" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/model" "net/http" "strings" @@ -43,7 +45,7 @@ func authHelper(c *gin.Context, minRole int) { return } } - if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { + if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "用户已被封禁", @@ -70,24 +72,25 @@ func authHelper(c *gin.Context, minRole int) { func UserAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleCommonUser) + authHelper(c, model.RoleCommonUser) } } func AdminAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleAdminUser) + authHelper(c, model.RoleAdminUser) } } func RootAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleRootUser) + authHelper(c, model.RoleRootUser) } } func TokenAuth() func(c *gin.Context) { return func(c *gin.Context) { + ctx := c.Request.Context() key := c.Request.Header.Get("Authorization") key = strings.TrimPrefix(key, "Bearer ") key = strings.TrimPrefix(key, "sk-") @@ -98,6 +101,12 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusUnauthorized, err.Error()) return } + if token.Subnet != nil && *token.Subnet != "" { + if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) { + abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP())) + return + } + } userEnabled, err := model.CacheIsUserEnabled(token.UserId) if err != nil { abortWithMessage(c, http.StatusInternalServerError, err.Error()) @@ -107,12 +116,25 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusForbidden, "用户已被封禁") return } - c.Set("id", token.UserId) - c.Set("token_id", token.Id) - c.Set("token_name", token.Name) + requestModel, err := getRequestModel(c) + if err != nil && shouldCheckModel(c) { + abortWithMessage(c, http.StatusBadRequest, err.Error()) + return + } + c.Set(ctxkey.RequestModel, requestModel) + if token.Models != nil && *token.Models != "" { + c.Set(ctxkey.AvailableModels, *token.Models) + if requestModel != "" && !isModelInList(requestModel, *token.Models) { + abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) + return + } + } + c.Set(ctxkey.Id, token.UserId) + c.Set(ctxkey.TokenId, token.Id) + c.Set(ctxkey.TokenName, token.Name) if len(parts) > 1 { if model.IsAdmin(token.UserId) { - c.Set("specific_channel_id", parts[1]) + c.Set(ctxkey.SpecificChannelId, parts[1]) } else { abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return @@ -121,3 +143,19 @@ func TokenAuth() func(c *gin.Context) { c.Next() } } + +func shouldCheckModel(c *gin.Context) bool { + if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { + return true + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { + return true + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images") { + return true + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + return true + } + return false +} diff --git a/middleware/distributor.go b/middleware/distributor.go index f57d3cfc..d0fd7ba5 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -2,14 +2,13 @@ package middleware import ( "fmt" - "github.com/songquanpeng/one-api/common" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/channeltype" "net/http" "strconv" - "strings" - - "github.com/gin-gonic/gin" ) type ModelRequest struct { @@ -18,12 +17,12 @@ type ModelRequest struct { func Distribute() func(c *gin.Context) { return func(c *gin.Context) { - userId := c.GetInt("id") + userId := c.GetInt(ctxkey.Id) userGroup, _ := model.CacheGetUserGroup(userId) - c.Set("group", userGroup) + c.Set(ctxkey.Group, userGroup) var requestModel string var channel *model.Channel - channelId, ok := c.Get("specific_channel_id") + channelId, ok := c.Get(ctxkey.SpecificChannelId) if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { @@ -35,47 +34,16 @@ func Distribute() func(c *gin.Context) { abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } - if channel.Status != common.ChannelStatusEnabled { + if channel.Status != model.ChannelStatusEnabled { abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") return } } else { - // Select a channel for the user - var modelRequest ModelRequest - err := common.UnmarshalBodyReusable(c, &modelRequest) + requestModel = c.GetString(ctxkey.RequestModel) + var err error + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的请求") - return - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - if modelRequest.Model == "" { - modelRequest.Model = "text-moderation-stable" - } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - if modelRequest.Model == "" { - modelRequest.Model = c.Param("model") - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.Model == "" { - modelRequest.Model = "dall-e-2" - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - if modelRequest.Model == "" { - modelRequest.Model = "whisper-1" - } - } - - if strings.HasPrefix(modelRequest.Model, "gpt-4-gizmo") { - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, "gpt-4-gizmo",false) - } else { - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model,false) - } - - if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel) if channel != nil { logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" @@ -90,28 +58,36 @@ func Distribute() func(c *gin.Context) { } func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { - c.Set("channel", channel.Type) - c.Set("channel_id", channel.Id) - c.Set("channel_name", channel.Name) - c.Set("model_mapping", channel.GetModelMapping()) - c.Set("original_model", modelName) // for retry + c.Set(ctxkey.Channel, channel.Type) + c.Set(ctxkey.ChannelId, channel.Id) + c.Set(ctxkey.ChannelName, channel.Name) + c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) + c.Set(ctxkey.OriginalModel, modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - c.Set("base_url", channel.GetBaseURL()) + c.Set(ctxkey.BaseURL, channel.GetBaseURL()) + cfg, _ := channel.LoadConfig() // this is for backward compatibility switch channel.Type { - case common.ChannelTypeAzure: - c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeXunfei: - c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeGemini: - c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeAIProxyLibrary: - c.Set(common.ConfigKeyLibraryID, channel.Other) - case common.ChannelTypeAli: - c.Set(common.ConfigKeyPlugin, channel.Other) - } - cfg, _ := channel.LoadConfig() - for k, v := range cfg { - c.Set(common.ConfigKeyPrefix+k, v) + case channeltype.Azure: + if cfg.APIVersion == "" { + cfg.APIVersion = channel.Other + } + case channeltype.Xunfei: + if cfg.APIVersion == "" { + cfg.APIVersion = channel.Other + } + case channeltype.Gemini: + if cfg.APIVersion == "" { + cfg.APIVersion = channel.Other + } + case channeltype.AIProxyLibrary: + if cfg.LibraryID == "" { + cfg.LibraryID = channel.Other + } + case channeltype.Ali: + if cfg.Plugin == "" { + cfg.Plugin = channel.Other + } } + c.Set(ctxkey.Config, cfg) } diff --git a/middleware/logger.go b/middleware/logger.go index 6aae4f23..191364f8 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -3,14 +3,14 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/helper" ) func SetUpLogger(server *gin.Engine) { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { var requestID string if param.Keys != nil { - requestID = param.Keys[logger.RequestIdKey].(string) + requestID = param.Keys[helper.RequestIdKey].(string) } return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", param.TimeStamp.Format("2006/01/02 - 15:04:05"), diff --git a/middleware/request-id.go b/middleware/request-id.go index a4c49ddb..bef09e32 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -4,16 +4,15 @@ import ( "context" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/helper" - "github.com/songquanpeng/one-api/common/logger" ) func RequestId() func(c *gin.Context) { return func(c *gin.Context) { id := helper.GenRequestID() - c.Set(logger.RequestIdKey, id) - ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) + c.Set(helper.RequestIdKey, id) + ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id) c.Request = c.Request.WithContext(ctx) - c.Header(logger.RequestIdKey, id) + c.Header(helper.RequestIdKey, id) c.Next() } } diff --git a/middleware/utils.go b/middleware/utils.go index bc14c367..4d2f8092 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -1,18 +1,60 @@ package middleware import ( + "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "strings" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, gin.H{ "error": gin.H{ - "message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)), + "message": helper.MessageWithRequestId(message, c.GetString(helper.RequestIdKey)), "type": "one_api_error", }, }) c.Abort() logger.Error(c.Request.Context(), message) } + +func getRequestModel(c *gin.Context) (string, error) { + var modelRequest ModelRequest + err := common.UnmarshalBodyReusable(c, &modelRequest) + if err != nil { + return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + if modelRequest.Model == "" { + modelRequest.Model = "text-moderation-stable" + } + } + if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + if modelRequest.Model == "" { + modelRequest.Model = c.Param("model") + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + if modelRequest.Model == "" { + modelRequest.Model = "dall-e-2" + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + if modelRequest.Model == "" { + modelRequest.Model = "whisper-1" + } + } + return modelRequest.Model, nil +} + +func isModelInList(modelName string, models string) bool { + modelList := strings.Split(models, ",") + for _, model := range modelList { + if modelName == model { + return true + } + } + return false +} diff --git a/model/ability.go b/model/ability.go index 7127abc3..2db72518 100644 --- a/model/ability.go +++ b/model/ability.go @@ -1,7 +1,10 @@ package model import ( + "context" "github.com/songquanpeng/one-api/common" + "gorm.io/gorm" + "sort" "strings" ) @@ -13,7 +16,7 @@ type Ability struct { Priority *int64 `json:"priority" gorm:"bigint;default:0;index"` } -func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { +func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { ability := Ability{} groupCol := "`group`" trueVal := "1" @@ -23,8 +26,13 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { } var err error = nil - maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) - channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) + var channelQuery *gorm.DB + if ignoreFirstPriority { + channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) + } else { + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) + channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) + } if common.UsingSQLite || common.UsingPostgreSQL { err = channelQuery.Order("RANDOM()").First(&ability).Error } else { @@ -49,7 +57,7 @@ func (channel *Channel) AddAbilities() error { Group: group, Model: model, ChannelId: channel.Id, - Enabled: channel.Status == common.ChannelStatusEnabled, + Enabled: channel.Status == ChannelStatusEnabled, Priority: channel.Priority, } abilities = append(abilities, ability) @@ -82,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error { func UpdateAbilityStatus(channelId int, status bool) error { return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error } + +func GetGroupModels(ctx context.Context, group string) ([]string, error) { + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + var models []string + err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error + if err != nil { + return nil, err + } + sort.Strings(models) + return models, err +} diff --git a/model/cache.go b/model/cache.go index dd20d857..cfb0f8a4 100644 --- a/model/cache.go +++ b/model/cache.go @@ -8,6 +8,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" "math/rand" "sort" "strconv" @@ -21,6 +22,7 @@ var ( UserId2GroupCacheSeconds = config.SyncFrequency UserId2QuotaCacheSeconds = config.SyncFrequency UserId2StatusCacheSeconds = config.SyncFrequency + GroupModelsCacheSeconds = config.SyncFrequency ) func CacheGetTokenByKey(key string) (*Token, error) { @@ -146,13 +148,32 @@ func CacheIsUserEnabled(userId int) (bool, error) { return userEnabled, err } +func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { + if !common.RedisEnabled { + return GetGroupModels(ctx, group) + } + modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group)) + if err == nil { + return strings.Split(modelsStr, ","), nil + } + models, err := GetGroupModels(ctx, group) + if err != nil { + return nil, err + } + err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second) + if err != nil { + logger.SysError("Redis set group models error: " + err.Error()) + } + return models, nil +} + var group2model2channels map[string]map[string][]*Channel var channelSyncLock sync.RWMutex func InitChannelCache() { newChannelId2channel := make(map[int]*Channel) var channels []*Channel - DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) + DB.Where("status = ?", ChannelStatusEnabled).Find(&channels) for _, channel := range channels { newChannelId2channel[channel.Id] = channel } @@ -205,7 +226,7 @@ func SyncChannelCache(frequency int) { func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { if !config.MemoryCacheEnabled { - return GetRandomSatisfiedChannel(group, model) + return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority) } channelSyncLock.RLock() defer channelSyncLock.RUnlock() @@ -227,7 +248,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPrior idx := rand.Intn(endIdx) if ignoreFirstPriority { if endIdx < len(channels) { // which means there are more than one priority - idx = common.RandRange(endIdx, len(channels)) + idx = random.RandRange(endIdx, len(channels)) } } return channels[idx], nil diff --git a/model/channel.go b/model/channel.go index fc4905b1..ec52683e 100644 --- a/model/channel.go +++ b/model/channel.go @@ -3,13 +3,19 @@ package model import ( "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/gorm" ) +const ( + ChannelStatusUnknown = 0 + ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! + ChannelStatusManuallyDisabled = 2 // also don't use 0 + ChannelStatusAutoDisabled = 3 +) + type Channel struct { Id int `json:"id"` Type int `json:"type" gorm:"default:0"` @@ -32,6 +38,16 @@ type Channel struct { Config string `json:"config"` } +type ChannelConfig struct { + Region string `json:"region,omitempty"` + SK string `json:"sk,omitempty"` + AK string `json:"ak,omitempty"` + UserID string `json:"user_id,omitempty"` + APIVersion string `json:"api_version,omitempty"` + LibraryID string `json:"library_id,omitempty"` + Plugin string `json:"plugin,omitempty"` +} + func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { var channels []*Channel var err error @@ -39,7 +55,7 @@ func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { case "all": err = DB.Order("id desc").Find(&channels).Error case "disabled": - err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error + err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error default: err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error } @@ -155,20 +171,20 @@ func (channel *Channel) Delete() error { return err } -func (channel *Channel) LoadConfig() (map[string]string, error) { +func (channel *Channel) LoadConfig() (ChannelConfig, error) { + var cfg ChannelConfig if channel.Config == "" { - return nil, nil + return cfg, nil } - cfg := make(map[string]string) err := json.Unmarshal([]byte(channel.Config), &cfg) if err != nil { - return nil, err + return cfg, err } return cfg, nil } func UpdateChannelStatusById(id int, status int) { - err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) + err := UpdateAbilityStatus(id, status == ChannelStatusEnabled) if err != nil { logger.SysError("failed to update ability status: " + err.Error()) } @@ -199,6 +215,6 @@ func DeleteChannelByStatus(status int64) (int64, error) { } func DeleteDisabledChannel() (int64, error) { - result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) + result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{}) return result.RowsAffected, result.Error } diff --git a/model/log.go b/model/log.go index 9bde778f..b2fd0ea2 100644 --- a/model/log.go +++ b/model/log.go @@ -7,7 +7,6 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "gorm.io/gorm" ) @@ -51,6 +50,21 @@ func RecordLog(userId int, logType int, content string) { } } +func RecordTopupLog(userId int, content string, quota int) { + log := &Log{ + UserId: userId, + Username: GetUsernameById(userId), + CreatedAt: helper.GetTimestamp(), + Type: LogTypeTopup, + Content: content, + Quota: quota, + } + err := LOG_DB.Create(log).Error + if err != nil { + logger.SysError("failed to record log: " + err.Error()) + } +} + func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !config.LogConsumeEnabled { diff --git a/model/main.go b/model/main.go index ca7a35b2..4b5323c4 100644 --- a/model/main.go +++ b/model/main.go @@ -7,6 +7,7 @@ import ( "github.com/songquanpeng/one-api/common/env" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" @@ -23,7 +24,7 @@ func CreateRootAccountIfNeed() error { var user User //if user.Status != util.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { - logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") + logger.SysLog("no user exists, creating a root user for you: username is root, password is 123456") hashedPassword, err := common.Password2Hash("123456") if err != nil { return err @@ -31,13 +32,29 @@ func CreateRootAccountIfNeed() error { rootUser := User{ Username: "root", Password: hashedPassword, - Role: common.RoleRootUser, - Status: common.UserStatusEnabled, + Role: RoleRootUser, + Status: UserStatusEnabled, DisplayName: "Root User", - AccessToken: helper.GetUUID(), - Quota: 100000000, + AccessToken: random.GetUUID(), + Quota: 500000000000000, } DB.Create(&rootUser) + if config.InitialRootToken != "" { + logger.SysLog("creating initial root token as requested") + token := Token{ + Id: 1, + UserId: rootUser.Id, + Key: config.InitialRootToken, + Status: TokenStatusEnabled, + Name: "Initial Root Token", + CreatedTime: helper.GetTimestamp(), + AccessedTime: helper.GetTimestamp(), + ExpiredTime: -1, + RemainQuota: 500000000000000, + UnlimitedQuota: true, + } + DB.Create(&token) + } } return nil } diff --git a/model/option.go b/model/option.go index 1d1c28b4..bed8d4c3 100644 --- a/model/option.go +++ b/model/option.go @@ -1,9 +1,9 @@ package model import ( - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "strconv" "strings" "time" @@ -66,9 +66,9 @@ func InitOptionMap() { config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) - config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() - config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() - config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() + config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString() + config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString() + config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString() config.OptionMap["TopUpLink"] = config.TopUpLink config.OptionMap["ChatLink"] = config.ChatLink config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) @@ -82,7 +82,7 @@ func loadOptionsFromDatabase() { options, _ := AllOption() for _, option := range options { if option.Key == "ModelRatio" { - option.Value = common.AddNewMissingRatio(option.Value) + option.Value = billingratio.AddNewMissingRatio(option.Value) } err := updateOptionMap(option.Key, option.Value) if err != nil { @@ -172,6 +172,10 @@ func updateOptionMap(key string, value string) (err error) { config.GitHubClientId = value case "GitHubClientSecret": config.GitHubClientSecret = value + case "LarkClientId": + config.LarkClientId = value + case "LarkClientSecret": + config.LarkClientSecret = value case "Footer": config.Footer = value case "SystemName": @@ -205,11 +209,11 @@ func updateOptionMap(key string, value string) (err error) { case "RetryTimes": config.RetryTimes, _ = strconv.Atoi(value) case "ModelRatio": - err = common.UpdateModelRatioByJSONString(value) + err = billingratio.UpdateModelRatioByJSONString(value) case "GroupRatio": - err = common.UpdateGroupRatioByJSONString(value) + err = billingratio.UpdateGroupRatioByJSONString(value) case "CompletionRatio": - err = common.UpdateCompletionRatioByJSONString(value) + err = billingratio.UpdateCompletionRatioByJSONString(value) case "TopUpLink": config.TopUpLink = value case "ChatLink": diff --git a/model/redemption.go b/model/redemption.go index e0ae68e2..45871a71 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -8,13 +8,19 @@ import ( "gorm.io/gorm" ) +const ( + RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value! + RedemptionCodeStatusDisabled = 2 // also don't use 0 + RedemptionCodeStatusUsed = 3 // also don't use 0 +) + type Redemption struct { Id int `json:"id"` UserId int `json:"user_id"` Key string `json:"key" gorm:"type:char(32);uniqueIndex"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` - Quota int64 `json:"quota" gorm:"default:100"` + Quota int64 `json:"quota" gorm:"bigint;default:100"` CreatedTime int64 `json:"created_time" gorm:"bigint"` RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"` Count int `json:"count" gorm:"-:all"` // only for api request @@ -61,7 +67,7 @@ func Redeem(key string, userId int) (quota int64, err error) { if err != nil { return errors.New("无效的兑换码") } - if redemption.Status != common.RedemptionCodeStatusEnabled { + if redemption.Status != RedemptionCodeStatusEnabled { return errors.New("该兑换码已被使用") } err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error @@ -69,7 +75,7 @@ func Redeem(key string, userId int) (quota int64, err error) { return err } redemption.RedeemedTime = helper.GetTimestamp() - redemption.Status = common.RedemptionCodeStatusUsed + redemption.Status = RedemptionCodeStatusUsed err = tx.Save(redemption).Error return err }) diff --git a/model/token.go b/model/token.go index 4cc6cf98..abc54d35 100644 --- a/model/token.go +++ b/model/token.go @@ -13,24 +13,44 @@ import ( "gorm.io/gorm" ) +const ( + TokenStatusEnabled = 1 // don't use 0, 0 is the default value! + TokenStatusDisabled = 2 // also don't use 0 + TokenStatusExpired = 3 + TokenStatusExhausted = 4 +) + type Token struct { - Id int `json:"id"` - UserId int `json:"user_id"` - Key string `json:"key" gorm:"type:char(48);uniqueIndex"` - Status int `json:"status" gorm:"default:1"` - Name string `json:"name" gorm:"index" ` - CreatedTime int64 `json:"created_time" gorm:"bigint"` - AccessedTime int64 `json:"accessed_time" gorm:"bigint"` - ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired - RemainQuota int64 `json:"remain_quota" gorm:"default:0"` - UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` - UsedQuota int64 `json:"used_quota" gorm:"default:0"` // used quota + Id int `json:"id"` + UserId int `json:"user_id"` + Key string `json:"key" gorm:"type:char(48);uniqueIndex"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index" ` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + AccessedTime int64 `json:"accessed_time" gorm:"bigint"` + ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired + RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` + UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` + UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota + Models *string `json:"models" gorm:"default:''"` // allowed models + Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet } -func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { +func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { var tokens []*Token var err error - err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error + query := DB.Where("user_id = ?", userId) + + switch order { + case "remain_quota": + query = query.Order("unlimited_quota desc, remain_quota desc") + case "used_quota": + query = query.Order("used_quota desc") + default: + query = query.Order("id desc") + } + + err = query.Limit(num).Offset(startIdx).Find(&tokens).Error return tokens, err } @@ -59,17 +79,17 @@ func ValidateUserToken(key string) (token *Token, err error) { } return nil, errors.New("令牌验证失败") } - if token.Status == common.TokenStatusExhausted { - return nil, errors.New("该令牌额度已用尽") - } else if token.Status == common.TokenStatusExpired { + if token.Status == TokenStatusExhausted { + return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id) + } else if token.Status == TokenStatusExpired { return nil, errors.New("该令牌已过期") } - if token.Status != common.TokenStatusEnabled { + if token.Status != TokenStatusEnabled { return nil, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { if !common.RedisEnabled { - token.Status = common.TokenStatusExpired + token.Status = TokenStatusExpired err := token.SelectUpdate() if err != nil { logger.SysError("failed to update token status" + err.Error()) @@ -80,7 +100,7 @@ func ValidateUserToken(key string) (token *Token, err error) { if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !common.RedisEnabled { // in this case, we can make sure the token is exhausted - token.Status = common.TokenStatusExhausted + token.Status = TokenStatusExhausted err := token.SelectUpdate() if err != nil { logger.SysError("failed to update token status" + err.Error()) @@ -120,7 +140,7 @@ func (token *Token) Insert() error { // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error + err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error return err } diff --git a/model/user.go b/model/user.go index e325394b..1dc633b1 100644 --- a/model/user.go +++ b/model/user.go @@ -6,12 +6,25 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" "gorm.io/gorm" "strings" ) +const ( + RoleGuestUser = 0 + RoleCommonUser = 1 + RoleAdminUser = 10 + RoleRootUser = 100 +) + +const ( + UserStatusEnabled = 1 // don't use 0, 0 is the default value! + UserStatusDisabled = 2 // also don't use 0 + UserStatusDeleted = 3 +) + // User if you add sensitive fields, don't forget to clean them in setupLogin function. // Otherwise, the sensitive information will be saved on local storage in plain text! type User struct { @@ -24,11 +37,12 @@ type User struct { Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` + LarkId string `json:"lark_id" gorm:"column:lark_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management - Quota int64 `json:"quota" gorm:"type:int;default:0"` - UsedQuota int64 `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota - RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number + Quota int64 `json:"quota" gorm:"bigint;default:0"` + UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0;column:used_quota"` // used quota + RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number Group string `json:"group" gorm:"type:varchar(32);default:'default'"` AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` @@ -40,8 +54,21 @@ func GetMaxUserId() int { return user.Id } -func GetAllUsers(startIdx int, num int) (users []*User, err error) { - err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted).Find(&users).Error +func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) { + query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted) + + switch order { + case "quota": + query = query.Order("quota desc") + case "used_quota": + query = query.Order("used_quota desc") + case "request_count": + query = query.Order("request_count desc") + default: + query = query.Order("id desc") + } + + err = query.Find(&users).Error return users, err } @@ -94,8 +121,8 @@ func (user *User) Insert(inviterId int) error { } } user.Quota = config.QuotaForNewUser - user.AccessToken = helper.GetUUID() - user.AffCode = helper.GetRandomString(4) + user.AccessToken = random.GetUUID() + user.AffCode = random.GetRandomString(4) result := DB.Create(user) if result.Error != nil { return result.Error @@ -124,9 +151,9 @@ func (user *User) Update(updatePassword bool) error { return err } } - if user.Status == common.UserStatusDisabled { + if user.Status == UserStatusDisabled { blacklist.BanUser(user.Id) - } else if user.Status == common.UserStatusEnabled { + } else if user.Status == UserStatusEnabled { blacklist.UnbanUser(user.Id) } err = DB.Model(user).Updates(user).Error @@ -138,8 +165,8 @@ func (user *User) Delete() error { return errors.New("id 为空!") } blacklist.BanUser(user.Id) - user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID()) - user.Status = common.UserStatusDeleted + user.Username = fmt.Sprintf("deleted_%s", random.GetUUID()) + user.Status = UserStatusDeleted err := DB.Model(user).Updates(user).Error return err } @@ -163,7 +190,7 @@ func (user *User) ValidateAndFill() (err error) { } } okay := common.ValidatePasswordAndHash(password, user.Password) - if !okay || user.Status != common.UserStatusEnabled { + if !okay || user.Status != UserStatusEnabled { return errors.New("用户名或密码错误,或用户已被封禁") } return nil @@ -193,6 +220,14 @@ func (user *User) FillUserByGitHubId() error { return nil } +func (user *User) FillUserByLarkId() error { + if user.LarkId == "" { + return errors.New("lark id 为空!") + } + DB.Where(User{LarkId: user.LarkId}).First(user) + return nil +} + func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") @@ -221,6 +256,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool { return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 } +func IsLarkIdAlreadyTaken(githubId string) bool { + return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 +} + func IsUsernameAlreadyTaken(username string) bool { return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 } @@ -247,7 +286,7 @@ func IsAdmin(userId int) bool { logger.SysError("no such user " + err.Error()) return false } - return user.Role >= common.RoleAdminUser + return user.Role >= RoleAdminUser } func IsUserEnabled(userId int) (bool, error) { @@ -259,7 +298,7 @@ func IsUserEnabled(userId int) (bool, error) { if err != nil { return false, err } - return user.Status == common.UserStatusEnabled, nil + return user.Status == UserStatusEnabled, nil } func ValidateAccessToken(token string) (user *User) { @@ -332,7 +371,7 @@ func decreaseUserQuota(id int, quota int64) (err error) { } func GetRootUserEmail() (email string) { - DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) + DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email) return email } diff --git a/monitor/channel.go b/monitor/channel.go index 597ab11a..7e5dc58a 100644 --- a/monitor/channel.go +++ b/monitor/channel.go @@ -2,7 +2,6 @@ package monitor import ( "fmt" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/message" @@ -29,27 +28,27 @@ func notifyRootUser(subject string, content string) { // DisableChannel disable & notify func DisableChannel(channelId int, channelName string, reason string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + subject := fmt.Sprintf("渠道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("渠道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) notifyRootUser(subject, content) } func MetricDisableChannel(channelId int, successRate float64) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) - subject := fmt.Sprintf("通道 #%d 已被禁用", channelId) - content := fmt.Sprintf("该渠道在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", - config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100) + subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId) + content := fmt.Sprintf("该渠道(#%d)在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", + channelId, config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100) notifyRootUser(subject, content) } // EnableChannel enable & notify func EnableChannel(channelId int, channelName string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) + model.UpdateChannelStatusById(channelId, model.ChannelStatusEnabled) logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) - subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + subject := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) + content := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) notifyRootUser(subject, content) } diff --git a/monitor/manage.go b/monitor/manage.go new file mode 100644 index 00000000..946e78af --- /dev/null +++ b/monitor/manage.go @@ -0,0 +1,62 @@ +package monitor + +import ( + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/relay/model" + "net/http" + "strings" +) + +func ShouldDisableChannel(err *model.Error, statusCode int) bool { + if !config.AutomaticDisableChannelEnabled { + return false + } + if err == nil { + return false + } + if statusCode == http.StatusUnauthorized { + return true + } + switch err.Type { + case "insufficient_quota": + return true + // https://docs.anthropic.com/claude/reference/errors + case "authentication_error": + return true + case "permission_error": + return true + case "forbidden": + return true + } + if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + return true + } + if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic + return true + } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { + return true + } + //if strings.Contains(err.Message, "quota") { + // return true + //} + if strings.Contains(err.Message, "credit") { + return true + } + if strings.Contains(err.Message, "balance") { + return true + } + return false +} + +func ShouldEnableChannel(err error, openAIErr *model.Error) bool { + if !config.AutomaticEnableChannelEnabled { + return false + } + if err != nil { + return false + } + if openAIErr != nil { + return false + } + return true +} diff --git a/relay/adaptor.go b/relay/adaptor.go new file mode 100644 index 00000000..794a84a6 --- /dev/null +++ b/relay/adaptor.go @@ -0,0 +1,60 @@ +package relay + +import ( + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/aiproxy" + "github.com/songquanpeng/one-api/relay/adaptor/ali" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws" + "github.com/songquanpeng/one-api/relay/adaptor/baidu" + "github.com/songquanpeng/one-api/relay/adaptor/cloudflare" + "github.com/songquanpeng/one-api/relay/adaptor/cohere" + "github.com/songquanpeng/one-api/relay/adaptor/coze" + "github.com/songquanpeng/one-api/relay/adaptor/deepl" + "github.com/songquanpeng/one-api/relay/adaptor/gemini" + "github.com/songquanpeng/one-api/relay/adaptor/ollama" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/adaptor/palm" + "github.com/songquanpeng/one-api/relay/adaptor/tencent" + "github.com/songquanpeng/one-api/relay/adaptor/xunfei" + "github.com/songquanpeng/one-api/relay/adaptor/zhipu" + "github.com/songquanpeng/one-api/relay/apitype" +) + +func GetAdaptor(apiType int) adaptor.Adaptor { + switch apiType { + case apitype.AIProxyLibrary: + return &aiproxy.Adaptor{} + case apitype.Ali: + return &ali.Adaptor{} + case apitype.Anthropic: + return &anthropic.Adaptor{} + case apitype.AwsClaude: + return &aws.Adaptor{} + case apitype.Baidu: + return &baidu.Adaptor{} + case apitype.Gemini: + return &gemini.Adaptor{} + case apitype.OpenAI: + return &openai.Adaptor{} + case apitype.PaLM: + return &palm.Adaptor{} + case apitype.Tencent: + return &tencent.Adaptor{} + case apitype.Xunfei: + return &xunfei.Adaptor{} + case apitype.Zhipu: + return &zhipu.Adaptor{} + case apitype.Ollama: + return &ollama.Adaptor{} + case apitype.Coze: + return &coze.Adaptor{} + case apitype.Cohere: + return &cohere.Adaptor{} + case apitype.Cloudflare: + return &cloudflare.Adaptor{} + case apitype.DeepL: + return &deepl.Adaptor{} + } + return nil +} diff --git a/relay/channel/ai360/constants.go b/relay/adaptor/ai360/constants.go similarity index 100% rename from relay/channel/ai360/constants.go rename to relay/adaptor/ai360/constants.go diff --git a/relay/channel/aiproxy/adaptor.go b/relay/adaptor/aiproxy/adaptor.go similarity index 54% rename from relay/channel/aiproxy/adaptor.go rename to relay/adaptor/aiproxy/adaptor.go index 2b4e3022..42d49c0a 100644 --- a/relay/channel/aiproxy/adaptor.go +++ b/relay/adaptor/aiproxy/adaptor.go @@ -4,27 +4,27 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) type Adaptor struct { + meta *meta.Meta } -func (a *Adaptor) Init(meta *util.RelayMeta) { - +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) req.Header.Set("Authorization", "Bearer "+meta.APIKey) return nil } @@ -34,15 +34,22 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } aiProxyLibraryRequest := ConvertRequest(*request) - aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID) + aiProxyLibraryRequest.LibraryId = a.meta.Config.LibraryID return aiProxyLibraryRequest, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = StreamHandler(c, resp) } else { diff --git a/relay/channel/aiproxy/constants.go b/relay/adaptor/aiproxy/constants.go similarity index 60% rename from relay/channel/aiproxy/constants.go rename to relay/adaptor/aiproxy/constants.go index c4df51c4..818d2709 100644 --- a/relay/channel/aiproxy/constants.go +++ b/relay/adaptor/aiproxy/constants.go @@ -1,6 +1,6 @@ package aiproxy -import "github.com/songquanpeng/one-api/relay/channel/openai" +import "github.com/songquanpeng/one-api/relay/adaptor/openai" var ModelList = []string{""} diff --git a/relay/channel/aiproxy/main.go b/relay/adaptor/aiproxy/main.go similarity index 95% rename from relay/channel/aiproxy/main.go rename to relay/adaptor/aiproxy/main.go index 96972407..01a568f6 100644 --- a/relay/channel/aiproxy/main.go +++ b/relay/adaptor/aiproxy/main.go @@ -8,7 +8,8 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "io" @@ -53,7 +54,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon FinishReason: "stop", } fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion", Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, @@ -66,7 +67,7 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion choice.Delta.Content = aiProxyDocuments2Markdown(documents) choice.FinishReason = &constant.StopFinishReason return &openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: "", @@ -78,7 +79,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = response.Content return &openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: response.Model, diff --git a/relay/channel/aiproxy/model.go b/relay/adaptor/aiproxy/model.go similarity index 100% rename from relay/channel/aiproxy/model.go rename to relay/adaptor/aiproxy/model.go diff --git a/relay/adaptor/ali/adaptor.go b/relay/adaptor/ali/adaptor.go new file mode 100644 index 00000000..4aa8a11a --- /dev/null +++ b/relay/adaptor/ali/adaptor.go @@ -0,0 +1,105 @@ +package ali + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" +) + +// https://help.aliyun.com/zh/dashscope/developer-reference/api-details + +type Adaptor struct { + meta *meta.Meta +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + fullRequestURL := "" + switch meta.Mode { + case relaymode.Embeddings: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL) + case relaymode.ImagesGenerations: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL) + default: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) + } + + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + if meta.IsStream { + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("X-DashScope-SSE", "enable") + } + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + + if meta.Mode == relaymode.ImagesGenerations { + req.Header.Set("X-DashScope-Async", "enable") + } + if a.meta.Config.Plugin != "" { + req.Header.Set("X-DashScope-Plugin", a.meta.Config.Plugin) + } + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case relaymode.Embeddings: + aliEmbeddingRequest := ConvertEmbeddingRequest(*request) + return aliEmbeddingRequest, nil + default: + aliRequest := ConvertRequest(*request) + return aliRequest, nil + } +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + aliRequest := ConvertImageRequest(*request) + return aliRequest, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + case relaymode.ImagesGenerations: + err, usage = ImageHandler(c, resp) + default: + err, usage = Handler(c, resp) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "ali" +} diff --git a/relay/channel/ali/constants.go b/relay/adaptor/ali/constants.go similarity index 65% rename from relay/channel/ali/constants.go rename to relay/adaptor/ali/constants.go index 16bcfca4..3f24ce2e 100644 --- a/relay/channel/ali/constants.go +++ b/relay/adaptor/ali/constants.go @@ -3,4 +3,5 @@ package ali var ModelList = []string{ "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", "text-embedding-v1", + "ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", } diff --git a/relay/adaptor/ali/image.go b/relay/adaptor/ali/image.go new file mode 100644 index 00000000..8261803d --- /dev/null +++ b/relay/adaptor/ali/image.go @@ -0,0 +1,192 @@ +package ali + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" + "time" +) + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + responseFormat := c.GetString("response_format") + + var aliTaskResponse TaskResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &aliTaskResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + if aliTaskResponse.Message != "" { + logger.SysError("aliAsyncTask err: " + string(responseBody)) + return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil + } + + aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey) + if err != nil { + return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Output.TaskStatus != "SUCCEEDED" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: aliResponse.Output.Message, + Type: "ali_error", + Param: "", + Code: aliResponse.Output.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, nil +} + +func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) { + url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID) + + var aliResponse TaskResponse + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return &aliResponse, err, nil + } + + req.Header.Set("Authorization", "Bearer "+key) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + logger.SysError("aliAsyncTask client.Do err: " + err.Error()) + return &aliResponse, err, nil + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + + var response TaskResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + logger.SysError("aliAsyncTask NewDecoder err: " + err.Error()) + return &aliResponse, err, nil + } + + return &response, nil, responseBody +} + +func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) { + waitSeconds := 2 + step := 0 + maxStep := 20 + + var taskResponse TaskResponse + var responseBody []byte + + for { + step++ + rsp, err, body := asyncTask(taskID, key) + responseBody = body + if err != nil { + return &taskResponse, responseBody, err + } + + if rsp.Output.TaskStatus == "" { + return &taskResponse, responseBody, nil + } + + switch rsp.Output.TaskStatus { + case "FAILED": + fallthrough + case "CANCELED": + fallthrough + case "SUCCEEDED": + fallthrough + case "UNKNOWN": + return rsp, responseBody, nil + } + if step >= maxStep { + break + } + time.Sleep(time.Duration(waitSeconds) * time.Second) + } + + return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout") +} + +func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse { + imageResponse := openai.ImageResponse{ + Created: helper.GetTimestamp(), + } + + for _, data := range response.Output.Results { + var b64Json string + if responseFormat == "b64_json" { + // 读取 data.Url 的图片数据并转存到 b64Json + imageData, err := getImageData(data.Url) + if err != nil { + // 处理获取图片数据失败的情况 + logger.SysError("getImageData Error getting image data: " + err.Error()) + continue + } + + // 将图片数据转为 Base64 编码的字符串 + b64Json = Base64Encode(imageData) + } else { + // 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image + b64Json = data.B64Image + } + + imageResponse.Data = append(imageResponse.Data, openai.ImageData{ + Url: data.Url, + B64Json: b64Json, + RevisedPrompt: "", + }) + } + return &imageResponse +} + +func getImageData(url string) ([]byte, error) { + response, err := http.Get(url) + if err != nil { + return nil, err + } + defer response.Body.Close() + + imageData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + return imageData, nil +} + +func Base64Encode(data []byte) string { + b64Json := base64.StdEncoding.EncodeToString(data) + return b64Json +} diff --git a/relay/channel/ali/main.go b/relay/adaptor/ali/main.go similarity index 89% rename from relay/channel/ali/main.go rename to relay/adaptor/ali/main.go index 62115d58..0462c26b 100644 --- a/relay/channel/ali/main.go +++ b/relay/adaptor/ali/main.go @@ -7,7 +7,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" "io" "net/http" @@ -48,6 +48,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { MaxTokens: request.MaxTokens, Temperature: request.Temperature, TopP: request.TopP, + TopK: request.TopK, + ResultFormat: "message", + Tools: request.Tools, }, } } @@ -63,6 +66,17 @@ func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingReque } } +func ConvertImageRequest(request model.ImageRequest) *ImageRequest { + var imageRequest ImageRequest + imageRequest.Input.Prompt = request.Prompt + imageRequest.Model = request.Model + imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) + imageRequest.Parameters.N = request.N + imageRequest.ResponseFormat = request.ResponseFormat + + return &imageRequest +} + func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var aliResponse EmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&aliResponse) @@ -117,19 +131,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR } func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { - choice := openai.TextResponseChoice{ - Index: 0, - Message: model.Message{ - Role: "assistant", - Content: response.Output.Text, - }, - FinishReason: response.Output.FinishReason, - } fullTextResponse := openai.TextResponse{ Id: response.RequestId, Object: "chat.completion", Created: helper.GetTimestamp(), - Choices: []openai.TextResponseChoice{choice}, + Choices: response.Output.Choices, Usage: model.Usage{ PromptTokens: response.Usage.InputTokens, CompletionTokens: response.Usage.OutputTokens, @@ -140,10 +146,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { } func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + if len(aliResponse.Output.Choices) == 0 { + return nil + } + aliChoice := aliResponse.Output.Choices[0] var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = aliResponse.Output.Text - if aliResponse.Output.FinishReason != "null" { - finishReason := aliResponse.Output.FinishReason + choice.Delta = aliChoice.Message + if aliChoice.FinishReason != "null" { + finishReason := aliChoice.FinishReason choice.FinishReason = &finishReason } response := openai.ChatCompletionsStreamResponse{ @@ -204,6 +214,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens } response := streamResponseAli2OpenAI(&aliResponse) + if response == nil { + return true + } //response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) //lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) @@ -226,6 +239,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + ctx := c.Request.Context() var aliResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -235,6 +249,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + logger.Debugf(ctx, "response body: %s\n", responseBody) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/adaptor/ali/model.go b/relay/adaptor/ali/model.go new file mode 100644 index 00000000..450b5f52 --- /dev/null +++ b/relay/adaptor/ali/model.go @@ -0,0 +1,154 @@ +package ali + +import ( + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +type Message struct { + Content string `json:"content"` + Role string `json:"role"` +} + +type Input struct { + //Prompt string `json:"prompt"` + Messages []Message `json:"messages"` +} + +type Parameters struct { + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` + IncrementalOutput bool `json:"incremental_output,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + ResultFormat string `json:"result_format,omitempty"` + Tools []model.Tool `json:"tools,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model"` + Input Input `json:"input"` + Parameters Parameters `json:"parameters,omitempty"` +} + +type ImageRequest struct { + Model string `json:"model"` + Input struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + } `json:"input"` + Parameters struct { + Size string `json:"size,omitempty"` + N int `json:"n,omitempty"` + Steps string `json:"steps,omitempty"` + Scale string `json:"scale,omitempty"` + } `json:"parameters,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` +} + +type TaskResponse struct { + StatusCode int `json:"status_code,omitempty"` + RequestId string `json:"request_id,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Output struct { + TaskId string `json:"task_id,omitempty"` + TaskStatus string `json:"task_status,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Results []struct { + B64Image string `json:"b64_image,omitempty"` + Url string `json:"url,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + } `json:"results,omitempty"` + TaskMetrics struct { + Total int `json:"TOTAL,omitempty"` + Succeeded int `json:"SUCCEEDED,omitempty"` + Failed int `json:"FAILED,omitempty"` + } `json:"task_metrics,omitempty"` + } `json:"output,omitempty"` + Usage Usage `json:"usage"` +} + +type Header struct { + Action string `json:"action,omitempty"` + Streaming string `json:"streaming,omitempty"` + TaskID string `json:"task_id,omitempty"` + Event string `json:"event,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` + Attributes any `json:"attributes,omitempty"` +} + +type Payload struct { + Model string `json:"model,omitempty"` + Task string `json:"task,omitempty"` + TaskGroup string `json:"task_group,omitempty"` + Function string `json:"function,omitempty"` + Parameters struct { + SampleRate int `json:"sample_rate,omitempty"` + Rate float64 `json:"rate,omitempty"` + Format string `json:"format,omitempty"` + } `json:"parameters,omitempty"` + Input struct { + Text string `json:"text,omitempty"` + } `json:"input,omitempty"` + Usage struct { + Characters int `json:"characters,omitempty"` + } `json:"usage,omitempty"` +} + +type WSSMessage struct { + Header Header `json:"header,omitempty"` + Payload Payload `json:"payload,omitempty"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type Embedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type EmbeddingResponse struct { + Output struct { + Embeddings []Embedding `json:"embeddings"` + } `json:"output"` + Usage Usage `json:"usage"` + Error +} + +type Error struct { + Code string `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Output struct { + //Text string `json:"text"` + //FinishReason string `json:"finish_reason"` + Choices []openai.TextResponseChoice `json:"choices"` +} + +type ChatResponse struct { + Output Output `json:"output"` + Usage Usage `json:"usage"` + Error +} diff --git a/relay/channel/anthropic/adaptor.go b/relay/adaptor/anthropic/adaptor.go similarity index 62% rename from relay/channel/anthropic/adaptor.go rename to relay/adaptor/anthropic/adaptor.go index a165b35c..b1136e84 100644 --- a/relay/channel/anthropic/adaptor.go +++ b/relay/adaptor/anthropic/adaptor.go @@ -4,9 +4,9 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) @@ -14,16 +14,16 @@ import ( type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) req.Header.Set("x-api-key", meta.APIKey) anthropicVersion := c.Request.Header.Get("anthropic-version") if anthropicVersion == "" { @@ -41,11 +41,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = StreamHandler(c, resp) } else { diff --git a/relay/channel/anthropic/constants.go b/relay/adaptor/anthropic/constants.go similarity index 100% rename from relay/channel/anthropic/constants.go rename to relay/adaptor/anthropic/constants.go diff --git a/relay/channel/anthropic/main.go b/relay/adaptor/anthropic/main.go similarity index 94% rename from relay/channel/anthropic/main.go rename to relay/adaptor/anthropic/main.go index 3eeb0b2c..a8de185c 100644 --- a/relay/channel/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -4,16 +4,17 @@ import ( "bufio" "encoding/json" "fmt" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" ) func stopReasonClaude2OpenAI(reason *string) string { @@ -38,6 +39,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { MaxTokens: textRequest.MaxTokens, Temperature: textRequest.Temperature, TopP: textRequest.TopP, + TopK: textRequest.TopK, Stream: textRequest.Stream, } if claudeRequest.MaxTokens == 0 { @@ -90,7 +92,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { } // https://docs.anthropic.com/claude/reference/messages-streaming -func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { +func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { var response *Response var responseText string var stopReason string @@ -128,7 +130,7 @@ func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo return &openaiResponse, response } -func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { +func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { var responseText string if len(claudeResponse.Content) > 0 { responseText = claudeResponse.Content[0].Text @@ -175,10 +177,10 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC if len(data) < 6 { continue } - if !strings.HasPrefix(data, "data: ") { + if !strings.HasPrefix(data, "data:") { continue } - data = strings.TrimPrefix(data, "data: ") + data = strings.TrimPrefix(data, "data:") dataChan <- data } stopChan <- true @@ -191,14 +193,14 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC select { case data := <-dataChan: // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") + data = strings.TrimSpace(data) var claudeResponse StreamResponse err := json.Unmarshal([]byte(data), &claudeResponse) if err != nil { logger.SysError("error unmarshalling stream response: " + err.Error()) return true } - response, meta := streamResponseClaude2OpenAI(&claudeResponse) + response, meta := StreamResponseClaude2OpenAI(&claudeResponse) if meta != nil { usage.PromptTokens += meta.Usage.InputTokens usage.CompletionTokens += meta.Usage.OutputTokens @@ -253,7 +255,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st StatusCode: resp.StatusCode, }, nil } - fullTextResponse := responseClaude2OpenAI(&claudeResponse) + fullTextResponse := ResponseClaude2OpenAI(&claudeResponse) fullTextResponse.Model = modelName usage := model.Usage{ PromptTokens: claudeResponse.Usage.InputTokens, diff --git a/relay/channel/anthropic/model.go b/relay/adaptor/anthropic/model.go similarity index 100% rename from relay/channel/anthropic/model.go rename to relay/adaptor/anthropic/model.go diff --git a/relay/adaptor/aws/adapter.go b/relay/adaptor/aws/adapter.go new file mode 100644 index 00000000..7245d3d9 --- /dev/null +++ b/relay/adaptor/aws/adapter.go @@ -0,0 +1,82 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/songquanpeng/one-api/common/ctxkey" + "io" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var _ adaptor.Adaptor = new(Adaptor) + +type Adaptor struct { + meta *meta.Meta + awsClient *bedrockruntime.Client +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta + a.awsClient = bedrockruntime.New(bedrockruntime.Options{ + Region: meta.Config.Region, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), + }) +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return "", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + claudeReq := anthropic.ConvertRequest(*request) + c.Set(ctxkey.RequestModel, request.Model) + c.Set(ctxkey.ConvertedRequest, claudeReq) + return claudeReq, nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, a.awsClient) + } else { + err, usage = Handler(c, a.awsClient, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() (models []string) { + for n := range awsModelIDMap { + models = append(models, n) + } + return +} + +func (a *Adaptor) GetChannelName() string { + return "aws" +} diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go new file mode 100644 index 00000000..0776f985 --- /dev/null +++ b/relay/adaptor/aws/main.go @@ -0,0 +1,191 @@ +// Package aws provides the AWS adaptor for the relay service. +package aws + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/songquanpeng/one-api/common/ctxkey" + "io" + "net/http" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/gin-gonic/gin" + "github.com/jinzhu/copier" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +func wrapErr(err error) *relaymodel.ErrorWithStatusCode { + return &relaymodel.ErrorWithStatusCode{ + StatusCode: http.StatusInternalServerError, + Error: relaymodel.Error{ + Message: fmt.Sprintf("%s", err.Error()), + }, + } +} + +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html +var awsModelIDMap = map[string]string{ + "claude-instant-1.2": "anthropic.claude-instant-v1", + "claude-2.0": "anthropic.claude-v2", + "claude-2.1": "anthropic.claude-v2:1", + "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", + "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", +} + +func awsModelID(requestModel string) (string, error) { + if awsModelID, ok := awsModelIDMap[requestModel]; ok { + return awsModelID, nil + } + + return "", errors.Errorf("model %s not found", requestModel) +} + +func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + if err != nil { + return wrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) + if !ok { + return wrapErr(errors.New("request not found")), nil + } + claudeReq := claudeReq_.(*anthropic.Request) + awsClaudeReq := &Request{ + AnthropicVersion: "bedrock-2023-05-31", + } + if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { + return wrapErr(errors.Wrap(err, "copy request")), nil + } + + awsReq.Body, err = json.Marshal(awsClaudeReq) + if err != nil { + return wrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) + if err != nil { + return wrapErr(errors.Wrap(err, "InvokeModel")), nil + } + + claudeResponse := new(anthropic.Response) + err = json.Unmarshal(awsResp.Body, claudeResponse) + if err != nil { + return wrapErr(errors.Wrap(err, "unmarshal response")), nil + } + + openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) + openaiResp.Model = modelName + usage := relaymodel.Usage{ + PromptTokens: claudeResponse.Usage.InputTokens, + CompletionTokens: claudeResponse.Usage.OutputTokens, + TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, + } + openaiResp.Usage = usage + + c.JSON(http.StatusOK, openaiResp) + return nil, &usage +} + +func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + createdTime := helper.GetTimestamp() + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + if err != nil { + return wrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) + if !ok { + return wrapErr(errors.New("request not found")), nil + } + claudeReq := claudeReq_.(*anthropic.Request) + + awsClaudeReq := &Request{ + AnthropicVersion: "bedrock-2023-05-31", + } + if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { + return wrapErr(errors.Wrap(err, "copy request")), nil + } + awsReq.Body, err = json.Marshal(awsClaudeReq) + if err != nil { + return wrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) + if err != nil { + return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil + } + stream := awsResp.GetStream() + defer stream.Close() + + c.Writer.Header().Set("Content-Type", "text/event-stream") + var usage relaymodel.Usage + var id string + c.Stream(func(w io.Writer) bool { + event, ok := <-stream.Events() + if !ok { + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + + switch v := event.(type) { + case *types.ResponseStreamMemberChunk: + claudeResp := new(anthropic.StreamResponse) + err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return false + } + + response, meta := anthropic.StreamResponseClaude2OpenAI(claudeResp) + if meta != nil { + usage.PromptTokens += meta.Usage.InputTokens + usage.CompletionTokens += meta.Usage.OutputTokens + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + return true + } + if response == nil { + return true + } + response.Id = id + response.Model = c.GetString(ctxkey.OriginalModel) + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case *types.UnknownUnionMember: + fmt.Println("unknown tag:", v.Tag) + return false + default: + fmt.Println("union is nil or unknown type") + return false + } + }) + + return nil, &usage +} diff --git a/relay/adaptor/aws/model.go b/relay/adaptor/aws/model.go new file mode 100644 index 00000000..bcbfb584 --- /dev/null +++ b/relay/adaptor/aws/model.go @@ -0,0 +1,17 @@ +package aws + +import "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + +// Request is the request to AWS Claude +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +type Request struct { + // AnthropicVersion should be "bedrock-2023-05-31" + AnthropicVersion string `json:"anthropic_version"` + Messages []anthropic.Message `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` +} diff --git a/relay/channel/baichuan/constants.go b/relay/adaptor/baichuan/constants.go similarity index 100% rename from relay/channel/baichuan/constants.go rename to relay/adaptor/baichuan/constants.go diff --git a/relay/channel/baidu/adaptor.go b/relay/adaptor/baidu/adaptor.go similarity index 63% rename from relay/channel/baidu/adaptor.go rename to relay/adaptor/baidu/adaptor.go index 2d2e24f6..15306b95 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/adaptor/baidu/adaptor.go @@ -3,25 +3,25 @@ package baidu import ( "errors" "fmt" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "strings" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" ) type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t suffix := "chat/" if strings.HasPrefix(meta.ActualModelName, "Embedding") { @@ -38,16 +38,34 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { suffix += "completions_pro" case "ERNIE-Bot-4": suffix += "completions_pro" - case "ERNIE-3.5-8K": - suffix += "completions" - case "ERNIE-Bot-8K": - suffix += "ernie_bot_8k" case "ERNIE-Bot": suffix += "completions" - case "ERNIE-Speed": - suffix += "ernie_speed" case "ERNIE-Bot-turbo": suffix += "eb-instant" + case "ERNIE-Speed": + suffix += "ernie_speed" + case "ERNIE-4.0-8K": + suffix += "completions_pro" + case "ERNIE-3.5-8K": + suffix += "completions" + case "ERNIE-3.5-8K-0205": + suffix += "ernie-3.5-8k-0205" + case "ERNIE-3.5-8K-1222": + suffix += "ernie-3.5-8k-1222" + case "ERNIE-Bot-8K": + suffix += "ernie_bot_8k" + case "ERNIE-3.5-4K-0205": + suffix += "ernie-3.5-4k-0205" + case "ERNIE-Speed-8K": + suffix += "ernie_speed" + case "ERNIE-Speed-128K": + suffix += "ernie-speed-128k" + case "ERNIE-Lite-8K-0922": + suffix += "eb-instant" + case "ERNIE-Lite-8K-0308": + suffix += "ernie-lite-8k" + case "ERNIE-Tiny-8K": + suffix += "ernie-tiny-8k" case "BLOOMZ-7B": suffix += "bloomz_7b1" case "Embedding-V1": @@ -59,7 +77,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { case "tao-8k": suffix += "tao_8k" default: - suffix += meta.ActualModelName + suffix += strings.ToLower(meta.ActualModelName) } fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) var accessToken string @@ -71,8 +89,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { return fullRequestURL, nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) req.Header.Set("Authorization", "Bearer "+meta.APIKey) return nil } @@ -82,7 +100,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } switch relayMode { - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) return baiduEmbeddingRequest, nil default: @@ -91,16 +109,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = StreamHandler(c, resp) } else { switch meta.Mode { - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: err, usage = EmbeddingHandler(c, resp) default: err, usage = Handler(c, resp) diff --git a/relay/adaptor/baidu/constants.go b/relay/adaptor/baidu/constants.go new file mode 100644 index 00000000..f952adc6 --- /dev/null +++ b/relay/adaptor/baidu/constants.go @@ -0,0 +1,20 @@ +package baidu + +var ModelList = []string{ + "ERNIE-4.0-8K", + "ERNIE-3.5-8K", + "ERNIE-3.5-8K-0205", + "ERNIE-3.5-8K-1222", + "ERNIE-Bot-8K", + "ERNIE-3.5-4K-0205", + "ERNIE-Speed-8K", + "ERNIE-Speed-128K", + "ERNIE-Lite-8K-0922", + "ERNIE-Lite-8K-0308", + "ERNIE-Tiny-8K", + "BLOOMZ-7B", + "Embedding-V1", + "bge-large-zh", + "bge-large-en", + "tao-8k", +} diff --git a/relay/channel/baidu/main.go b/relay/adaptor/baidu/main.go similarity index 98% rename from relay/channel/baidu/main.go rename to relay/adaptor/baidu/main.go index 9ca9e47d..6df5ce84 100644 --- a/relay/channel/baidu/main.go +++ b/relay/adaptor/baidu/main.go @@ -8,10 +8,10 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/client" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "strings" @@ -305,7 +305,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) { } req.Header.Add("Content-Type", "application/json") req.Header.Add("Accept", "application/json") - res, err := util.ImpatientHTTPClient.Do(req) + res, err := client.ImpatientHTTPClient.Do(req) if err != nil { return nil, err } diff --git a/relay/channel/baidu/model.go b/relay/adaptor/baidu/model.go similarity index 100% rename from relay/channel/baidu/model.go rename to relay/adaptor/baidu/model.go diff --git a/relay/adaptor/cloudflare/adaptor.go b/relay/adaptor/cloudflare/adaptor.go new file mode 100644 index 00000000..6ff6b0d3 --- /dev/null +++ b/relay/adaptor/cloudflare/adaptor.go @@ -0,0 +1,66 @@ +package cloudflare + +import ( + "errors" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +type Adaptor struct { + meta *meta.Meta +} + +// ConvertImageRequest implements adaptor.Adaptor. +func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertImageRequest implements adaptor.Adaptor. + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return ConvertRequest(*request), nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp, meta.PromptTokens, meta.ActualModelName) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "cloudflare" +} diff --git a/relay/adaptor/cloudflare/constant.go b/relay/adaptor/cloudflare/constant.go new file mode 100644 index 00000000..dee79a76 --- /dev/null +++ b/relay/adaptor/cloudflare/constant.go @@ -0,0 +1,36 @@ +package cloudflare + +var ModelList = []string{ + "@cf/meta/llama-2-7b-chat-fp16", + "@cf/meta/llama-2-7b-chat-int8", + "@cf/mistral/mistral-7b-instruct-v0.1", + "@hf/thebloke/deepseek-coder-6.7b-base-awq", + "@hf/thebloke/deepseek-coder-6.7b-instruct-awq", + "@cf/deepseek-ai/deepseek-math-7b-base", + "@cf/deepseek-ai/deepseek-math-7b-instruct", + "@cf/thebloke/discolm-german-7b-v1-awq", + "@cf/tiiuae/falcon-7b-instruct", + "@cf/google/gemma-2b-it-lora", + "@hf/google/gemma-7b-it", + "@cf/google/gemma-7b-it-lora", + "@hf/nousresearch/hermes-2-pro-mistral-7b", + "@hf/thebloke/llama-2-13b-chat-awq", + "@cf/meta-llama/llama-2-7b-chat-hf-lora", + "@cf/meta/llama-3-8b-instruct", + "@hf/thebloke/llamaguard-7b-awq", + "@hf/thebloke/mistral-7b-instruct-v0.1-awq", + "@hf/mistralai/mistral-7b-instruct-v0.2", + "@cf/mistral/mistral-7b-instruct-v0.2-lora", + "@hf/thebloke/neural-chat-7b-v3-1-awq", + "@cf/openchat/openchat-3.5-0106", + "@hf/thebloke/openhermes-2.5-mistral-7b-awq", + "@cf/microsoft/phi-2", + "@cf/qwen/qwen1.5-0.5b-chat", + "@cf/qwen/qwen1.5-1.8b-chat", + "@cf/qwen/qwen1.5-14b-chat-awq", + "@cf/qwen/qwen1.5-7b-chat-awq", + "@cf/defog/sqlcoder-7b-2", + "@hf/nexusflow/starling-lm-7b-beta", + "@cf/tinyllama/tinyllama-1.1b-chat-v1.0", + "@hf/thebloke/zephyr-7b-beta-awq", +} diff --git a/relay/adaptor/cloudflare/main.go b/relay/adaptor/cloudflare/main.go new file mode 100644 index 00000000..e85bbc25 --- /dev/null +++ b/relay/adaptor/cloudflare/main.go @@ -0,0 +1,152 @@ +package cloudflare + +import ( + "bufio" + "bytes" + "encoding/json" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { + lastMessage := textRequest.Messages[len(textRequest.Messages)-1] + return &Request{ + MaxTokens: textRequest.MaxTokens, + Prompt: lastMessage.StringContent(), + Stream: textRequest.Stream, + Temperature: textRequest.Temperature, + } +} + +func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse { + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: cloudflareResponse.Result.Response, + }, + FinishReason: "stop", + } + fullTextResponse := openai.TextResponse{ + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = cloudflareResponse.Response + choice.Delta.Role = "assistant" + openaiResponse := openai.ChatCompletionsStreamResponse{ + Object: "chat.completion.chunk", + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + Created: helper.GetTimestamp(), + } + return &openaiResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, '\n'); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < len("data: ") { + continue + } + data = strings.TrimPrefix(data, "data: ") + dataChan <- data + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + id := helper.GetResponseID(c) + responseModel := c.GetString("original_model") + var responseText string + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + var cloudflareResponse StreamResponse + err := json.Unmarshal([]byte(data), &cloudflareResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) + if response == nil { + return true + } + responseText += cloudflareResponse.Response + response.Id = id + response.Model = responseModel + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + _ = resp.Body.Close() + usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens) + return nil, usage +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var cloudflareResponse Response + err = json.Unmarshal(responseBody, &cloudflareResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse) + fullTextResponse.Model = modelName + usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens) + fullTextResponse.Usage = *usage + fullTextResponse.Id = helper.GetResponseID(c) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, usage +} diff --git a/relay/adaptor/cloudflare/model.go b/relay/adaptor/cloudflare/model.go new file mode 100644 index 00000000..0664ecd1 --- /dev/null +++ b/relay/adaptor/cloudflare/model.go @@ -0,0 +1,25 @@ +package cloudflare + +type Request struct { + Lora string `json:"lora,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Prompt string `json:"prompt,omitempty"` + Raw bool `json:"raw,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} + +type Result struct { + Response string `json:"response"` +} + +type Response struct { + Result Result `json:"result"` + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` +} + +type StreamResponse struct { + Response string `json:"response"` +} diff --git a/relay/adaptor/cohere/adaptor.go b/relay/adaptor/cohere/adaptor.go new file mode 100644 index 00000000..6fdb1b04 --- /dev/null +++ b/relay/adaptor/cohere/adaptor.go @@ -0,0 +1,64 @@ +package cohere + +import ( + "errors" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +type Adaptor struct{} + +// ConvertImageRequest implements adaptor.Adaptor. +func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertImageRequest implements adaptor.Adaptor. + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("%s/v1/chat", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return ConvertRequest(*request), nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "Cohere" +} diff --git a/relay/adaptor/cohere/constant.go b/relay/adaptor/cohere/constant.go new file mode 100644 index 00000000..9e70652c --- /dev/null +++ b/relay/adaptor/cohere/constant.go @@ -0,0 +1,14 @@ +package cohere + +var ModelList = []string{ + "command", "command-nightly", + "command-light", "command-light-nightly", + "command-r", "command-r-plus", +} + +func init() { + num := len(ModelList) + for i := 0; i < num; i++ { + ModelList = append(ModelList, ModelList[i]+"-internet") + } +} diff --git a/relay/adaptor/cohere/main.go b/relay/adaptor/cohere/main.go new file mode 100644 index 00000000..4bc3fa8d --- /dev/null +++ b/relay/adaptor/cohere/main.go @@ -0,0 +1,241 @@ +package cohere + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +var ( + WebSearchConnector = Connector{ID: "web-search"} +) + +func stopReasonCohere2OpenAI(reason *string) string { + if reason == nil { + return "" + } + switch *reason { + case "COMPLETE": + return "stop" + default: + return *reason + } +} + +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { + cohereRequest := Request{ + Model: textRequest.Model, + Message: "", + MaxTokens: textRequest.MaxTokens, + Temperature: textRequest.Temperature, + P: textRequest.TopP, + K: textRequest.TopK, + Stream: textRequest.Stream, + FrequencyPenalty: textRequest.FrequencyPenalty, + PresencePenalty: textRequest.FrequencyPenalty, + Seed: int(textRequest.Seed), + } + if cohereRequest.Model == "" { + cohereRequest.Model = "command-r" + } + if strings.HasSuffix(cohereRequest.Model, "-internet") { + cohereRequest.Model = strings.TrimSuffix(cohereRequest.Model, "-internet") + cohereRequest.Connectors = append(cohereRequest.Connectors, WebSearchConnector) + } + for _, message := range textRequest.Messages { + if message.Role == "user" { + cohereRequest.Message = message.Content.(string) + } else { + var role string + if message.Role == "assistant" { + role = "CHATBOT" + } else if message.Role == "system" { + role = "SYSTEM" + } else { + role = "USER" + } + cohereRequest.ChatHistory = append(cohereRequest.ChatHistory, ChatMessage{ + Role: role, + Message: message.Content.(string), + }) + } + } + return &cohereRequest +} + +func StreamResponseCohere2OpenAI(cohereResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { + var response *Response + var responseText string + var finishReason string + + switch cohereResponse.EventType { + case "stream-start": + return nil, nil + case "text-generation": + responseText += cohereResponse.Text + case "stream-end": + usage := cohereResponse.Response.Meta.Tokens + response = &Response{ + Meta: Meta{ + Tokens: Usage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + }, + }, + } + finishReason = *cohereResponse.Response.FinishReason + default: + return nil, nil + } + + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = responseText + choice.Delta.Role = "assistant" + if finishReason != "" { + choice.FinishReason = &finishReason + } + var openaiResponse openai.ChatCompletionsStreamResponse + openaiResponse.Object = "chat.completion.chunk" + openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + return &openaiResponse, response +} + +func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse { + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: cohereResponse.Text, + Name: nil, + }, + FinishReason: stopReasonCohere2OpenAI(cohereResponse.FinishReason), + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", cohereResponse.ResponseID), + Model: "model", + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + createdTime := helper.GetTimestamp() + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, '\n'); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + dataChan <- data + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + var usage model.Usage + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + var cohereResponse StreamResponse + err := json.Unmarshal([]byte(data), &cohereResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + response, meta := StreamResponseCohere2OpenAI(&cohereResponse) + if meta != nil { + usage.PromptTokens += meta.Meta.Tokens.InputTokens + usage.CompletionTokens += meta.Meta.Tokens.OutputTokens + return true + } + if response == nil { + return true + } + response.Id = fmt.Sprintf("chatcmpl-%d", createdTime) + response.Model = c.GetString("original_model") + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + _ = resp.Body.Close() + return nil, &usage +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var cohereResponse Response + err = json.Unmarshal(responseBody, &cohereResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if cohereResponse.ResponseID == "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: cohereResponse.Message, + Type: cohereResponse.Message, + Param: "", + Code: resp.StatusCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := ResponseCohere2OpenAI(&cohereResponse) + fullTextResponse.Model = modelName + usage := model.Usage{ + PromptTokens: cohereResponse.Meta.Tokens.InputTokens, + CompletionTokens: cohereResponse.Meta.Tokens.OutputTokens, + TotalTokens: cohereResponse.Meta.Tokens.InputTokens + cohereResponse.Meta.Tokens.OutputTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/relay/adaptor/cohere/model.go b/relay/adaptor/cohere/model.go new file mode 100644 index 00000000..64fa9c94 --- /dev/null +++ b/relay/adaptor/cohere/model.go @@ -0,0 +1,147 @@ +package cohere + +type Request struct { + Message string `json:"message" required:"true"` + Model string `json:"model,omitempty"` // 默认值为"command-r" + Stream bool `json:"stream,omitempty"` // 默认值为false + Preamble string `json:"preamble,omitempty"` + ChatHistory []ChatMessage `json:"chat_history,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO" + Connectors []Connector `json:"connectors,omitempty"` + Documents []Document `json:"documents,omitempty"` + Temperature float64 `json:"temperature,omitempty"` // 默认值为0.3 + MaxTokens int `json:"max_tokens,omitempty"` + MaxInputTokens int `json:"max_input_tokens,omitempty"` + K int `json:"k,omitempty"` // 默认值为0 + P float64 `json:"p,omitempty"` // 默认值为0.75 + Seed int `json:"seed,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0 + PresencePenalty float64 `json:"presence_penalty,omitempty"` // 默认值为0.0 + Tools []Tool `json:"tools,omitempty"` + ToolResults []ToolResult `json:"tool_results,omitempty"` +} + +type ChatMessage struct { + Role string `json:"role" required:"true"` + Message string `json:"message" required:"true"` +} + +type Tool struct { + Name string `json:"name" required:"true"` + Description string `json:"description" required:"true"` + ParameterDefinitions map[string]ParameterSpec `json:"parameter_definitions"` +} + +type ParameterSpec struct { + Description string `json:"description"` + Type string `json:"type" required:"true"` + Required bool `json:"required"` +} + +type ToolResult struct { + Call ToolCall `json:"call"` + Outputs []map[string]interface{} `json:"outputs"` +} + +type ToolCall struct { + Name string `json:"name" required:"true"` + Parameters map[string]interface{} `json:"parameters" required:"true"` +} + +type StreamResponse struct { + IsFinished bool `json:"is_finished"` + EventType string `json:"event_type"` + GenerationID string `json:"generation_id,omitempty"` + SearchQueries []*SearchQuery `json:"search_queries,omitempty"` + SearchResults []*SearchResult `json:"search_results,omitempty"` + Documents []*Document `json:"documents,omitempty"` + Text string `json:"text,omitempty"` + Citations []*Citation `json:"citations,omitempty"` + Response *Response `json:"response,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` +} + +type SearchQuery struct { + Text string `json:"text"` + GenerationID string `json:"generation_id"` +} + +type SearchResult struct { + SearchQuery *SearchQuery `json:"search_query"` + DocumentIDs []string `json:"document_ids"` + Connector *Connector `json:"connector"` +} + +type Connector struct { + ID string `json:"id"` +} + +type Document struct { + ID string `json:"id"` + Snippet string `json:"snippet"` + Timestamp string `json:"timestamp"` + Title string `json:"title"` + URL string `json:"url"` +} + +type Citation struct { + Start int `json:"start"` + End int `json:"end"` + Text string `json:"text"` + DocumentIDs []string `json:"document_ids"` +} + +type Response struct { + ResponseID string `json:"response_id"` + Text string `json:"text"` + GenerationID string `json:"generation_id"` + ChatHistory []*Message `json:"chat_history"` + FinishReason *string `json:"finish_reason"` + Meta Meta `json:"meta"` + Citations []*Citation `json:"citations"` + Documents []*Document `json:"documents"` + SearchResults []*SearchResult `json:"search_results"` + SearchQueries []*SearchQuery `json:"search_queries"` + Message string `json:"message"` +} + +type Message struct { + Role string `json:"role"` + Message string `json:"message"` +} + +type Version struct { + Version string `json:"version"` +} + +type Units struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type ChatEntry struct { + Role string `json:"role"` + Message string `json:"message"` +} + +type Meta struct { + APIVersion APIVersion `json:"api_version"` + BilledUnits BilledUnits `json:"billed_units"` + Tokens Usage `json:"tokens"` +} + +type APIVersion struct { + Version string `json:"version"` +} + +type BilledUnits struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} diff --git a/relay/channel/common.go b/relay/adaptor/common.go similarity index 80% rename from relay/channel/common.go rename to relay/adaptor/common.go index c6e1abf2..82a5160e 100644 --- a/relay/channel/common.go +++ b/relay/adaptor/common.go @@ -1,15 +1,16 @@ -package channel +package adaptor import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/client" + "github.com/songquanpeng/one-api/relay/meta" "io" "net/http" ) -func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) { +func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) { req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) if meta.IsStream && c.Request.Header.Get("Accept") == "" { @@ -17,7 +18,7 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.Rela } } -func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { fullRequestURL, err := a.GetRequestURL(meta) if err != nil { return nil, fmt.Errorf("get request url failed: %w", err) @@ -38,7 +39,7 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBod } func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) { - resp, err := util.HTTPClient.Do(req) + resp, err := client.HTTPClient.Do(req) if err != nil { return nil, err } diff --git a/relay/adaptor/coze/adaptor.go b/relay/adaptor/coze/adaptor.go new file mode 100644 index 00000000..44f560e8 --- /dev/null +++ b/relay/adaptor/coze/adaptor.go @@ -0,0 +1,75 @@ +package coze + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +type Adaptor struct { + meta *meta.Meta +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("%s/open_api/v2/chat", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + request.User = a.meta.Config.UserID + return ConvertRequest(*request), nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + var responseText *string + if meta.IsStream { + err, responseText = StreamHandler(c, resp) + } else { + err, responseText = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + if responseText != nil { + usage = openai.ResponseText2Usage(*responseText, meta.ActualModelName, meta.PromptTokens) + } else { + usage = &model.Usage{} + } + usage.PromptTokens = meta.PromptTokens + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "coze" +} diff --git a/relay/adaptor/coze/constant/contenttype/define.go b/relay/adaptor/coze/constant/contenttype/define.go new file mode 100644 index 00000000..69c876bc --- /dev/null +++ b/relay/adaptor/coze/constant/contenttype/define.go @@ -0,0 +1,5 @@ +package contenttype + +const ( + Text = "text" +) diff --git a/relay/adaptor/coze/constant/event/define.go b/relay/adaptor/coze/constant/event/define.go new file mode 100644 index 00000000..c03e8c17 --- /dev/null +++ b/relay/adaptor/coze/constant/event/define.go @@ -0,0 +1,7 @@ +package event + +const ( + Message = "message" + Done = "done" + Error = "error" +) diff --git a/relay/adaptor/coze/constant/messagetype/define.go b/relay/adaptor/coze/constant/messagetype/define.go new file mode 100644 index 00000000..6c1c25db --- /dev/null +++ b/relay/adaptor/coze/constant/messagetype/define.go @@ -0,0 +1,6 @@ +package messagetype + +const ( + Answer = "answer" + FollowUp = "follow_up" +) diff --git a/relay/adaptor/coze/constants.go b/relay/adaptor/coze/constants.go new file mode 100644 index 00000000..d20fd875 --- /dev/null +++ b/relay/adaptor/coze/constants.go @@ -0,0 +1,3 @@ +package coze + +var ModelList = []string{} diff --git a/relay/adaptor/coze/helper.go b/relay/adaptor/coze/helper.go new file mode 100644 index 00000000..0396afcb --- /dev/null +++ b/relay/adaptor/coze/helper.go @@ -0,0 +1,10 @@ +package coze + +import "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/event" + +func event2StopReason(e *string) string { + if e == nil || *e == event.Message { + return "" + } + return "stop" +} diff --git a/relay/adaptor/coze/main.go b/relay/adaptor/coze/main.go new file mode 100644 index 00000000..721c5d13 --- /dev/null +++ b/relay/adaptor/coze/main.go @@ -0,0 +1,215 @@ +package coze + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" +) + +// https://www.coze.com/open + +func stopReasonCoze2OpenAI(reason *string) string { + if reason == nil { + return "" + } + switch *reason { + case "end_turn": + return "stop" + case "stop_sequence": + return "stop" + case "max_tokens": + return "length" + default: + return *reason + } +} + +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { + cozeRequest := Request{ + Stream: textRequest.Stream, + User: textRequest.User, + BotId: strings.TrimPrefix(textRequest.Model, "bot-"), + } + for i, message := range textRequest.Messages { + if i == len(textRequest.Messages)-1 { + cozeRequest.Query = message.StringContent() + continue + } + cozeMessage := Message{ + Role: message.Role, + Content: message.StringContent(), + } + cozeRequest.ChatHistory = append(cozeRequest.ChatHistory, cozeMessage) + } + return &cozeRequest +} + +func StreamResponseCoze2OpenAI(cozeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { + var response *Response + var stopReason string + var choice openai.ChatCompletionsStreamResponseChoice + + if cozeResponse.Message != nil { + if cozeResponse.Message.Type != messagetype.Answer { + return nil, nil + } + choice.Delta.Content = cozeResponse.Message.Content + } + choice.Delta.Role = "assistant" + finishReason := stopReasonCoze2OpenAI(&stopReason) + if finishReason != "null" { + choice.FinishReason = &finishReason + } + var openaiResponse openai.ChatCompletionsStreamResponse + openaiResponse.Object = "chat.completion.chunk" + openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + openaiResponse.Id = cozeResponse.ConversationId + return &openaiResponse, response +} + +func ResponseCoze2OpenAI(cozeResponse *Response) *openai.TextResponse { + var responseText string + for _, message := range cozeResponse.Messages { + if message.Type == messagetype.Answer { + responseText = message.Content + break + } + } + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: responseText, + Name: nil, + }, + FinishReason: "stop", + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", cozeResponse.ConversationId), + Model: "coze-bot", + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *string) { + var responseText string + createdTime := helper.GetTimestamp() + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 { + continue + } + if !strings.HasPrefix(data, "data:") { + continue + } + data = strings.TrimPrefix(data, "data:") + dataChan <- data + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + var modelName string + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + var cozeResponse StreamResponse + err := json.Unmarshal([]byte(data), &cozeResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + response, _ := StreamResponseCoze2OpenAI(&cozeResponse) + if response == nil { + return true + } + for _, choice := range response.Choices { + responseText += conv.AsString(choice.Delta.Content) + } + response.Model = modelName + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + _ = resp.Body.Close() + return nil, &responseText +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *string) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var cozeResponse Response + err = json.Unmarshal(responseBody, &cozeResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if cozeResponse.Code != 0 { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: cozeResponse.Msg, + Code: cozeResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := ResponseCoze2OpenAI(&cozeResponse) + fullTextResponse.Model = modelName + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + var responseText string + if len(fullTextResponse.Choices) > 0 { + responseText = fullTextResponse.Choices[0].Message.StringContent() + } + return nil, &responseText +} diff --git a/relay/adaptor/coze/model.go b/relay/adaptor/coze/model.go new file mode 100644 index 00000000..d0afecfe --- /dev/null +++ b/relay/adaptor/coze/model.go @@ -0,0 +1,38 @@ +package coze + +type Message struct { + Role string `json:"role"` + Type string `json:"type"` + Content string `json:"content"` + ContentType string `json:"content_type"` +} + +type ErrorInformation struct { + Code int `json:"code"` + Msg string `json:"msg"` +} + +type Request struct { + ConversationId string `json:"conversation_id,omitempty"` + BotId string `json:"bot_id"` + User string `json:"user"` + Query string `json:"query"` + ChatHistory []Message `json:"chat_history,omitempty"` + Stream bool `json:"stream"` +} + +type Response struct { + ConversationId string `json:"conversation_id,omitempty"` + Messages []Message `json:"messages,omitempty"` + Code int `json:"code,omitempty"` + Msg string `json:"msg,omitempty"` +} + +type StreamResponse struct { + Event string `json:"event,omitempty"` + Message *Message `json:"message,omitempty"` + IsFinish bool `json:"is_finish,omitempty"` + Index int `json:"index,omitempty"` + ConversationId string `json:"conversation_id,omitempty"` + ErrorInformation *ErrorInformation `json:"error_information,omitempty"` +} diff --git a/relay/adaptor/deepl/adaptor.go b/relay/adaptor/deepl/adaptor.go new file mode 100644 index 00000000..d018a096 --- /dev/null +++ b/relay/adaptor/deepl/adaptor.go @@ -0,0 +1,73 @@ +package deepl + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +type Adaptor struct { + meta *meta.Meta + promptText string +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("%s/v2/translate", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "DeepL-Auth-Key "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + convertedRequest, text := ConvertRequest(*request) + a.promptText = text + return convertedRequest, nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err = StreamHandler(c, resp, meta.ActualModelName) + } else { + err = Handler(c, resp, meta.ActualModelName) + } + promptTokens := len(a.promptText) + usage = &model.Usage{ + PromptTokens: promptTokens, + TotalTokens: promptTokens, + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "deepl" +} diff --git a/relay/adaptor/deepl/constants.go b/relay/adaptor/deepl/constants.go new file mode 100644 index 00000000..6a4f2545 --- /dev/null +++ b/relay/adaptor/deepl/constants.go @@ -0,0 +1,9 @@ +package deepl + +// https://developers.deepl.com/docs/api-reference/glossaries + +var ModelList = []string{ + "deepl-zh", + "deepl-en", + "deepl-ja", +} diff --git a/relay/adaptor/deepl/helper.go b/relay/adaptor/deepl/helper.go new file mode 100644 index 00000000..6d3a914b --- /dev/null +++ b/relay/adaptor/deepl/helper.go @@ -0,0 +1,11 @@ +package deepl + +import "strings" + +func parseLangFromModelName(modelName string) string { + parts := strings.Split(modelName, "-") + if len(parts) == 1 { + return "ZH" + } + return parts[1] +} diff --git a/relay/adaptor/deepl/main.go b/relay/adaptor/deepl/main.go new file mode 100644 index 00000000..f8bbae14 --- /dev/null +++ b/relay/adaptor/deepl/main.go @@ -0,0 +1,137 @@ +package deepl + +import ( + "encoding/json" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/constant/finishreason" + "github.com/songquanpeng/one-api/relay/constant/role" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +// https://developers.deepl.com/docs/getting-started/your-first-api-request + +func ConvertRequest(textRequest model.GeneralOpenAIRequest) (*Request, string) { + var text string + if len(textRequest.Messages) != 0 { + text = textRequest.Messages[len(textRequest.Messages)-1].StringContent() + } + deeplRequest := Request{ + TargetLang: parseLangFromModelName(textRequest.Model), + Text: []string{text}, + } + return &deeplRequest, text +} + +func StreamResponseDeepL2OpenAI(deeplResponse *Response) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + if len(deeplResponse.Translations) != 0 { + choice.Delta.Content = deeplResponse.Translations[0].Text + } + choice.Delta.Role = role.Assistant + choice.FinishReason = &constant.StopFinishReason + openaiResponse := openai.ChatCompletionsStreamResponse{ + Object: constant.StreamObject, + Created: helper.GetTimestamp(), + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &openaiResponse +} + +func ResponseDeepL2OpenAI(deeplResponse *Response) *openai.TextResponse { + var responseText string + if len(deeplResponse.Translations) != 0 { + responseText = deeplResponse.Translations[0].Text + } + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: role.Assistant, + Content: responseText, + Name: nil, + }, + FinishReason: finishreason.Stop, + } + fullTextResponse := openai.TextResponse{ + Object: constant.NonStreamObject, + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response, modelName string) *model.ErrorWithStatusCode { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + var deeplResponse Response + err = json.Unmarshal(responseBody, &deeplResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + fullTextResponse := StreamResponseDeepL2OpenAI(&deeplResponse) + fullTextResponse.Model = modelName + fullTextResponse.Id = helper.GetResponseID(c) + jsonData, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) + } + common.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + if jsonData != nil { + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)}) + jsonData = nil + return true + } + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + }) + _ = resp.Body.Close() + return nil +} + +func Handler(c *gin.Context, resp *http.Response, modelName string) *model.ErrorWithStatusCode { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + var deeplResponse Response + err = json.Unmarshal(responseBody, &deeplResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + if deeplResponse.Message != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: deeplResponse.Message, + Code: "deepl_error", + }, + StatusCode: resp.StatusCode, + } + } + fullTextResponse := ResponseDeepL2OpenAI(&deeplResponse) + fullTextResponse.Model = modelName + fullTextResponse.Id = helper.GetResponseID(c) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil +} diff --git a/relay/adaptor/deepl/model.go b/relay/adaptor/deepl/model.go new file mode 100644 index 00000000..3f823d21 --- /dev/null +++ b/relay/adaptor/deepl/model.go @@ -0,0 +1,16 @@ +package deepl + +type Request struct { + Text []string `json:"text"` + TargetLang string `json:"target_lang"` +} + +type Translation struct { + DetectedSourceLanguage string `json:"detected_source_language,omitempty"` + Text string `json:"text,omitempty"` +} + +type Response struct { + Translations []Translation `json:"translations,omitempty"` + Message string `json:"message,omitempty"` +} diff --git a/relay/adaptor/deepseek/constants.go b/relay/adaptor/deepseek/constants.go new file mode 100644 index 00000000..ad840bc2 --- /dev/null +++ b/relay/adaptor/deepseek/constants.go @@ -0,0 +1,6 @@ +package deepseek + +var ModelList = []string{ + "deepseek-chat", + "deepseek-coder", +} diff --git a/relay/channel/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go similarity index 61% rename from relay/channel/gemini/adaptor.go rename to relay/adaptor/gemini/adaptor.go index f3305e5d..a4dcae93 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -3,33 +3,35 @@ package gemini import ( "errors" "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/helper" - channelhelper "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + channelhelper "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" ) type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - version := helper.AssignOrDefault(meta.APIVersion, "v1") +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion) action := "generateContent" if meta.IsStream { - action = "streamGenerateContent" + action = "streamGenerateContent?alt=sse" } return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channelhelper.SetupCommonRequestHeader(c, req, meta) req.Header.Set("x-goog-api-key", meta.APIKey) return nil @@ -42,11 +44,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channelhelper.DoRequestHelper(a, c, meta, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string err, responseText = StreamHandler(c, resp) diff --git a/relay/channel/gemini/constants.go b/relay/adaptor/gemini/constants.go similarity index 71% rename from relay/channel/gemini/constants.go rename to relay/adaptor/gemini/constants.go index e8d3a155..32e7c240 100644 --- a/relay/channel/gemini/constants.go +++ b/relay/adaptor/gemini/constants.go @@ -3,6 +3,6 @@ package gemini // https://ai.google.dev/models/gemini var ModelList = []string{ - "gemini-pro", "gemini-1.0-pro-001", + "gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro", "gemini-pro-vision", "gemini-1.0-pro-vision-001", } diff --git a/relay/channel/gemini/main.go b/relay/adaptor/gemini/main.go similarity index 81% rename from relay/channel/gemini/main.go rename to relay/adaptor/gemini/main.go index c24694c8..faccc4cb 100644 --- a/relay/channel/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -4,17 +4,19 @@ import ( "bufio" "encoding/json" "fmt" + "io" + "net/http" + "strings" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" "github.com/gin-gonic/gin" ) @@ -53,7 +55,17 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { MaxOutputTokens: textRequest.MaxTokens, }, } - if textRequest.Functions != nil { + if textRequest.Tools != nil { + functions := make([]model.Function, 0, len(textRequest.Tools)) + for _, tool := range textRequest.Tools { + functions = append(functions, tool.Function) + } + geminiRequest.Tools = []ChatTools{ + { + FunctionDeclarations: functions, + }, + } + } else if textRequest.Functions != nil { geminiRequest.Tools = []ChatTools{ { FunctionDeclarations: textRequest.Functions, @@ -153,9 +165,33 @@ type ChatPromptFeedback struct { SafetyRatings []ChatSafetyRating `json:"safetyRatings"` } +func getToolCalls(candidate *ChatCandidate) []model.Tool { + var toolCalls []model.Tool + + item := candidate.Content.Parts[0] + if item.FunctionCall == nil { + return toolCalls + } + argsBytes, err := json.Marshal(item.FunctionCall.Arguments) + if err != nil { + logger.FatalLog("getToolCalls failed: " + err.Error()) + return toolCalls + } + toolCall := model.Tool{ + Id: fmt.Sprintf("call_%s", random.GetUUID()), + Type: "function", + Function: model.Function{ + Arguments: string(argsBytes), + Name: item.FunctionCall.FunctionName, + }, + } + toolCalls = append(toolCalls, toolCall) + return toolCalls +} + func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion", Created: helper.GetTimestamp(), Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), @@ -164,13 +200,19 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { choice := openai.TextResponseChoice{ Index: i, Message: model.Message{ - Role: "assistant", - Content: "", + Role: "assistant", }, FinishReason: constant.StopFinishReason, } if len(candidate.Content.Parts) > 0 { - choice.Message.Content = candidate.Content.Parts[0].Text + if candidate.Content.Parts[0].FunctionCall != nil { + choice.Message.ToolCalls = getToolCalls(&candidate) + } else { + choice.Message.Content = candidate.Content.Parts[0].Text + } + } else { + choice.Message.Content = "" + choice.FinishReason = candidate.FinishReason } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } @@ -190,8 +232,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { responseText := "" - dataChan := make(chan string) - stopChan := make(chan bool) scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -205,14 +245,16 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) + dataChan := make(chan string) + stopChan := make(chan bool) go func() { for scanner.Scan() { data := scanner.Text() data = strings.TrimSpace(data) - if !strings.HasPrefix(data, "\"text\": \"") { + if !strings.HasPrefix(data, "data: ") { continue } - data = strings.TrimPrefix(data, "\"text\": \"") + data = strings.TrimPrefix(data, "data: ") data = strings.TrimSuffix(data, "\"") dataChan <- data } @@ -222,23 +264,17 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - // this is used to prevent annoying \ related format bug - data = fmt.Sprintf("{\"content\": \"%s\"}", data) - type dummyStruct struct { - Content string `json:"content"` + var geminiResponse ChatResponse + err := json.Unmarshal([]byte(data), &geminiResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true } - var dummy dummyStruct - err := json.Unmarshal([]byte(data), &dummy) - responseText += dummy.Content - var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = dummy.Content - response := openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), - Object: "chat.completion.chunk", - Created: helper.GetTimestamp(), - Model: "gemini-pro", - Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + response := streamResponseGeminiChat2OpenAI(&geminiResponse) + if response == nil { + return true } + responseText += response.Choices[0].Delta.StringContent() jsonResponse, err := json.Marshal(response) if err != nil { logger.SysError("error marshalling stream response: " + err.Error()) diff --git a/relay/channel/gemini/model.go b/relay/adaptor/gemini/model.go similarity index 75% rename from relay/channel/gemini/model.go rename to relay/adaptor/gemini/model.go index d1e3c4fd..47b74fbc 100644 --- a/relay/channel/gemini/model.go +++ b/relay/adaptor/gemini/model.go @@ -12,9 +12,15 @@ type InlineData struct { Data string `json:"data"` } +type FunctionCall struct { + FunctionName string `json:"name"` + Arguments any `json:"args"` +} + type Part struct { - Text string `json:"text,omitempty"` - InlineData *InlineData `json:"inlineData,omitempty"` + Text string `json:"text,omitempty"` + InlineData *InlineData `json:"inlineData,omitempty"` + FunctionCall *FunctionCall `json:"functionCall,omitempty"` } type ChatContent struct { @@ -28,7 +34,7 @@ type ChatSafetySettings struct { } type ChatTools struct { - FunctionDeclarations any `json:"functionDeclarations,omitempty"` + FunctionDeclarations any `json:"function_declarations,omitempty"` } type ChatGenerationConfig struct { diff --git a/relay/channel/groq/constants.go b/relay/adaptor/groq/constants.go similarity index 80% rename from relay/channel/groq/constants.go rename to relay/adaptor/groq/constants.go index fc9a9ebd..1aa2574b 100644 --- a/relay/channel/groq/constants.go +++ b/relay/adaptor/groq/constants.go @@ -7,4 +7,6 @@ var ModelList = []string{ "llama2-7b-2048", "llama2-70b-4096", "mixtral-8x7b-32768", + "llama3-8b-8192", + "llama3-70b-8192", } diff --git a/relay/adaptor/interface.go b/relay/adaptor/interface.go new file mode 100644 index 00000000..01b2e2cb --- /dev/null +++ b/relay/adaptor/interface.go @@ -0,0 +1,21 @@ +package adaptor + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +type Adaptor interface { + Init(meta *meta.Meta) + GetRequestURL(meta *meta.Meta) (string, error) + SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error + ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + ConvertImageRequest(request *model.ImageRequest) (any, error) + DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) + DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) + GetModelList() []string + GetChannelName() string +} diff --git a/relay/channel/lingyiwanwu/constants.go b/relay/adaptor/lingyiwanwu/constants.go similarity index 100% rename from relay/channel/lingyiwanwu/constants.go rename to relay/adaptor/lingyiwanwu/constants.go diff --git a/relay/adaptor/minimax/constants.go b/relay/adaptor/minimax/constants.go new file mode 100644 index 00000000..1b2fc104 --- /dev/null +++ b/relay/adaptor/minimax/constants.go @@ -0,0 +1,11 @@ +package minimax + +// https://www.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd + +var ModelList = []string{ + "abab6.5-chat", + "abab6.5s-chat", + "abab6-chat", + "abab5.5-chat", + "abab5.5s-chat", +} diff --git a/relay/adaptor/minimax/main.go b/relay/adaptor/minimax/main.go new file mode 100644 index 00000000..fc9b5d26 --- /dev/null +++ b/relay/adaptor/minimax/main.go @@ -0,0 +1,14 @@ +package minimax + +import ( + "fmt" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +func GetRequestURL(meta *meta.Meta) (string, error) { + if meta.Mode == relaymode.ChatCompletions { + return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil + } + return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode) +} diff --git a/relay/channel/mistral/constants.go b/relay/adaptor/mistral/constants.go similarity index 100% rename from relay/channel/mistral/constants.go rename to relay/adaptor/mistral/constants.go diff --git a/relay/channel/moonshot/constants.go b/relay/adaptor/moonshot/constants.go similarity index 100% rename from relay/channel/moonshot/constants.go rename to relay/adaptor/moonshot/constants.go diff --git a/relay/adaptor/ollama/adaptor.go b/relay/adaptor/ollama/adaptor.go new file mode 100644 index 00000000..66702c5d --- /dev/null +++ b/relay/adaptor/ollama/adaptor.go @@ -0,0 +1,82 @@ +package ollama + +import ( + "errors" + "fmt" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/model" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + // https://github.com/ollama/ollama/blob/main/docs/api.md + fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) + if meta.Mode == relaymode.Embeddings { + fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) + } + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case relaymode.Embeddings: + ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request) + return ollamaEmbeddingRequest, nil + default: + return ConvertRequest(*request), nil + } +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "ollama" +} diff --git a/relay/adaptor/ollama/constants.go b/relay/adaptor/ollama/constants.go new file mode 100644 index 00000000..d9dc72a8 --- /dev/null +++ b/relay/adaptor/ollama/constants.go @@ -0,0 +1,11 @@ +package ollama + +var ModelList = []string{ + "codellama:7b-instruct", + "llama2:7b", + "llama2:latest", + "llama3:latest", + "phi3:latest", + "qwen:0.5b-chat", + "qwen:7b", +} diff --git a/relay/channel/ollama/main.go b/relay/adaptor/ollama/main.go similarity index 67% rename from relay/channel/ollama/main.go rename to relay/adaptor/ollama/main.go index 7ec646a3..c5fe08e6 100644 --- a/relay/channel/ollama/main.go +++ b/relay/adaptor/ollama/main.go @@ -5,16 +5,19 @@ import ( "context" "encoding/json" "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/common/random" "io" "net/http" "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/image" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" ) func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { @@ -30,9 +33,22 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { Stream: request.Stream, } for _, message := range request.Messages { + openaiContent := message.ParseContent() + var imageUrls []string + var contentText string + for _, part := range openaiContent { + switch part.Type { + case model.ContentTypeText: + contentText = part.Text + case model.ContentTypeImageURL: + _, data, _ := image.GetImageFromUrl(part.ImageURL.Url) + imageUrls = append(imageUrls, data) + } + } ollamaRequest.Messages = append(ollamaRequest.Messages, Message{ Role: message.Role, - Content: message.StringContent(), + Content: contentText, + Images: imageUrls, }) } return &ollamaRequest @@ -50,7 +66,8 @@ func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse { choice.FinishReason = "stop" } fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Model: response.Model, Object: "chat.completion", Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, @@ -71,7 +88,7 @@ func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompl choice.FinishReason = &constant.StopFinishReason } response := openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: ollamaResponse.Model, @@ -139,6 +156,64 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC return nil, &usage } +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + Model: request.Model, + Prompt: strings.Join(request.ParseInput(), " "), + } +} + +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var ollamaResponse EmbeddingResponse + err := json.NewDecoder(resp.Body).Decode(&ollamaResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if ollamaResponse.Error != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: ollamaResponse.Error, + Type: "ollama_error", + Param: "", + Code: "ollama_error", + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, 1), + Model: "text-embedding-v1", + Usage: model.Usage{TotalTokens: 0}, + } + + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: 0, + Embedding: response.Embedding, + }) + return &openAIEmbeddingResponse +} + func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { ctx := context.TODO() var ollamaResponse ChatResponse diff --git a/relay/channel/ollama/model.go b/relay/adaptor/ollama/model.go similarity index 86% rename from relay/channel/ollama/model.go rename to relay/adaptor/ollama/model.go index a8ef1ffc..8baf56a0 100644 --- a/relay/channel/ollama/model.go +++ b/relay/adaptor/ollama/model.go @@ -35,3 +35,13 @@ type ChatResponse struct { EvalDuration int `json:"eval_duration,omitempty"` Error string `json:"error,omitempty"` } + +type EmbeddingRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` +} + +type EmbeddingResponse struct { + Error string `json:"error,omitempty"` + Embedding []float64 `json:"embedding,omitempty"` +} diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go new file mode 100644 index 00000000..2e2e4100 --- /dev/null +++ b/relay/adaptor/openai/adaptor.go @@ -0,0 +1,115 @@ +package openai + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/minimax" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" + "strings" +) + +type Adaptor struct { + ChannelType int +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.ChannelType = meta.ChannelType +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + switch meta.ChannelType { + case channeltype.Azure: + if meta.Mode == relaymode.ImagesGenerations { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api + // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview + fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion) + return fullRequestURL, nil + } + + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api + requestURL := strings.Split(meta.RequestURLPath, "?")[0] + requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.Config.APIVersion) + task := strings.TrimPrefix(requestURL, "/v1/") + model_ := meta.ActualModelName + model_ = strings.Replace(model_, ".", "", -1) + //https://github.com/songquanpeng/one-api/issues/1191 + // {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version} + requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) + return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil + case channeltype.Minimax: + return minimax.GetRequestURL(meta) + default: + return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil + } +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + if meta.ChannelType == channeltype.Azure { + req.Header.Set("api-key", meta.APIKey) + return nil + } + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + if meta.ChannelType == channeltype.OpenRouter { + req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") + req.Header.Set("X-Title", "One API") + } + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + var responseText string + err, responseText, usage = StreamHandler(c, resp, meta.Mode) + if usage == nil || usage.TotalTokens == 0 { + usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } + if usage.TotalTokens != 0 && usage.PromptTokens == 0 { // some channels don't return prompt tokens & completion tokens + usage.PromptTokens = meta.PromptTokens + usage.CompletionTokens = usage.TotalTokens - meta.PromptTokens + } + } else { + switch meta.Mode { + case relaymode.ImagesGenerations: + err, _ = ImageHandler(c, resp) + default: + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + _, modelList := GetCompatibleChannelMeta(a.ChannelType) + return modelList +} + +func (a *Adaptor) GetChannelName() string { + channelName, _ := GetCompatibleChannelMeta(a.ChannelType) + return channelName +} diff --git a/relay/adaptor/openai/compatible.go b/relay/adaptor/openai/compatible.go new file mode 100644 index 00000000..0116a2eb --- /dev/null +++ b/relay/adaptor/openai/compatible.go @@ -0,0 +1,58 @@ +package openai + +import ( + "github.com/songquanpeng/one-api/relay/adaptor/ai360" + "github.com/songquanpeng/one-api/relay/adaptor/baichuan" + "github.com/songquanpeng/one-api/relay/adaptor/deepseek" + "github.com/songquanpeng/one-api/relay/adaptor/groq" + "github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu" + "github.com/songquanpeng/one-api/relay/adaptor/minimax" + "github.com/songquanpeng/one-api/relay/adaptor/mistral" + "github.com/songquanpeng/one-api/relay/adaptor/moonshot" + "github.com/songquanpeng/one-api/relay/adaptor/stepfun" + "github.com/songquanpeng/one-api/relay/adaptor/togetherai" + "github.com/songquanpeng/one-api/relay/channeltype" +) + +var CompatibleChannels = []int{ + channeltype.Azure, + channeltype.AI360, + channeltype.Moonshot, + channeltype.Baichuan, + channeltype.Minimax, + channeltype.Mistral, + channeltype.Groq, + channeltype.LingYiWanWu, + channeltype.StepFun, + channeltype.DeepSeek, + channeltype.TogetherAI, +} + +func GetCompatibleChannelMeta(channelType int) (string, []string) { + switch channelType { + case channeltype.Azure: + return "azure", ModelList + case channeltype.AI360: + return "360", ai360.ModelList + case channeltype.Moonshot: + return "moonshot", moonshot.ModelList + case channeltype.Baichuan: + return "baichuan", baichuan.ModelList + case channeltype.Minimax: + return "minimax", minimax.ModelList + case channeltype.Mistral: + return "mistralai", mistral.ModelList + case channeltype.Groq: + return "groq", groq.ModelList + case channeltype.LingYiWanWu: + return "lingyiwanwu", lingyiwanwu.ModelList + case channeltype.StepFun: + return "stepfun", stepfun.ModelList + case channeltype.DeepSeek: + return "deepseek", deepseek.ModelList + case channeltype.TogetherAI: + return "together.ai", togetherai.ModelList + default: + return "openai", ModelList + } +} diff --git a/relay/channel/openai/constants.go b/relay/adaptor/openai/constants.go similarity index 92% rename from relay/channel/openai/constants.go rename to relay/adaptor/openai/constants.go index ea236ea1..2ffff007 100644 --- a/relay/channel/openai/constants.go +++ b/relay/adaptor/openai/constants.go @@ -6,7 +6,7 @@ var ModelList = []string{ "gpt-3.5-turbo-instruct", "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", - "gpt-4-turbo-preview", + "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-vision-preview", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", diff --git a/relay/adaptor/openai/helper.go b/relay/adaptor/openai/helper.go new file mode 100644 index 00000000..7d73303b --- /dev/null +++ b/relay/adaptor/openai/helper.go @@ -0,0 +1,30 @@ +package openai + +import ( + "fmt" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/model" + "strings" +) + +func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { + usage := &model.Usage{} + usage.PromptTokens = promptTokens + usage.CompletionTokens = CountTokenText(responseText, modeName) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + return usage +} + +func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + switch channelType { + case channeltype.OpenAI: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + case channeltype.Azure: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) + } + } + return fullRequestURL +} diff --git a/relay/adaptor/openai/image.go b/relay/adaptor/openai/image.go new file mode 100644 index 00000000..0f89618a --- /dev/null +++ b/relay/adaptor/openai/image.go @@ -0,0 +1,44 @@ +package openai + +import ( + "bytes" + "encoding/json" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var imageResponse ImageResponse + responseBody, err := io.ReadAll(resp.Body) + + if err != nil { + return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &imageResponse) + if err != nil { + return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, nil +} diff --git a/relay/channel/openai/main.go b/relay/adaptor/openai/main.go similarity index 68% rename from relay/channel/openai/main.go rename to relay/adaptor/openai/main.go index d47cd164..72c675e1 100644 --- a/relay/channel/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -6,14 +6,21 @@ import ( "encoding/json" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "strings" ) +const ( + dataPrefix = "data: " + done = "[DONE]" + dataPrefixLength = len(dataPrefix) +) + func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { responseText := "" scanner := bufio.NewScanner(resp.Body) @@ -35,39 +42,46 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E go func() { for scanner.Scan() { data := scanner.Text() - if len(data) < 6 { // ignore blank line or wrong format + if len(data) < dataPrefixLength { // ignore blank line or wrong format continue } - if data[:6] != "data: " && data[:6] != "[DONE]" { + if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done { continue } - dataChan <- data - data = data[6:] - if !strings.HasPrefix(data, "[DONE]") { - switch relayMode { - case constant.RelayModeChatCompletions: - var streamResponse ChatCompletionsStreamResponse - err := json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - continue // just ignore the error - } - for _, choice := range streamResponse.Choices { - responseText += choice.Delta.Content - } - if streamResponse.Usage != nil { - usage = streamResponse.Usage - } - case constant.RelayModeCompletions: - var streamResponse CompletionsStreamResponse - err := json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - continue - } - for _, choice := range streamResponse.Choices { - responseText += choice.Text - } + if strings.HasPrefix(data[dataPrefixLength:], done) { + dataChan <- data + continue + } + switch relayMode { + case relaymode.ChatCompletions: + var streamResponse ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + dataChan <- data // if error happened, pass the data to client + continue // just ignore the error + } + if len(streamResponse.Choices) == 0 { + // but for empty choice, we should not pass it to client, this is for azure + continue // just ignore empty choice + } + dataChan <- data + for _, choice := range streamResponse.Choices { + responseText += conv.AsString(choice.Delta.Content) + } + if streamResponse.Usage != nil { + usage = streamResponse.Usage + } + case relaymode.Completions: + dataChan <- data + var streamResponse CompletionsStreamResponse + err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + for _, choice := range streamResponse.Choices { + responseText += choice.Text } } } diff --git a/relay/channel/openai/model.go b/relay/adaptor/openai/model.go similarity index 88% rename from relay/channel/openai/model.go rename to relay/adaptor/openai/model.go index 6c0b2c53..4c974de4 100644 --- a/relay/channel/openai/model.go +++ b/relay/adaptor/openai/model.go @@ -110,20 +110,22 @@ type EmbeddingResponse struct { model.Usage `json:"usage"` } +type ImageData struct { + Url string `json:"url,omitempty"` + B64Json string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` +} + type ImageResponse struct { - Created int `json:"created"` - Data []struct { - Url string `json:"url"` - } + Created int64 `json:"created"` + Data []ImageData `json:"data"` + //model.Usage `json:"usage"` } type ChatCompletionsStreamResponseChoice struct { - Index int `json:"index"` - Delta struct { - Content string `json:"content"` - Role string `json:"role,omitempty"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` + Index int `json:"index"` + Delta model.Message `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` } type ChatCompletionsStreamResponse struct { @@ -132,7 +134,7 @@ type ChatCompletionsStreamResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionsStreamResponseChoice `json:"choices"` - Usage *model.Usage `json:"usage"` + Usage *model.Usage `json:"usage,omitempty"` } type CompletionsStreamResponse struct { diff --git a/relay/channel/openai/token.go b/relay/adaptor/openai/token.go similarity index 97% rename from relay/channel/openai/token.go rename to relay/adaptor/openai/token.go index 0720425f..bb9c38a9 100644 --- a/relay/channel/openai/token.go +++ b/relay/adaptor/openai/token.go @@ -4,10 +4,10 @@ import ( "errors" "fmt" "github.com/pkoukk/tiktoken-go" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/model" "math" "strings" @@ -28,7 +28,7 @@ func InitTokenEncoders() { if err != nil { logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) } - for model := range common.ModelRatio { + for model := range billingratio.ModelRatio { if strings.HasPrefix(model, "gpt-3.5") { tokenEncoderMap[model] = gpt35TokenEncoder } else if strings.HasPrefix(model, "gpt-4") { @@ -206,3 +206,7 @@ func CountTokenText(text string, model string) int { tokenEncoder := getTokenEncoder(model) return getTokenNum(tokenEncoder, text) } + +func CountToken(text string) int { + return CountTokenInput(text, "gpt-3.5-turbo") +} diff --git a/relay/channel/openai/util.go b/relay/adaptor/openai/util.go similarity index 100% rename from relay/channel/openai/util.go rename to relay/adaptor/openai/util.go diff --git a/relay/channel/palm/adaptor.go b/relay/adaptor/palm/adaptor.go similarity index 59% rename from relay/channel/palm/adaptor.go rename to relay/adaptor/palm/adaptor.go index efd0620c..98aa3e18 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/adaptor/palm/adaptor.go @@ -4,10 +4,10 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) @@ -15,16 +15,16 @@ import ( type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) req.Header.Set("x-goog-api-key", meta.APIKey) return nil } @@ -36,11 +36,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string err, responseText = StreamHandler(c, resp) diff --git a/relay/channel/palm/constants.go b/relay/adaptor/palm/constants.go similarity index 100% rename from relay/channel/palm/constants.go rename to relay/adaptor/palm/constants.go diff --git a/relay/channel/palm/model.go b/relay/adaptor/palm/model.go similarity index 100% rename from relay/channel/palm/model.go rename to relay/adaptor/palm/model.go diff --git a/relay/channel/palm/palm.go b/relay/adaptor/palm/palm.go similarity index 97% rename from relay/channel/palm/palm.go rename to relay/adaptor/palm/palm.go index 56738544..1e60e7cd 100644 --- a/relay/channel/palm/palm.go +++ b/relay/adaptor/palm/palm.go @@ -7,7 +7,8 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "io" @@ -74,7 +75,7 @@ func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletio func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) + responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID()) createdTime := helper.GetTimestamp() dataChan := make(chan string) stopChan := make(chan bool) diff --git a/relay/adaptor/stepfun/constants.go b/relay/adaptor/stepfun/constants.go new file mode 100644 index 00000000..a82e562b --- /dev/null +++ b/relay/adaptor/stepfun/constants.go @@ -0,0 +1,7 @@ +package stepfun + +var ModelList = []string{ + "step-1-32k", + "step-1v-32k", + "step-1-200k", +} diff --git a/relay/channel/tencent/adaptor.go b/relay/adaptor/tencent/adaptor.go similarity index 66% rename from relay/channel/tencent/adaptor.go rename to relay/adaptor/tencent/adaptor.go index f348674e..a97476d6 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/adaptor/tencent/adaptor.go @@ -4,10 +4,10 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "strings" @@ -19,16 +19,16 @@ type Adaptor struct { Sign string } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) req.Header.Set("Authorization", a.Sign) req.Header.Set("X-TC-Action", meta.ActualModelName) return nil @@ -52,11 +52,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return tencentRequest, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string err, responseText = StreamHandler(c, resp) diff --git a/relay/channel/tencent/constants.go b/relay/adaptor/tencent/constants.go similarity index 100% rename from relay/channel/tencent/constants.go rename to relay/adaptor/tencent/constants.go diff --git a/relay/channel/tencent/main.go b/relay/adaptor/tencent/main.go similarity index 95% rename from relay/channel/tencent/main.go rename to relay/adaptor/tencent/main.go index cfdc0bfd..2ca5724e 100644 --- a/relay/channel/tencent/main.go +++ b/relay/adaptor/tencent/main.go @@ -10,9 +10,11 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "io" @@ -40,7 +42,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { return &ChatRequest{ Timestamp: helper.GetTimestamp(), Expired: helper.GetTimestamp() + 24*60*60, - QueryID: helper.GetUUID(), + QueryID: random.GetUUID(), Temperature: request.Temperature, TopP: request.TopP, Stream: stream, @@ -70,7 +72,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { response := openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: "tencent-hunyuan", @@ -129,7 +131,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } response := streamResponseTencent2OpenAI(&TencentResponse) if len(response.Choices) != 0 { - responseText += response.Choices[0].Delta.Content + responseText += conv.AsString(response.Choices[0].Delta.Content) } jsonResponse, err := json.Marshal(response) if err != nil { diff --git a/relay/channel/tencent/model.go b/relay/adaptor/tencent/model.go similarity index 100% rename from relay/channel/tencent/model.go rename to relay/adaptor/tencent/model.go diff --git a/relay/adaptor/togetherai/constants.go b/relay/adaptor/togetherai/constants.go new file mode 100644 index 00000000..0a79fbdc --- /dev/null +++ b/relay/adaptor/togetherai/constants.go @@ -0,0 +1,10 @@ +package togetherai + +// https://docs.together.ai/docs/inference-models + +var ModelList = []string{ + "meta-llama/Llama-3-70b-chat-hf", + "deepseek-ai/deepseek-coder-33b-instruct", + "mistralai/Mixtral-8x22B-Instruct-v0.1", + "Qwen/Qwen1.5-72B-Chat", +} diff --git a/relay/channel/xunfei/adaptor.go b/relay/adaptor/xunfei/adaptor.go similarity index 54% rename from relay/channel/xunfei/adaptor.go rename to relay/adaptor/xunfei/adaptor.go index 92d9d7d6..3af97831 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/adaptor/xunfei/adaptor.go @@ -3,10 +3,10 @@ package xunfei import ( "errors" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "strings" @@ -14,18 +14,27 @@ import ( type Adaptor struct { request *model.GeneralOpenAIRequest + meta *meta.Meta } -func (a *Adaptor) Init(meta *util.RelayMeta) { - +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return "", nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + version := parseAPIVersionByModelName(meta.ActualModelName) + if version == "" { + version = a.meta.Config.APIVersion + } + if version == "" { + version = "v1.1" + } + a.meta.Config.APIVersion = version // check DoResponse for auth part return nil } @@ -38,14 +47,21 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} dummyResp.StatusCode = http.StatusOK return dummyResp, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { splits := strings.Split(meta.APIKey, "|") if len(splits) != 3 { return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) @@ -54,9 +70,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) } if meta.IsStream { - err, usage = StreamHandler(c, *a.request, splits[0], splits[1], splits[2]) + err, usage = StreamHandler(c, meta, *a.request, splits[0], splits[1], splits[2]) } else { - err, usage = Handler(c, *a.request, splits[0], splits[1], splits[2]) + err, usage = Handler(c, meta, *a.request, splits[0], splits[1], splits[2]) } return } diff --git a/relay/channel/xunfei/constants.go b/relay/adaptor/xunfei/constants.go similarity index 100% rename from relay/channel/xunfei/constants.go rename to relay/adaptor/xunfei/constants.go diff --git a/relay/channel/xunfei/main.go b/relay/adaptor/xunfei/main.go similarity index 74% rename from relay/channel/xunfei/main.go rename to relay/adaptor/xunfei/main.go index f89aea2b..c3e768b7 100644 --- a/relay/channel/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -11,8 +11,10 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "io" "net/http" @@ -26,7 +28,11 @@ import ( func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) + var lastToolCalls []model.Tool for _, message := range request.Messages { + if message.ToolCalls != nil { + lastToolCalls = message.ToolCalls + } messages = append(messages, Message{ Role: message.Role, Content: message.StringContent(), @@ -39,9 +45,33 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string xunfeiRequest.Parameter.Chat.TopK = request.N xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Payload.Message.Text = messages + if len(lastToolCalls) != 0 { + for _, toolCall := range lastToolCalls { + xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function) + } + } + return &xunfeiRequest } +func getToolCalls(response *ChatResponse) []model.Tool { + var toolCalls []model.Tool + if len(response.Payload.Choices.Text) == 0 { + return toolCalls + } + item := response.Payload.Choices.Text[0] + if item.FunctionCall == nil { + return toolCalls + } + toolCall := model.Tool{ + Id: fmt.Sprintf("call_%s", random.GetUUID()), + Type: "function", + Function: *item.FunctionCall, + } + toolCalls = append(toolCalls, toolCall) + return toolCalls +} + func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { if len(response.Payload.Choices.Text) == 0 { response.Payload.Choices.Text = []ChatResponseTextItem{ @@ -53,13 +83,14 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { choice := openai.TextResponseChoice{ Index: 0, Message: model.Message{ - Role: "assistant", - Content: response.Payload.Choices.Text[0].Content, + Role: "assistant", + Content: response.Payload.Choices.Text[0].Content, + ToolCalls: getToolCalls(response), }, FinishReason: constant.StopFinishReason, } fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion", Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, @@ -78,11 +109,12 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl } var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content + choice.Delta.ToolCalls = getToolCalls(xunfeiResponse) if xunfeiResponse.Payload.Choices.Status == 2 { choice.FinishReason = &constant.StopFinishReason } response := openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: "SparkDesk", @@ -117,11 +149,11 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { return callUrl } -func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { - domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) +func StreamHandler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { + domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { - return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil } common.SetEventStreamHeaders(c) var usage model.Usage @@ -147,11 +179,11 @@ func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId return nil, &usage } -func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { - domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) +func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { + domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { - return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil } var usage model.Usage var content string @@ -171,11 +203,7 @@ func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId strin } } if len(xunfeiResponse.Payload.Choices.Text) == 0 { - xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ - { - Content: "", - }, - } + return openai.ErrorWrapper(err, "xunfei_empty_response_detected", http.StatusInternalServerError), nil } xunfeiResponse.Payload.Choices.Text[0].Content = content @@ -202,15 +230,21 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, if err != nil { return nil, nil, err } + _, msg, err := conn.ReadMessage() + if err != nil { + return nil, nil, err + } dataChan := make(chan ChatResponse) stopChan := make(chan bool) go func() { for { - _, msg, err := conn.ReadMessage() - if err != nil { - logger.SysError("error reading stream response: " + err.Error()) - break + if msg == nil { + _, msg, err = conn.ReadMessage() + if err != nil { + logger.SysError("error reading stream response: " + err.Error()) + break + } } var response ChatResponse err = json.Unmarshal(msg, &response) @@ -218,6 +252,7 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, logger.SysError("error unmarshalling stream response: " + err.Error()) break } + msg = nil dataChan <- response if response.Payload.Choices.Status == 2 { err := conn.Close() @@ -233,25 +268,12 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, return dataChan, stopChan, nil } -func getAPIVersion(c *gin.Context, modelName string) string { - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion != "" { - return apiVersion - } +func parseAPIVersionByModelName(modelName string) string { parts := strings.Split(modelName, "-") if len(parts) == 2 { - apiVersion = parts[1] - return apiVersion - + return parts[1] } - apiVersion = c.GetString(common.ConfigKeyAPIVersion) - if apiVersion != "" { - return apiVersion - } - apiVersion = "v1.1" - logger.SysLog("api_version not found, using default: " + apiVersion) - return apiVersion + return "" } // https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E @@ -269,8 +291,7 @@ func apiVersion2domain(apiVersion string) string { return "general" + apiVersion } -func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) { - apiVersion := getAPIVersion(c, modelName) +func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) { domain := apiVersion2domain(apiVersion) authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) return domain, authUrl diff --git a/relay/channel/xunfei/model.go b/relay/adaptor/xunfei/model.go similarity index 81% rename from relay/channel/xunfei/model.go rename to relay/adaptor/xunfei/model.go index 1266739d..97a43154 100644 --- a/relay/channel/xunfei/model.go +++ b/relay/adaptor/xunfei/model.go @@ -26,13 +26,18 @@ type ChatRequest struct { Message struct { Text []Message `json:"text"` } `json:"message"` + Functions struct { + Text []model.Function `json:"text,omitempty"` + } `json:"functions,omitempty"` } `json:"payload"` } type ChatResponseTextItem struct { - Content string `json:"content"` - Role string `json:"role"` - Index int `json:"index"` + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` + ContentType string `json:"content_type"` + FunctionCall *model.Function `json:"function_call"` } type ChatResponse struct { diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go new file mode 100644 index 00000000..78b01fb3 --- /dev/null +++ b/relay/adaptor/zhipu/adaptor.go @@ -0,0 +1,149 @@ +package zhipu + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "math" + "net/http" + "strings" +) + +type Adaptor struct { + APIVersion string +} + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) SetVersionByModeName(modelName string) { + if strings.HasPrefix(modelName, "glm-") { + a.APIVersion = "v4" + } else { + a.APIVersion = "v3" + } +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + switch meta.Mode { + case relaymode.ImagesGenerations: + return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil + case relaymode.Embeddings: + return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil + } + a.SetVersionByModeName(meta.ActualModelName) + if a.APIVersion == "v4" { + return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil + } + method := "invoke" + if meta.IsStream { + method = "sse-invoke" + } + return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + token := GetToken(meta.APIKey) + req.Header.Set("Authorization", token) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case relaymode.Embeddings: + baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request) + return baiduEmbeddingRequest, err + default: + // TopP (0.0, 1.0) + request.TopP = math.Min(0.99, request.TopP) + request.TopP = math.Max(0.01, request.TopP) + + // Temperature (0.0, 1.0) + request.Temperature = math.Min(0.99, request.Temperature) + request.Temperature = math.Max(0.01, request.Temperature) + a.SetVersionByModeName(request.Model) + if a.APIVersion == "v4" { + return request, nil + } + return ConvertRequest(*request), nil + } +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + newRequest := ImageRequest{ + Model: request.Model, + Prompt: request.Prompt, + UserId: request.User, + } + return newRequest, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, _, usage = openai.StreamHandler(c, resp, meta.Mode) + } else { + err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingsHandler(c, resp) + return + case relaymode.ImagesGenerations: + err, usage = openai.ImageHandler(c, resp) + return + } + if a.APIVersion == "v4" { + return a.DoResponseV4(c, resp, meta) + } + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + if meta.Mode == relaymode.Embeddings { + err, usage = EmbeddingsHandler(c, resp) + } else { + err, usage = Handler(c, resp) + } + } + return +} + +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) (*EmbeddingRequest, error) { + inputs := request.ParseInput() + if len(inputs) != 1 { + return nil, errors.New("invalid input length, zhipu only support one input") + } + return &EmbeddingRequest{ + Model: request.Model, + Input: inputs[0], + }, nil +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "zhipu" +} diff --git a/relay/channel/zhipu/constants.go b/relay/adaptor/zhipu/constants.go similarity index 62% rename from relay/channel/zhipu/constants.go rename to relay/adaptor/zhipu/constants.go index 1655a59d..e1192123 100644 --- a/relay/channel/zhipu/constants.go +++ b/relay/adaptor/zhipu/constants.go @@ -2,5 +2,6 @@ package zhipu var ModelList = []string{ "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", - "glm-4", "glm-4v", "glm-3-turbo", + "glm-4", "glm-4v", "glm-3-turbo", "embedding-2", + "cogview-3", } diff --git a/relay/channel/zhipu/main.go b/relay/adaptor/zhipu/main.go similarity index 80% rename from relay/channel/zhipu/main.go rename to relay/adaptor/zhipu/main.go index a46fd537..74a1a05e 100644 --- a/relay/channel/zhipu/main.go +++ b/relay/adaptor/zhipu/main.go @@ -8,7 +8,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "io" @@ -254,3 +254,50 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var zhipuResponse EmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &zhipuResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseZhipu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)), + Model: response.Model, + Usage: model.Usage{ + PromptTokens: response.PromptTokens, + CompletionTokens: response.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + }, + } + + for _, item := range response.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} diff --git a/relay/channel/zhipu/model.go b/relay/adaptor/zhipu/model.go similarity index 65% rename from relay/channel/zhipu/model.go rename to relay/adaptor/zhipu/model.go index b63e1d6f..f91de1dc 100644 --- a/relay/channel/zhipu/model.go +++ b/relay/adaptor/zhipu/model.go @@ -44,3 +44,27 @@ type tokenData struct { Token string ExpiryTime time.Time } + +type EmbeddingRequest struct { + Model string `json:"model"` + Input string `json:"input"` +} + +type EmbeddingResponse struct { + Model string `json:"model"` + Object string `json:"object"` + Embeddings []EmbeddingData `json:"data"` + model.Usage `json:"usage"` +} + +type EmbeddingData struct { + Index int `json:"index"` + Object string `json:"object"` + Embedding []float64 `json:"embedding"` +} + +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + UserId string `json:"user_id,omitempty"` +} diff --git a/relay/apitype/define.go b/relay/apitype/define.go new file mode 100644 index 00000000..cf1df694 --- /dev/null +++ b/relay/apitype/define.go @@ -0,0 +1,22 @@ +package apitype + +const ( + OpenAI = iota + Anthropic + PaLM + Baidu + Zhipu + Ali + Xunfei + AIProxyLibrary + Tencent + Gemini + Ollama + AwsClaude + Coze + Cohere + Cloudflare + DeepL + + Dummy // this one is only for count, do not add any channel after this +) diff --git a/relay/billing/billing.go b/relay/billing/billing.go new file mode 100644 index 00000000..a99d37ee --- /dev/null +++ b/relay/billing/billing.go @@ -0,0 +1,42 @@ +package billing + +import ( + "context" + "fmt" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" +) + +func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { + if preConsumedQuota != 0 { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, "error return pre-consumed quota: "+err.Error()) + } + }(ctx) + } +} + +func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + // quotaDelta is remaining quota to be consumed + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + logger.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(ctx, userId) + if err != nil { + logger.SysError("error update user quota cache: " + err.Error()) + } + // totalQuota is total quota consumed + if totalQuota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) + model.UpdateChannelUsedQuota(channelId, totalQuota) + } + if totalQuota <= 0 { + logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) + } +} diff --git a/common/group-ratio.go b/relay/billing/ratio/group.go similarity index 97% rename from common/group-ratio.go rename to relay/billing/ratio/group.go index 2de6e810..8e9c5b73 100644 --- a/common/group-ratio.go +++ b/relay/billing/ratio/group.go @@ -1,4 +1,4 @@ -package common +package ratio import ( "encoding/json" diff --git a/relay/billing/ratio/image.go b/relay/billing/ratio/image.go new file mode 100644 index 00000000..5a29cddc --- /dev/null +++ b/relay/billing/ratio/image.go @@ -0,0 +1,51 @@ +package ratio + +var ImageSizeRatios = map[string]map[string]float64{ + "dall-e-2": { + "256x256": 1, + "512x512": 1.125, + "1024x1024": 1.25, + }, + "dall-e-3": { + "1024x1024": 1, + "1024x1792": 2, + "1792x1024": 2, + }, + "ali-stable-diffusion-xl": { + "512x1024": 1, + "1024x768": 1, + "1024x1024": 1, + "576x1024": 1, + "1024x576": 1, + }, + "ali-stable-diffusion-v1.5": { + "512x1024": 1, + "1024x768": 1, + "1024x1024": 1, + "576x1024": 1, + "1024x576": 1, + }, + "wanx-v1": { + "1024x1024": 1, + "720x1280": 1, + "1280x720": 1, + }, +} + +var ImageGenerationAmounts = map[string][2]int{ + "dall-e-2": {1, 10}, + "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. + "ali-stable-diffusion-xl": {1, 4}, // Ali + "ali-stable-diffusion-v1.5": {1, 4}, // Ali + "wanx-v1": {1, 4}, // Ali + "cogview-3": {1, 1}, +} + +var ImagePromptLengthLimitations = map[string]int{ + "dall-e-2": 1000, + "dall-e-3": 4000, + "ali-stable-diffusion-xl": 4000, + "ali-stable-diffusion-v1.5": 4000, + "wanx-v1": 4000, + "cogview-3": 833, +} diff --git a/common/model-ratio.go b/relay/billing/ratio/model.go similarity index 69% rename from common/model-ratio.go rename to relay/billing/ratio/model.go index 0ef015d0..5dae2abd 100644 --- a/common/model-ratio.go +++ b/relay/billing/ratio/model.go @@ -1,9 +1,10 @@ -package common +package ratio import ( "encoding/json" - "github.com/songquanpeng/one-api/common/logger" "strings" + + "github.com/songquanpeng/one-api/common/logger" ) const ( @@ -31,6 +32,8 @@ var ModelRatio = map[string]float64{ "gpt-4-1106-preview": 5, // $0.01 / 1K tokens "gpt-4-0125-preview": 5, // $0.01 / 1K tokens "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo": 5, // $0.01 / 1K tokens + "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens "gpt-4-vision-preview": 5, // $0.01 / 1K tokens "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens "gpt-3.5-turbo-0301": 0.75, @@ -64,8 +67,8 @@ var ModelRatio = map[string]float64{ "text-search-ada-doc-001": 10, "text-moderation-stable": 0.1, "text-moderation-latest": 0.1, - "dall-e-2": 8, // $0.016 - $0.020 / image - "dall-e-3": 20, // $0.040 - $0.120 / image + "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image + "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image // https://www.anthropic.com/api#pricing "claude-instant-1.2": 0.8 / 1000 * USD, "claude-2.0": 8.0 / 1000 * USD, @@ -74,31 +77,48 @@ var ModelRatio = map[string]float64{ "claude-3-sonnet-20240229": 3.0 / 1000 * USD, "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 - "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens - "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens - "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens - "ERNIE-Bot-8k": 0.024 * RMB, - "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens - "bge-large-zh": 0.002 * RMB, - "bge-large-en": 0.002 * RMB, - "bge-large-8k": 0.002 * RMB, + "ERNIE-4.0-8K": 0.120 * RMB, + "ERNIE-3.5-8K": 0.012 * RMB, + "ERNIE-3.5-8K-0205": 0.024 * RMB, + "ERNIE-3.5-8K-1222": 0.012 * RMB, + "ERNIE-Bot-8K": 0.024 * RMB, + "ERNIE-3.5-4K-0205": 0.012 * RMB, + "ERNIE-Speed-8K": 0.004 * RMB, + "ERNIE-Speed-128K": 0.004 * RMB, + "ERNIE-Lite-8K-0922": 0.008 * RMB, + "ERNIE-Lite-8K-0308": 0.003 * RMB, + "ERNIE-Tiny-8K": 0.001 * RMB, + "BLOOMZ-7B": 0.004 * RMB, + "Embedding-V1": 0.002 * RMB, + "bge-large-zh": 0.002 * RMB, + "bge-large-en": 0.002 * RMB, + "tao-8k": 0.002 * RMB, // https://ai.google.dev/pricing - "PaLM-2": 1, - "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "PaLM-2": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-1.0-pro-vision-001": 1, + "gemini-1.0-pro-001": 1, + "gemini-1.5-pro": 1, // https://open.bigmodel.cn/pricing - "glm-4": 0.1 * RMB, - "glm-4v": 0.1 * RMB, - "glm-3-turbo": 0.005 * RMB, - "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens - "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens - "chatglm_std": 0.3572, // ¥0.005 / 1k tokens - "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing + "glm-4": 0.1 * RMB, + "glm-4v": 0.1 * RMB, + "glm-3-turbo": 0.005 * RMB, + "embedding-2": 0.0005 * RMB, + "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens + "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens + "chatglm_std": 0.3572, // ¥0.005 / 1k tokens + "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "cogview-3": 0.25 * RMB, + // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing + "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens "qwen-plus": 1.4286, // ¥0.02 / 1k tokens "qwen-max": 1.4286, // ¥0.02 / 1k tokens "qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens + "ali-stable-diffusion-xl": 8, + "ali-stable-diffusion-v1.5": 8, + "wanx-v1": 8, "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens @@ -120,6 +140,8 @@ var ModelRatio = map[string]float64{ "Baichuan2-Turbo-192k": 0.016 * RMB, "Baichuan2-53B": 0.02 * RMB, // https://api.minimax.chat/document/price + "abab6.5-chat": 0.03 * RMB, + "abab6.5s-chat": 0.01 * RMB, "abab6-chat": 0.1 * RMB, "abab5.5-chat": 0.015 * RMB, "abab5.5s-chat": 0.005 * RMB, @@ -130,15 +152,35 @@ var ModelRatio = map[string]float64{ "mistral-medium-latest": 2.7 / 1000 * USD, "mistral-large-latest": 8.0 / 1000 * USD, "mistral-embed": 0.1 / 1000 * USD, - // https://wow.groq.com/ - "llama2-70b-4096": 0.7 / 1000 * USD, - "llama2-7b-2048": 0.1 / 1000 * USD, + // https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed + "llama3-70b-8192": 0.59 / 1000 * USD, "mixtral-8x7b-32768": 0.27 / 1000 * USD, + "llama3-8b-8192": 0.05 / 1000 * USD, "gemma-7b-it": 0.1 / 1000 * USD, + "llama2-70b-4096": 0.64 / 1000 * USD, + "llama2-7b-2048": 0.1 / 1000 * USD, // https://platform.lingyiwanwu.com/docs#-计费单元 - "yi-34b-chat-0205": 2.5 / 1000000 * RMB, - "yi-34b-chat-200k": 12.0 / 1000000 * RMB, - "yi-vl-plus": 6.0 / 1000000 * RMB, + "yi-34b-chat-0205": 2.5 / 1000 * RMB, + "yi-34b-chat-200k": 12.0 / 1000 * RMB, + "yi-vl-plus": 6.0 / 1000 * RMB, + // stepfun todo + "step-1v-32k": 0.024 * RMB, + "step-1-32k": 0.024 * RMB, + "step-1-200k": 0.15 * RMB, + // https://cohere.com/pricing + "command": 0.5, + "command-nightly": 0.5, + "command-light": 0.5, + "command-light-nightly": 0.5, + "command-r": 0.5 / 1000 * USD, + "command-r-plus ": 3.0 / 1000 * USD, + // https://platform.deepseek.com/api-docs/pricing/ + "deepseek-chat": 1.0 / 1000 * RMB, + "deepseek-coder": 1.0 / 1000 * RMB, + // https://www.deepl.com/pro?cta=header-prices + "deepl-zh": 25.0 / 1000 * USD, + "deepl-en": 25.0 / 1000 * USD, + "deepl-ja": 25.0 / 1000 * USD, } var CompletionRatio = map[string]float64{} @@ -194,6 +236,9 @@ func GetModelRatio(name string) float64 { if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } + if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { + name = strings.TrimSuffix(name, "-internet") + } ratio, ok := ModelRatio[name] if strings.Index(name, "gpt-4-gizmo") != -1 { return ModelRatio["gpt-4-gizmo"] @@ -240,7 +285,7 @@ func GetCompletionRatio(name string) float64 { return 4.0 / 3.0 } if strings.HasPrefix(name, "gpt-4") { - if strings.HasSuffix(name, "preview") { + if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") { return 3 } return 2 @@ -254,9 +299,25 @@ func GetCompletionRatio(name string) float64 { if strings.HasPrefix(name, "mistral-") { return 3 } + if strings.HasPrefix(name, "gemini-") { + return 3 + } + if strings.HasPrefix(name, "deepseek-") { + return 2 + } switch name { case "llama2-70b-4096": - return 0.8 / 0.7 + return 0.8 / 0.64 + case "llama3-8b-8192": + return 2 + case "llama3-70b-8192": + return 0.79 / 0.59 + case "command", "command-light", "command-nightly", "command-light-nightly": + return 2 + case "command-r": + return 3 + case "command-r-plus": + return 5 } return 1 } diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go deleted file mode 100644 index 6a3245ad..00000000 --- a/relay/channel/ali/adaptor.go +++ /dev/null @@ -1,86 +0,0 @@ -package ali - -import ( - "errors" - "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" -) - -// https://help.aliyun.com/zh/dashscope/developer-reference/api-details - -type Adaptor struct { -} - -func (a *Adaptor) Init(meta *util.RelayMeta) { - -} - -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) - if meta.Mode == constant.RelayModeEmbeddings { - fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL) - } - return fullRequestURL, nil -} - -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) - if meta.IsStream { - req.Header.Set("Accept", "text/event-stream") - } - req.Header.Set("Authorization", "Bearer "+meta.APIKey) - if meta.IsStream { - req.Header.Set("X-DashScope-SSE", "enable") - } - if c.GetString(common.ConfigKeyPlugin) != "" { - req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin)) - } - return nil -} - -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") - } - switch relayMode { - case constant.RelayModeEmbeddings: - baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) - return baiduEmbeddingRequest, nil - default: - baiduRequest := ConvertRequest(*request) - return baiduRequest, nil - } -} - -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) -} - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - if meta.IsStream { - err, usage = StreamHandler(c, resp) - } else { - switch meta.Mode { - case constant.RelayModeEmbeddings: - err, usage = EmbeddingHandler(c, resp) - default: - err, usage = Handler(c, resp) - } - } - return -} - -func (a *Adaptor) GetModelList() []string { - return ModelList -} - -func (a *Adaptor) GetChannelName() string { - return "ali" -} diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go deleted file mode 100644 index 76e814d1..00000000 --- a/relay/channel/ali/model.go +++ /dev/null @@ -1,73 +0,0 @@ -package ali - -type Message struct { - Content string `json:"content"` - Role string `json:"role"` -} - -type Input struct { - //Prompt string `json:"prompt"` - Messages []Message `json:"messages"` -} - -type Parameters struct { - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Seed uint64 `json:"seed,omitempty"` - EnableSearch bool `json:"enable_search,omitempty"` - IncrementalOutput bool `json:"incremental_output,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` -} - -type ChatRequest struct { - Model string `json:"model"` - Input Input `json:"input"` - Parameters Parameters `json:"parameters,omitempty"` -} - -type EmbeddingRequest struct { - Model string `json:"model"` - Input struct { - Texts []string `json:"texts"` - } `json:"input"` - Parameters *struct { - TextType string `json:"text_type,omitempty"` - } `json:"parameters,omitempty"` -} - -type Embedding struct { - Embedding []float64 `json:"embedding"` - TextIndex int `json:"text_index"` -} - -type EmbeddingResponse struct { - Output struct { - Embeddings []Embedding `json:"embeddings"` - } `json:"output"` - Usage Usage `json:"usage"` - Error -} - -type Error struct { - Code string `json:"code"` - Message string `json:"message"` - RequestId string `json:"request_id"` -} - -type Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type Output struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` -} - -type ChatResponse struct { - Output Output `json:"output"` - Usage Usage `json:"usage"` - Error -} diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go deleted file mode 100644 index 45a4e901..00000000 --- a/relay/channel/baidu/constants.go +++ /dev/null @@ -1,13 +0,0 @@ -package baidu - -var ModelList = []string{ - "ERNIE-Bot-4", - "ERNIE-Bot-8K", - "ERNIE-Bot", - "ERNIE-Speed", - "ERNIE-Bot-turbo", - "Embedding-V1", - "bge-large-zh", - "bge-large-en", - "tao-8k", -} diff --git a/relay/channel/interface.go b/relay/channel/interface.go deleted file mode 100644 index e25db83f..00000000 --- a/relay/channel/interface.go +++ /dev/null @@ -1,20 +0,0 @@ -package channel - -import ( - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" -) - -type Adaptor interface { - Init(meta *util.RelayMeta) - GetRequestURL(meta *util.RelayMeta) (string, error) - SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error - ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) - DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) - DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) - GetModelList() []string - GetChannelName() string -} diff --git a/relay/channel/minimax/constants.go b/relay/channel/minimax/constants.go deleted file mode 100644 index c3da5b2d..00000000 --- a/relay/channel/minimax/constants.go +++ /dev/null @@ -1,7 +0,0 @@ -package minimax - -var ModelList = []string{ - "abab5.5s-chat", - "abab5.5-chat", - "abab6-chat", -} diff --git a/relay/channel/minimax/main.go b/relay/channel/minimax/main.go deleted file mode 100644 index a01821c2..00000000 --- a/relay/channel/minimax/main.go +++ /dev/null @@ -1,14 +0,0 @@ -package minimax - -import ( - "fmt" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/util" -) - -func GetRequestURL(meta *util.RelayMeta) (string, error) { - if meta.Mode == constant.RelayModeChatCompletions { - return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil - } - return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode) -} diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go deleted file mode 100644 index 06c66101..00000000 --- a/relay/channel/ollama/adaptor.go +++ /dev/null @@ -1,65 +0,0 @@ -package ollama - -import ( - "errors" - "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" -) - -type Adaptor struct { -} - -func (a *Adaptor) Init(meta *util.RelayMeta) { - -} - -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - // https://github.com/ollama/ollama/blob/main/docs/api.md - fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) - return fullRequestURL, nil -} - -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) - req.Header.Set("Authorization", "Bearer "+meta.APIKey) - return nil -} - -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") - } - switch relayMode { - case constant.RelayModeEmbeddings: - return nil, errors.New("not supported") - default: - return ConvertRequest(*request), nil - } -} - -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) -} - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - if meta.IsStream { - err, usage = StreamHandler(c, resp) - } else { - err, usage = Handler(c, resp) - } - return -} - -func (a *Adaptor) GetModelList() []string { - return ModelList -} - -func (a *Adaptor) GetChannelName() string { - return "ollama" -} diff --git a/relay/channel/ollama/constants.go b/relay/channel/ollama/constants.go deleted file mode 100644 index 32f82b2a..00000000 --- a/relay/channel/ollama/constants.go +++ /dev/null @@ -1,5 +0,0 @@ -package ollama - -var ModelList = []string{ - "qwen:0.5b-chat", -} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go deleted file mode 100644 index 47594030..00000000 --- a/relay/channel/openai/adaptor.go +++ /dev/null @@ -1,92 +0,0 @@ -package openai - -import ( - "errors" - "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/minimax" - "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" - "strings" -) - -type Adaptor struct { - ChannelType int -} - -func (a *Adaptor) Init(meta *util.RelayMeta) { - a.ChannelType = meta.ChannelType -} - -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - switch meta.ChannelType { - case common.ChannelTypeAzure: - // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api - requestURL := strings.Split(meta.RequestURLPath, "?")[0] - requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) - task := strings.TrimPrefix(requestURL, "/v1/") - model_ := meta.ActualModelName - model_ = strings.Replace(model_, ".", "", -1) - // https://github.com/songquanpeng/one-api/issues/67 - model_ = strings.TrimSuffix(model_, "-0301") - model_ = strings.TrimSuffix(model_, "-0314") - model_ = strings.TrimSuffix(model_, "-0613") - - requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) - return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil - case common.ChannelTypeMinimax: - return minimax.GetRequestURL(meta) - default: - return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil - } -} - -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) - if meta.ChannelType == common.ChannelTypeAzure { - req.Header.Set("api-key", meta.APIKey) - return nil - } - req.Header.Set("Authorization", "Bearer "+meta.APIKey) - if meta.ChannelType == common.ChannelTypeOpenRouter { - req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") - req.Header.Set("X-Title", "One API") - } - return nil -} - -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") - } - return request, nil -} - -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) -} - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - if meta.IsStream { - var responseText string - err, responseText, _ = StreamHandler(c, resp, meta.Mode) - usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) - } else { - err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) - } - return -} - -func (a *Adaptor) GetModelList() []string { - _, modelList := GetCompatibleChannelMeta(a.ChannelType) - return modelList -} - -func (a *Adaptor) GetChannelName() string { - channelName, _ := GetCompatibleChannelMeta(a.ChannelType) - return channelName -} diff --git a/relay/channel/openai/compatible.go b/relay/channel/openai/compatible.go deleted file mode 100644 index e4951a34..00000000 --- a/relay/channel/openai/compatible.go +++ /dev/null @@ -1,46 +0,0 @@ -package openai - -import ( - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/channel/ai360" - "github.com/songquanpeng/one-api/relay/channel/baichuan" - "github.com/songquanpeng/one-api/relay/channel/groq" - "github.com/songquanpeng/one-api/relay/channel/lingyiwanwu" - "github.com/songquanpeng/one-api/relay/channel/minimax" - "github.com/songquanpeng/one-api/relay/channel/mistral" - "github.com/songquanpeng/one-api/relay/channel/moonshot" -) - -var CompatibleChannels = []int{ - common.ChannelTypeAzure, - common.ChannelType360, - common.ChannelTypeMoonshot, - common.ChannelTypeBaichuan, - common.ChannelTypeMinimax, - common.ChannelTypeMistral, - common.ChannelTypeGroq, - common.ChannelTypeLingYiWanWu, -} - -func GetCompatibleChannelMeta(channelType int) (string, []string) { - switch channelType { - case common.ChannelTypeAzure: - return "azure", ModelList - case common.ChannelType360: - return "360", ai360.ModelList - case common.ChannelTypeMoonshot: - return "moonshot", moonshot.ModelList - case common.ChannelTypeBaichuan: - return "baichuan", baichuan.ModelList - case common.ChannelTypeMinimax: - return "minimax", minimax.ModelList - case common.ChannelTypeMistral: - return "mistralai", mistral.ModelList - case common.ChannelTypeGroq: - return "groq", groq.ModelList - case common.ChannelTypeLingYiWanWu: - return "lingyiwanwu", lingyiwanwu.ModelList - default: - return "openai", ModelList - } -} diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go deleted file mode 100644 index 9bca8cab..00000000 --- a/relay/channel/openai/helper.go +++ /dev/null @@ -1,11 +0,0 @@ -package openai - -import "github.com/songquanpeng/one-api/relay/model" - -func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { - usage := &model.Usage{} - usage.PromptTokens = promptTokens - usage.CompletionTokens = CountTokenText(responseText, modeName) - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - return usage -} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go deleted file mode 100644 index 0ca23d59..00000000 --- a/relay/channel/zhipu/adaptor.go +++ /dev/null @@ -1,101 +0,0 @@ -package zhipu - -import ( - "errors" - "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" - "io" - "math" - "net/http" - "strings" -) - -type Adaptor struct { - APIVersion string -} - -func (a *Adaptor) Init(meta *util.RelayMeta) { - -} - -func (a *Adaptor) SetVersionByModeName(modelName string) { - if strings.HasPrefix(modelName, "glm-") { - a.APIVersion = "v4" - } else { - a.APIVersion = "v3" - } -} - -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - a.SetVersionByModeName(meta.ActualModelName) - if a.APIVersion == "v4" { - return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil - } - method := "invoke" - if meta.IsStream { - method = "sse-invoke" - } - return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil -} - -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) - token := GetToken(meta.APIKey) - req.Header.Set("Authorization", token) - return nil -} - -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") - } - // TopP (0.0, 1.0) - request.TopP = math.Min(0.99, request.TopP) - request.TopP = math.Max(0.01, request.TopP) - - // Temperature (0.0, 1.0) - request.Temperature = math.Min(0.99, request.Temperature) - request.Temperature = math.Max(0.01, request.Temperature) - a.SetVersionByModeName(request.Model) - if a.APIVersion == "v4" { - return request, nil - } - return ConvertRequest(*request), nil -} - -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) -} - -func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - if meta.IsStream { - err, _, usage = openai.StreamHandler(c, resp, meta.Mode) - } else { - err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName) - } - return -} - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - if a.APIVersion == "v4" { - return a.DoResponseV4(c, resp, meta) - } - if meta.IsStream { - err, usage = StreamHandler(c, resp) - } else { - err, usage = Handler(c, resp) - } - return -} - -func (a *Adaptor) GetModelList() []string { - return ModelList -} - -func (a *Adaptor) GetChannelName() string { - return "zhipu" -} diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go new file mode 100644 index 00000000..60964565 --- /dev/null +++ b/relay/channeltype/define.go @@ -0,0 +1,46 @@ +package channeltype + +const ( + Unknown = iota + OpenAI + API2D + Azure + CloseAI + OpenAISB + OpenAIMax + OhMyGPT + Custom + Ails + AIProxy + PaLM + API2GPT + AIGC2D + Anthropic + Baidu + Zhipu + Ali + Xunfei + AI360 + OpenRouter + AIProxyLibrary + FastGPT + Tencent + Gemini + Moonshot + Baichuan + Minimax + Mistral + Groq + Ollama + LingYiWanWu + StepFun + AwsClaude + Coze + Cohere + DeepSeek + Cloudflare + DeepL + TogetherAI + + Dummy +) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go new file mode 100644 index 00000000..1bb71402 --- /dev/null +++ b/relay/channeltype/helper.go @@ -0,0 +1,41 @@ +package channeltype + +import "github.com/songquanpeng/one-api/relay/apitype" + +func ToAPIType(channelType int) int { + apiType := apitype.OpenAI + switch channelType { + case Anthropic: + apiType = apitype.Anthropic + case Baidu: + apiType = apitype.Baidu + case PaLM: + apiType = apitype.PaLM + case Zhipu: + apiType = apitype.Zhipu + case Ali: + apiType = apitype.Ali + case Xunfei: + apiType = apitype.Xunfei + case AIProxyLibrary: + apiType = apitype.AIProxyLibrary + case Tencent: + apiType = apitype.Tencent + case Gemini: + apiType = apitype.Gemini + case Ollama: + apiType = apitype.Ollama + case AwsClaude: + apiType = apitype.AwsClaude + case Coze: + apiType = apitype.Coze + case Cohere: + apiType = apitype.Cohere + case Cloudflare: + apiType = apitype.Cloudflare + case DeepL: + apiType = apitype.DeepL + } + + return apiType +} diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go new file mode 100644 index 00000000..f5767f47 --- /dev/null +++ b/relay/channeltype/url.go @@ -0,0 +1,50 @@ +package channeltype + +var ChannelBaseURLs = []string{ + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "https://generativelanguage.googleapis.com", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 + "https://hunyuan.cloud.tencent.com", // 23 + "https://generativelanguage.googleapis.com", // 24 + "https://api.moonshot.cn", // 25 + "https://api.baichuan-ai.com", // 26 + "https://api.minimax.chat", // 27 + "https://api.mistral.ai", // 28 + "https://api.groq.com/openai", // 29 + "http://localhost:11434", // 30 + "https://api.lingyiwanwu.com", // 31 + "https://api.stepfun.com", // 32 + "", // 33 + "https://api.coze.com", // 34 + "https://api.cohere.ai", // 35 + "https://api.deepseek.com", // 36 + "https://api.cloudflare.com", // 37 + "https://api-free.deepl.com", // 38 + "https://api.together.xyz", // 39 +} + +func init() { + if len(ChannelBaseURLs) != Dummy { + panic("channel base urls length not match") + } +} diff --git a/relay/util/init.go b/relay/client/init.go similarity index 96% rename from relay/util/init.go rename to relay/client/init.go index 03dad31b..4b59cba7 100644 --- a/relay/util/init.go +++ b/relay/client/init.go @@ -1,4 +1,4 @@ -package util +package client import ( "github.com/songquanpeng/one-api/common/config" diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go deleted file mode 100644 index b249f6a2..00000000 --- a/relay/constant/api_type.go +++ /dev/null @@ -1,48 +0,0 @@ -package constant - -import ( - "github.com/songquanpeng/one-api/common" -) - -const ( - APITypeOpenAI = iota - APITypeAnthropic - APITypePaLM - APITypeBaidu - APITypeZhipu - APITypeAli - APITypeXunfei - APITypeAIProxyLibrary - APITypeTencent - APITypeGemini - APITypeOllama - - APITypeDummy // this one is only for count, do not add any channel after this -) - -func ChannelType2APIType(channelType int) int { - apiType := APITypeOpenAI - switch channelType { - case common.ChannelTypeAnthropic: - apiType = APITypeAnthropic - case common.ChannelTypeBaidu: - apiType = APITypeBaidu - case common.ChannelTypePaLM: - apiType = APITypePaLM - case common.ChannelTypeZhipu: - apiType = APITypeZhipu - case common.ChannelTypeAli: - apiType = APITypeAli - case common.ChannelTypeXunfei: - apiType = APITypeXunfei - case common.ChannelTypeAIProxyLibrary: - apiType = APITypeAIProxyLibrary - case common.ChannelTypeTencent: - apiType = APITypeTencent - case common.ChannelTypeGemini: - apiType = APITypeGemini - case common.ChannelTypeOllama: - apiType = APITypeOllama - } - return apiType -} diff --git a/relay/constant/common.go b/relay/constant/common.go index b6606cc6..f31477ca 100644 --- a/relay/constant/common.go +++ b/relay/constant/common.go @@ -1,3 +1,5 @@ package constant var StopFinishReason = "stop" +var StreamObject = "chat.completion.chunk" +var NonStreamObject = "chat.completion" diff --git a/relay/constant/finishreason/define.go b/relay/constant/finishreason/define.go new file mode 100644 index 00000000..1ed9c425 --- /dev/null +++ b/relay/constant/finishreason/define.go @@ -0,0 +1,5 @@ +package finishreason + +const ( + Stop = "stop" +) diff --git a/relay/constant/image.go b/relay/constant/image.go deleted file mode 100644 index 5e04895f..00000000 --- a/relay/constant/image.go +++ /dev/null @@ -1,24 +0,0 @@ -package constant - -var DalleSizeRatios = map[string]map[string]float64{ - "dall-e-2": { - "256x256": 1, - "512x512": 1.125, - "1024x1024": 1.25, - }, - "dall-e-3": { - "1024x1024": 1, - "1024x1792": 2, - "1792x1024": 2, - }, -} - -var DalleGenerationImageAmounts = map[string][2]int{ - "dall-e-2": {1, 10}, - "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. -} - -var DalleImagePromptLengthLimitations = map[string]int{ - "dall-e-2": 1000, - "dall-e-3": 4000, -} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go deleted file mode 100644 index 5e2fe574..00000000 --- a/relay/constant/relay_mode.go +++ /dev/null @@ -1,42 +0,0 @@ -package constant - -import "strings" - -const ( - RelayModeUnknown = iota - RelayModeChatCompletions - RelayModeCompletions - RelayModeEmbeddings - RelayModeModerations - RelayModeImagesGenerations - RelayModeEdits - RelayModeAudioSpeech - RelayModeAudioTranscription - RelayModeAudioTranslation -) - -func Path2RelayMode(path string) int { - relayMode := RelayModeUnknown - if strings.HasPrefix(path, "/v1/chat/completions") { - relayMode = RelayModeChatCompletions - } else if strings.HasPrefix(path, "/v1/completions") { - relayMode = RelayModeCompletions - } else if strings.HasPrefix(path, "/v1/embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasSuffix(path, "embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasPrefix(path, "/v1/moderations") { - relayMode = RelayModeModerations - } else if strings.HasPrefix(path, "/v1/images/generations") { - relayMode = RelayModeImagesGenerations - } else if strings.HasPrefix(path, "/v1/edits") { - relayMode = RelayModeEdits - } else if strings.HasPrefix(path, "/v1/audio/speech") { - relayMode = RelayModeAudioSpeech - } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { - relayMode = RelayModeAudioTranscription - } else if strings.HasPrefix(path, "/v1/audio/translations") { - relayMode = RelayModeAudioTranslation - } - return relayMode -} diff --git a/relay/constant/role/define.go b/relay/constant/role/define.go new file mode 100644 index 00000000..972488c5 --- /dev/null +++ b/relay/constant/role/define.go @@ -0,0 +1,5 @@ +package role + +const ( + Assistant = "assistant" +) diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 155954d2..15e74290 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -10,12 +10,17 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/client" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "strings" @@ -23,17 +28,18 @@ import ( func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() + meta := meta.GetByContext(c) audioModel := "whisper-1" - tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - userId := c.GetInt("id") - group := c.GetString("group") - tokenName := c.GetString("token_name") + tokenId := c.GetInt(ctxkey.TokenId) + channelType := c.GetInt(ctxkey.Channel) + channelId := c.GetInt(ctxkey.ChannelId) + userId := c.GetInt(ctxkey.Id) + group := c.GetString(ctxkey.Group) + tokenName := c.GetString(ctxkey.TokenName) var ttsRequest openai.TextToSpeechRequest - if relayMode == constant.RelayModeAudioSpeech { + if relayMode == relaymode.AudioSpeech { // Read JSON err := common.UnmarshalBodyReusable(c, &ttsRequest) // Check if JSON is valid @@ -47,13 +53,13 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } } - modelRatio := common.GetModelRatio(audioModel) - groupRatio := common.GetGroupRatio(group) + modelRatio := billingratio.GetModelRatio(audioModel) + groupRatio := billingratio.GetGroupRatio(group) ratio := modelRatio * groupRatio var quota int64 var preConsumedQuota int64 switch relayMode { - case constant.RelayModeAudioSpeech: + case relaymode.AudioSpeech: preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota default: @@ -83,9 +89,27 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } + succeed := false + defer func() { + if succeed { + return + } + if preConsumedQuota > 0 { + // we need to roll back the pre-consumed quota + defer func(ctx context.Context) { + go func() { + // negative means add quota back for token & user + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) + } + }() + }(c.Request.Context()) + } + }() // map model name - modelMapping := c.GetString("model_mapping") + modelMapping := c.GetString(ctxkey.ModelMapping) if modelMapping != "" { modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) @@ -97,17 +121,22 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } } - baseURL := common.ChannelBaseURLs[channelType] + baseURL := channeltype.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") + if c.GetString(ctxkey.BaseURL) != "" { + baseURL = c.GetString(ctxkey.BaseURL) } - fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) - if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiVersion := util.GetAzureAPIVersion(c) - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) + fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType) + if channelType == channeltype.Azure { + apiVersion := meta.Config.APIVersion + if relayMode == relaymode.AudioTranscription { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) + } else if relayMode == relaymode.AudioSpeech { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion) + } } requestBody := &bytes.Buffer{} @@ -123,7 +152,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") @@ -135,7 +164,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - resp, err := util.HTTPClient.Do(req) + resp, err := client.HTTPClient.Do(req) if err != nil { return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } @@ -149,7 +178,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode != constant.RelayModeAudioSpeech { + if relayMode != relaymode.AudioSpeech { responseBody, err := io.ReadAll(resp.Body) if err != nil { return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -188,23 +217,12 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { - if preConsumedQuota > 0 { - // we need to roll back the pre-consumed quota - defer func(ctx context.Context) { - go func() { - // negative means add quota back for token & user - err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) - if err != nil { - logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) - } - }() - }(c.Request.Context()) - } - return util.RelayErrorHandler(resp) + return RelayErrorHandler(resp) } + succeed = true quotaDelta := quota - preConsumedQuota defer func(ctx context.Context) { - go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) }(c.Request.Context()) for k, v := range resp.Header { diff --git a/relay/controller/error.go b/relay/controller/error.go new file mode 100644 index 00000000..29d4f125 --- /dev/null +++ b/relay/controller/error.go @@ -0,0 +1,101 @@ +package controller + +import ( + "encoding/json" + "fmt" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strconv" +) + +type GeneralErrorResponse struct { + Error model.Error `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` + Header struct { + Message string `json:"message"` + } `json:"header"` + Response struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } `json:"response"` +} + +func (e GeneralErrorResponse) ToMessage() string { + if e.Error.Message != "" { + return e.Error.Message + } + if e.Message != "" { + return e.Message + } + if e.Msg != "" { + return e.Msg + } + if e.Err != "" { + return e.Err + } + if e.ErrorMsg != "" { + return e.ErrorMsg + } + if e.Header.Message != "" { + return e.Header.Message + } + if e.Response.Error.Message != "" { + return e.Response.Error.Message + } + return "" +} + +func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *model.ErrorWithStatusCode) { + if resp == nil { + return &model.ErrorWithStatusCode{ + StatusCode: 500, + Error: model.Error{ + Message: "resp is nil", + Type: "upstream_error", + Code: "bad_response", + }, + } + } + ErrorWithStatusCode = &model.ErrorWithStatusCode{ + StatusCode: resp.StatusCode, + Error: model.Error{ + Message: "", + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + if config.DebugEnabled { + logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody))) + } + err = resp.Body.Close() + if err != nil { + return + } + var errResponse GeneralErrorResponse + err = json.Unmarshal(responseBody, &errResponse) + if err != nil { + return + } + if errResponse.Error.Message != "" { + // OpenAI format error, so we override the default one + ErrorWithStatusCode.Error = errResponse.Error + } else { + ErrorWithStatusCode.Error.Message = errResponse.ToMessage() + } + if ErrorWithStatusCode.Error.Message == "" { + ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } + return +} diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 600a8d65..dccff486 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -9,12 +9,16 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/controller/validator" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/relaymode" "math" "net/http" + "strings" ) func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) { @@ -23,21 +27,21 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener if err != nil { return nil, err } - if relayMode == constant.RelayModeModerations && textRequest.Model == "" { + if relayMode == relaymode.Moderations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } - if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" { + if relayMode == relaymode.Embeddings && textRequest.Model == "" { textRequest.Model = c.Param("model") } - err = util.ValidateTextRequest(textRequest, relayMode) + err = validator.ValidateTextRequest(textRequest, relayMode) if err != nil { return nil, err } return textRequest, nil } -func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) { - imageRequest := &openai.ImageRequest{} +func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { + imageRequest := &relaymodel.ImageRequest{} err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { return nil, err @@ -54,9 +58,25 @@ func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error return imageRequest, nil } -func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode { +func isValidImageSize(model string, size string) bool { + if model == "cogview-3" { + return true + } + _, ok := billingratio.ImageSizeRatios[model][size] + return ok +} + +func getImageSizeRatio(model string, size string) float64 { + ratio, ok := billingratio.ImageSizeRatios[model][size] + if !ok { + return 1 + } + return ratio +} + +func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { // model validation - _, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] + hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size) if !hasValidSize { return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) } @@ -64,27 +84,24 @@ func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMet if imageRequest.Prompt == "" { return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) } - if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] { + if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] { return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) } // Number of generated images validation if !isWithinRange(imageRequest.Model, imageRequest.N) { // channel not azure - if meta.ChannelType != common.ChannelTypeAzure { + if meta.ChannelType != channeltype.Azure { return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) } } return nil } -func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) { +func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { if imageRequest == nil { return 0, errors.New("imageRequest is nil") } - imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] - if !hasValidSize { - return 0, fmt.Errorf("size not supported for this image model: %s", imageRequest.Size) - } + imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { if imageRequest.Size == "1024x1024" { imageCostRatio *= 2 @@ -97,25 +114,25 @@ func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) { func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { - case constant.RelayModeChatCompletions: + case relaymode.ChatCompletions: return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) - case constant.RelayModeCompletions: + case relaymode.Completions: return openai.CountTokenInput(textRequest.Prompt, textRequest.Model) - case constant.RelayModeModerations: + case relaymode.Moderations: return openai.CountTokenInput(textRequest.Input, textRequest.Model) } return 0 } func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 { - preConsumedTokens := config.PreConsumedQuota + preConsumedTokens := config.PreConsumedQuota + int64(promptTokens) if textRequest.MaxTokens != 0 { - preConsumedTokens = int64(promptTokens) + int64(textRequest.MaxTokens) + preConsumedTokens += int64(textRequest.MaxTokens) } return int64(float64(preConsumedTokens) * ratio) } -func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) { +func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *meta.Meta) (int64, *relaymodel.ErrorWithStatusCode) { preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) @@ -144,13 +161,13 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR return preConsumedQuota, nil } -func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return } var quota int64 - completionRatio := common.GetCompletionRatio(textRequest.Model) + completionRatio := billingratio.GetCompletionRatio(textRequest.Model) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) @@ -177,3 +194,34 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota) } + +func getMappedModelName(modelName string, mapping map[string]string) (string, bool) { + if mapping == nil { + return modelName, false + } + mappedModelName := mapping[modelName] + if mappedModelName != "" { + return mappedModelName, true + } + return modelName, false +} + +func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { + if resp == nil { + if meta.ChannelType == channeltype.AwsClaude { + return false + } + return true + } + if resp.StatusCode != http.StatusOK { + return true + } + if meta.ChannelType == channeltype.DeepL { + // skip stream check for deepl + return false + } + if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + return true + } + return false +} diff --git a/relay/controller/image.go b/relay/controller/image.go index 20ea0a4c..6620bef5 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -6,33 +6,32 @@ import ( "encoding/json" "errors" "fmt" - "github.com/songquanpeng/one-api/common" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" - "strings" - - "github.com/gin-gonic/gin" ) func isWithinRange(element string, value int) bool { - if _, ok := constant.DalleGenerationImageAmounts[element]; !ok { + if _, ok := billingratio.ImageGenerationAmounts[element]; !ok { return false } - min := constant.DalleGenerationImageAmounts[element][0] - max := constant.DalleGenerationImageAmounts[element][1] - + min := billingratio.ImageGenerationAmounts[element][0] + max := billingratio.ImageGenerationAmounts[element][1] return value >= min && value <= max } func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() - meta := util.GetRelayMeta(c) + meta := meta.GetByContext(c) imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) @@ -42,7 +41,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus // map model name var isModelMapped bool meta.OriginModelName = imageRequest.Model - imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping) + imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping) meta.ActualModelName = imageRequest.Model // model validation @@ -56,17 +55,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) } - requestURL := c.Request.URL.String() - fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) - if meta.ChannelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api - apiVersion := util.GetAzureAPIVersion(c) - // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion) - } - var requestBody io.Reader - if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body + if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) @@ -76,8 +66,31 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus requestBody = c.Request.Body } - modelRatio := common.GetModelRatio(imageRequest.Model) - groupRatio := common.GetGroupRatio(meta.Group) + adaptor := relay.GetAdaptor(meta.APIType) + if adaptor == nil { + return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(meta) + + switch meta.ChannelType { + case channeltype.Ali: + fallthrough + case channeltype.Baidu: + fallthrough + case channeltype.Zhipu: + finalRequest, err := adaptor.ConvertImageRequest(imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) + } + jsonStr, err := json.Marshal(finalRequest) + if err != nil { + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + } + + modelRatio := billingratio.GetModelRatio(imageRequest.Model) + groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) @@ -87,40 +100,18 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - if err != nil { - return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - token := c.Request.Header.Get("Authorization") - if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication - token = strings.TrimPrefix(token, "Bearer ") - req.Header.Set("api-key", token) - } else { - req.Header.Set("Authorization", token) - } - - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - - resp, err := util.HTTPClient.Do(req) + // do request + resp, err := adaptor.DoRequest(c, meta, requestBody) if err != nil { + logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - err = req.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - var imageResponse openai.ImageResponse - defer func(ctx context.Context) { - if resp.StatusCode != http.StatusOK { + if resp != nil && resp.StatusCode != http.StatusOK { return } + err := model.PostConsumeTokenQuota(meta.TokenId, quota) if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) @@ -130,43 +121,21 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus logger.SysError("error update user quota cache: " + err.Error()) } if quota != 0 { - tokenName := c.GetString("token_name") + tokenName := c.GetString(ctxkey.TokenName) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) - channelId := c.GetInt("channel_id") + channelId := c.GetInt(ctxkey.ChannelId) model.UpdateChannelUsedQuota(channelId, quota) } }(c.Request.Context()) - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &imageResponse) - if err != nil { - return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + // do response + _, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + return respErr } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } return nil } diff --git a/relay/controller/text.go b/relay/controller/text.go index ba008713..6ed19b1d 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -5,21 +5,22 @@ import ( "encoding/json" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/helper" + "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/apitype" + "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" - "strings" ) func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { ctx := c.Request.Context() - meta := util.GetRelayMeta(c) + meta := meta.GetByContext(c) // get & validate textRequest textRequest, err := getAndValidateTextRequest(c, meta.Mode) if err != nil { @@ -31,11 +32,11 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { // map model name var isModelMapped bool meta.OriginModelName = textRequest.Model - textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping) + textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model // get model ratio & group ratio - modelRatio := common.GetModelRatio(textRequest.Model) - groupRatio := common.GetGroupRatio(meta.Group) + modelRatio := billingratio.GetModelRatio(textRequest.Model) + groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio // pre-consume quota promptTokens := getPromptTokens(textRequest, meta.Mode) @@ -46,16 +47,17 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { return bizErr } - adaptor := helper.GetAdaptor(meta.APIType) + adaptor := relay.GetAdaptor(meta.APIType) if adaptor == nil { return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) } + adaptor.Init(meta) // get request body var requestBody io.Reader - if meta.APIType == constant.APITypeOpenAI { + if meta.APIType == apitype.OpenAI { // no need to convert request for openai - shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan + shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan if shouldResetRequestBody { jsonStr, err := json.Marshal(textRequest) if err != nil { @@ -84,18 +86,16 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json") - if errorHappened { - util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) - return util.RelayErrorHandler(resp) + if isErrorHappened(meta, resp) { + billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + return RelayErrorHandler(resp) } - meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") // do response usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { logger.Errorf(ctx, "respErr is not nil: %+v", respErr) - util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) return respErr } // post-consume quota diff --git a/relay/util/validation.go b/relay/controller/validator/validation.go similarity index 76% rename from relay/util/validation.go rename to relay/controller/validator/validation.go index ef8d840c..8ff520b8 100644 --- a/relay/util/validation.go +++ b/relay/controller/validator/validation.go @@ -1,9 +1,9 @@ -package util +package validator import ( "errors" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "math" ) @@ -15,20 +15,20 @@ func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) return errors.New("model is required") } switch relayMode { - case constant.RelayModeCompletions: + case relaymode.Completions: if textRequest.Prompt == "" { return errors.New("field prompt is required") } - case constant.RelayModeChatCompletions: + case relaymode.ChatCompletions: if textRequest.Messages == nil || len(textRequest.Messages) == 0 { return errors.New("field messages is required") } - case constant.RelayModeEmbeddings: - case constant.RelayModeModerations: + case relaymode.Embeddings: + case relaymode.Moderations: if textRequest.Input == "" { return errors.New("field input is required") } - case constant.RelayModeEdits: + case relaymode.Edits: if textRequest.Instruction == "" { return errors.New("field instruction is required") } diff --git a/relay/helper/main.go b/relay/helper/main.go deleted file mode 100644 index e7342329..00000000 --- a/relay/helper/main.go +++ /dev/null @@ -1,45 +0,0 @@ -package helper - -import ( - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/aiproxy" - "github.com/songquanpeng/one-api/relay/channel/ali" - "github.com/songquanpeng/one-api/relay/channel/anthropic" - "github.com/songquanpeng/one-api/relay/channel/baidu" - "github.com/songquanpeng/one-api/relay/channel/gemini" - "github.com/songquanpeng/one-api/relay/channel/ollama" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/channel/palm" - "github.com/songquanpeng/one-api/relay/channel/tencent" - "github.com/songquanpeng/one-api/relay/channel/xunfei" - "github.com/songquanpeng/one-api/relay/channel/zhipu" - "github.com/songquanpeng/one-api/relay/constant" -) - -func GetAdaptor(apiType int) channel.Adaptor { - switch apiType { - case constant.APITypeAIProxyLibrary: - return &aiproxy.Adaptor{} - case constant.APITypeAli: - return &ali.Adaptor{} - case constant.APITypeAnthropic: - return &anthropic.Adaptor{} - case constant.APITypeBaidu: - return &baidu.Adaptor{} - case constant.APITypeGemini: - return &gemini.Adaptor{} - case constant.APITypeOpenAI: - return &openai.Adaptor{} - case constant.APITypePaLM: - return &palm.Adaptor{} - case constant.APITypeTencent: - return &tencent.Adaptor{} - case constant.APITypeXunfei: - return &xunfei.Adaptor{} - case constant.APITypeZhipu: - return &zhipu.Adaptor{} - case constant.APITypeOllama: - return &ollama.Adaptor{} - } - return nil -} diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go new file mode 100644 index 00000000..9714ebb5 --- /dev/null +++ b/relay/meta/relay_meta.go @@ -0,0 +1,56 @@ +package meta + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/relaymode" + "strings" +) + +type Meta struct { + Mode int + ChannelType int + ChannelId int + TokenId int + TokenName string + UserId int + Group string + ModelMapping map[string]string + BaseURL string + APIKey string + APIType int + Config model.ChannelConfig + IsStream bool + OriginModelName string + ActualModelName string + RequestURLPath string + PromptTokens int // only for DoResponse +} + +func GetByContext(c *gin.Context) *Meta { + meta := Meta{ + Mode: relaymode.GetByPath(c.Request.URL.Path), + ChannelType: c.GetInt(ctxkey.Channel), + ChannelId: c.GetInt(ctxkey.ChannelId), + TokenId: c.GetInt(ctxkey.TokenId), + TokenName: c.GetString(ctxkey.TokenName), + UserId: c.GetInt(ctxkey.Id), + Group: c.GetString(ctxkey.Group), + ModelMapping: c.GetStringMapString(ctxkey.ModelMapping), + OriginModelName: c.GetString(ctxkey.RequestModel), + BaseURL: c.GetString(ctxkey.BaseURL), + APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + RequestURLPath: c.Request.URL.String(), + } + cfg, ok := c.Get(ctxkey.Config) + if ok { + meta.Config = cfg.(model.ChannelConfig) + } + if meta.BaseURL == "" { + meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType] + } + meta.APIType = channeltype.ToAPIType(meta.ChannelType) + return &meta +} diff --git a/relay/model/general.go b/relay/model/general.go index fbcc04e8..30772894 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -5,25 +5,29 @@ type ResponseFormat struct { } type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` + Model string `json:"model,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"` Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` + FunctionCall any `json:"function_call,omitempty"` + Functions any `json:"functions,omitempty"` User string `json:"user,omitempty"` + Prompt any `json:"prompt,omitempty"` + Input any `json:"input,omitempty"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string { diff --git a/relay/model/image.go b/relay/model/image.go new file mode 100644 index 00000000..bab84256 --- /dev/null +++ b/relay/model/image.go @@ -0,0 +1,12 @@ +package model + +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + User string `json:"user,omitempty"` +} diff --git a/relay/model/message.go b/relay/model/message.go index c6c8a271..32a1055b 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -1,9 +1,10 @@ package model type Message struct { - Role string `json:"role"` - Content any `json:"content"` - Name *string `json:"name,omitempty"` + Role string `json:"role,omitempty"` + Content any `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []Tool `json:"tool_calls,omitempty"` } func (m Message) IsStringContent() bool { diff --git a/relay/model/tool.go b/relay/model/tool.go new file mode 100644 index 00000000..253dca35 --- /dev/null +++ b/relay/model/tool.go @@ -0,0 +1,14 @@ +package model + +type Tool struct { + Id string `json:"id,omitempty"` + Type string `json:"type"` + Function Function `json:"function"` +} + +type Function struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Parameters any `json:"parameters,omitempty"` // request + Arguments any `json:"arguments,omitempty"` // response +} diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go new file mode 100644 index 00000000..96d09438 --- /dev/null +++ b/relay/relaymode/define.go @@ -0,0 +1,14 @@ +package relaymode + +const ( + Unknown = iota + ChatCompletions + Completions + Embeddings + Moderations + ImagesGenerations + Edits + AudioSpeech + AudioTranscription + AudioTranslation +) diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go new file mode 100644 index 00000000..926dd42e --- /dev/null +++ b/relay/relaymode/helper.go @@ -0,0 +1,29 @@ +package relaymode + +import "strings" + +func GetByPath(path string) int { + relayMode := Unknown + if strings.HasPrefix(path, "/v1/chat/completions") { + relayMode = ChatCompletions + } else if strings.HasPrefix(path, "/v1/completions") { + relayMode = Completions + } else if strings.HasPrefix(path, "/v1/embeddings") { + relayMode = Embeddings + } else if strings.HasSuffix(path, "embeddings") { + relayMode = Embeddings + } else if strings.HasPrefix(path, "/v1/moderations") { + relayMode = Moderations + } else if strings.HasPrefix(path, "/v1/images/generations") { + relayMode = ImagesGenerations + } else if strings.HasPrefix(path, "/v1/edits") { + relayMode = Edits + } else if strings.HasPrefix(path, "/v1/audio/speech") { + relayMode = AudioSpeech + } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { + relayMode = AudioTranscription + } else if strings.HasPrefix(path, "/v1/audio/translations") { + relayMode = AudioTranslation + } + return relayMode +} diff --git a/relay/util/billing.go b/relay/util/billing.go deleted file mode 100644 index 495d011e..00000000 --- a/relay/util/billing.go +++ /dev/null @@ -1,19 +0,0 @@ -package util - -import ( - "context" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/model" -) - -func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { - if preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) - if err != nil { - logger.Error(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(ctx) - } -} diff --git a/relay/util/common.go b/relay/util/common.go deleted file mode 100644 index 535ef680..00000000 --- a/relay/util/common.go +++ /dev/null @@ -1,187 +0,0 @@ -package util - -import ( - "context" - "encoding/json" - "fmt" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/model" - relaymodel "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strconv" - "strings" - - "github.com/gin-gonic/gin" -) - -func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { - if !config.AutomaticDisableChannelEnabled { - return false - } - if err == nil { - return false - } - if statusCode == http.StatusUnauthorized { - return true - } - switch err.Type { - case "insufficient_quota": - return true - // https://docs.anthropic.com/claude/reference/errors - case "authentication_error": - return true - case "permission_error": - return true - case "forbidden": - return true - } - if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { - return true - } - if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic - return true - } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { - return true - } - return false -} - -func ShouldEnableChannel(err error, openAIErr *relaymodel.Error) bool { - if !config.AutomaticEnableChannelEnabled { - return false - } - if err != nil { - return false - } - if openAIErr != nil { - return false - } - return true -} - -type GeneralErrorResponse struct { - Error relaymodel.Error `json:"error"` - Message string `json:"message"` - Msg string `json:"msg"` - Err string `json:"err"` - ErrorMsg string `json:"error_msg"` - Header struct { - Message string `json:"message"` - } `json:"header"` - Response struct { - Error struct { - Message string `json:"message"` - } `json:"error"` - } `json:"response"` -} - -func (e GeneralErrorResponse) ToMessage() string { - if e.Error.Message != "" { - return e.Error.Message - } - if e.Message != "" { - return e.Message - } - if e.Msg != "" { - return e.Msg - } - if e.Err != "" { - return e.Err - } - if e.ErrorMsg != "" { - return e.ErrorMsg - } - if e.Header.Message != "" { - return e.Header.Message - } - if e.Response.Error.Message != "" { - return e.Response.Error.Message - } - return "" -} - -func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.ErrorWithStatusCode) { - ErrorWithStatusCode = &relaymodel.ErrorWithStatusCode{ - StatusCode: resp.StatusCode, - Error: relaymodel.Error{ - Message: "", - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, - } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return - } - if config.DebugEnabled { - logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody))) - } - err = resp.Body.Close() - if err != nil { - return - } - var errResponse GeneralErrorResponse - err = json.Unmarshal(responseBody, &errResponse) - if err != nil { - return - } - if errResponse.Error.Message != "" { - // OpenAI format error, so we override the default one - ErrorWithStatusCode.Error = errResponse.Error - } else { - ErrorWithStatusCode.Error.Message = errResponse.ToMessage() - } - if ErrorWithStatusCode.Error.Message == "" { - ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) - } - return -} - -func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { - switch channelType { - case common.ChannelTypeOpenAI: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - case common.ChannelTypeAzure: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) - } - } - return fullRequestURL -} - -func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { - // quotaDelta is remaining quota to be consumed - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - logger.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(ctx, userId) - if err != nil { - logger.SysError("error update user quota cache: " + err.Error()) - } - // totalQuota is total quota consumed - if totalQuota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) - model.UpdateChannelUsedQuota(channelId, totalQuota) - } - if totalQuota <= 0 { - logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) - } -} - -func GetAzureAPIVersion(c *gin.Context) string { - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString(common.ConfigKeyAPIVersion) - } - return apiVersion -} diff --git a/relay/util/model_mapping.go b/relay/util/model_mapping.go deleted file mode 100644 index 39e062a1..00000000 --- a/relay/util/model_mapping.go +++ /dev/null @@ -1,12 +0,0 @@ -package util - -func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) { - if mapping == nil { - return modelName, false - } - mappedModelName := mapping[modelName] - if mappedModelName != "" { - return mappedModelName, true - } - return modelName, false -} diff --git a/relay/util/relay_meta.go b/relay/util/relay_meta.go deleted file mode 100644 index 31b9d2b4..00000000 --- a/relay/util/relay_meta.go +++ /dev/null @@ -1,55 +0,0 @@ -package util - -import ( - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/constant" - "strings" -) - -type RelayMeta struct { - Mode int - ChannelType int - ChannelId int - TokenId int - TokenName string - UserId int - Group string - ModelMapping map[string]string - BaseURL string - APIVersion string - APIKey string - APIType int - Config map[string]string - IsStream bool - OriginModelName string - ActualModelName string - RequestURLPath string - PromptTokens int // only for DoResponse -} - -func GetRelayMeta(c *gin.Context) *RelayMeta { - meta := RelayMeta{ - Mode: constant.Path2RelayMode(c.Request.URL.Path), - ChannelType: c.GetInt("channel"), - ChannelId: c.GetInt("channel_id"), - TokenId: c.GetInt("token_id"), - TokenName: c.GetString("token_name"), - UserId: c.GetInt("id"), - Group: c.GetString("group"), - ModelMapping: c.GetStringMapString("model_mapping"), - BaseURL: c.GetString("base_url"), - APIVersion: c.GetString(common.ConfigKeyAPIVersion), - APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), - Config: nil, - RequestURLPath: c.Request.URL.String(), - } - if meta.ChannelType == common.ChannelTypeAzure { - meta.APIVersion = GetAzureAPIVersion(c) - } - if meta.BaseURL == "" { - meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType] - } - meta.APIType = constant.ChannelType2APIType(meta.ChannelType) - return &meta -} diff --git a/router/api-router.go b/router/api.go similarity index 92% rename from router/api-router.go rename to router/api.go index 47443375..fa6ff2e2 100644 --- a/router/api-router.go +++ b/router/api.go @@ -2,6 +2,7 @@ package router import ( "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/controller/auth" "github.com/songquanpeng/one-api/middleware" "github.com/gin-contrib/gzip" @@ -21,11 +22,13 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) - apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) - apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) - apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) - apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) + apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) + apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) + apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) + apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) + apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), auth.WeChatBind) apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) + apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp) userRoute := apiRouter.Group("/user") { @@ -43,6 +46,7 @@ func SetApiRouter(router *gin.Engine) { selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/aff", controller.GetAffCode) selfRoute.POST("/topup", controller.TopUp) + selfRoute.GET("/available_models", controller.GetUserAvailableModels) } adminRoute := userRoute.Group("/") @@ -68,7 +72,7 @@ func SetApiRouter(router *gin.Engine) { { channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/search", controller.SearchChannels) - channelRoute.GET("/models", controller.ListModels) + channelRoute.GET("/models", controller.ListAllModels) channelRoute.GET("/:id", controller.GetChannel) channelRoute.GET("/test", controller.TestChannels) channelRoute.GET("/test/:id", controller.TestChannel) diff --git a/router/relay-router.go b/router/relay.go similarity index 100% rename from router/relay-router.go rename to router/relay.go diff --git a/router/web-router.go b/router/web.go similarity index 100% rename from router/web-router.go rename to router/web.go diff --git a/web/README.md b/web/README.md index 29f4713e..829271e2 100644 --- a/web/README.md +++ b/web/README.md @@ -2,6 +2,9 @@ > 每个文件夹代表一个主题,欢迎提交你的主题 +> [!WARNING] +> 不是每一个主题都及时同步了所有功能,由于精力有限,优先更新默认主题,其他主题欢迎 & 期待 PR + ## 提交新的主题 > 欢迎在页面底部保留你和 One API 的版权信息以及指向链接 diff --git a/web/air/src/components/ChannelsTable.js b/web/air/src/components/ChannelsTable.js index dee21a01..c384d50c 100644 --- a/web/air/src/components/ChannelsTable.js +++ b/web/air/src/components/ChannelsTable.js @@ -437,7 +437,7 @@ const ChannelsTable = () => { if (success) { record.response_time = time * 1000; record.test_time = Date.now() / 1000; - showInfo(`通道 ${record.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); + showInfo(`渠道 ${record.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); } else { showError(message); } @@ -447,7 +447,7 @@ const ChannelsTable = () => { const res = await API.get(`/api/channel/test?scope=${scope}`); const { success, message } = res.data; if (success) { - showInfo('已成功开始测试通道,请刷新页面查看结果。'); + showInfo('已成功开始测试渠道,请刷新页面查看结果。'); } else { showError(message); } @@ -470,7 +470,7 @@ const ChannelsTable = () => { if (success) { record.balance = balance; record.balance_updated_time = Date.now() / 1000; - showInfo(`通道 ${record.name} 余额更新成功!`); + showInfo(`渠道 ${record.name} 余额更新成功!`); } else { showError(message); } @@ -481,7 +481,7 @@ const ChannelsTable = () => { const res = await API.get(`/api/channel/update_balance`); const { success, message } = res.data; if (success) { - showInfo('已更新完毕所有已启用通道余额!'); + showInfo('已更新完毕所有已启用渠道余额!'); } else { showError(message); } @@ -490,7 +490,7 @@ const ChannelsTable = () => { const batchDeleteChannels = async () => { if (selectedChannels.length === 0) { - showError('请先选择要删除的通道!'); + showError('请先选择要删除的渠道!'); return; } setLoading(true); @@ -501,7 +501,7 @@ const ChannelsTable = () => { const res = await API.post(`/api/channel/batch`, { ids: ids }); const { success, message, data } = res.data; if (success) { - showSuccess(`已删除 ${data} 个通道!`); + showSuccess(`已删除 ${data} 个渠道!`); await refresh(); } else { showError(message); @@ -513,7 +513,7 @@ const ChannelsTable = () => { const res = await API.post(`/api/channel/fix`); const { success, message, data } = res.data; if (success) { - showSuccess(`已修复 ${data} 个通道!`); + showSuccess(`已修复 ${data} 个渠道!`); await refresh(); } else { showError(message); @@ -633,7 +633,7 @@ const ChannelsTable = () => { onConfirm={() => { testChannels("all") }} position={isMobile() ? 'top' : 'left'} > - + { okType={'secondary'} onConfirm={updateAllChannelsBalance} > - + */} - + @@ -673,7 +673,7 @@ const ChannelsTable = () => { setEnableBatchDelete(v); }}> { position={'top'} > + style={{ marginRight: 8 }}>删除所选渠道 { value={inputs.ChannelDisableThreshold} type='number' min='0' - placeholder='单位秒,当运行通道全部测试时,超过此时间将自动禁用通道' + placeholder='单位秒,当运行渠道全部测试时,超过此时间将自动禁用渠道' /> { diff --git a/web/air/src/components/TokensTable.js b/web/air/src/components/TokensTable.js index 9c4deb6e..0853ddfb 100644 --- a/web/air/src/components/TokensTable.js +++ b/web/air/src/components/TokensTable.js @@ -247,6 +247,8 @@ const TokensTable = () => { const [editingToken, setEditingToken] = useState({ id: undefined }); + const [orderBy, setOrderBy] = useState(''); + const [dropdownVisible, setDropdownVisible] = useState(false); const closeEdit = () => { setShowEdit(false); @@ -269,7 +271,7 @@ const TokensTable = () => { let pageData = tokens.slice((activePage - 1) * pageSize, activePage * pageSize); const loadTokens = async (startIdx) => { setLoading(true); - const res = await API.get(`/api/token/?p=${startIdx}&size=${pageSize}`); + const res = await API.get(`/api/token/?p=${startIdx}&size=${pageSize}&order=${orderBy}`); const { success, message, data } = res.data; if (success) { if (startIdx === 0) { @@ -289,7 +291,7 @@ const TokensTable = () => { (async () => { if (activePage === Math.ceil(tokens.length / pageSize) + 1) { // In this case we have to load more data and then append them. - await loadTokens(activePage - 1); + await loadTokens(activePage - 1, orderBy); } setActivePage(activePage); })(); @@ -317,7 +319,7 @@ const TokensTable = () => { if (nextLink) { nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; } else { - nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + nextUrl = `https://app.nextchat.dev/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; } let url; @@ -392,12 +394,12 @@ const TokensTable = () => { }; useEffect(() => { - loadTokens(0) + loadTokens(0, orderBy) .then() .catch((reason) => { showError(reason); }); - }, [pageSize]); + }, [pageSize, orderBy]); const removeRecord = key => { let newDataSource = [...tokens]; @@ -452,6 +454,7 @@ const TokensTable = () => { // if keyword is blank, load files instead. await loadTokens(0); setActivePage(1); + setOrderBy(''); return; } setSearching(true); @@ -520,6 +523,23 @@ const TokensTable = () => { } }; + const handleOrderByChange = (e, { value }) => { + setOrderBy(value); + setActivePage(1); + setDropdownVisible(false); + }; + + const renderSelectedOption = (orderBy) => { + switch (orderBy) { + case 'remain_quota': + return '按剩余额度排序'; + case 'used_quota': + return '按已用额度排序'; + default: + return '默认排序'; + } + }; + return ( <> @@ -579,6 +599,21 @@ const TokensTable = () => { await copyText(keys); } }>复制所选令牌到剪贴板 + setDropdownVisible(visible)} + render={ + + handleOrderByChange('', { value: '' })}>默认排序 + handleOrderByChange('', { value: 'remain_quota' })}>按剩余额度排序 + handleOrderByChange('', { value: 'used_quota' })}>按已用额度排序 + + } + > + + ); }; diff --git a/web/air/src/components/UsersTable.js b/web/air/src/components/UsersTable.js index f3de46d6..4fc16ba5 100644 --- a/web/air/src/components/UsersTable.js +++ b/web/air/src/components/UsersTable.js @@ -1,6 +1,6 @@ import React, { useEffect, useState } from 'react'; import { API, showError, showSuccess } from '../helpers'; -import { Button, Form, Popconfirm, Space, Table, Tag, Tooltip } from '@douyinfe/semi-ui'; +import { Button, Form, Popconfirm, Space, Table, Tag, Tooltip, Dropdown } from '@douyinfe/semi-ui'; import { ITEMS_PER_PAGE } from '../constants'; import { renderGroup, renderNumber, renderQuota } from '../helpers/render'; import AddUser from '../pages/User/AddUser'; @@ -139,6 +139,8 @@ const UsersTable = () => { const [editingUser, setEditingUser] = useState({ id: undefined }); + const [orderBy, setOrderBy] = useState(''); + const [dropdownVisible, setDropdownVisible] = useState(false); const setCount = (data) => { if (data.length >= (activePage) * ITEMS_PER_PAGE) { @@ -162,7 +164,7 @@ const UsersTable = () => { }; const loadUsers = async (startIdx) => { - const res = await API.get(`/api/user/?p=${startIdx}`); + const res = await API.get(`/api/user/?p=${startIdx}&order=${orderBy}`); const { success, message, data } = res.data; if (success) { if (startIdx === 0) { @@ -184,19 +186,19 @@ const UsersTable = () => { (async () => { if (activePage === Math.ceil(users.length / ITEMS_PER_PAGE) + 1) { // In this case we have to load more data and then append them. - await loadUsers(activePage - 1); + await loadUsers(activePage - 1, orderBy); } setActivePage(activePage); })(); }; useEffect(() => { - loadUsers(0) + loadUsers(0, orderBy) .then() .catch((reason) => { showError(reason); }); - }, []); + }, [orderBy]); const manageUser = async (username, action, record) => { const res = await API.post('/api/user/manage', { @@ -239,6 +241,7 @@ const UsersTable = () => { // if keyword is blank, load files instead. await loadUsers(0); setActivePage(1); + setOrderBy(''); return; } setSearching(true); @@ -301,6 +304,25 @@ const UsersTable = () => { } }; + const handleOrderByChange = (e, { value }) => { + setOrderBy(value); + setActivePage(1); + setDropdownVisible(false); + }; + + const renderSelectedOption = (orderBy) => { + switch (orderBy) { + case 'quota': + return '按剩余额度排序'; + case 'used_quota': + return '按已用额度排序'; + case 'request_count': + return '按请求次数排序'; + default: + return '默认排序'; + } + }; + return ( <> @@ -331,6 +353,22 @@ const UsersTable = () => { setShowAddUser(true); } }>添加用户 + setDropdownVisible(visible)} + render={ + + handleOrderByChange('', { value: '' })}>默认排序 + handleOrderByChange('', { value: 'quota' })}>按剩余额度排序 + handleOrderByChange('', { value: 'used_quota' })}>按已用额度排序 + handleOrderByChange('', { value: 'request_count' })}>按请求次数排序 + + } + > + + ); }; diff --git a/web/air/src/pages/Channel/EditChannel.js b/web/air/src/pages/Channel/EditChannel.js index 2b84011b..efb2cee8 100644 --- a/web/air/src/pages/Channel/EditChannel.js +++ b/web/air/src/pages/Channel/EditChannel.js @@ -230,7 +230,7 @@ const EditChannel = (props) => { localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); } if (localInputs.type === 3 && localInputs.other === '') { - localInputs.other = '2023-06-01-preview'; + localInputs.other = '2024-03-01-preview'; } if (localInputs.type === 18 && localInputs.other === '') { localInputs.other = 'v2.1'; @@ -348,7 +348,7 @@ const EditChannel = (props) => { { handleInputChange('other', value) }} diff --git a/web/berry/.prettierrc b/web/berry/.prettierrc new file mode 100644 index 00000000..d5fba07c --- /dev/null +++ b/web/berry/.prettierrc @@ -0,0 +1,8 @@ +{ + "bracketSpacing": true, + "printWidth": 140, + "singleQuote": true, + "trailingComma": "none", + "tabWidth": 2, + "useTabs": false +} diff --git a/web/berry/README.md b/web/berry/README.md index 170feedc..84b2bc2c 100644 --- a/web/berry/README.md +++ b/web/berry/README.md @@ -49,7 +49,7 @@ const typeConfig = { base_url: "请填写AZURE_OPENAI_ENDPOINT", // 注意:通过判断 `other` 是否有值来判断是否需要显示 `other` 输入框, 默认是没有值的 - other: "请输入默认API版本,例如:2023-06-01-preview", + other: "请输入默认API版本,例如:2024-03-01-preview", }, modelGroup: "openai", // 模型组名称,这个值是给 填入渠道支持模型 按钮使用的。 填入渠道支持模型 按钮会根据这个值来获取模型组,如果填写默认是 openai }, diff --git a/web/berry/public/favicon.ico b/web/berry/public/favicon.ico index fbcfb14a..c2c8de0c 100644 Binary files a/web/berry/public/favicon.ico and b/web/berry/public/favicon.ico differ diff --git a/web/berry/public/index.html b/web/berry/public/index.html index 6f232250..abd079e1 100644 --- a/web/berry/public/index.html +++ b/web/berry/public/index.html @@ -11,11 +11,6 @@ name="description" content="OpenAI 接口聚合管理,支持多种渠道包括 Azure,可用于二次分发管理 key,仅单可执行文件,已打包好 Docker 镜像,一键部署,开箱即用" /> - - diff --git a/web/berry/src/App.js b/web/berry/src/App.js index fc54c632..d6422a0f 100644 --- a/web/berry/src/App.js +++ b/web/berry/src/App.js @@ -1,8 +1,9 @@ -import { useSelector } from 'react-redux'; +import { useEffect } from 'react'; +import { useSelector, useDispatch } from 'react-redux'; import { ThemeProvider } from '@mui/material/styles'; import { CssBaseline, StyledEngineProvider } from '@mui/material'; - +import { SET_THEME } from 'store/actions'; // routing import Routes from 'routes'; @@ -20,8 +21,16 @@ import { SnackbarProvider } from 'notistack'; // ==============================|| APP ||============================== // const App = () => { + const dispatch = useDispatch(); const customization = useSelector((state) => state.customization); + useEffect(() => { + const storedTheme = localStorage.getItem('theme'); + if (storedTheme) { + dispatch({ type: SET_THEME, theme: storedTheme }); + } + }, [dispatch]); + return ( diff --git a/web/berry/src/assets/fonts/roboto-500.woff2 b/web/berry/src/assets/fonts/roboto-500.woff2 new file mode 100644 index 00000000..2360b721 Binary files /dev/null and b/web/berry/src/assets/fonts/roboto-500.woff2 differ diff --git a/web/berry/src/assets/fonts/roboto-700.woff2 b/web/berry/src/assets/fonts/roboto-700.woff2 new file mode 100644 index 00000000..4aeda71b Binary files /dev/null and b/web/berry/src/assets/fonts/roboto-700.woff2 differ diff --git a/web/berry/src/assets/fonts/roboto-regular.woff2 b/web/berry/src/assets/fonts/roboto-regular.woff2 new file mode 100644 index 00000000..b65a361a Binary files /dev/null and b/web/berry/src/assets/fonts/roboto-regular.woff2 differ diff --git a/web/berry/src/assets/images/icons/lark.svg b/web/berry/src/assets/images/icons/lark.svg new file mode 100644 index 00000000..239e1bef --- /dev/null +++ b/web/berry/src/assets/images/icons/lark.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/berry/src/assets/images/logo-white.svg b/web/berry/src/assets/images/logo-white.svg new file mode 100644 index 00000000..d6289b9a --- /dev/null +++ b/web/berry/src/assets/images/logo-white.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/berry/src/assets/scss/_themes-vars.module.scss b/web/berry/src/assets/scss/_themes-vars.module.scss index a470b033..661bb6c6 100644 --- a/web/berry/src/assets/scss/_themes-vars.module.scss +++ b/web/berry/src/assets/scss/_themes-vars.module.scss @@ -46,11 +46,16 @@ $grey600: #4b5565; $grey700: #364152; $grey900: #121926; +$tableBackground: #f4f6f8; +$tableBorderBottom: #f1f3f4; + // ==============================|| DARK THEME VARIANTS ||============================== // // paper & background $darkBackground: #1a223f; // level 3 $darkPaper: #111936; // level 4 +$darkDivider: rgba(227, 232, 239, 0.2); +$darkSelectedBack : rgba(124, 77, 255, 0.15); // dark 800 & 900 $darkLevel1: #29314f; // level 1 @@ -154,4 +159,9 @@ $darkTextSecondary: #8492c4; darkSecondaryDark: $darkSecondaryDark; darkSecondary200: $darkSecondary200; darkSecondary800: $darkSecondary800; + + darkDivider: $darkDivider; + darkSelectedBack: $darkSelectedBack; + tableBackground: $tableBackground; + tableBorderBottom: $tableBorderBottom; } diff --git a/web/berry/src/assets/scss/fonts.scss b/web/berry/src/assets/scss/fonts.scss new file mode 100644 index 00000000..c792aab2 --- /dev/null +++ b/web/berry/src/assets/scss/fonts.scss @@ -0,0 +1,32 @@ + +/* roboto-regular */ +@font-face { + font-family: 'Roboto'; + font-style: normal; + font-weight: 400; + font-display: swap; + src: local('Roboto'), url('../fonts/roboto-regular.woff2') format('woff2'); + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; + } + + /* roboto-500 */ +@font-face { + font-family: 'Roboto'; + font-style: normal; + font-weight: 500; + font-display: swap; + src: local('Roboto'), url('../fonts/roboto-500.woff2') format('woff2'); + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; +} + + +/* roboto-700 */ +@font-face { + font-family: 'Roboto'; + font-style: normal; + font-weight: 700; + font-display: swap; + src: local('Roboto'), url('../fonts/roboto-700.woff2') format('woff2'); + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; +} + \ No newline at end of file diff --git a/web/berry/src/assets/scss/style.scss b/web/berry/src/assets/scss/style.scss index 17d566e6..5d2d8975 100644 --- a/web/berry/src/assets/scss/style.scss +++ b/web/berry/src/assets/scss/style.scss @@ -1,3 +1,4 @@ +@import 'fonts.scss'; // color variants @import 'themes-vars.module.scss'; diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index 2c506881..e6b0aed5 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -3,186 +3,228 @@ export const CHANNEL_OPTIONS = { key: 1, text: 'OpenAI', value: 1, - color: 'primary' + color: 'success' }, 14: { key: 14, text: 'Anthropic Claude', value: 14, - color: 'info' + color: 'primary' }, + // 33: { + // key: 33, + // text: 'AWS Claude', + // value: 33, + // color: 'primary' + // }, 3: { key: 3, text: 'Azure OpenAI', value: 3, - color: 'secondary' + color: 'success' }, 11: { key: 11, text: 'Google PaLM2', value: 11, - color: 'orange' + color: 'warning' }, 24: { key: 24, text: 'Google Gemini', value: 24, - color: 'orange' + color: 'warning' }, 28: { key: 28, text: 'Mistral AI', value: 28, - color: 'orange' + color: 'warning' }, 15: { key: 15, text: '百度文心千帆', value: 15, - color: 'default' + color: 'primary' }, 17: { key: 17, text: '阿里通义千问', value: 17, - color: 'default' + color: 'primary' }, 18: { key: 18, text: '讯飞星火认知', value: 18, - color: 'default' + color: 'primary' }, 16: { key: 16, text: '智谱 ChatGLM', value: 16, - color: 'default' + color: 'primary' }, 19: { key: 19, text: '360 智脑', value: 19, - color: 'default' + color: 'primary' }, 25: { key: 25, text: 'Moonshot AI', value: 25, - color: 'default' + color: 'primary' }, 23: { key: 23, text: '腾讯混元', value: 23, - color: 'default' + color: 'primary' }, 26: { key: 26, text: '百川大模型', value: 26, - color: 'default' + color: 'primary' }, 27: { key: 27, text: 'MiniMax', value: 27, - color: 'default' + color: 'primary' }, 29: { key: 29, text: 'Groq', value: 29, - color: 'default' + color: 'primary' }, 30: { key: 30, text: 'Ollama', value: 30, - color: 'default' + color: 'primary' }, 31: { key: 31, text: '零一万物', value: 31, - color: 'default' + color: 'primary' + }, + 32: { + key: 32, + text: '阶跃星辰', + value: 32, + color: 'primary' + }, + // 34: { + // key: 34, + // text: 'Coze', + // value: 34, + // color: 'primary' + // }, + 35: { + key: 35, + text: 'Cohere', + value: 35, + color: 'primary' + }, + 36: { + key: 36, + text: 'DeepSeek', + value: 36, + color: 'primary' + }, + 38: { + key: 38, + text: 'DeepL', + value: 38, + color: 'primary' + }, + 39: { + key: 39, + text: 'together.ai', + value: 39, + color: 'primary' }, 8: { key: 8, text: '自定义渠道', value: 8, - color: 'primary' + color: 'error' }, 22: { key: 22, text: '知识库:FastGPT', value: 22, - color: 'default' + color: 'success' }, 21: { key: 21, text: '知识库:AI Proxy', value: 21, - color: 'purple' + color: 'success' }, 20: { key: 20, text: '代理:OpenRouter', value: 20, - color: 'primary' + color: 'success' }, 2: { key: 2, text: '代理:API2D', value: 2, - color: 'primary' + color: 'success' }, 5: { key: 5, text: '代理:OpenAI-SB', value: 5, - color: 'primary' + color: 'success' }, 7: { key: 7, text: '代理:OhMyGPT', value: 7, - color: 'primary' + color: 'success' }, 10: { key: 10, text: '代理:AI Proxy', value: 10, - color: 'primary' + color: 'success' }, 4: { key: 4, text: '代理:CloseAI', value: 4, - color: 'primary' + color: 'success' }, 6: { key: 6, text: '代理:OpenAI Max', value: 6, - color: 'primary' + color: 'success' }, 9: { key: 9, text: '代理:AI.LS', value: 9, - color: 'primary' + color: 'success' }, 12: { key: 12, text: '代理:API2GPT', value: 12, - color: 'primary' + color: 'success' }, 13: { key: 13, text: '代理:AIGC2D', value: 13, - color: 'primary' + color: 'success' } }; diff --git a/web/berry/src/constants/SnackbarConstants.js b/web/berry/src/constants/SnackbarConstants.js index a05c6652..19523da1 100644 --- a/web/berry/src/constants/SnackbarConstants.js +++ b/web/berry/src/constants/SnackbarConstants.js @@ -18,7 +18,7 @@ export const snackbarConstants = { }, NOTICE: { variant: 'info', - autoHideDuration: 20000 + autoHideDuration: 7000 } }, Mobile: { diff --git a/web/berry/src/hooks/useLogin.js b/web/berry/src/hooks/useLogin.js index 53626577..39d8b407 100644 --- a/web/berry/src/hooks/useLogin.js +++ b/web/berry/src/hooks/useLogin.js @@ -48,6 +48,28 @@ const useLogin = () => { } }; + const larkLogin = async (code, state) => { + try { + const res = await API.get(`/api/oauth/lark?code=${code}&state=${state}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/panel'); + } else { + dispatch({ type: LOGIN, payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/panel'); + } + } + return { success, message }; + } catch (err) { + // 请求失败,设置错误信息 + return { success: false, message: '' }; + } + }; + const wechatLogin = async (code) => { try { const res = await API.get(`/api/oauth/wechat?code=${code}`); @@ -72,7 +94,7 @@ const useLogin = () => { navigate('/'); }; - return { login, logout, githubLogin, wechatLogin }; + return { login, logout, githubLogin, wechatLogin, larkLogin }; }; export default useLogin; diff --git a/web/berry/src/layout/MainLayout/Header/ProfileSection/index.js b/web/berry/src/layout/MainLayout/Header/ProfileSection/index.js index 3e351254..e1392dc0 100644 --- a/web/berry/src/layout/MainLayout/Header/ProfileSection/index.js +++ b/web/berry/src/layout/MainLayout/Header/ProfileSection/index.js @@ -71,8 +71,8 @@ const ProfileSection = () => { alignItems: 'center', borderRadius: '27px', transition: 'all .2s ease-in-out', - borderColor: theme.palette.primary.light, - backgroundColor: theme.palette.primary.light, + borderColor: theme.typography.menuChip.background, + backgroundColor: theme.typography.menuChip.background, '&[aria-controls="menu-list-grow"], &:hover': { borderColor: theme.palette.primary.main, background: `${theme.palette.primary.main}!important`, diff --git a/web/berry/src/layout/MainLayout/Header/index.js b/web/berry/src/layout/MainLayout/Header/index.js index 51d40c75..8fd9c950 100644 --- a/web/berry/src/layout/MainLayout/Header/index.js +++ b/web/berry/src/layout/MainLayout/Header/index.js @@ -7,6 +7,7 @@ import { Avatar, Box, ButtonBase } from '@mui/material'; // project imports import LogoSection from '../LogoSection'; import ProfileSection from './ProfileSection'; +import ThemeButton from 'ui-component/ThemeButton'; // assets import { IconMenu2 } from '@tabler/icons-react'; @@ -37,9 +38,8 @@ const Header = ({ handleLeftDrawerToggle }) => { sx={{ ...theme.typography.commonAvatar, ...theme.typography.mediumAvatar, + ...theme.typography.menuButton, transition: 'all .2s ease-in-out', - background: theme.palette.secondary.light, - color: theme.palette.secondary.dark, '&:hover': { background: theme.palette.secondary.dark, color: theme.palette.secondary.light @@ -55,7 +55,7 @@ const Header = ({ handleLeftDrawerToggle }) => { - + ); diff --git a/web/berry/src/layout/MainLayout/Sidebar/MenuCard/index.js b/web/berry/src/layout/MainLayout/Sidebar/MenuCard/index.js index 16b13231..dadd3eca 100644 --- a/web/berry/src/layout/MainLayout/Sidebar/MenuCard/index.js +++ b/web/berry/src/layout/MainLayout/Sidebar/MenuCard/index.js @@ -36,7 +36,7 @@ import { useNavigate } from 'react-router-dom'; // })); const CardStyle = styled(Card)(({ theme }) => ({ - background: theme.palette.primary.light, + background: theme.typography.menuChip.background, marginBottom: '22px', overflow: 'hidden', position: 'relative', @@ -121,7 +121,6 @@ const MenuCard = () => { /> - {/* */} ); diff --git a/web/berry/src/layout/MainLayout/Sidebar/index.js b/web/berry/src/layout/MainLayout/Sidebar/index.js index e3c6d12d..10652ba6 100644 --- a/web/berry/src/layout/MainLayout/Sidebar/index.js +++ b/web/berry/src/layout/MainLayout/Sidebar/index.js @@ -39,7 +39,13 @@ const Sidebar = ({ drawerOpen, drawerToggle, window }) => { - + @@ -48,7 +54,13 @@ const Sidebar = ({ drawerOpen, drawerToggle, window }) => { - + diff --git a/web/berry/src/layout/MinimalLayout/Header/index.js b/web/berry/src/layout/MinimalLayout/Header/index.js index 4f61da60..feaeb603 100644 --- a/web/berry/src/layout/MinimalLayout/Header/index.js +++ b/web/berry/src/layout/MinimalLayout/Header/index.js @@ -1,10 +1,30 @@ // material-ui -import { useTheme } from "@mui/material/styles"; -import { Box, Button, Stack } from "@mui/material"; -import LogoSection from "layout/MainLayout/LogoSection"; -import { Link } from "react-router-dom"; -import { useLocation } from "react-router-dom"; -import { useSelector } from "react-redux"; +import { useState } from 'react'; +import { useTheme } from '@mui/material/styles'; +import { + Box, + Button, + Stack, + Popper, + IconButton, + List, + ListItemButton, + Paper, + ListItemText, + Typography, + Divider, + ClickAwayListener +} from '@mui/material'; +import LogoSection from 'layout/MainLayout/LogoSection'; +import { Link } from 'react-router-dom'; +import { useLocation } from 'react-router-dom'; +import { useSelector } from 'react-redux'; +import ThemeButton from 'ui-component/ThemeButton'; +import ProfileSection from 'layout/MainLayout/Header/ProfileSection'; +import { IconMenu2 } from '@tabler/icons-react'; +import Transitions from 'ui-component/extended/Transitions'; +import MainCard from 'ui-component/cards/MainCard'; +import { useMediaQuery } from '@mui/material'; // ==============================|| MAIN NAVBAR / HEADER ||============================== // @@ -12,16 +32,26 @@ const Header = () => { const theme = useTheme(); const { pathname } = useLocation(); const account = useSelector((state) => state.account); + const [open, setOpen] = useState(null); + const isMobile = useMediaQuery(theme.breakpoints.down('sm')); + + const handleOpenMenu = (event) => { + setOpen(open ? null : event.currentTarget); + }; + + const handleCloseMenu = () => { + setOpen(null); + }; return ( <> @@ -31,43 +61,99 @@ const Header = () => { - - - - {account.user ? ( - + + {isMobile ? ( + <> + + + + + ) : ( - + <> + + + + {account.user ? ( + <> + + + + ) : ( + + )} + )} + + + {({ TransitionProps }) => ( + + + + + + + 首页} /> + + + + 关于} /> + + + {account.user ? ( + + 控制台 + + ) : ( + + 登录 + + )} + + + + + + )} + ); }; diff --git a/web/berry/src/layout/MinimalLayout/index.js b/web/berry/src/layout/MinimalLayout/index.js index c2919c6d..81047fd1 100644 --- a/web/berry/src/layout/MinimalLayout/index.js +++ b/web/berry/src/layout/MinimalLayout/index.js @@ -1,6 +1,6 @@ import { Outlet } from 'react-router-dom'; import { useTheme } from '@mui/material/styles'; -import { AppBar, Box, CssBaseline, Toolbar } from '@mui/material'; +import { AppBar, Box, CssBaseline, Toolbar, Container } from '@mui/material'; import Header from './Header'; import Footer from 'ui-component/Footer'; @@ -22,9 +22,11 @@ const MinimalLayout = () => { flex: 'none' }} > - -
- + + +
+ + diff --git a/web/berry/src/routes/OtherRoutes.js b/web/berry/src/routes/OtherRoutes.js index 085c4add..58c0b660 100644 --- a/web/berry/src/routes/OtherRoutes.js +++ b/web/berry/src/routes/OtherRoutes.js @@ -8,6 +8,7 @@ import MinimalLayout from 'layout/MinimalLayout'; const AuthLogin = Loadable(lazy(() => import('views/Authentication/Auth/Login'))); const AuthRegister = Loadable(lazy(() => import('views/Authentication/Auth/Register'))); const GitHubOAuth = Loadable(lazy(() => import('views/Authentication/Auth/GitHubOAuth'))); +const LarkOAuth = Loadable(lazy(() => import('views/Authentication/Auth/LarkOAuth'))); const ForgetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ForgetPassword'))); const ResetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ResetPassword'))); const Home = Loadable(lazy(() => import('views/Home'))); @@ -48,6 +49,10 @@ const OtherRoutes = { path: '/oauth/github', element: }, + { + path: '/oauth/lark', + element: + }, { path: '/404', element: diff --git a/web/berry/src/store/actions.js b/web/berry/src/store/actions.js index 221e8578..f1592d17 100644 --- a/web/berry/src/store/actions.js +++ b/web/berry/src/store/actions.js @@ -7,3 +7,4 @@ export const SET_BORDER_RADIUS = '@customization/SET_BORDER_RADIUS'; export const SET_SITE_INFO = '@siteInfo/SET_SITE_INFO'; export const LOGIN = '@account/LOGIN'; export const LOGOUT = '@account/LOGOUT'; +export const SET_THEME = '@customization/SET_THEME'; diff --git a/web/berry/src/store/customizationReducer.js b/web/berry/src/store/customizationReducer.js index bd8e5f00..0c104025 100644 --- a/web/berry/src/store/customizationReducer.js +++ b/web/berry/src/store/customizationReducer.js @@ -9,7 +9,8 @@ export const initialState = { defaultId: 'default', fontFamily: config.fontFamily, borderRadius: config.borderRadius, - opened: true + opened: true, + theme: 'light' }; // ==============================|| CUSTOMIZATION REDUCER ||============================== // @@ -38,6 +39,11 @@ const customizationReducer = (state = initialState, action) => { ...state, borderRadius: action.borderRadius }; + case actionTypes.SET_THEME: + return { + ...state, + theme: action.theme + }; default: return state; } diff --git a/web/berry/src/themes/compStyleOverride.js b/web/berry/src/themes/compStyleOverride.js index b6e87e01..67a3dd14 100644 --- a/web/berry/src/themes/compStyleOverride.js +++ b/web/berry/src/themes/compStyleOverride.js @@ -1,5 +1,5 @@ export default function componentStyleOverrides(theme) { - const bgColor = theme.colors?.grey50; + const bgColor = theme.mode === 'dark' ? theme.backgroundDefault : theme.colors?.grey50; return { MuiButton: { styleOverrides: { @@ -12,15 +12,7 @@ export default function componentStyleOverrides(theme) { } } }, - MuiMenuItem: { - styleOverrides: { - root: { - '&:hover': { - backgroundColor: theme.colors?.grey100 - } - } - } - }, //MuiAutocomplete-popper MuiPopover-root + //MuiAutocomplete-popper MuiPopover-root MuiAutocomplete: { styleOverrides: { popper: { @@ -226,12 +218,12 @@ export default function componentStyleOverrides(theme) { MuiTableCell: { styleOverrides: { root: { - borderBottom: '1px solid rgb(241, 243, 244)', + borderBottom: '1px solid ' + theme.tableBorderBottom, textAlign: 'center' }, head: { color: theme.darkTextSecondary, - backgroundColor: 'rgb(244, 246, 248)' + backgroundColor: theme.headBackgroundColor } } }, @@ -239,7 +231,7 @@ export default function componentStyleOverrides(theme) { styleOverrides: { root: { '&:hover': { - backgroundColor: 'rgb(244, 246, 248)' + backgroundColor: theme.headBackgroundColor } } } @@ -247,10 +239,29 @@ export default function componentStyleOverrides(theme) { MuiTooltip: { styleOverrides: { tooltip: { - color: theme.paper, + color: theme.colors.paper, background: theme.colors?.grey700 } } + }, + MuiCssBaseline: { + styleOverrides: ` + .apexcharts-title-text { + fill: ${theme.textDark} !important + } + .apexcharts-text { + fill: ${theme.textDark} !important + } + .apexcharts-legend-text { + color: ${theme.textDark} !important + } + .apexcharts-menu { + background: ${theme.backgroundDefault} !important + } + .apexcharts-gridline, .apexcharts-xaxistooltip-background, .apexcharts-yaxistooltip-background { + stroke: ${theme.divider} !important; + } + ` } }; } diff --git a/web/berry/src/themes/index.js b/web/berry/src/themes/index.js index 6e694aa6..addd61f7 100644 --- a/web/berry/src/themes/index.js +++ b/web/berry/src/themes/index.js @@ -15,19 +15,10 @@ import themeTypography from './typography'; export const theme = (customization) => { const color = colors; - + const options = customization.theme === 'light' ? GetLightOption() : GetDarkOption(); const themeOption = { colors: color, - heading: color.grey900, - paper: color.paper, - backgroundDefault: color.paper, - background: color.primaryLight, - darkTextPrimary: color.grey700, - darkTextSecondary: color.grey500, - textDark: color.grey900, - menuSelected: color.secondaryDark, - menuSelectedBack: color.secondaryLight, - divider: color.grey200, + ...options, customization }; @@ -53,3 +44,49 @@ export const theme = (customization) => { }; export default theme; + +function GetDarkOption() { + const color = colors; + return { + mode: 'dark', + heading: color.darkTextTitle, + paper: color.darkLevel2, + backgroundDefault: color.darkPaper, + background: color.darkBackground, + darkTextPrimary: color.darkTextPrimary, + darkTextSecondary: color.darkTextSecondary, + textDark: color.darkTextTitle, + menuSelected: color.darkSecondaryMain, + menuSelectedBack: color.darkSelectedBack, + divider: color.darkDivider, + borderColor: color.darkBorderColor, + menuButton: color.darkLevel1, + menuButtonColor: color.darkSecondaryMain, + menuChip: color.darkLevel1, + headBackgroundColor: color.darkBackground, + tableBorderBottom: color.darkDivider + }; +} + +function GetLightOption() { + const color = colors; + return { + mode: 'light', + heading: color.grey900, + paper: color.paper, + backgroundDefault: color.paper, + background: color.primaryLight, + darkTextPrimary: color.grey700, + darkTextSecondary: color.grey500, + textDark: color.grey900, + menuSelected: color.secondaryDark, + menuSelectedBack: color.secondaryLight, + divider: color.grey200, + borderColor: color.grey300, + menuButton: color.secondaryLight, + menuButtonColor: color.secondaryDark, + menuChip: color.primaryLight, + headBackgroundColor: color.tableBackground, + tableBorderBottom: color.tableBorderBottom + }; +} diff --git a/web/berry/src/themes/palette.js b/web/berry/src/themes/palette.js index 09768555..70c78782 100644 --- a/web/berry/src/themes/palette.js +++ b/web/berry/src/themes/palette.js @@ -5,7 +5,7 @@ export default function themePalette(theme) { return { - mode: 'light', + mode: theme.mode, common: { black: theme.colors?.darkPaper }, diff --git a/web/berry/src/themes/typography.js b/web/berry/src/themes/typography.js index 24bfabb9..f20d87a5 100644 --- a/web/berry/src/themes/typography.js +++ b/web/berry/src/themes/typography.js @@ -132,6 +132,19 @@ export default function themeTypography(theme) { width: '44px', height: '44px', fontSize: '1.5rem' + }, + menuButton: { + color: theme.menuButtonColor, + background: theme.menuButton + }, + menuChip: { + background: theme.menuChip + }, + CardWrapper: { + backgroundColor: theme.mode === 'dark' ? theme.colors.darkLevel2 : theme.colors.primaryDark + }, + SubCard: { + border: theme.mode === 'dark' ? '1px solid rgba(227, 232, 239, 0.2)' : '1px solid rgb(227, 232, 239)' } }; } diff --git a/web/berry/src/ui-component/Logo.js b/web/berry/src/ui-component/Logo.js index a34fe895..52e61f4f 100644 --- a/web/berry/src/ui-component/Logo.js +++ b/web/berry/src/ui-component/Logo.js @@ -1,6 +1,8 @@ // material-ui -import logo from 'assets/images/logo.svg'; +import logoLight from 'assets/images/logo.svg'; +import logoDark from 'assets/images/logo-white.svg'; import { useSelector } from 'react-redux'; +import { useTheme } from '@mui/material/styles'; /** * if you want to use image instead of uncomment following. @@ -14,6 +16,8 @@ import { useSelector } from 'react-redux'; const Logo = () => { const siteInfo = useSelector((state) => state.siteInfo); + const theme = useTheme(); + const logo = theme.palette.mode === 'light' ? logoLight : logoDark; return {siteInfo.system_name}; }; diff --git a/web/berry/src/ui-component/ThemeButton.js b/web/berry/src/ui-component/ThemeButton.js new file mode 100644 index 00000000..c907c646 --- /dev/null +++ b/web/berry/src/ui-component/ThemeButton.js @@ -0,0 +1,50 @@ +import { useDispatch, useSelector } from 'react-redux'; +import { SET_THEME } from 'store/actions'; +import { useTheme } from '@mui/material/styles'; +import { Avatar, Box, ButtonBase } from '@mui/material'; +import { IconSun, IconMoon } from '@tabler/icons-react'; + +export default function ThemeButton() { + const dispatch = useDispatch(); + + const defaultTheme = useSelector((state) => state.customization.theme); + + const theme = useTheme(); + + return ( + + + { + let theme = defaultTheme === 'light' ? 'dark' : 'light'; + dispatch({ type: SET_THEME, theme: theme }); + localStorage.setItem('theme', theme); + }} + color="inherit" + > + {defaultTheme === 'light' ? : } + + + + ); +} diff --git a/web/berry/src/ui-component/cards/MainCard.js b/web/berry/src/ui-component/cards/MainCard.js index 8735282c..32353027 100644 --- a/web/berry/src/ui-component/cards/MainCard.js +++ b/web/berry/src/ui-component/cards/MainCard.js @@ -15,7 +15,7 @@ const headerSX = { const MainCard = forwardRef( ( { - border = true, + border = false, boxShadow, children, content = true, diff --git a/web/berry/src/ui-component/cards/SubCard.js b/web/berry/src/ui-component/cards/SubCard.js index 05f9abb7..a63819a8 100644 --- a/web/berry/src/ui-component/cards/SubCard.js +++ b/web/berry/src/ui-component/cards/SubCard.js @@ -15,8 +15,7 @@ const SubCard = forwardRef( )} @@ -62,7 +61,8 @@ SubCard.propTypes = { secondary: PropTypes.oneOfType([PropTypes.node, PropTypes.string, PropTypes.object]), sx: PropTypes.object, contentSX: PropTypes.object, - title: PropTypes.oneOfType([PropTypes.node, PropTypes.string, PropTypes.object]) + title: PropTypes.oneOfType([PropTypes.node, PropTypes.string, PropTypes.object]), + subTitle: PropTypes.oneOfType([PropTypes.node, PropTypes.string, PropTypes.object]) }; SubCard.defaultProps = { diff --git a/web/berry/src/utils/chart.js b/web/berry/src/utils/chart.js index 4633fe37..8cf6d847 100644 --- a/web/berry/src/utils/chart.js +++ b/web/berry/src/utils/chart.js @@ -40,7 +40,8 @@ export function generateChartOptions(data, unit) { chart: { sparkline: { enabled: true - } + }, + background: 'transparent' }, dataLabels: { enabled: false diff --git a/web/berry/src/utils/common.js b/web/berry/src/utils/common.js index 25e5c635..947df3bf 100644 --- a/web/berry/src/utils/common.js +++ b/web/berry/src/utils/common.js @@ -51,9 +51,9 @@ export function showError(error) { export function showNotice(message, isHTML = false) { if (isHTML) { - enqueueSnackbar(, getSnackbarOptions('INFO')); + enqueueSnackbar(, getSnackbarOptions('NOTICE')); } else { - enqueueSnackbar(message, getSnackbarOptions('INFO')); + enqueueSnackbar(message, getSnackbarOptions('NOTICE')); } } @@ -91,6 +91,13 @@ export async function onGitHubOAuthClicked(github_client_id, openInNewTab = fals } } +export async function onLarkOAuthClicked(lark_client_id) { + const state = await getOAuthState(); + if (!state) return; + let redirect_uri = `${window.location.origin}/oauth/lark`; + window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`); +} + export function isAdmin() { let user = localStorage.getItem('user'); if (!user) return false; diff --git a/web/berry/src/views/Authentication/Auth/LarkOAuth.js b/web/berry/src/views/Authentication/Auth/LarkOAuth.js new file mode 100644 index 00000000..88ced5d8 --- /dev/null +++ b/web/berry/src/views/Authentication/Auth/LarkOAuth.js @@ -0,0 +1,94 @@ +import { Link, useNavigate, useSearchParams } from 'react-router-dom'; +import React, { useEffect, useState } from 'react'; +import { showError } from 'utils/common'; +import useLogin from 'hooks/useLogin'; + +// material-ui +import { useTheme } from '@mui/material/styles'; +import { Grid, Stack, Typography, useMediaQuery, CircularProgress } from '@mui/material'; + +// project imports +import AuthWrapper from '../AuthWrapper'; +import AuthCardWrapper from '../AuthCardWrapper'; +import Logo from 'ui-component/Logo'; + +// assets + +// ================================|| AUTH3 - LOGIN ||================================ // + +const LarkOAuth = () => { + const theme = useTheme(); + const matchDownSM = useMediaQuery(theme.breakpoints.down('md')); + + const [searchParams] = useSearchParams(); + const [prompt, setPrompt] = useState('处理中...'); + const { larkLogin } = useLogin(); + + let navigate = useNavigate(); + + const sendCode = async (code, state, count) => { + const { success, message } = await larkLogin(code, state); + if (!success) { + if (message) { + showError(message); + } + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + await new Promise((resolve) => setTimeout(resolve, 2000)); + navigate('/login'); + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, 2000)); + await sendCode(code, state, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + let state = searchParams.get('state'); + sendCode(code, state, 0).then(); + }, []); + + return ( + + + + + + + + + + + + + + + + + + 飞书 登录 + + + + + + + + + {prompt} + + + + + + + + + + ); +}; + +export default LarkOAuth; diff --git a/web/berry/src/views/Authentication/Auth/Register.js b/web/berry/src/views/Authentication/Auth/Register.js index 4489e560..8027649d 100644 --- a/web/berry/src/views/Authentication/Auth/Register.js +++ b/web/berry/src/views/Authentication/Auth/Register.js @@ -51,7 +51,7 @@ const Register = () => { - 已经有帐号了?点击登录 + 已经有帐号了?点击登录 diff --git a/web/berry/src/views/Authentication/AuthForms/AuthLogin.js b/web/berry/src/views/Authentication/AuthForms/AuthLogin.js index 70aa2230..bc7a35c0 100644 --- a/web/berry/src/views/Authentication/AuthForms/AuthLogin.js +++ b/web/berry/src/views/Authentication/AuthForms/AuthLogin.js @@ -35,7 +35,8 @@ import VisibilityOff from '@mui/icons-material/VisibilityOff'; import Github from 'assets/images/icons/github.svg'; import Wechat from 'assets/images/icons/wechat.svg'; -import { onGitHubOAuthClicked } from 'utils/common'; +import Lark from 'assets/images/icons/lark.svg'; +import { onGitHubOAuthClicked, onLarkOAuthClicked } from 'utils/common'; // ============================|| FIREBASE - LOGIN ||============================ // @@ -49,7 +50,7 @@ const LoginForm = ({ ...others }) => { // const [checked, setChecked] = useState(true); let tripartiteLogin = false; - if (siteInfo.github_oauth || siteInfo.wechat_login) { + if (siteInfo.github_oauth || siteInfo.wechat_login || siteInfo.lark_client_id) { tripartiteLogin = true; } @@ -121,6 +122,29 @@ const LoginForm = ({ ...others }) => { )} + {siteInfo.lark_client_id && ( + + + + + + )} { {({ errors, handleBlur, handleChange, handleSubmit, isSubmitting, touched, values }) => (
- 用户名 + 用户名 / 邮箱 { diff --git a/web/berry/src/views/Authentication/AuthWrapper.js b/web/berry/src/views/Authentication/AuthWrapper.js index 8cd0ec29..dc875704 100644 --- a/web/berry/src/views/Authentication/AuthWrapper.js +++ b/web/berry/src/views/Authentication/AuthWrapper.js @@ -8,7 +8,7 @@ import { UserContext } from 'contexts/UserContext'; // ==============================|| AUTHENTICATION 1 WRAPPER ||============================== // const AuthStyle = styled('div')(({ theme }) => ({ - backgroundColor: theme.palette.primary.light + backgroundColor: theme.palette.background.default })); // eslint-disable-next-line diff --git a/web/berry/src/views/Channel/component/EditModal.js b/web/berry/src/views/Channel/component/EditModal.js index 07111c97..03b4df57 100644 --- a/web/berry/src/views/Channel/component/EditModal.js +++ b/web/berry/src/views/Channel/component/EditModal.js @@ -21,15 +21,16 @@ import { Container, Autocomplete, FormHelperText, - Checkbox + Switch, + Checkbox, } from "@mui/material"; import { Formik } from "formik"; import * as Yup from "yup"; import { defaultConfig, typeConfig } from "../type/Config"; //typeConfig import { createFilterOptions } from "@mui/material/Autocomplete"; -import CheckBoxOutlineBlankIcon from '@mui/icons-material/CheckBoxOutlineBlank'; -import CheckBoxIcon from '@mui/icons-material/CheckBox'; +import CheckBoxOutlineBlankIcon from "@mui/icons-material/CheckBoxOutlineBlank"; +import CheckBoxIcon from "@mui/icons-material/CheckBox"; const icon = ; const checkedIcon = ; @@ -79,6 +80,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { const [inputPrompt, setInputPrompt] = useState(defaultConfig.prompt); const [groupOptions, setGroupOptions] = useState([]); const [modelOptions, setModelOptions] = useState([]); + const [batchAdd, setBatchAdd] = useState(false); const initChannel = (typeValue) => { if (typeConfig[typeValue]?.inputLabel) { @@ -151,7 +153,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { try { let res = await API.get(`/api/channel/models`); const { data } = res.data; - data.forEach(item => { + data.forEach((item) => { if (!item.owned_by) { item.owned_by = "未知"; } @@ -166,7 +168,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { }); setModelOptions( - data.map((model) => { + data.map((model) => { return { id: model.id, group: model.owned_by, @@ -258,7 +260,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { 2 ); } - data.base_url = data.base_url ?? ''; + data.base_url = data.base_url ?? ""; data.is_edit = true; initChannel(data.type); setInitialInput(data); @@ -273,6 +275,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { }, []); useEffect(() => { + setBatchAdd(false); if (channelId) { loadChannel().then(); } else { @@ -340,13 +343,17 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { }, }} > - {Object.values(CHANNEL_OPTIONS).map((option) => { - return ( - - {option.text} - - ); - })} + {Object.values(CHANNEL_OPTIONS) + .sort((a, b) => { + return a.text.localeCompare(b.text); + }) + .map((option) => { + return ( + + {option.text} + + ); + })} {touched.type && errors.type ? ( @@ -551,7 +558,12 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { }} renderOption={(props, option, { selected }) => (
  • - + {option.id}
  • )} @@ -597,20 +609,38 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { error={Boolean(touched.key && errors.key)} sx={{ ...theme.typography.otherInput }} > - - {inputLabel.key} - - + {!batchAdd ? ( + <> + + {inputLabel.key} + + + + ) : ( + + )} + {touched.key && errors.key ? ( {errors.key} @@ -622,6 +652,19 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { )}
    + {channelId === 0 && ( + + setBatchAdd(e.target.checked)} + /> + 批量添加 + + )} { let groups = []; if (group === "") { @@ -14,7 +27,7 @@ const GroupLabel = ({ group }) => { return ( } spacing={0.5}> {groups.map((group, index) => { - return ; + return ; })} ); diff --git a/web/berry/src/views/Channel/component/TableHead.js b/web/berry/src/views/Channel/component/TableHead.js index 736dd8aa..8c47e440 100644 --- a/web/berry/src/views/Channel/component/TableHead.js +++ b/web/berry/src/views/Channel/component/TableHead.js @@ -10,6 +10,7 @@ const ChannelTableHead = () => { 类型 状态 响应时间 + 已消耗 余额 优先级 操作 diff --git a/web/berry/src/views/Channel/component/TableRow.js b/web/berry/src/views/Channel/component/TableRow.js index baca42cd..2a7b9c7f 100644 --- a/web/berry/src/views/Channel/component/TableRow.js +++ b/web/berry/src/views/Channel/component/TableRow.js @@ -11,10 +11,7 @@ import { MenuItem, TableCell, IconButton, - FormControl, - InputLabel, - InputAdornment, - Input, + TextField, Dialog, DialogActions, DialogContent, @@ -31,12 +28,7 @@ import ResponseTimeLabel from "./ResponseTimeLabel"; import GroupLabel from "./GroupLabel"; import NameLabel from "./NameLabel"; -import { - IconDotsVertical, - IconEdit, - IconTrash, - IconPencil, -} from "@tabler/icons-react"; +import { IconDotsVertical, IconEdit, IconTrash } from "@tabler/icons-react"; export default function ChannelTableRow({ item, @@ -79,11 +71,19 @@ export default function ChannelTableRow({ } }; - const handlePriority = async () => { - if (priorityValve === "" || priorityValve === item.priority) { + const handlePriority = async (event) => { + const currentValue = parseInt(event.target.value); + if (isNaN(currentValue) || currentValue === priorityValve) { return; } - await manageChannel(item.id, "priority", priorityValve); + + if (currentValue < 0) { + showError("优先级不能小于 0"); + return; + } + + await manageChannel(item.id, "priority", currentValue); + setPriority(currentValue); }; const handleResponseTime = async () => { @@ -93,7 +93,7 @@ export default function ChannelTableRow({ test_time: Date.now() / 1000, response_time: time * 1000, }); - showInfo(`通道 ${item.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); + showInfo(`渠道 ${item.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); } }; @@ -170,6 +170,7 @@ export default function ChannelTableRow({ handle_action={handleResponseTime} /> + {renderNumber(item.used_quota)} - - 优先级 - setPriority(e.target.value)} - sx={{ textAlign: "center" }} - endAdornment={ - - - - - - } - /> - + @@ -240,9 +230,9 @@ export default function ChannelTableRow({ - 删除通道 + 删除渠道 - 是否删除通道 {item.name}? + 是否删除渠道 {item.name}? diff --git a/web/berry/src/views/Channel/index.js b/web/berry/src/views/Channel/index.js index 39ab5d82..c12ff3ba 100644 --- a/web/berry/src/views/Channel/index.js +++ b/web/berry/src/views/Channel/index.js @@ -135,7 +135,7 @@ export default function ChannelPage() { const res = await API.get(`/api/channel/test`); const { success, message } = res.data; if (success) { - showInfo('已成功开始测试所有通道,请刷新页面查看结果。'); + showInfo('已成功开始测试所有渠道,请刷新页面查看结果。'); } else { showError(message); } @@ -159,7 +159,7 @@ export default function ChannelPage() { const res = await API.get(`/api/channel/update_balance`); const { success, message } = res.data; if (success) { - showInfo('已更新完毕所有已启用通道余额!'); + showInfo('已更新完毕所有已启用渠道余额!'); } else { showError(message); } @@ -193,20 +193,14 @@ export default function ChannelPage() { return ( <> - + 渠道 - - - - OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。 - - - + {matchUpMd ? ( - + diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index 8dfe77a4..7e42ca8d 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -41,7 +41,7 @@ const typeConfig = { }, prompt: { base_url: "请填写AZURE_OPENAI_ENDPOINT", - other: "请输入默认API版本,例如:2023-06-01-preview", + other: "请输入默认API版本,例如:2024-03-01-preview", }, }, 11: { diff --git a/web/berry/src/views/Dashboard/component/StatisticalLineChartCard.js b/web/berry/src/views/Dashboard/component/StatisticalLineChartCard.js index 53cd46b0..e6b46e25 100644 --- a/web/berry/src/views/Dashboard/component/StatisticalLineChartCard.js +++ b/web/berry/src/views/Dashboard/component/StatisticalLineChartCard.js @@ -12,7 +12,7 @@ import MainCard from 'ui-component/cards/MainCard'; import SkeletonTotalOrderCard from 'ui-component/cards/Skeleton/EarningCard'; const CardWrapper = styled(MainCard)(({ theme }) => ({ - backgroundColor: theme.palette.primary.dark, + ...theme.typography.CardWrapper, color: '#fff', overflow: 'hidden', position: 'relative', @@ -65,7 +65,7 @@ const StatisticalLineChartCard = ({ isLoading, title, chartData, todayValue }) = ) : ( - + diff --git a/web/berry/src/views/Log/index.js b/web/berry/src/views/Log/index.js index da24b4fd..f8cef0e8 100644 --- a/web/berry/src/views/Log/index.js +++ b/web/berry/src/views/Log/index.js @@ -102,11 +102,11 @@ export default function Log() { return ( <> - + 日志 - + - + diff --git a/web/berry/src/views/Profile/index.js b/web/berry/src/views/Profile/index.js index e0683228..483e3141 100644 --- a/web/berry/src/views/Profile/index.js +++ b/web/berry/src/views/Profile/index.js @@ -12,7 +12,8 @@ import { DialogTitle, DialogContent, DialogActions, - Divider + Divider, + SvgIcon } from '@mui/material'; import Grid from '@mui/material/Unstable_Grid2'; import SubCard from 'ui-component/cards/SubCard'; @@ -20,12 +21,13 @@ import { IconBrandWechat, IconBrandGithub, IconMail } from '@tabler/icons-react' import Label from 'ui-component/Label'; import { API } from 'utils/api'; import { showError, showSuccess } from 'utils/common'; -import { onGitHubOAuthClicked } from 'utils/common'; +import { onGitHubOAuthClicked, onLarkOAuthClicked } from 'utils/common'; import * as Yup from 'yup'; import WechatModal from 'views/Authentication/AuthForms/WechatModal'; import { useSelector } from 'react-redux'; import EmailModal from './component/EmailModal'; import Turnstile from 'react-turnstile'; +import { ReactComponent as Lark } from 'assets/images/icons/lark.svg'; const validationSchema = Yup.object().shape({ username: Yup.string().required('用户名 不能为空').min(3, '用户名 不能小于 3 个字符'), @@ -137,6 +139,9 @@ export default function Profile() { + @@ -205,6 +210,13 @@ export default function Profile() { )} + {status.lark_client_id && !inputs.lark_id && ( + + + + )} - + - + diff --git a/web/berry/src/views/Setting/component/OperationSetting.js b/web/berry/src/views/Setting/component/OperationSetting.js index d91298b2..2bed715b 100644 --- a/web/berry/src/views/Setting/component/OperationSetting.js +++ b/web/berry/src/views/Setting/component/OperationSetting.js @@ -371,7 +371,7 @@ const OperationSetting = () => { value={inputs.ChannelDisableThreshold} onChange={handleInputChange} label="最长响应时间" - placeholder="单位秒,当运行通道全部测试时,超过此时间将自动禁用通道" + placeholder="单位秒,当运行渠道全部测试时,超过此时间将自动禁用渠道" disabled={loading} /> @@ -392,7 +392,7 @@ const OperationSetting = () => { { } /> { GitHubOAuthEnabled: '', GitHubClientId: '', GitHubClientSecret: '', + LarkClientId: '', + LarkClientSecret: '', Notice: '', SMTPServer: '', SMTPPort: '', @@ -48,7 +50,9 @@ const SystemSetting = () => { TurnstileSecretKey: '', RegisterEnabled: '', EmailDomainRestrictionEnabled: '', - EmailDomainWhitelist: [] + EmailDomainWhitelist: [], + MessagePusherAddress: '', + MessagePusherToken: '' }); const [originInputs, setOriginInputs] = useState({}); let [loading, setLoading] = useState(false); @@ -134,7 +138,11 @@ const SystemSetting = () => { name === 'WeChatAccountQRCodeImageURL' || name === 'TurnstileSiteKey' || name === 'TurnstileSecretKey' || - name === 'EmailDomainWhitelist' + name === 'EmailDomainWhitelist' || + name === 'MessagePusherAddress' || + name === 'MessagePusherToken' || + name === 'LarkClientId' || + name === 'LarkClientSecret' ) { setInputs((inputs) => ({ ...inputs, [name]: value })); } else { @@ -199,6 +207,24 @@ const SystemSetting = () => { } }; + const submitMessagePusher = async () => { + if (originInputs['MessagePusherAddress'] !== inputs.MessagePusherAddress) { + await updateOption('MessagePusherAddress', removeTrailingSlash(inputs.MessagePusherAddress)); + } + if (originInputs['MessagePusherToken'] !== inputs.MessagePusherToken && inputs.MessagePusherToken !== '') { + await updateOption('MessagePusherToken', inputs.MessagePusherToken); + } + }; + + const submitLarkOAuth = async () => { + if (originInputs['LarkClientId'] !== inputs.LarkClientId) { + await updateOption('LarkClientId', inputs.LarkClientId); + } + if (originInputs['LarkClientSecret'] !== inputs.LarkClientSecret && inputs.LarkClientSecret !== '') { + await updateOption('LarkClientSecret', inputs.LarkClientSecret); + } + }; + return ( <> @@ -473,6 +499,61 @@ const SystemSetting = () => { + + {' '} + 用以支持通过飞书进行登录注册, + + 点击此处 + + 管理你的飞书应用 + + } + > + + + + 主页链接填 {inputs.ServerAddress} + ,重定向 URL 填 {`${inputs.ServerAddress}/oauth/lark`} + + + + + App ID + + + + + + App Secret + + + + + + + + { + + 用以推送报警信息, + + 点击此处 + + 了解 Message Pusher + + } + > + + + + Message Pusher 推送地址 + + + + + + Message Pusher 访问凭证 + + + + + + + + ; +const checkedIcon = ; +const filter = createFilterOptions(); const validationSchema = Yup.object().shape({ is_edit: Yup.boolean(), - name: Yup.string().required("名称 不能为空"), - remain_quota: Yup.number().min(0, "必须大于等于0"), + name: Yup.string().required('名称 不能为空'), + remain_quota: Yup.number().min(0, '必须大于等于0'), expired_time: Yup.number(), - unlimited_quota: Yup.boolean(), + unlimited_quota: Yup.boolean() }); const originInputs = { is_edit: false, - name: "", + name: '', remain_quota: 0, expired_time: -1, unlimited_quota: false, + subnet: '', + models: [] }; const EditModal = ({ open, tokenId, onCancel, onOk }) => { const theme = useTheme(); const [inputs, setInputs] = useState(originInputs); + const [modelOptions, setModelOptions] = useState([]); const submit = async (values, { setErrors, setStatus, setSubmitting }) => { setSubmitting(true); values.remain_quota = parseInt(values.remain_quota); let res; + let models = values.models.join(','); if (values.is_edit) { - res = await API.put(`/api/token/`, { ...values, id: parseInt(tokenId) }); + res = await API.put(`/api/token/`, { ...values, id: parseInt(tokenId), models: models }); } else { - res = await API.post(`/api/token/`, values); + res = await API.post(`/api/token/`, { ...values, models: models }); } const { success, message } = res.data; if (success) { if (values.is_edit) { - showSuccess("令牌更新成功!"); + showSuccess('令牌更新成功!'); } else { - showSuccess("令牌创建成功,请在列表页面点击复制获取令牌!"); + showSuccess('令牌创建成功,请在列表页面点击复制获取令牌!'); } setSubmitting(false); setStatus({ success: true }); @@ -78,61 +91,55 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => { const { success, message, data } = res.data; if (success) { data.is_edit = true; + if (data.models === '') { + data.models = []; + } else { + data.models = data.models.split(','); + } setInputs(data); } else { showError(message); } }; + const loadAvailableModels = async () => { + let res = await API.get(`/api/user/available_models`); + const { success, message, data } = res.data; + if (success) { + setModelOptions(data); + } else { + showError(message); + } + }; useEffect(() => { if (tokenId) { loadToken().then(); } else { - setInputs({...originInputs}); + setInputs({ ...originInputs }); } + loadAvailableModels().then(); }, [tokenId]); return ( - + - {tokenId ? "编辑Token" : "新建Token"} + {tokenId ? '编辑令牌' : '新建令牌'} - - 注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。 - - - {({ - errors, - handleBlur, - handleChange, - handleSubmit, - touched, - values, - setFieldError, - setFieldValue, - isSubmitting, - }) => ( + 注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。 + + {({ errors, handleBlur, handleChange, handleSubmit, touched, values, setFieldError, setFieldValue, isSubmitting }) => ( - + 名称 { name="name" onBlur={handleBlur} onChange={handleChange} - inputProps={{ autoComplete: "name" }} + inputProps={{ autoComplete: 'name' }} aria-describedby="helper-text-channel-name-label" /> {touched.name && errors.name && ( @@ -151,42 +158,99 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => { )} + + { + const event = { + target: { + name: 'models', + value: value + } + }; + handleChange(event); + }} + onBlur={handleBlur} + // filterSelectedOptions + disableCloseOnSelect + renderInput={(params) => } + filterOptions={(options, params) => { + const filtered = filter(options, params); + const { inputValue } = params; + const isExisting = options.some((option) => inputValue === option); + if (inputValue !== '' && !isExisting) { + filtered.push(inputValue); + } + return filtered; + }} + renderOption={(props, option, { selected }) => ( +
  • + + {option} +
  • + )} + /> + {errors.models ? ( + + {errors.models} + + ) : ( + 请选择允许使用的模型,留空则不进行限制 + )} +
    + + IP 限制 + + {touched.subnet && errors.subnet ? ( + + {errors.subnet} + + ) : ( + + 请输入允许访问的网段,例如:192.168.0.0/24,请使用英文逗号分隔多个网段 + + )} + {values.expired_time !== -1 && ( - - + + { if (newError === null) { - setFieldError("expired_time", null); + setFieldError('expired_time', null); } else { - setFieldError("expired_time", "无效的日期"); + setFieldError('expired_time', '无效的日期'); } }} onChange={(newValue) => { - setFieldValue("expired_time", newValue.unix()); + setFieldValue('expired_time', newValue.unix()); }} slotProps={{ actionBar: { - actions: ["today", "accept"], - }, + actions: ['today', 'accept'] + } }} /> {errors.expired_time && ( - + {errors.expired_time} )} @@ -196,35 +260,22 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => { checked={values.expired_time === -1} onClick={() => { if (values.expired_time === -1) { - setFieldValue( - "expired_time", - Math.floor(Date.now() / 1000) - ); + setFieldValue('expired_time', Math.floor(Date.now() / 1000)); } else { - setFieldValue("expired_time", -1); + setFieldValue('expired_time', -1); } }} - />{" "} + />{' '} 永不过期 - - - 额度 - + + 额度 - {renderQuotaWithPrompt(values.remain_quota)} - - } + endAdornment={{renderQuotaWithPrompt(values.remain_quota)}} onBlur={handleBlur} onChange={handleChange} aria-describedby="helper-text-channel-remain_quota-label" @@ -232,10 +283,7 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => { /> {touched.remain_quota && errors.remain_quota && ( - + {errors.remain_quota} )} @@ -243,19 +291,13 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => { { - setFieldValue("unlimited_quota", !values.unlimited_quota); + setFieldValue('unlimited_quota', !values.unlimited_quota); }} - />{" "} + />{' '} 无限额度 - @@ -273,5 +315,5 @@ EditModal.propTypes = { open: PropTypes.bool, tokenId: PropTypes.number, onCancel: PropTypes.func, - onOk: PropTypes.func, + onOk: PropTypes.func }; diff --git a/web/berry/src/views/Token/component/TableRow.js b/web/berry/src/views/Token/component/TableRow.js index 2753764c..51ab0d4b 100644 --- a/web/berry/src/views/Token/component/TableRow.js +++ b/web/berry/src/views/Token/component/TableRow.js @@ -28,7 +28,7 @@ const COPY_OPTIONS = [ { key: 'next', text: 'ChatGPT Next', - url: 'https://chat.oneapi.pro/#/?settings={"key":"sk-{key}","url":"{serverAddress}"}', + url: 'https://app.nextchat.dev/#/?settings={"key":"sk-{key}","url":"{serverAddress}"}', encode: false }, { key: 'ama', text: 'BotGem', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true }, diff --git a/web/berry/src/views/Token/index.js b/web/berry/src/views/Token/index.js index 97ece35f..b3315eb9 100644 --- a/web/berry/src/views/Token/index.js +++ b/web/berry/src/views/Token/index.js @@ -141,9 +141,8 @@ export default function Token() { return ( <> - + 令牌 - diff --git a/web/berry/src/views/User/index.js b/web/berry/src/views/User/index.js index 463f525a..e53e5bbb 100644 --- a/web/berry/src/views/User/index.js +++ b/web/berry/src/views/User/index.js @@ -139,7 +139,7 @@ export default function Users() { return ( <> - + 用户 - + - + diff --git a/web/default/package.json b/web/default/package.json index 5290f744..ba45011f 100644 --- a/web/default/package.json +++ b/web/default/package.json @@ -18,7 +18,7 @@ }, "scripts": { "start": "react-scripts start", - "build": "react-scripts build && mv -f build ../build/default", + "build": "react-scripts build && rm -rf ../build/default && mv -f build ../build/default", "test": "react-scripts test", "eject": "react-scripts eject" }, diff --git a/web/default/src/App.js b/web/default/src/App.js index 13c884dc..4ece4eeb 100644 --- a/web/default/src/App.js +++ b/web/default/src/App.js @@ -24,6 +24,7 @@ import EditRedemption from './pages/Redemption/EditRedemption'; import TopUp from './pages/TopUp'; import Log from './pages/Log'; import Chat from './pages/Chat'; +import LarkOAuth from './components/LarkOAuth'; const Home = lazy(() => import('./pages/Home')); const About = lazy(() => import('./pages/About')); @@ -239,6 +240,14 @@ function App() { } /> + }> + + + } + /> {type2label[type]?.text}; + return ; } function renderBalance(type, balance) { @@ -234,7 +234,7 @@ const ChannelsTable = () => { newChannels[realIdx].response_time = time * 1000; newChannels[realIdx].test_time = Date.now() / 1000; setChannels(newChannels); - showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); + showInfo(`渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); } else { showError(message); } @@ -244,7 +244,7 @@ const ChannelsTable = () => { const res = await API.get(`/api/channel/test?scope=${scope}`); const { success, message } = res.data; if (success) { - showInfo('已成功开始测试通道,请刷新页面查看结果。'); + showInfo('已成功开始测试渠道,请刷新页面查看结果。'); } else { showError(message); } @@ -270,7 +270,7 @@ const ChannelsTable = () => { newChannels[realIdx].balance = balance; newChannels[realIdx].balance_updated_time = Date.now() / 1000; setChannels(newChannels); - showInfo(`通道 ${name} 余额更新成功!`); + showInfo(`渠道 ${name} 余额更新成功!`); } else { showError(message); } @@ -281,7 +281,7 @@ const ChannelsTable = () => { const res = await API.get(`/api/channel/update_balance`); const { success, message } = res.data; if (success) { - showInfo('已更新完毕所有已启用通道余额!'); + showInfo('已更新完毕所有已启用渠道余额!'); } else { showError(message); } @@ -333,6 +333,8 @@ const ChannelsTable = () => { setPromptShown("channel-test"); }}> OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。 +
    + 渠道测试仅支持 chat 模型,优先使用 gpt-3.5-turbo,如果该模型不可用则使用你所配置的模型列表中的第一个模型。 ) } diff --git a/web/default/src/components/LarkOAuth.js b/web/default/src/components/LarkOAuth.js new file mode 100644 index 00000000..bc2fb682 --- /dev/null +++ b/web/default/src/components/LarkOAuth.js @@ -0,0 +1,58 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { Dimmer, Loader, Segment } from 'semantic-ui-react'; +import { useNavigate, useSearchParams } from 'react-router-dom'; +import { API, showError, showSuccess } from '../helpers'; +import { UserContext } from '../context/User'; + +const LarkOAuth = () => { + const [searchParams, setSearchParams] = useSearchParams(); + + const [userState, userDispatch] = useContext(UserContext); + const [prompt, setPrompt] = useState('处理中...'); + const [processing, setProcessing] = useState(true); + + let navigate = useNavigate(); + + const sendCode = async (code, state, count) => { + const res = await API.get(`/api/oauth/lark?code=${code}&state=${state}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/setting'); + } else { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } + } else { + showError(message); + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + navigate('/setting'); // in case this is failed to bind lark + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, count * 2000)); + await sendCode(code, state, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + let state = searchParams.get('state'); + sendCode(code, state, 0).then(); + }, []); + + return ( + + + {prompt} + + + ); +}; + +export default LarkOAuth; diff --git a/web/default/src/components/LoginForm.js b/web/default/src/components/LoginForm.js index a3913220..71566ef8 100644 --- a/web/default/src/components/LoginForm.js +++ b/web/default/src/components/LoginForm.js @@ -3,7 +3,8 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f import { Link, useNavigate, useSearchParams } from 'react-router-dom'; import { UserContext } from '../context/User'; import { API, getLogo, showError, showSuccess, showWarning } from '../helpers'; -import { onGitHubOAuthClicked } from './utils'; +import { onGitHubOAuthClicked, onLarkOAuthClicked } from './utils'; +import larkIcon from '../images/lark.svg'; const LoginForm = () => { const [inputs, setInputs] = useState({ @@ -94,7 +95,7 @@ const LoginForm = () => { fluid icon='user' iconPosition='left' - placeholder='用户名' + placeholder='用户名 / 邮箱地址' name='username' value={username} onChange={handleChange} @@ -124,29 +125,52 @@ const LoginForm = () => { 点击注册 - {status.github_oauth || status.wechat_login ? ( + {status.github_oauth || status.wechat_login || status.lark_client_id ? ( <> Or - {status.github_oauth ? ( - ) } + { + status.lark_client_id && ( + + ) + } + { const [activePage, setActivePage] = useState(1); const [searchKeyword, setSearchKeyword] = useState(''); const [searching, setSearching] = useState(false); + const [orderBy, setOrderBy] = useState(''); const loadUsers = async (startIdx) => { - const res = await API.get(`/api/user/?p=${startIdx}`); + const res = await API.get(`/api/user/?p=${startIdx}&order=${orderBy}`); const { success, message, data } = res.data; if (success) { if (startIdx === 0) { @@ -47,19 +48,19 @@ const UsersTable = () => { (async () => { if (activePage === Math.ceil(users.length / ITEMS_PER_PAGE) + 1) { // In this case we have to load more data and then append them. - await loadUsers(activePage - 1); + await loadUsers(activePage - 1, orderBy); } setActivePage(activePage); })(); }; useEffect(() => { - loadUsers(0) + loadUsers(0, orderBy) .then() .catch((reason) => { showError(reason); }); - }, []); + }, [orderBy]); const manageUser = (username, action, idx) => { (async () => { @@ -110,6 +111,7 @@ const UsersTable = () => { // if keyword is blank, load files instead. await loadUsers(0); setActivePage(1); + setOrderBy(''); return; } setSearching(true); @@ -148,6 +150,11 @@ const UsersTable = () => { setLoading(false); }; + const handleOrderByChange = (e, { value }) => { + setOrderBy(value); + setActivePage(1); + }; + return ( <> @@ -322,6 +329,19 @@ const UsersTable = () => { + \ No newline at end of file diff --git a/web/default/src/pages/Channel/EditChannel.js b/web/default/src/pages/Channel/EditChannel.js index 4de8e87a..5c7f13ff 100644 --- a/web/default/src/pages/Channel/EditChannel.js +++ b/web/default/src/pages/Channel/EditChannel.js @@ -54,6 +54,12 @@ const EditChannel = () => { const [basicModels, setBasicModels] = useState([]); const [fullModels, setFullModels] = useState([]); const [customModel, setCustomModel] = useState(''); + const [config, setConfig] = useState({ + region: '', + sk: '', + ak: '', + user_id: '' + }); const handleInputChange = (e, { name, value }) => { setInputs((inputs) => ({ ...inputs, [name]: value })); if (name === 'type') { @@ -65,6 +71,10 @@ const EditChannel = () => { } }; + const handleConfigChange = (e, { name, value }) => { + setConfig((inputs) => ({ ...inputs, [name]: value })); + }; + const loadChannel = async () => { let res = await API.get(`/api/channel/${channelId}`); const { success, message, data } = res.data; @@ -83,6 +93,10 @@ const EditChannel = () => { data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2); } setInputs(data); + if (data.config !== '') { + setConfig(JSON.parse(data.config)); + } + setBasicModels(getChannelModels(data.type)); } else { showError(message); } @@ -99,9 +113,6 @@ const EditChannel = () => { })); setOriginModelOptions(localModelOptions); setFullModels(res.data.data.map((model) => model.id)); - setBasicModels(res.data.data.filter((model) => { - return model.id.startsWith('gpt-3') || model.id.startsWith('text-'); - }).map((model) => model.id)); } catch (error) { showError(error.message); } @@ -137,12 +148,20 @@ const EditChannel = () => { useEffect(() => { if (isEdit) { loadChannel().then(); + } else { + let localModels = getChannelModels(inputs.type); + setBasicModels(localModels); } fetchModels().then(); fetchGroups().then(); }, []); const submit = async () => { + if (inputs.key === '') { + if (config.ak !== '' && config.sk !== '' && config.region !== '') { + inputs.key = `${config.ak}|${config.sk}|${config.region}`; + } + } if (!isEdit && (inputs.name === '' || inputs.key === '')) { showInfo('请填写渠道名称和渠道密钥!'); return; @@ -155,12 +174,12 @@ const EditChannel = () => { showInfo('模型映射必须是合法的 JSON 格式!'); return; } - let localInputs = inputs; + let localInputs = {...inputs}; if (localInputs.base_url && localInputs.base_url.endsWith('/')) { localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); } if (localInputs.type === 3 && localInputs.other === '') { - localInputs.other = '2023-06-01-preview'; + localInputs.other = '2024-03-01-preview'; } if (localInputs.type === 18 && localInputs.other === '') { localInputs.other = 'v2.1'; @@ -168,6 +187,7 @@ const EditChannel = () => { let res; localInputs.models = localInputs.models.join(','); localInputs.group = localInputs.groups.join(','); + localInputs.config = JSON.stringify(config); if (isEdit) { res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) }); } else { @@ -242,7 +262,7 @@ const EditChannel = () => { { ) } + { + inputs.type === 34 && ( + + 对于 Coze 而言,模型名称即 Bot ID,你可以添加一个前缀 `bot-`,例如:`bot-123456`。 + + ) + } { fluid multiple search - onLabelClick={(e, { value }) => {copy(value).then()}} + onLabelClick={(e, { value }) => { + copy(value).then(); + }} selection onChange={handleInputChange} value={inputs.models} @@ -355,7 +384,7 @@ const EditChannel = () => {
    + }}>填入相关模型 @@ -391,7 +420,52 @@ const EditChannel = () => { /> { - batch ? + inputs.type === 33 && ( + + + + + + ) + } + { + inputs.type === 34 && ( + ) + } + { + inputs.type !== 33 && (batch ? { value={inputs.key} autoComplete='new-password' /> - + ) } { - !isEdit && ( + inputs.type === 37 && ( + + + + ) + } + { + inputs.type !== 33 && !isEdit && ( { ) } { - inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( + inputs.type !== 3 && inputs.type !== 33 && inputs.type !== 8 && inputs.type !== 22 && ( { const params = useParams(); const tokenId = params.id; const isEdit = tokenId !== undefined; const [loading, setLoading] = useState(isEdit); + const [modelOptions, setModelOptions] = useState([]); const originInputs = { name: '', remain_quota: isEdit ? 0 : 500000, expired_time: -1, - unlimited_quota: false + unlimited_quota: false, + models: [], + subnet: "", }; const [inputs, setInputs] = useState(originInputs); const { name, remain_quota, expired_time, unlimited_quota } = inputs; @@ -22,8 +25,8 @@ const EditToken = () => { setInputs((inputs) => ({ ...inputs, [name]: value })); }; const handleCancel = () => { - navigate("/token"); - } + navigate('/token'); + }; const setExpiredTime = (month, day, hour, minute) => { let now = new Date(); let timestamp = now.getTime() / 1000; @@ -50,6 +53,11 @@ const EditToken = () => { if (data.expired_time !== -1) { data.expired_time = timestamp2string(data.expired_time); } + if (data.models === '') { + data.models = []; + } else { + data.models = data.models.split(','); + } setInputs(data); } else { showError(message); @@ -60,8 +68,26 @@ const EditToken = () => { if (isEdit) { loadToken().then(); } + loadAvailableModels().then(); }, []); + const loadAvailableModels = async () => { + let res = await API.get(`/api/user/available_models`); + const { success, message, data } = res.data; + if (success) { + let options = data.map((model) => { + return { + key: model, + text: model, + value: model + }; + }); + setModelOptions(options); + } else { + showError(message); + } + }; + const submit = async () => { if (!isEdit && inputs.name === '') return; let localInputs = inputs; @@ -74,6 +100,7 @@ const EditToken = () => { } localInputs.expired_time = Math.ceil(time / 1000); } + localInputs.models = localInputs.models.join(','); let res; if (isEdit) { res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) }); @@ -109,6 +136,34 @@ const EditToken = () => { required={!isEdit} /> + + { + copy(value).then(); + }} + selection + onChange={handleInputChange} + value={inputs.models} + autoComplete='new-password' + options={modelOptions} + /> + + + + { const [topUpLink, setTopUpLink] = useState(''); const [userQuota, setUserQuota] = useState(0); const [isSubmitting, setIsSubmitting] = useState(false); + const [user, setUser] = useState({}); const topUp = async () => { if (redemptionCode === '') { @@ -41,7 +42,14 @@ const TopUp = () => { showError('超级管理员未设置充值链接!'); return; } - window.open(topUpLink, '_blank'); + let url = new URL(topUpLink); + let username = user.username; + let user_id = user.id; + // add username and user_id to the topup link + url.searchParams.append('username', username); + url.searchParams.append('user_id', user_id); + url.searchParams.append('transaction_id', crypto.randomUUID()); + window.open(url.toString(), '_blank'); }; const getUserQuota = async ()=>{ @@ -49,6 +57,7 @@ const TopUp = () => { const {success, message, data} = res.data; if (success) { setUserQuota(data.quota); + setUser(data); } else { showError(message); } @@ -80,7 +89,7 @@ const TopUp = () => { }} />