Compare commits

..

48 Commits

Author SHA1 Message Date
抒情熊
fdd7bf41c0
feat: support multipart/form-data format request (#1690)
* "add parser multipart/form-data"

* chore: fix impl

* chore: update impl

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-09-22 17:32:47 +08:00
徐瑞东
29389ed44f
fix: modify the type of token models to be text (#1761)
* fix: modify the type of token models to be text

* chore: update receiver name

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-09-22 16:51:16 +08:00
byte911
88acc5a614 fix: return the usage info if not null (#1792)
Usage is missing.
2024-09-22 16:41:10 +08:00
TimeTrapzz
a21681096a
feat: add siliconflow usage (#1798) 2024-09-22 16:31:26 +08:00
lihangfu
32f90a79a8
feat: support SparkDesk-v3.1-128K (#1732)
* feat: 支持SparkDesk-v3.1-128K以及hunyuan-vision

* feat: 支持SparkDesk-v3.1-128K以及hunyuan-vision

---------

Co-authored-by: lihangfu <hfli8@iflytek.com>
2024-09-22 16:29:09 +08:00
OnEvent
99c8c77504
feat: add oidc support (#1725)
* feat: add the ui for configuring the third-party standard OAuth2.0/OIDC.

- update SystemSetting.js
- add setup ui
- add configuration

* feat: add the ui for "allow the OAuth 2.0 to login"

- update SystemSetting.js

* feat: add OAuth 2.0 web ui and its process functions

- update common.js
- update AuthLogin.js
- update config.js

* fix: missing "Userinfo" endpoint configuration entry, used by OAuth clients to request user information from the IdP.

- update config.js
- update SystemSetting.js

* feat: updated the icons for Lark and OIDC to match the style of the icons for WeChat, EMail, GitHub.

- update lark.svg
- new oidc.svg

* refactor: Changing OAuth 2.0 to OIDC

* feat: add OIDC login method

* feat: Add support for OIDC login to the backend

* fix: Change the AppId and AppSecret on the Web UI to the standard usage: ClientId, ClientSecret.

* feat: Support quick configuration of OIDC through Well-Known Discovery Endpoint

* feat: Standardize terminology, add well-known configuration

- Change the AppId and AppSecret on the Server End to the standard usage: ClientId, ClientSecret.
- add Well-Known configuration to store in database, no actual use in server end but store and display in web ui only
2024-09-21 23:03:20 +08:00
TAKO
649ecbf29c
feat: support new openai models (4o 0806, chatgpt-4o-latest) (#1721)
* feat: support new model gpt-4o-2024-08-06

* feat: support new model chatgpt-4o-latest
2024-09-21 23:01:19 +08:00
qinguoyi
3a27c90910
fix: getTokenById return token nil, make panic (#1728)
* fix:getTokenById return token nil, make panic

* chore: remove useless err check

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-09-21 23:00:29 +08:00
千寻简
cba82404ae
feat: add lobechat open link options (#1741)
Co-authored-by: Star <iii9777@163.com>
2024-09-21 22:49:31 +08:00
forrestlinfeng
c9ac670ba1
feat: update stepfun models (#1740)
Co-authored-by: chenlinfeng <chenlinfeng@step.ai>
2024-09-21 22:48:46 +08:00
leavegee
15f815c23c
fix: fix ali embedding model always use v1 (#1747)
* fix:ali embedding model: v2 and v3

* chore: use ctxkey.RequestModel to eliminate hardcoding

---------

Co-authored-by: xuejia <gexuejia@djbx.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-09-21 22:40:06 +08:00
majian
89b63ca96f
feat: ResponseFormat support json_schema (#1759)
* feat: responseFormat support json_schema

* chore: rename struct name

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-09-21 22:35:24 +08:00
Ghostz
8cc54489b9
feat: update disabled channel (#1780)
* Update disabled channel

* Update manage.go

* Update manage.go

* chore: add missing space

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2024-09-21 22:31:53 +08:00
guogeer
58bf60805e
fix: postgres use COALESCE replace null (#1793)
Co-authored-by: jinqi.guo <jinqi.guo@ubtrobot.com>
2024-09-21 22:13:31 +08:00
AJ's Life Journey
6714cf96d6
fix: Groq organization not auto-disabled when blocked (#1822) 2024-09-21 22:12:09 +08:00
longkeyy
f9774698e9
feat: synchronize with the official release of the groq model (#1677)
update groq add gemma2-9b-it llama3.1 family fixup price k/token -> m/token
2024-08-06 23:51:08 +08:00
TAKO
2af6f6a166 feat: add Cloudflare New Free Model Llama 3.1 8b (#1703) 2024-08-06 23:49:48 +08:00
MotorBottle
04bb3ef392
feat: add Max Tokens and Context Window Setting Options for Ollama Channel (#1694)
* Update main.go with max_tokens param

* Update model.go with max_tokens param

* Update model.go

* Update main.go

* Update main.go

* Adds num_ctx param for Ollama Channel

* Added num_ctx param for ollama adapter

* Added num_ctx param for ollama adapter

* Improved data process logic
2024-08-06 23:44:37 +08:00
longkeyy
b4bfa418a8
feat: update gemini model and price (#1705) 2024-08-06 23:43:33 +08:00
SLKun
e7e99e558a
feat: update Ollama embedding API to latest version with multi-text embedding support (#1715) 2024-08-06 23:43:20 +08:00
Shenghang Tsai
402fcf7f79
feat: add SiliconFlow (#1717)
* Add SiliconFlow

* Update README.md

* Update README.md

* Update channel.constants.js

* Update ChannelConstants.js

* Update channel.constants.js

* Update ChannelConstants.js

* Update compatible.go

* Update README.md
2024-08-06 23:42:25 +08:00
Junyan Qin
36039e329e
docs: update introduction for QChatGPT (#1707) 2024-08-06 23:33:43 +08:00
Laisky.Cai
c936198ac8
feat: add Proxy channel type and relay mode (#1678)
Add the Proxy channel type and relay mode to support proxying requests to custom upstream services.
2024-07-22 22:51:19 +08:00
TAKO
296ab013b8
feat: support gpt-4o mini (#1665)
* feat: support gpt-4o mini

* feat: fix gpt-4o mini image price
2024-07-22 22:44:08 +08:00
zijiren
5f03c856b4
feat: fast build linux/arm64 frontend (#1663)
* feat: fast build linux/arm64 frontend

* fix: dockerfile as replace to AS

* fix: trim space
2024-07-22 22:39:22 +08:00
igophper
39383e5532
fix: support embedding models for doubao (#1662)
Fixes #1594
2024-07-22 22:38:50 +08:00
JustSong
2a892c1937 revert: feat: fast build linux/arm64 frontend (#1645)
This reverts commit 1c44d7e1cd.
2024-07-17 22:50:52 +08:00
Laisky.Cai
adba54acd3
fix: implement improved headers for anthropic to support 8k outputs (#1654) 2024-07-16 23:48:54 +08:00
zijiren
6209ff9ea9
feat: vertexai support proxy url(example: cloudflare ai gateway) and fix some vertexai bug (#1642)
* feat: vertexai support proxy url(example: cloudflare ai gateway)

* fix: do resp model mapping

* fix: missing system

* fix: stream need query alt=sse
2024-07-16 01:02:06 +08:00
zijiren
1c44d7e1cd
feat: fast build linux/arm64 frontend (#1645) 2024-07-14 18:06:11 +08:00
zijiren
a3eefb7af0
fix: rate limit can be zero (#1643) 2024-07-14 18:03:23 +08:00
dependabot[bot]
b65bee46fb
chore(deps): bump google.golang.org/grpc from 1.64.0 to 1.64.1 (#1641)
Bumps [google.golang.org/grpc](https://github.com/grpc/grpc-go) from 1.64.0 to 1.64.1.
- [Release notes](https://github.com/grpc/grpc-go/releases)
- [Commits](https://github.com/grpc/grpc-go/compare/v1.64.0...v1.64.1)

---
updated-dependencies:
- dependency-name: google.golang.org/grpc
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-07-14 17:58:38 +08:00
F1ynn Zhan9
422a4e8ee5
feat: add field stop in GeneralOpenAIRequest (#1637) 2024-07-14 17:57:16 +08:00
LiuVaayne
cf9b5f0b92
feat: support claude and gemini in vertex ai (#1621)
* feat: support claude and gemini in vertex ai

* fix: do not show api key field in channel page when the type is VertexAI

* fix: update getToken function to include channelId in cache key
2024-07-13 14:59:28 +08:00
Ghostz
65acb94f45
fix: text filed check for 4v request (#1634) 2024-07-13 14:57:08 +08:00
zijiren
6ad169975f
fix: impl cloudflare worker ai gateway (#1617) 2024-07-09 22:57:06 +08:00
Qiying Wang
f636c50c84
fix: duplicate [DONE] (#1629) 2024-07-09 22:43:59 +08:00
Qiying Wang
720fe2dfeb
feat: refactor AwsClaude to Aws to support both llama3 and claude (#1601)
* feat: refactor AwsClaude to Aws to support both llama3 and claude

* fix: aws llama3 ratio
2024-07-06 13:19:41 +08:00
Jason
e090e76c86
feat: add Novita AI as model provider (#1609) 2024-07-06 13:16:46 +08:00
open source
6a941748f8
feat: add initial root access token (#1598)
Signed-off-by: xiaobo <peterwillcn@gmail.com>
2024-07-06 13:15:17 +08:00
open source
46a0773580
fix: update readme docs (#1599)
Signed-off-by: xiaobo <peterwillcn@gmail.com>
2024-07-06 13:14:32 +08:00
zijiren
ffdb0b0c81
fix: use musl libc (#1597) 2024-07-06 13:14:07 +08:00
zijiren
efd30a40b3
feat: cloudflare support native openai api (#1596) 2024-07-06 13:12:30 +08:00
Qiying Wang
d7a78f3397
feat: support test specific model (#1600) 2024-07-05 18:05:16 +08:00
Leo Q
273be55797
feat(ui): show available models for air theme (#1595)
* feat(ui): air 主题显示可用模型

* chore: 改为全角括号
2024-07-04 08:35:41 +08:00
Leo Q
ec6ad24810
feat: support smtp without auth (#1101) 2024-07-03 22:23:49 +08:00
LinZeliang
c4fe57c165
feat: support one or more log file (#1400)
Co-authored-by: Laisky.Cai <github@laisky.com>
2024-07-03 20:53:29 +08:00
igophper
274fcf3d76
refactor: init db (#1590)
Co-authored-by: 江杭辉 <jianghanghui@k.app>
2024-07-03 20:50:40 +08:00
112 changed files with 3045 additions and 717 deletions

View File

@ -1,61 +0,0 @@
name: Publish Docker image (amd64)
on:
push:
tags:
- 'v*.*.*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs:
push_to_registries:
name: Push Docker image to multiple registries
runs-on: ubuntu-latest
permissions:
packages: write
contents: read
steps:
- name: Check out the repo
uses: actions/checkout@v3
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- name: Save version info
run: |
git describe --tags > VERSION
- name: Log in to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to the Container registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
with:
images: |
justsong/one-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images
uses: docker/build-push-action@v3
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@ -1,4 +1,4 @@
name: Publish Docker image (amd64, English) name: Publish Docker image (English)
on: on:
push: push:
@ -34,6 +34,13 @@ jobs:
- name: Translate - name: Translate
run: | run: |
python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Log in to Docker Hub - name: Log in to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
with: with:
@ -51,6 +58,7 @@ jobs:
uses: docker/build-push-action@v3 uses: docker/build-push-action@v3
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64
push: true push: true
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}

View File

@ -1,10 +1,9 @@
name: Publish Docker image (arm64) name: Publish Docker image
on: on:
push: push:
tags: tags:
- 'v*.*.*' - 'v*.*.*'
- '!*-alpha*'
workflow_dispatch: workflow_dispatch:
inputs: inputs:
name: name:

View File

@ -1,4 +1,4 @@
FROM node:16 as builder FROM --platform=$BUILDPLATFORM node:16 AS builder
WORKDIR /web WORKDIR /web
COPY ./VERSION . COPY ./VERSION .
@ -16,7 +16,9 @@ WORKDIR /web/air
RUN npm install RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang AS builder2 FROM golang:alpine AS builder2
RUN apk add --no-cache g++
ENV GO111MODULE=on \ ENV GO111MODULE=on \
CGO_ENABLED=1 \ CGO_ENABLED=1 \
@ -27,7 +29,7 @@ ADD go.mod go.sum ./
RUN go mod download RUN go mod download
COPY . . COPY . .
COPY --from=builder /web/build ./web/build COPY --from=builder /web/build ./web/build
RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine FROM alpine

View File

@ -245,16 +245,41 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs` + Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs`
5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. 5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address.
+ Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn`
6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. 6. 'MEMORY_CACHE_ENABLED': Enabling memory caching can cause a certain delay in updating user quotas, with optional values of 'true' and 'false'. If not set, it defaults to 'false'.
7. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
+ Example: `SYNC_FREQUENCY=60` + Example: `SYNC_FREQUENCY=60`
7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. 8. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
+ Example: `NODE_TYPE=slave` + Example: `NODE_TYPE=slave`
8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. 9. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
+ Example: `CHANNEL_UPDATE_FREQUENCY=1440` + Example: `CHANNEL_UPDATE_FREQUENCY=1440`
9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. 10. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
+ Example: `CHANNEL_TEST_FREQUENCY=1440` + Example: `CHANNEL_TEST_FREQUENCY=1440`
10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. 11. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
+ Example: `POLLING_INTERVAL=5` + Example: `POLLING_INTERVAL=5`
12. `BATCH_UPDATE_ENABLED`: Enabling batch database update aggregation can cause a certain delay in updating user quotas. The optional values are 'true' and 'false', but if not set, it defaults to 'false'.
+Example: ` BATCH_UPDATE_ENABLED=true`
+If you encounter an issue with too many database connections, you can try enabling this option.
13. `BATCH_UPDATE_INTERVAL=5`: The time interval for batch updating aggregates, measured in seconds, defaults to '5'.
+Example: ` BATCH_UPDATE_INTERVAL=5`
14. Request frequency limit:
+ `GLOBAL_API_RATE_LIMIT`: Global API rate limit (excluding relay requests), the maximum number of requests within three minutes per IP, default to 180.
+ `GLOBAL_WEL_RATE_LIMIT`: Global web speed limit, the maximum number of requests within three minutes per IP, default to 60.
15. Encoder cache settings:
+`TIKTOKEN_CACHE_DIR`: By default, when the program starts, it will download the encoding of some common word elements online, such as' gpt-3.5 turbo '. In some unstable network environments or offline situations, it may cause startup problems. This directory can be configured to cache data and can be migrated to an offline environment.
+`DATA_GYM_CACHE_DIR`: Currently, this configuration has the same function as' TIKTOKEN-CACHE-DIR ', but its priority is not as high as it.
16. `RELAY_TIMEOUT`: Relay timeout setting, measured in seconds, with no default timeout time set.
17. `RELAY_PROXY`: After setting up, use this proxy to request APIs.
18. `USER_CONTENT_REQUEST_TIMEOUT`: The timeout period for users to upload and download content, measured in seconds.
19. `USER_CONTENT_REQUEST_PROXY`: After setting up, use this agent to request content uploaded by users, such as images.
20. `SQLITE_BUSY_TIMEOUT`: SQLite lock wait timeout setting, measured in milliseconds, default to '3000'.
21. `GEMINI_SAFETY_SETTING`: Gemini's security settings are set to 'BLOCK-NONE' by default.
22. `GEMINI_VERSION`: The Gemini version used by the One API, which defaults to 'v1'.
23. `THE`: The system's theme setting, default to 'default', specific optional values refer to [here] (./web/README. md).
24. `ENABLE_METRIC`: Whether to disable channels based on request success rate, default not enabled, optional values are 'true' and 'false'.
25. `METRIC_QUEUE_SIZE`: Request success rate statistics queue size, default to '10'.
26. `METRIC_SUCCESS_RATE_THRESHOLD`: Request success rate threshold, default to '0.8'.
27. `INITIAL_ROOT_TOKEN`: If this value is set, a root user token with the value of the environment variable will be automatically created when the system starts for the first time.
28. `INITIAL_ROOT_ACCESS_TOKEN`: If this value is set, a system management token will be automatically created for the root user with a value of the environment variable when the system starts for the first time.
### Command Line Parameters ### Command Line Parameters
1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`. 1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`.

View File

@ -88,6 +88,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/)
+ [x] [DeepL](https://www.deepl.com/) + [x] [DeepL](https://www.deepl.com/)
+ [x] [together.ai](https://www.together.ai/) + [x] [together.ai](https://www.together.ai/)
+ [x] [novita.ai](https://www.novita.ai/)
+ [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。 3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
@ -250,9 +252,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
#### QChatGPT - QQ机器人 #### QChatGPT - QQ机器人
项目主页https://github.com/RockChinQ/QChatGPT 项目主页https://github.com/RockChinQ/QChatGPT
根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。 根据[文档](https://qchatgpt.rockchin.top)完成部署后,在 `data/provider.json`设置`requester.openai-chat-completions.base-url`为 One API 实例地址,并填写 API Key 到 `keys.openai` 组中,设置 `model` 为要使用的模型名称。
可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。 运行期间可以通过`!model`命令查看、切换可用模型。
### 部署到第三方平台 ### 部署到第三方平台
<details> <details>
@ -370,32 +372,33 @@ graph LR
9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440` + 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
11. 例子:`CHANNEL_TEST_FREQUENCY=1440` +例子:`CHANNEL_TEST_FREQUENCY=1440`
12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5` + 例子:`POLLING_INTERVAL=5`
13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true``false`,未设置则默认为 `false` 12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`BATCH_UPDATE_ENABLED=true` + 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5` 13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`
+ 例子:`BATCH_UPDATE_INTERVAL=5` + 例子:`BATCH_UPDATE_INTERVAL=5`
15. 请求频率限制: 14. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180` + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60` + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`
16. 编码器缓存设置: 15. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
18. `RELAY_PROXY`:设置后使用该代理来请求 API。 17. `RELAY_PROXY`:设置后使用该代理来请求 API。
19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 18. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。
20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 19. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。
21. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000` 20. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`
22. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE` 21. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`
23. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1` 22. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1`
24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 23. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true``false` 24. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true``false`
26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10` 25. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`
27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8` 26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`
28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000` 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`

View File

@ -35,6 +35,7 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false var GitHubOAuthEnabled = false
var OidcEnabled = false
var WeChatAuthEnabled = false var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false var TurnstileCheckEnabled = false
var RegisterEnabled = true var RegisterEnabled = true
@ -70,6 +71,13 @@ var GitHubClientSecret = ""
var LarkClientId = "" var LarkClientId = ""
var LarkClientSecret = "" var LarkClientSecret = ""
var OidcClientId = ""
var OidcClientSecret = ""
var OidcWellKnown = ""
var OidcAuthorizationEndpoint = ""
var OidcTokenEndpoint = ""
var OidcUserinfoEndpoint = ""
var WeChatServerAddress = "" var WeChatServerAddress = ""
var WeChatServerToken = "" var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = "" var WeChatAccountQRCodeImageURL = ""
@ -143,8 +151,12 @@ var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN")
var GeminiVersion = env.String("GEMINI_VERSION", "v1") var GeminiVersion = env.String("GEMINI_VERSION", "v1")
var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
var RelayProxy = env.String("RELAY_PROXY", "") var RelayProxy = env.String("RELAY_PROXY", "")
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)

View File

@ -31,15 +31,15 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
contentType := c.Request.Header.Get("Content-Type") contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") { if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v) err = json.Unmarshal(requestBody, &v)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
} else { } else {
// skip for now c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
// TODO: someday non json request have variant model, we will need to implementation this err = c.ShouldBind(&v)
} }
if err != nil { if err != nil {
return err return err
} }
// Reset request body // Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil return nil
} }

View File

@ -27,7 +27,12 @@ var setupLogOnce sync.Once
func SetupLogger() { func SetupLogger() {
setupLogOnce.Do(func() { setupLogOnce.Do(func() {
if LogDir != "" { if LogDir != "" {
logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) var logPath string
if config.OnlyOneLogFile {
logPath = filepath.Join(LogDir, "oneapi.log")
} else {
logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
}
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
log.Fatal("failed to open log file") log.Fatal("failed to open log file")

View File

@ -6,11 +6,16 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"net"
"net/smtp" "net/smtp"
"strings" "strings"
"time" "time"
) )
func shouldAuth() bool {
return config.SMTPAccount != "" || config.SMTPToken != ""
}
func SendEmail(subject string, receiver string, content string) error { func SendEmail(subject string, receiver string, content string) error {
if receiver == "" { if receiver == "" {
return fmt.Errorf("receiver is empty") return fmt.Errorf("receiver is empty")
@ -41,16 +46,24 @@ func SendEmail(subject string, receiver string, content string) error {
"Date: %s\r\n"+ "Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
to := strings.Split(receiver, ";") to := strings.Split(receiver, ";")
if config.SMTPPort == 465 || !shouldAuth() {
// need advanced client
var conn net.Conn
var err error
if config.SMTPPort == 465 { if config.SMTPPort == 465 {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: config.SMTPServer, ServerName: config.SMTPServer,
} }
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
} else {
conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort))
}
if err != nil { if err != nil {
return err return err
} }
@ -59,9 +72,11 @@ func SendEmail(subject string, receiver string, content string) error {
return err return err
} }
defer client.Close() defer client.Close()
if shouldAuth() {
if err = client.Auth(auth); err != nil { if err = client.Auth(auth); err != nil {
return err return err
} }
}
if err = client.Mail(config.SMTPFrom); err != nil { if err = client.Mail(config.SMTPFrom); err != nil {
return err return err
} }

225
controller/auth/oidc.go Normal file
View File

@ -0,0 +1,225 @@
package auth
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
)
type OidcResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type OidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{
"client_id": config.OidcClientId,
"client_secret": config.OidcClientSecret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress),
}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res.Body.Close()
var oidcResponse OidcResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
var oidcUser OidcUser
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
if err != nil {
return nil, err
}
return &oidcUser, nil
}
func OidcAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
OidcBind(c)
return
}
if !config.OidcEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
err := user.FillUserByOidcId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if config.RegisterEnabled {
user.Email = oidcUser.Email
if oidcUser.PreferredUsername != "" {
user.Username = oidcUser.PreferredUsername
} else {
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if oidcUser.Name != "" {
user.DisplayName = oidcUser.Name
} else {
user.DisplayName = "OIDC User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
controller.SetupLogin(&user, c)
}
func OidcBind(c *gin.Context) {
if !config.OidcEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 OIDC 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.OidcId = oidcUser.OpenID
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@ -17,9 +17,11 @@ func GetSubscription(c *gin.Context) {
if config.DisplayTokenStatEnabled { if config.DisplayTokenStatEnabled {
tokenId := c.GetInt(ctxkey.TokenId) tokenId := c.GetInt(ctxkey.TokenId)
token, err = model.GetTokenById(tokenId) token, err = model.GetTokenById(tokenId)
if err == nil {
expiredTime = token.ExpiredTime expiredTime = token.ExpiredTime
remainQuota = token.RemainQuota remainQuota = token.RemainQuota
usedQuota = token.UsedQuota usedQuota = token.UsedQuota
}
} else { } else {
userId := c.GetInt(ctxkey.Id) userId := c.GetInt(ctxkey.Id)
remainQuota, err = model.GetUserQuota(userId) remainQuota, err = model.GetUserQuota(userId)

View File

@ -81,6 +81,26 @@ type APGC2DGPTUsageResponse struct {
TotalUsed float64 `json:"total_used"` TotalUsed float64 `json:"total_used"`
} }
type SiliconFlowUsageResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Status bool `json:"status"`
Data struct {
ID string `json:"id"`
Name string `json:"name"`
Image string `json:"image"`
Email string `json:"email"`
IsAdmin bool `json:"isAdmin"`
Balance string `json:"balance"`
Status string `json:"status"`
Introduction string `json:"introduction"`
Role string `json:"role"`
ChargeBalance string `json:"chargeBalance"`
TotalBalance string `json:"totalBalance"`
Category string `json:"category"`
} `json:"data"`
}
// GetAuthHeader get auth header // GetAuthHeader get auth header
func GetAuthHeader(token string) http.Header { func GetAuthHeader(token string) http.Header {
h := http.Header{} h := http.Header{}
@ -203,6 +223,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
return response.TotalAvailable, nil return response.TotalAvailable, nil
} }
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
url := "https://api.siliconflow.cn/v1/user/info"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := SiliconFlowUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if response.Code != 20000 {
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
}
balance, err := strconv.ParseFloat(response.Data.Balance, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) { func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := channeltype.ChannelBaseURLs[channel.Type] baseURL := channeltype.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" { if channel.GetBaseURL() == "" {
@ -227,6 +269,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return updateChannelAPI2GPTBalance(channel) return updateChannelAPI2GPTBalance(channel)
case channeltype.AIGC2D: case channeltype.AIGC2D:
return updateChannelAIGC2DBalance(channel) return updateChannelAIGC2DBalance(channel)
case channeltype.SiliconFlow:
return updateChannelSiliconFlowBalance(channel)
default: default:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
} }

View File

@ -14,6 +14,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
@ -27,15 +28,15 @@ import (
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"github.com/gin-gonic/gin"
) )
func buildTestRequest() *relaymodel.GeneralOpenAIRequest { func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
if model == "" {
model = "gpt-3.5-turbo"
}
testRequest := &relaymodel.GeneralOpenAIRequest{ testRequest := &relaymodel.GeneralOpenAIRequest{
MaxTokens: 2, MaxTokens: 2,
Stream: false, Model: model,
Model: "gpt-3.5-turbo",
} }
testMessage := relaymodel.Message{ testMessage := relaymodel.Message{
Role: "user", Role: "user",
@ -45,7 +46,7 @@ func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
return testRequest return testRequest
} }
func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) { func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{ c.Request = &http.Request{
@ -68,12 +69,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
adaptor.Init(meta) adaptor.Init(meta)
var modelName string modelName := request.Model
modelList := adaptor.GetModelList()
modelMap := channel.GetModelMapping() modelMap := channel.GetModelMapping()
if len(modelList) != 0 {
modelName = modelList[0]
}
if modelName == "" || !strings.Contains(channel.Models, modelName) { if modelName == "" || !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",") modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 { if len(modelNames) > 0 {
@ -83,9 +80,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
modelName = modelMap[modelName] modelName = modelMap[modelName]
} }
} }
request := buildTestRequest() meta.OriginModelName, meta.ActualModelName = request.Model, modelName
request.Model = modelName request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil { if err != nil {
return err, nil return err, nil
@ -139,10 +135,15 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
model := c.Query("model")
testRequest := buildTestRequest(model)
tik := time.Now() tik := time.Now()
err, _ = testChannel(channel) err, _ = testChannel(channel, testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if err != nil {
milliseconds = 0
}
go channel.UpdateResponseTime(milliseconds) go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0 consumedTime := float64(milliseconds) / 1000.0
if err != nil { if err != nil {
@ -150,6 +151,7 @@ func TestChannel(c *gin.Context) {
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
"time": consumedTime, "time": consumedTime,
"model": model,
}) })
return return
} }
@ -157,6 +159,7 @@ func TestChannel(c *gin.Context) {
"success": true, "success": true,
"message": "", "message": "",
"time": consumedTime, "time": consumedTime,
"model": model,
}) })
return return
} }
@ -187,11 +190,12 @@ func testChannels(notify bool, scope string) error {
for _, channel := range channels { for _, channel := range channels {
isChannelEnabled := channel.Status == model.ChannelStatusEnabled isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
err, openaiErr := testChannel(channel) testRequest := buildTestRequest("")
err, openaiErr := testChannel(channel, testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold { if isChannelEnabled && milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
if config.AutomaticDisableChannelEnabled { if config.AutomaticDisableChannelEnabled {
monitor.DisableChannel(channel.Id, channel.Name, err.Error()) monitor.DisableChannel(channel.Id, channel.Name, err.Error())
} else { } else {

View File

@ -36,6 +36,12 @@ func GetStatus(c *gin.Context) {
"chat_link": config.ChatLink, "chat_link": config.ChatLink,
"quota_per_unit": config.QuotaPerUnit, "quota_per_unit": config.QuotaPerUnit,
"display_in_currency": config.DisplayInCurrencyEnabled, "display_in_currency": config.DisplayInCurrencyEnabled,
"oidc": config.OidcEnabled,
"oidc_client_id": config.OidcClientId,
"oidc_well_known": config.OidcWellKnown,
"oidc_authorization_endpoint": config.OidcAuthorizationEndpoint,
"oidc_token_endpoint": config.OidcTokenEndpoint,
"oidc_userinfo_endpoint": config.OidcUserinfoEndpoint,
}, },
}) })
return return

View File

@ -34,6 +34,8 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
fallthrough fallthrough
case relaymode.AudioTranscription: case relaymode.AudioTranscription:
err = controller.RelayAudioHelper(c, relayMode) err = controller.RelayAudioHelper(c, relayMode)
case relaymode.Proxy:
err = controller.RelayProxyHelper(c, relayMode)
default: default:
err = controller.RelayTextHelper(c) err = controller.RelayTextHelper(c)
} }
@ -85,12 +87,15 @@ func Relay(c *gin.Context) {
channelId := c.GetInt(ctxkey.ChannelId) channelId := c.GetInt(ctxkey.ChannelId)
lastFailedChannelId = channelId lastFailedChannelId = channelId
channelName := c.GetString(ctxkey.ChannelName) channelName := c.GetString(ctxkey.ChannelName)
// BUG: bizErr is in race condition
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
} }
if bizErr != nil { if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests { if bizErr.StatusCode == http.StatusTooManyRequests {
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
} }
// BUG: bizErr is in race condition
bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId) bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
c.JSON(bizErr.StatusCode, gin.H{ c.JSON(bizErr.StatusCode, gin.H{
"error": bizErr.Error, "error": bizErr.Error,

35
go.mod
View File

@ -4,6 +4,7 @@ module github.com/songquanpeng/one-api
go 1.20 go 1.20
require ( require (
cloud.google.com/go/iam v1.1.10
github.com/aws/aws-sdk-go-v2 v1.27.0 github.com/aws/aws-sdk-go-v2 v1.27.0
github.com/aws/aws-sdk-go-v2/credentials v1.17.15 github.com/aws/aws-sdk-go-v2/credentials v1.17.15
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3
@ -19,12 +20,14 @@ require (
github.com/gorilla/websocket v1.5.1 github.com/gorilla/websocket v1.5.1
github.com/jinzhu/copier v0.4.0 github.com/jinzhu/copier v0.4.0
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/pkoukk/tiktoken-go v0.1.7 github.com/pkoukk/tiktoken-go v0.1.7
github.com/smartystreets/goconvey v1.8.1 github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.23.0 golang.org/x/crypto v0.24.0
golang.org/x/image v0.18.0 golang.org/x/image v0.18.0
google.golang.org/api v0.187.0
gorm.io/driver/mysql v1.5.6 gorm.io/driver/mysql v1.5.6
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.5 gorm.io/driver/sqlite v1.5.5
@ -32,6 +35,9 @@ require (
) )
require ( require (
cloud.google.com/go/auth v0.6.1 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
cloud.google.com/go/compute/metadata v0.3.0 // indirect
filippo.io/edwards25519 v1.1.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect
@ -45,13 +51,21 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/goccy/go-json v0.10.3 // indirect github.com/goccy/go-json v0.10.3 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.5 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/gorilla/context v1.1.2 // indirect github.com/gorilla/context v1.1.2 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect github.com/gorilla/securecookie v1.1.2 // indirect
@ -68,7 +82,7 @@ require (
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect
@ -76,11 +90,22 @@ require (
github.com/smarty/assertions v1.15.0 // indirect github.com/smarty/assertions v1.15.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.24.0 // indirect
go.opentelemetry.io/otel/metric v1.24.0 // indirect
go.opentelemetry.io/otel/trace v1.24.0 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/net v0.25.0 // indirect golang.org/x/net v0.26.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
golang.org/x/sync v0.7.0 // indirect golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.20.0 // indirect golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect golang.org/x/text v0.16.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect golang.org/x/time v0.5.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
google.golang.org/grpc v1.64.1 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

152
go.sum
View File

@ -1,5 +1,15 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go/auth v0.6.1 h1:T0Zw1XM5c1GlpN2HYr2s+m3vr1p2wy+8VN+Z1FKxW38=
cloud.google.com/go/auth v0.6.1/go.mod h1:eFHG7zDzbXHKmjJddFG/rBlcGp6t25SwRUiEQSlO4x4=
cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q=
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/iam v1.1.10 h1:ZSAr64oEhQSClwBL670MsJAW5/RLiC6kfw3Bqmd5ZDI=
cloud.google.com/go/iam v1.1.10/go.mod h1:iEgMq62sg8zx446GCaijmA2Miwg5o3UbO+nI47WHJps=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo= github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo=
github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
@ -18,12 +28,15 @@ github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@ -32,6 +45,12 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
@ -48,6 +67,11 @@ github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0Nglqm
github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw= github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
@ -64,11 +88,40 @@ github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBYGmXdxA=
github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E=
github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o= github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o=
@ -110,8 +163,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -120,6 +173,8 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
@ -128,6 +183,7 @@ github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQ
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
@ -149,26 +205,96 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI=
go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco=
go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI=
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.187.0 h1:Mxs7VATVC2v7CY+7Xwm4ndkX71hpElcvx0D1Ji/p1eo=
google.golang.org/api v0.187.0/go.mod h1:KIHlTc4x7N7gKKuVsdmfBXN13yEEWXWFURWY6SBp2gk=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 h1:MuYw1wJzT+ZkybKfaOXKp5hJiZDn2iHaXRw0mRYdHSc=
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4/go.mod h1:px9SlOOZBg1wM1zdnr8jEL4CNGUBZ+ZKYtNPApNQc4c=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d h1:k3zyW3BYYR30e8v3x0bTDdE9vpYFjZHK+HcyqkrppWk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA=
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
@ -185,5 +311,7 @@ gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATa
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=
gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

22
main.go
View File

@ -27,27 +27,19 @@ func main() {
common.Init() common.Init()
logger.SetupLogger() logger.SetupLogger()
logger.SysLogf("One API %s started", common.Version) logger.SysLogf("One API %s started", common.Version)
if os.Getenv("GIN_MODE") != "debug" {
if os.Getenv("GIN_MODE") != gin.DebugMode {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
if config.DebugEnabled { if config.DebugEnabled {
logger.SysLog("running in debug mode") logger.SysLog("running in debug mode")
} }
var err error
// Initialize SQL Database // Initialize SQL Database
model.DB, err = model.InitDB("SQL_DSN") model.InitDB()
if err != nil { model.InitLogDB()
logger.FatalLog("failed to initialize database: " + err.Error())
} var err error
if os.Getenv("LOG_SQL_DSN") != "" {
logger.SysLog("using secondary database for table logs")
model.LOG_DB, err = model.InitDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
}
} else {
model.LOG_DB = model.DB
}
err = model.CreateRootAccountIfNeed() err = model.CreateRootAccountIfNeed()
if err != nil { if err != nil {
logger.FatalLog("database init error: " + err.Error()) logger.FatalLog("database init error: " + err.Error())

View File

@ -140,6 +140,12 @@ func TokenAuth() func(c *gin.Context) {
return return
} }
} }
// set channel id for proxy relay
if channelId := c.Param("channelid"); channelId != "" {
c.Set(ctxkey.SpecificChannelId, channelId)
}
c.Next() c.Next()
} }
} }

View File

@ -12,7 +12,7 @@ import (
) )
type ModelRequest struct { type ModelRequest struct {
Model string `json:"model"` Model string `json:"model" form:"model"`
} }
func Distribute() func(c *gin.Context) { func Distribute() func(c *gin.Context) {

View File

@ -3,11 +3,12 @@ package middleware
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"net/http"
"time"
) )
var timeFormat = "2006-01-02T15:04:05.000Z" var timeFormat = "2006-01-02T15:04:05.000Z"
@ -70,6 +71,11 @@ func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark s
} }
func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) { func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
if maxRequestNum == 0 {
return func(c *gin.Context) {
c.Next()
}
}
if common.RedisEnabled { if common.RedisEnabled {
return func(c *gin.Context) { return func(c *gin.Context) {
redisRateLimiter(c, maxRequestNum, duration, mark) redisRateLimiter(c, maxRequestNum, duration, mark)

View File

@ -3,6 +3,7 @@ package model
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
@ -46,6 +47,8 @@ type ChannelConfig struct {
APIVersion string `json:"api_version,omitempty"` APIVersion string `json:"api_version,omitempty"`
LibraryID string `json:"library_id,omitempty"` LibraryID string `json:"library_id,omitempty"`
Plugin string `json:"plugin,omitempty"` Plugin string `json:"plugin,omitempty"`
VertexAIProjectID string `json:"vertex_ai_project_id,omitempty"`
VertexAIADC string `json:"vertex_ai_adc,omitempty"`
} }
func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {

View File

@ -3,6 +3,7 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@ -152,7 +153,11 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
} }
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") ifnull := "ifnull"
if common.UsingPostgreSQL {
ifnull = "COALESCE"
}
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull))
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
} }
@ -176,7 +181,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
} }
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") ifnull := "ifnull"
if common.UsingPostgreSQL {
ifnull = "COALESCE"
}
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull))
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
} }

View File

@ -1,6 +1,7 @@
package model package model
import ( import (
"database/sql"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
@ -29,13 +30,17 @@ func CreateRootAccountIfNeed() error {
if err != nil { if err != nil {
return err return err
} }
accessToken := random.GetUUID()
if config.InitialRootAccessToken != "" {
accessToken = config.InitialRootAccessToken
}
rootUser := User{ rootUser := User{
Username: "root", Username: "root",
Password: hashedPassword, Password: hashedPassword,
Role: RoleRootUser, Role: RoleRootUser,
Status: UserStatusEnabled, Status: UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: random.GetUUID(), AccessToken: accessToken,
Quota: 500000000000000, Quota: 500000000000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)
@ -60,10 +65,22 @@ func CreateRootAccountIfNeed() error {
} }
func chooseDB(envName string) (*gorm.DB, error) { func chooseDB(envName string) (*gorm.DB, error) {
if os.Getenv(envName) != "" {
dsn := os.Getenv(envName) dsn := os.Getenv(envName)
if strings.HasPrefix(dsn, "postgres://") {
switch {
case strings.HasPrefix(dsn, "postgres://"):
// Use PostgreSQL // Use PostgreSQL
return openPostgreSQL(dsn)
case dsn != "":
// Use MySQL
return openMySQL(dsn)
default:
// Use SQLite
return openSQLite()
}
}
func openPostgreSQL(dsn string) (*gorm.DB, error) {
logger.SysLog("using PostgreSQL as database") logger.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{ return gorm.Open(postgres.New(postgres.Config{
@ -72,78 +89,132 @@ func chooseDB(envName string) (*gorm.DB, error) {
}), &gorm.Config{ }), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
} }
// Use MySQL
func openMySQL(dsn string) (*gorm.DB, error) {
logger.SysLog("using MySQL as database") logger.SysLog("using MySQL as database")
common.UsingMySQL = true common.UsingMySQL = true
return gorm.Open(mysql.Open(dsn), &gorm.Config{ return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
} }
// Use SQLite
func openSQLite() (*gorm.DB, error) {
logger.SysLog("SQL_DSN not set, using SQLite as database") logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true common.UsingSQLite = true
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ return gorm.Open(sqlite.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
} }
func InitDB(envName string) (db *gorm.DB, err error) { func InitDB() {
db, err = chooseDB(envName) var err error
if err == nil { DB, err = chooseDB("SQL_DSN")
if config.DebugSQLEnabled {
db = db.Debug()
}
sqlDB, err := db.DB()
if err != nil { if err != nil {
return nil, err logger.FatalLog("failed to initialize database: " + err.Error())
return
} }
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) sqlDB := setDBConns(DB)
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
if !config.IsMasterNode { if !config.IsMasterNode {
return db, err return
} }
if common.UsingMySQL { if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
} }
logger.SysLog("database migration started") logger.SysLog("database migration started")
err = db.AutoMigrate(&Channel{}) if err = migrateDB(); err != nil {
if err != nil { logger.FatalLog("failed to migrate database: " + err.Error())
return nil, err return
}
err = db.AutoMigrate(&Token{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&User{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Option{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Redemption{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return nil, err
} }
logger.SysLog("database migrated") logger.SysLog("database migrated")
return db, err }
} else {
logger.FatalLog(err) func migrateDB() error {
var err error
if err = DB.AutoMigrate(&Channel{}); err != nil {
return err
} }
return db, err if err = DB.AutoMigrate(&Token{}); err != nil {
return err
}
if err = DB.AutoMigrate(&User{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Option{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Redemption{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Ability{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Log{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Channel{}); err != nil {
return err
}
return nil
}
func InitLogDB() {
if os.Getenv("LOG_SQL_DSN") == "" {
LOG_DB = DB
return
}
logger.SysLog("using secondary database for table logs")
var err error
LOG_DB, err = chooseDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
return
}
setDBConns(LOG_DB)
if !config.IsMasterNode {
return
}
logger.SysLog("secondary database migration started")
err = migrateLOGDB()
if err != nil {
logger.FatalLog("failed to migrate secondary database: " + err.Error())
return
}
logger.SysLog("secondary database migrated")
}
func migrateLOGDB() error {
var err error
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
return err
}
return nil
}
func setDBConns(db *gorm.DB) *sql.DB {
if config.DebugSQLEnabled {
db = db.Debug()
}
sqlDB, err := db.DB()
if err != nil {
logger.FatalLog("failed to connect database: " + err.Error())
return nil
}
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
return sqlDB
} }
func closeDB(db *gorm.DB) error { func closeDB(db *gorm.DB) error {

View File

@ -28,6 +28,7 @@ func InitOptionMap() {
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled)
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
@ -130,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) {
config.EmailVerificationEnabled = boolValue config.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled": case "GitHubOAuthEnabled":
config.GitHubOAuthEnabled = boolValue config.GitHubOAuthEnabled = boolValue
case "OidcEnabled":
config.OidcEnabled = boolValue
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
config.WeChatAuthEnabled = boolValue config.WeChatAuthEnabled = boolValue
case "TurnstileCheckEnabled": case "TurnstileCheckEnabled":
@ -176,6 +179,18 @@ func updateOptionMap(key string, value string) (err error) {
config.LarkClientId = value config.LarkClientId = value
case "LarkClientSecret": case "LarkClientSecret":
config.LarkClientSecret = value config.LarkClientSecret = value
case "OidcClientId":
config.OidcClientId = value
case "OidcClientSecret":
config.OidcClientSecret = value
case "OidcWellKnown":
config.OidcWellKnown = value
case "OidcAuthorizationEndpoint":
config.OidcAuthorizationEndpoint = value
case "OidcTokenEndpoint":
config.OidcTokenEndpoint = value
case "OidcUserinfoEndpoint":
config.OidcUserinfoEndpoint = value
case "Footer": case "Footer":
config.Footer = value config.Footer = value
case "SystemName": case "SystemName":

View File

@ -30,7 +30,7 @@ type Token struct {
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
Models *string `json:"models" gorm:"default:''"` // allowed models Models *string `json:"models" gorm:"type:text"` // allowed models
Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
} }
@ -121,30 +121,40 @@ func GetTokenById(id int) (*Token, error) {
return &token, err return &token, err
} }
func (token *Token) Insert() error { func (t *Token) Insert() error {
var err error var err error
err = DB.Create(token).Error err = DB.Create(t).Error
return err return err
} }
// Update Make sure your token's fields is completed, because this will update non-zero values // Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error { func (t *Token) Update() error {
var err error var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error
return err return err
} }
func (token *Token) SelectUpdate() error { func (t *Token) SelectUpdate() error {
// This can update zero values // This can update zero values
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error return DB.Model(t).Select("accessed_time", "status").Updates(t).Error
} }
func (token *Token) Delete() error { func (t *Token) Delete() error {
var err error var err error
err = DB.Delete(token).Error err = DB.Delete(t).Error
return err return err
} }
func (t *Token) GetModels() string {
if t == nil {
return ""
}
if t.Models == nil {
return ""
}
return *t.Models
}
func DeleteTokenById(id int, userId int) (err error) { func DeleteTokenById(id int, userId int) (err error) {
// Why we need userId here? In case user want to delete other's token. // Why we need userId here? In case user want to delete other's token.
if id == 0 || userId == 0 { if id == 0 || userId == 0 {
@ -254,14 +264,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
token, err := GetTokenById(tokenId) token, err := GetTokenById(tokenId)
if err != nil {
return err
}
if quota > 0 { if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota) err = DecreaseUserQuota(token.UserId, quota)
} else { } else {
err = IncreaseUserQuota(token.UserId, -quota) err = IncreaseUserQuota(token.UserId, -quota)
} }
if err != nil {
return err
}
if !token.UnlimitedQuota { if !token.UnlimitedQuota {
if quota > 0 { if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota) err = DecreaseTokenQuota(tokenId, quota)

View File

@ -39,6 +39,7 @@ type User struct {
GitHubId string `json:"github_id" gorm:"column:github_id;index"` GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
LarkId string `json:"lark_id" gorm:"column:lark_id;index"` LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int64 `json:"quota" gorm:"bigint;default:0"` Quota int64 `json:"quota" gorm:"bigint;default:0"`
@ -245,6 +246,14 @@ func (user *User) FillUserByLarkId() error {
return nil return nil
} }
func (user *User) FillUserByOidcId() error {
if user.OidcId == "" {
return errors.New("oidc id 为空!")
}
DB.Where(User{OidcId: user.OidcId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error { func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" { if user.WeChatId == "" {
return errors.New("WeChat id 为空!") return errors.New("WeChat id 为空!")
@ -277,6 +286,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool {
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
} }
func IsOidcIdAlreadyTaken(oidcId string) bool {
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool { func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
} }

View File

@ -1,10 +1,11 @@
package monitor package monitor
import ( import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/model"
"net/http" "net/http"
"strings" "strings"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/model"
) )
func ShouldDisableChannel(err *model.Error, statusCode int) bool { func ShouldDisableChannel(err *model.Error, statusCode int) bool {
@ -18,31 +19,23 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
return true return true
} }
switch err.Type { switch err.Type {
case "insufficient_quota": case "insufficient_quota", "authentication_error", "permission_error", "forbidden":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
case "permission_error":
return true
case "forbidden":
return true return true
} }
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true return true
} }
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
return true lowerMessage := strings.ToLower(err.Message)
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") { if strings.Contains(lowerMessage, "your access was terminated") ||
return true strings.Contains(lowerMessage, "violation of our policies") ||
} strings.Contains(lowerMessage, "your credit balance is too low") ||
//if strings.Contains(err.Message, "quota") { strings.Contains(lowerMessage, "organization has been disabled") ||
// return true strings.Contains(lowerMessage, "credit") ||
//} strings.Contains(lowerMessage, "balance") ||
if strings.Contains(err.Message, "credit") { strings.Contains(lowerMessage, "permission denied") ||
return true strings.Contains(lowerMessage, "organization has been restricted") || // groq
} strings.Contains(lowerMessage, "已欠费") {
if strings.Contains(err.Message, "balance") {
return true return true
} }
return false return false

View File

@ -15,7 +15,9 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/ollama" "github.com/songquanpeng/one-api/relay/adaptor/ollama"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/adaptor/palm" "github.com/songquanpeng/one-api/relay/adaptor/palm"
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
"github.com/songquanpeng/one-api/relay/adaptor/tencent" "github.com/songquanpeng/one-api/relay/adaptor/tencent"
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
"github.com/songquanpeng/one-api/relay/adaptor/xunfei" "github.com/songquanpeng/one-api/relay/adaptor/xunfei"
"github.com/songquanpeng/one-api/relay/adaptor/zhipu" "github.com/songquanpeng/one-api/relay/adaptor/zhipu"
"github.com/songquanpeng/one-api/relay/apitype" "github.com/songquanpeng/one-api/relay/apitype"
@ -55,6 +57,10 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
return &cloudflare.Adaptor{} return &cloudflare.Adaptor{}
case apitype.DeepL: case apitype.DeepL:
return &deepl.Adaptor{} return &deepl.Adaptor{}
case apitype.VertexAI:
return &vertexai.Adaptor{}
case apitype.Proxy:
return &proxy.Adaptor{}
} }
return nil return nil
} }

View File

@ -3,6 +3,7 @@ package ali
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/render" "github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
@ -59,7 +60,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{ return &EmbeddingRequest{
Model: "text-embedding-v1", Model: request.Model,
Input: struct { Input: struct {
Texts []string `json:"texts"` Texts []string `json:"texts"`
}{ }{
@ -102,8 +103,9 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
}, nil }, nil
} }
requestModel := c.GetString(ctxkey.RequestModel)
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
fullTextResponse.Model = requestModel
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@ -3,12 +3,14 @@ package anthropic
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
) )
type Adaptor struct { type Adaptor struct {
@ -31,6 +33,13 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
} }
req.Header.Set("anthropic-version", anthropicVersion) req.Header.Set("anthropic-version", anthropicVersion)
req.Header.Set("anthropic-beta", "messages-2023-12-15") req.Header.Set("anthropic-beta", "messages-2023-12-15")
// https://x.com/alexalbert__/status/1812921642143900036
// claude-3-5-sonnet can support 8k context
if strings.HasPrefix(meta.ActualModelName, "claude-3-5-sonnet") {
req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15")
}
return nil return nil
} }

View File

@ -1,17 +1,16 @@
package aws package aws
import ( import (
"github.com/aws/aws-sdk-go-v2/aws" "errors"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/songquanpeng/one-api/common/ctxkey"
"io" "io"
"net/http" "net/http"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic" "github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
) )
@ -19,18 +18,52 @@ import (
var _ adaptor.Adaptor = new(Adaptor) var _ adaptor.Adaptor = new(Adaptor)
type Adaptor struct { type Adaptor struct {
meta *meta.Meta awsAdapter utils.AwsAdapter
awsClient *bedrockruntime.Client
Meta *meta.Meta
AwsClient *bedrockruntime.Client
} }
func (a *Adaptor) Init(meta *meta.Meta) { func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta a.Meta = meta
a.awsClient = bedrockruntime.New(bedrockruntime.Options{ a.AwsClient = bedrockruntime.New(bedrockruntime.Options{
Region: meta.Config.Region, Region: meta.Config.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
}) })
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
adaptor := GetAdaptor(request.Model)
if adaptor == nil {
return nil, errors.New("adaptor not found")
}
a.awsAdapter = adaptor
return adaptor.ConvertRequest(c, relayMode, request)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if a.awsAdapter == nil {
return nil, utils.WrapErr(errors.New("awsAdapter is nil"))
}
return a.awsAdapter.DoResponse(c, a.AwsClient, meta)
}
func (a *Adaptor) GetModelList() (models []string) {
for model := range adaptors {
models = append(models, model)
}
return
}
func (a *Adaptor) GetChannelName() string {
return "aws"
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil return "", nil
} }
@ -39,17 +72,6 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
claudeReq := anthropic.ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
@ -60,23 +82,3 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil return nil, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, a.awsClient)
} else {
err, usage = Handler(c, a.awsClient, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() (models []string) {
for n := range awsModelIDMap {
models = append(models, n)
}
return
}
func (a *Adaptor) GetChannelName() string {
return "aws"
}

View File

@ -0,0 +1,37 @@
package aws
import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var _ utils.AwsAdapter = new(Adaptor)
type Adaptor struct {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
claudeReq := anthropic.ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, awsCli)
} else {
err, usage = Handler(c, awsCli, meta.ActualModelName)
}
return
}

View File

@ -5,8 +5,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"io" "io"
"net/http" "net/http"
@ -17,23 +15,17 @@ import (
"github.com/jinzhu/copier" "github.com/jinzhu/copier"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic" "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
) )
func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: fmt.Sprintf("%s", err.Error()),
},
}
}
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var awsModelIDMap = map[string]string{ var AwsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1", "claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2", "claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1", "claude-2.1": "anthropic.claude-v2:1",
@ -44,7 +36,7 @@ var awsModelIDMap = map[string]string{
} }
func awsModelID(requestModel string) (string, error) { func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := awsModelIDMap[requestModel]; ok { if awsModelID, ok := AwsModelIDMap[requestModel]; ok {
return awsModelID, nil return awsModelID, nil
} }
@ -54,7 +46,7 @@ func awsModelID(requestModel string) (string, error) {
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsReq := &bedrockruntime.InvokeModelInput{ awsReq := &bedrockruntime.InvokeModelInput{
@ -65,30 +57,30 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok { if !ok {
return wrapErr(errors.New("request not found")), nil return utils.WrapErr(errors.New("request not found")), nil
} }
claudeReq := claudeReq_.(*anthropic.Request) claudeReq := claudeReq_.(*anthropic.Request)
awsClaudeReq := &Request{ awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31", AnthropicVersion: "bedrock-2023-05-31",
} }
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil return utils.WrapErr(errors.Wrap(err, "copy request")), nil
} }
awsReq.Body, err = json.Marshal(awsClaudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
} }
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModel")), nil return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil
} }
claudeResponse := new(anthropic.Response) claudeResponse := new(anthropic.Response)
err = json.Unmarshal(awsResp.Body, claudeResponse) err = json.Unmarshal(awsResp.Body, claudeResponse)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "unmarshal response")), nil return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
} }
openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse)
@ -108,7 +100,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
@ -119,7 +111,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok { if !ok {
return wrapErr(errors.New("request not found")), nil return utils.WrapErr(errors.New("request not found")), nil
} }
claudeReq := claudeReq_.(*anthropic.Request) claudeReq := claudeReq_.(*anthropic.Request)
@ -127,16 +119,16 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
AnthropicVersion: "bedrock-2023-05-31", AnthropicVersion: "bedrock-2023-05-31",
} }
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil return utils.WrapErr(errors.Wrap(err, "copy request")), nil
} }
awsReq.Body, err = json.Marshal(awsClaudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
} }
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
} }
stream := awsResp.GetStream() stream := awsResp.GetStream()
defer stream.Close() defer stream.Close()

View File

@ -0,0 +1,37 @@
package aws
import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var _ utils.AwsAdapter = new(Adaptor)
type Adaptor struct {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
llamaReq := ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, llamaReq)
return llamaReq, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, awsCli)
} else {
err, usage = Handler(c, awsCli, meta.ActualModelName)
}
return
}

View File

@ -0,0 +1,231 @@
// Package aws provides the AWS adaptor for the relay service.
package aws
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"text/template"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/random"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
// Only support llama-3-8b and llama-3-70b instruction models
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var AwsModelIDMap = map[string]string{
"llama3-8b-8192": "meta.llama3-8b-instruct-v1:0",
"llama3-70b-8192": "meta.llama3-70b-instruct-v1:0",
}
func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := AwsModelIDMap[requestModel]; ok {
return awsModelID, nil
}
return "", errors.Errorf("model %s not found", requestModel)
}
// promptTemplate with range
const promptTemplate = `<|begin_of_text|>{{range .Messages}}<|start_header_id|>{{.Role}}<|end_header_id|>{{.StringContent}}<|eot_id|>{{end}}<|start_header_id|>assistant<|end_header_id|>
`
var promptTpl = template.Must(template.New("llama3-chat").Parse(promptTemplate))
func RenderPrompt(messages []relaymodel.Message) string {
var buf bytes.Buffer
err := promptTpl.Execute(&buf, struct{ Messages []relaymodel.Message }{messages})
if err != nil {
logger.SysError("error rendering prompt messages: " + err.Error())
}
return buf.String()
}
func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request {
llamaRequest := Request{
MaxGenLen: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
}
if llamaRequest.MaxGenLen == 0 {
llamaRequest.MaxGenLen = 2048
}
prompt := RenderPrompt(textRequest.Messages)
llamaRequest.Prompt = prompt
return &llamaRequest
}
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
llamaReq, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return utils.WrapErr(errors.New("request not found")), nil
}
awsReq.Body, err = json.Marshal(llamaReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil
}
var llamaResponse Response
err = json.Unmarshal(awsResp.Body, &llamaResponse)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
}
openaiResp := ResponseLlama2OpenAI(&llamaResponse)
openaiResp.Model = modelName
usage := relaymodel.Usage{
PromptTokens: llamaResponse.PromptTokenCount,
CompletionTokens: llamaResponse.GenerationTokenCount,
TotalTokens: llamaResponse.PromptTokenCount + llamaResponse.GenerationTokenCount,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
}
func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse {
var responseText string
if len(llamaResponse.Generation) > 0 {
responseText = llamaResponse.Generation
}
choice := openai.TextResponseChoice{
Index: 0,
Message: relaymodel.Message{
Role: "assistant",
Content: responseText,
Name: nil,
},
FinishReason: llamaResponse.StopReason,
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
llamaReq, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return utils.WrapErr(errors.New("request not found")), nil
}
awsReq.Body, err = json.Marshal(llamaReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
}
stream := awsResp.GetStream()
defer stream.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
var llamaResp StreamResponse
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&llamaResp)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return false
}
if llamaResp.PromptTokenCount > 0 {
usage.PromptTokens = llamaResp.PromptTokenCount
}
if llamaResp.StopReason == "stop" {
usage.CompletionTokens = llamaResp.GenerationTokenCount
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
response := StreamResponseLlama2OpenAI(&llamaResp)
response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID())
response.Model = c.GetString(ctxkey.OriginalModel)
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case *types.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return false
default:
fmt.Println("union is nil or unknown type")
return false
}
})
return nil, &usage
}
func StreamResponseLlama2OpenAI(llamaResponse *StreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = llamaResponse.Generation
choice.Delta.Role = "assistant"
finishReason := llamaResponse.StopReason
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &openaiResponse
}

View File

@ -0,0 +1,45 @@
package aws_test
import (
"testing"
aws "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/stretchr/testify/assert"
)
func TestRenderPrompt(t *testing.T) {
messages := []relaymodel.Message{
{
Role: "user",
Content: "What's your name?",
},
}
prompt := aws.RenderPrompt(messages)
expected := `<|begin_of_text|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
`
assert.Equal(t, expected, prompt)
messages = []relaymodel.Message{
{
Role: "system",
Content: "Your name is Kat. You are a detective.",
},
{
Role: "user",
Content: "What's your name?",
},
{
Role: "assistant",
Content: "Kat",
},
{
Role: "user",
Content: "What's your job?",
},
}
prompt = aws.RenderPrompt(messages)
expected = `<|begin_of_text|><|start_header_id|>system<|end_header_id|>Your name is Kat. You are a detective.<|eot_id|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>Kat<|eot_id|><|start_header_id|>user<|end_header_id|>What's your job?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
`
assert.Equal(t, expected, prompt)
}

View File

@ -0,0 +1,29 @@
package aws
// Request is the request to AWS Llama3
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
type Request struct {
Prompt string `json:"prompt"`
MaxGenLen int `json:"max_gen_len,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
}
// Response is the response from AWS Llama3
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
type Response struct {
Generation string `json:"generation"`
PromptTokenCount int `json:"prompt_token_count"`
GenerationTokenCount int `json:"generation_token_count"`
StopReason string `json:"stop_reason"`
}
// {'generation': 'Hi', 'prompt_token_count': 15, 'generation_token_count': 1, 'stop_reason': None}
type StreamResponse struct {
Generation string `json:"generation"`
PromptTokenCount int `json:"prompt_token_count"`
GenerationTokenCount int `json:"generation_token_count"`
StopReason string `json:"stop_reason"`
}

View File

@ -0,0 +1,39 @@
package aws
import (
claude "github.com/songquanpeng/one-api/relay/adaptor/aws/claude"
llama3 "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
)
type AwsModelType int
const (
AwsClaude AwsModelType = iota + 1
AwsLlama3
)
var (
adaptors = map[string]AwsModelType{}
)
func init() {
for model := range claude.AwsModelIDMap {
adaptors[model] = AwsClaude
}
for model := range llama3.AwsModelIDMap {
adaptors[model] = AwsLlama3
}
}
func GetAdaptor(model string) utils.AwsAdapter {
adaptorType := adaptors[model]
switch adaptorType {
case AwsClaude:
return &claude.Adaptor{}
case AwsLlama3:
return &llama3.Adaptor{}
default:
return nil
}
}

View File

@ -0,0 +1,51 @@
package utils
import (
"errors"
"io"
"net/http"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
type AwsAdapter interface {
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
}
type Adaptor struct {
Meta *meta.Meta
AwsClient *bedrockruntime.Client
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.Meta = meta
a.AwsClient = bedrockruntime.New(bedrockruntime.Options{
Region: meta.Config.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
})
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
return nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}

View File

@ -0,0 +1,16 @@
package utils
import (
"net/http"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func WrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: err.Error(),
},
}
}

View File

@ -5,11 +5,13 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
) )
type Adaptor struct { type Adaptor struct {
@ -27,8 +29,33 @@ func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta a.meta = meta
} }
// WorkerAI cannot be used across accounts with AIGateWay
// https://developers.cloudflare.com/ai-gateway/providers/workersai/#openai-compatible-endpoints
// https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/workers-ai
func (a *Adaptor) isAIGateWay(baseURL string) bool {
return strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai")
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil isAIGateWay := a.isAIGateWay(meta.BaseURL)
var urlPrefix string
if isAIGateWay {
urlPrefix = meta.BaseURL
} else {
urlPrefix = fmt.Sprintf("%s/client/v4/accounts/%s/ai", meta.BaseURL, meta.Config.UserID)
}
switch meta.Mode {
case relaymode.ChatCompletions:
return fmt.Sprintf("%s/v1/chat/completions", urlPrefix), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/v1/embeddings", urlPrefix), nil
default:
if isAIGateWay {
return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModelName), nil
}
return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModelName), nil
}
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
@ -41,7 +68,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
return ConvertRequest(*request), nil switch relayMode {
case relaymode.Completions:
return ConvertCompletionsRequest(*request), nil
case relaymode.ChatCompletions, relaymode.Embeddings:
return request, nil
default:
return nil, errors.New("not implemented")
}
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {

View File

@ -1,6 +1,7 @@
package cloudflare package cloudflare
var ModelList = []string{ var ModelList = []string{
"@cf/meta/llama-3.1-8b-instruct",
"@cf/meta/llama-2-7b-chat-fp16", "@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8", "@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1", "@cf/mistral/mistral-7b-instruct-v0.1",

View File

@ -3,11 +3,13 @@ package cloudflare
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/render"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@ -16,57 +18,23 @@ import (
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
) )
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request {
var promptBuilder strings.Builder p, _ := textRequest.Prompt.(string)
for _, message := range textRequest.Messages {
promptBuilder.WriteString(message.StringContent())
promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息
}
return &Request{ return &Request{
Prompt: p,
MaxTokens: textRequest.MaxTokens, MaxTokens: textRequest.MaxTokens,
Prompt: promptBuilder.String(),
Stream: textRequest.Stream, Stream: textRequest.Stream,
Temperature: textRequest.Temperature, Temperature: textRequest.Temperature,
} }
} }
func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: cloudflareResponse.Result.Response,
},
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = cloudflareResponse.Response
choice.Delta.Role = "assistant"
openaiResponse := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
Created: helper.GetTimestamp(),
}
return &openaiResponse
}
func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines) scanner.Split(bufio.ScanLines)
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
id := helper.GetResponseID(c) id := helper.GetResponseID(c)
responseModel := c.GetString("original_model") responseModel := c.GetString(ctxkey.OriginalModel)
var responseText string var responseText string
for scanner.Scan() { for scanner.Scan() {
@ -77,22 +45,22 @@ func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelN
data = strings.TrimPrefix(data, "data: ") data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\r") data = strings.TrimSuffix(data, "\r")
var cloudflareResponse StreamResponse if data == "[DONE]" {
err := json.Unmarshal([]byte(data), &cloudflareResponse) break
}
var response openai.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
continue continue
} }
for _, v := range response.Choices {
response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) v.Delta.Role = "assistant"
if response == nil { responseText += v.Delta.StringContent()
continue
} }
responseText += cloudflareResponse.Response
response.Id = id response.Id = id
response.Model = responseModel response.Model = modelName
err = render.ObjectData(c, response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError(err.Error()) logger.SysError(err.Error())
@ -123,22 +91,25 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
var cloudflareResponse Response var response openai.TextResponse
err = json.Unmarshal(responseBody, &cloudflareResponse) err = json.Unmarshal(responseBody, &response)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse) response.Model = modelName
fullTextResponse.Model = modelName var responseText string
usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens) for _, v := range response.Choices {
fullTextResponse.Usage = *usage responseText += v.Message.Content.(string)
fullTextResponse.Id = helper.GetResponseID(c) }
jsonResponse, err := json.Marshal(fullTextResponse) usage := openai.ResponseText2Usage(responseText, modelName, promptTokens)
response.Usage = *usage
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse) _, _ = c.Writer.Write(jsonResponse)
return nil, usage return nil, usage
} }

View File

@ -1,6 +1,9 @@
package cloudflare package cloudflare
import "github.com/songquanpeng/one-api/relay/model"
type Request struct { type Request struct {
Messages []model.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"` Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
@ -8,18 +11,3 @@ type Request struct {
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
} }
type Result struct {
Response string `json:"response"`
}
type Response struct {
Result Result `json:"result"`
Success bool `json:"success"`
Errors []string `json:"errors"`
Messages []string `json:"messages"`
}
type StreamResponse struct {
Response string `json:"response"`
}

View File

@ -7,8 +7,12 @@ import (
) )
func GetRequestURL(meta *meta.Meta) (string, error) { func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions { switch meta.Mode {
case relaymode.ChatCompletions:
return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/api/v3/embeddings", meta.BaseURL), nil
default:
} }
return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode) return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
} }

View File

@ -3,6 +3,5 @@ package gemini
// https://ai.google.dev/models/gemini // https://ai.google.dev/models/gemini
var ModelList = []string{ var ModelList = []string{
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro", "gemini-pro", "gemini-1.0-pro", "gemini-1.5-flash", "gemini-1.5-pro", "text-embedding-004", "aqa",
"gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004",
} }

View File

@ -4,9 +4,14 @@ package groq
var ModelList = []string{ var ModelList = []string{
"gemma-7b-it", "gemma-7b-it",
"llama2-7b-2048",
"llama2-70b-4096",
"mixtral-8x7b-32768", "mixtral-8x7b-32768",
"llama3-8b-8192", "llama3-8b-8192",
"llama3-70b-8192", "llama3-70b-8192",
"gemma2-9b-it",
"llama-3.1-405b-reasoning",
"llama-3.1-70b-versatile",
"llama-3.1-8b-instant",
"llama3-groq-70b-8192-tool-use-preview",
"llama3-groq-8b-8192-tool-use-preview",
"whisper-large-v3",
} }

View File

@ -0,0 +1,19 @@
package novita
// https://novita.ai/llm-api
var ModelList = []string{
"meta-llama/llama-3-8b-instruct",
"meta-llama/llama-3-70b-instruct",
"nousresearch/hermes-2-pro-llama-3-8b",
"nousresearch/nous-hermes-llama2-13b",
"mistralai/mistral-7b-instruct",
"cognitivecomputations/dolphin-mixtral-8x22b",
"sao10k/l3-70b-euryale-v2.1",
"sophosympatheia/midnight-rose-70b",
"gryphe/mythomax-l2-13b",
"Nous-Hermes-2-Mixtral-8x7B-DPO",
"lzlv_70b",
"teknium/openhermes-2.5-mistral-7b",
"microsoft/wizardlm-2-8x22b",
}

View File

@ -0,0 +1,15 @@
package novita
import (
"fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
)
func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
return fmt.Sprintf("%s/chat/completions", meta.BaseURL), nil
}
return "", fmt.Errorf("unsupported relay mode %d for novita", meta.Mode)
}

View File

@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md // https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
if meta.Mode == relaymode.Embeddings { if meta.Mode == relaymode.Embeddings {
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) fullRequestURL = fmt.Sprintf("%s/api/embed", meta.BaseURL)
} }
return fullRequestURL, nil return fullRequestURL, nil
} }

View File

@ -31,6 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
TopP: request.TopP, TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty, FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty, PresencePenalty: request.PresencePenalty,
NumPredict: request.MaxTokens,
NumCtx: request.NumCtx,
}, },
Stream: request.Stream, Stream: request.Stream,
} }
@ -118,8 +120,10 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
for scanner.Scan() { for scanner.Scan() {
data := strings.TrimPrefix(scanner.Text(), "}") data := scanner.Text()
data = data + "}" if strings.HasPrefix(data, "}") {
data = strings.TrimPrefix(data, "}") + "}"
}
var ollamaResponse ChatResponse var ollamaResponse ChatResponse
err := json.Unmarshal([]byte(data), &ollamaResponse) err := json.Unmarshal([]byte(data), &ollamaResponse)
@ -158,7 +162,14 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{ return &EmbeddingRequest{
Model: request.Model, Model: request.Model,
Prompt: strings.Join(request.ParseInput(), " "), Input: request.ParseInput(),
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
} }
} }
@ -201,15 +212,17 @@ func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.Embeddi
openAIEmbeddingResponse := openai.EmbeddingResponse{ openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list", Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, 1), Data: make([]openai.EmbeddingResponseItem, 0, 1),
Model: "text-embedding-v1", Model: response.Model,
Usage: model.Usage{TotalTokens: 0}, Usage: model.Usage{TotalTokens: 0},
} }
for i, embedding := range response.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`, Object: `embedding`,
Index: 0, Index: i,
Embedding: response.Embedding, Embedding: embedding,
}) })
}
return &openAIEmbeddingResponse return &openAIEmbeddingResponse
} }

View File

@ -7,6 +7,8 @@ type Options struct {
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
} }
type Message struct { type Message struct {
@ -38,10 +40,14 @@ type ChatResponse struct {
type EmbeddingRequest struct { type EmbeddingRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Input []string `json:"input"`
// Truncate bool `json:"truncate,omitempty"`
Options *Options `json:"options,omitempty"`
// KeepAlive string `json:"keep_alive,omitempty"`
} }
type EmbeddingResponse struct { type EmbeddingResponse struct {
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
Embedding []float64 `json:"embedding,omitempty"` Model string `json:"model"`
Embeddings [][]float64 `json:"embeddings"`
} }

View File

@ -3,17 +3,19 @@ package openai
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/doubao" "github.com/songquanpeng/one-api/relay/adaptor/doubao"
"github.com/songquanpeng/one-api/relay/adaptor/minimax" "github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
) )
type Adaptor struct { type Adaptor struct {
@ -48,6 +50,8 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return minimax.GetRequestURL(meta) return minimax.GetRequestURL(meta)
case channeltype.Doubao: case channeltype.Doubao:
return doubao.GetRequestURL(meta) return doubao.GetRequestURL(meta)
case channeltype.Novita:
return novita.GetRequestURL(meta)
default: default:
return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
} }

View File

@ -10,8 +10,10 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/minimax" "github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/mistral" "github.com/songquanpeng/one-api/relay/adaptor/mistral"
"github.com/songquanpeng/one-api/relay/adaptor/moonshot" "github.com/songquanpeng/one-api/relay/adaptor/moonshot"
"github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/adaptor/stepfun" "github.com/songquanpeng/one-api/relay/adaptor/stepfun"
"github.com/songquanpeng/one-api/relay/adaptor/togetherai" "github.com/songquanpeng/one-api/relay/adaptor/togetherai"
"github.com/songquanpeng/one-api/relay/adaptor/siliconflow"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
) )
@ -28,6 +30,8 @@ var CompatibleChannels = []int{
channeltype.StepFun, channeltype.StepFun,
channeltype.DeepSeek, channeltype.DeepSeek,
channeltype.TogetherAI, channeltype.TogetherAI,
channeltype.Novita,
channeltype.SiliconFlow,
} }
func GetCompatibleChannelMeta(channelType int) (string, []string) { func GetCompatibleChannelMeta(channelType int) (string, []string) {
@ -56,6 +60,10 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
return "together.ai", togetherai.ModelList return "together.ai", togetherai.ModelList
case channeltype.Doubao: case channeltype.Doubao:
return "doubao", doubao.ModelList return "doubao", doubao.ModelList
case channeltype.Novita:
return "novita", novita.ModelList
case channeltype.SiliconFlow:
return "siliconflow", siliconflow.ModelList
default: default:
return "openai", ModelList return "openai", ModelList
} }

View File

@ -8,6 +8,9 @@ var ModelList = []string{
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o", "gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"chatgpt-4o-latest",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"gpt-4-vision-preview", "gpt-4-vision-preview",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",

View File

@ -4,11 +4,12 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/songquanpeng/one-api/common/render"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
@ -31,6 +32,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
doneRendered := false
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if len(data) < dataPrefixLength { // ignore blank line or wrong format if len(data) < dataPrefixLength { // ignore blank line or wrong format
@ -41,6 +43,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
} }
if strings.HasPrefix(data[dataPrefixLength:], done) { if strings.HasPrefix(data[dataPrefixLength:], done) {
render.StringData(c, data) render.StringData(c, data)
doneRendered = true
continue continue
} }
switch relayMode { switch relayMode {
@ -52,8 +55,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
render.StringData(c, data) // if error happened, pass the data to client render.StringData(c, data) // if error happened, pass the data to client
continue // just ignore the error continue // just ignore the error
} }
if len(streamResponse.Choices) == 0 { if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
// but for empty choice, we should not pass it to client, this is for azure // but for empty choice and no usage, we should not pass it to client, this is for azure
continue // just ignore empty choice continue // just ignore empty choice
} }
render.StringData(c, data) render.StringData(c, data)
@ -81,7 +84,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
if !doneRendered {
render.Done(c) render.Done(c)
}
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {

View File

@ -97,7 +97,11 @@ func CountTokenMessages(messages []model.Message, model string) int {
m := it.(map[string]any) m := it.(map[string]any)
switch m["type"] { switch m["type"] {
case "text": case "text":
tokenNum += getTokenNum(tokenEncoder, m["text"].(string)) if textValue, ok := m["text"]; ok {
if textString, ok := textValue.(string); ok {
tokenNum += getTokenNum(tokenEncoder, textString)
}
}
case "image_url": case "image_url":
imageUrl, ok := m["image_url"].(map[string]any) imageUrl, ok := m["image_url"].(map[string]any)
if ok { if ok {
@ -106,7 +110,7 @@ func CountTokenMessages(messages []model.Message, model string) int {
if imageUrl["detail"] != nil { if imageUrl["detail"] != nil {
detail = imageUrl["detail"].(string) detail = imageUrl["detail"].(string)
} }
imageTokens, err := countImageTokens(url, detail) imageTokens, err := countImageTokens(url, detail, model)
if err != nil { if err != nil {
logger.SysError("error counting image tokens: " + err.Error()) logger.SysError("error counting image tokens: " + err.Error())
} else { } else {
@ -130,11 +134,15 @@ const (
lowDetailCost = 85 lowDetailCost = 85
highDetailCostPerTile = 170 highDetailCostPerTile = 170
additionalCost = 85 additionalCost = 85
// gpt-4o-mini cost higher than other model
gpt4oMiniLowDetailCost = 2833
gpt4oMiniHighDetailCost = 5667
gpt4oMiniAdditionalCost = 2833
) )
// https://platform.openai.com/docs/guides/vision/calculating-costs // https://platform.openai.com/docs/guides/vision/calculating-costs
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb // https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
func countImageTokens(url string, detail string) (_ int, err error) { func countImageTokens(url string, detail string, model string) (_ int, err error) {
var fetchSize = true var fetchSize = true
var width, height int var width, height int
// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding // Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
@ -168,6 +176,9 @@ func countImageTokens(url string, detail string) (_ int, err error) {
} }
switch detail { switch detail {
case "low": case "low":
if strings.HasPrefix(model, "gpt-4o-mini") {
return gpt4oMiniLowDetailCost, nil
}
return lowDetailCost, nil return lowDetailCost, nil
case "high": case "high":
if fetchSize { if fetchSize {
@ -187,6 +198,9 @@ func countImageTokens(url string, detail string) (_ int, err error) {
height = int(float64(height) * ratio) height = int(float64(height) * ratio)
} }
numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512)) numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512))
if strings.HasPrefix(model, "gpt-4o-mini") {
return numSquares*gpt4oMiniHighDetailCost + gpt4oMiniAdditionalCost, nil
}
result := numSquares*highDetailCostPerTile + additionalCost result := numSquares*highDetailCostPerTile + additionalCost
return result, nil return result, nil
default: default:

View File

@ -0,0 +1,89 @@
package proxy
import (
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor"
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
var _ adaptor.Adaptor = new(Adaptor)
const channelName = "proxy"
type Adaptor struct{}
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
return nil, errors.New("notimplement")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
for k, v := range resp.Header {
for _, vv := range v {
c.Writer.Header().Set(k, vv)
}
}
c.Writer.WriteHeader(resp.StatusCode)
if _, gerr := io.Copy(c.Writer, resp.Body); gerr != nil {
return nil, &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: gerr.Error(),
},
}
}
return nil, nil
}
func (a *Adaptor) GetModelList() (models []string) {
return nil
}
func (a *Adaptor) GetChannelName() string {
return channelName
}
// GetRequestURL remove static prefix, and return the real request url to the upstream service
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
prefix := fmt.Sprintf("/v1/oneapi/proxy/%d", meta.ChannelId)
return meta.BaseURL + strings.TrimPrefix(meta.RequestURLPath, prefix), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
for k, v := range c.Request.Header {
req.Header.Set(k, v[0])
}
// remove unnecessary headers
req.Header.Del("Host")
req.Header.Del("Content-Length")
req.Header.Del("Accept-Encoding")
req.Header.Del("Connection")
// set authorization header
req.Header.Set("Authorization", meta.APIKey)
return nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return nil, errors.Errorf("not implement")
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}

View File

@ -0,0 +1,36 @@
package siliconflow
// https://docs.siliconflow.cn/docs/getting-started
var ModelList = []string{
"deepseek-ai/deepseek-llm-67b-chat",
"Qwen/Qwen1.5-14B-Chat",
"Qwen/Qwen1.5-7B-Chat",
"Qwen/Qwen1.5-110B-Chat",
"Qwen/Qwen1.5-32B-Chat",
"01-ai/Yi-1.5-6B-Chat",
"01-ai/Yi-1.5-9B-Chat-16K",
"01-ai/Yi-1.5-34B-Chat-16K",
"THUDM/chatglm3-6b",
"deepseek-ai/DeepSeek-V2-Chat",
"THUDM/glm-4-9b-chat",
"Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-57B-A14B-Instruct",
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"internlm/internlm2_5-7b-chat",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-large-zh-v1.5",
"Pro/Qwen/Qwen2-7B-Instruct",
"Pro/Qwen/Qwen2-1.5B-Instruct",
"Pro/Qwen/Qwen1.5-7B-Chat",
"Pro/THUDM/glm-4-9b-chat",
"Pro/THUDM/chatglm3-6b",
"Pro/01-ai/Yi-1.5-9B-Chat-16K",
"Pro/01-ai/Yi-1.5-6B-Chat",
"Pro/google/gemma-2-9b-it",
"Pro/internlm/internlm2_5-7b-chat",
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
}

View File

@ -1,7 +1,13 @@
package stepfun package stepfun
var ModelList = []string{ var ModelList = []string{
"step-1-8k",
"step-1-32k", "step-1-32k",
"step-1-128k",
"step-1-256k",
"step-1-flash",
"step-2-16k",
"step-1v-8k",
"step-1v-32k", "step-1v-32k",
"step-1-200k", "step-1x-medium",
} }

View File

@ -5,4 +5,5 @@ var ModelList = []string{
"hunyuan-standard", "hunyuan-standard",
"hunyuan-standard-256K", "hunyuan-standard-256K",
"hunyuan-pro", "hunyuan-pro",
"hunyuan-vision",
} }

View File

@ -0,0 +1,117 @@
package vertexai
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
var _ adaptor.Adaptor = new(Adaptor)
const channelName = "vertexai"
type Adaptor struct{}
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
adaptor := GetAdaptor(request.Model)
if adaptor == nil {
return nil, errors.New("adaptor not found")
}
return adaptor.ConvertRequest(c, relayMode, request)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
adaptor := GetAdaptor(meta.ActualModelName)
if adaptor == nil {
return nil, &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: "adaptor not found",
},
}
}
return adaptor.DoResponse(c, resp, meta)
}
func (a *Adaptor) GetModelList() (models []string) {
models = modelList
return
}
func (a *Adaptor) GetChannelName() string {
return channelName
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
suffix := ""
if strings.HasPrefix(meta.ActualModelName, "gemini") {
if meta.IsStream {
suffix = "streamGenerateContent?alt=sse"
} else {
suffix = "generateContent"
}
} else {
if meta.IsStream {
suffix = "streamRawPredict?alt=sse"
} else {
suffix = "rawPredict"
}
}
if meta.BaseURL != "" {
return fmt.Sprintf(
"%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
meta.BaseURL,
meta.Config.VertexAIProjectID,
meta.Config.Region,
meta.ActualModelName,
suffix,
), nil
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
meta.Config.Region,
meta.Config.VertexAIProjectID,
meta.Config.Region,
meta.ActualModelName,
suffix,
), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
token, err := getToken(c, meta.ChannelId, meta.Config.VertexAIADC)
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+token)
return nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}

View File

@ -0,0 +1,55 @@
package vertexai
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var ModelList = []string{
"claude-3-haiku@20240307", "claude-3-opus@20240229", "claude-3-5-sonnet@20240620", "claude-3-sonnet@20240229",
}
const anthropicVersion = "vertex-2023-10-16"
type Adaptor struct {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
claudeReq := anthropic.ConvertRequest(*request)
req := Request{
AnthropicVersion: anthropicVersion,
// Model: claudeReq.Model,
Messages: claudeReq.Messages,
System: claudeReq.System,
MaxTokens: claudeReq.MaxTokens,
Temperature: claudeReq.Temperature,
TopP: claudeReq.TopP,
TopK: claudeReq.TopK,
Stream: claudeReq.Stream,
Tools: claudeReq.Tools,
}
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, req)
return req, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = anthropic.StreamHandler(c, resp)
} else {
err, usage = anthropic.Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}

View File

@ -0,0 +1,19 @@
package vertexai
import "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
type Request struct {
// AnthropicVersion must be "vertex-2023-10-16"
AnthropicVersion string `json:"anthropic_version"`
// Model string `json:"model"`
Messages []anthropic.Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
}

View File

@ -0,0 +1,49 @@
package vertexai
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var ModelList = []string{
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
}
type Adaptor struct {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
geminiRequest := gemini.ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, geminiRequest)
return geminiRequest, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = gemini.StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else {
switch meta.Mode {
case relaymode.Embeddings:
err, usage = gemini.EmbeddingHandler(c, resp)
default:
err, usage = gemini.Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
}
return
}

View File

@ -0,0 +1,50 @@
package vertexai
import (
"net/http"
"github.com/gin-gonic/gin"
claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude"
gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
type VertexAIModelType int
const (
VerterAIClaude VertexAIModelType = iota + 1
VerterAIGemini
)
var modelMapping = map[string]VertexAIModelType{}
var modelList = []string{}
func init() {
modelList = append(modelList, claude.ModelList...)
for _, model := range claude.ModelList {
modelMapping[model] = VerterAIClaude
}
modelList = append(modelList, gemini.ModelList...)
for _, model := range gemini.ModelList {
modelMapping[model] = VerterAIGemini
}
}
type innerAIAdapter interface {
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
}
func GetAdaptor(model string) innerAIAdapter {
adaptorType := modelMapping[model]
switch adaptorType {
case VerterAIClaude:
return &claude.Adaptor{}
case VerterAIGemini:
return &gemini.Adaptor{}
default:
return nil
}
}

View File

@ -0,0 +1,62 @@
package vertexai
import (
"context"
"encoding/json"
"fmt"
"time"
credentials "cloud.google.com/go/iam/credentials/apiv1"
"cloud.google.com/go/iam/credentials/apiv1/credentialspb"
"github.com/patrickmn/go-cache"
"google.golang.org/api/option"
)
type ApplicationDefaultCredentials struct {
Type string `json:"type"`
ProjectID string `json:"project_id"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
ClientEmail string `json:"client_email"`
ClientID string `json:"client_id"`
AuthURI string `json:"auth_uri"`
TokenURI string `json:"token_uri"`
AuthProviderX509CertURL string `json:"auth_provider_x509_cert_url"`
ClientX509CertURL string `json:"client_x509_cert_url"`
UniverseDomain string `json:"universe_domain"`
}
var Cache = cache.New(50*time.Minute, 55*time.Minute)
const defaultScope = "https://www.googleapis.com/auth/cloud-platform"
func getToken(ctx context.Context, channelId int, adcJson string) (string, error) {
cacheKey := fmt.Sprintf("vertexai-token-%d", channelId)
if token, found := Cache.Get(cacheKey); found {
return token.(string), nil
}
adc := &ApplicationDefaultCredentials{}
if err := json.Unmarshal([]byte(adcJson), adc); err != nil {
return "", fmt.Errorf("Failed to decode credentials file: %w", err)
}
c, err := credentials.NewIamCredentialsClient(ctx, option.WithCredentialsJSON([]byte(adcJson)))
if err != nil {
return "", fmt.Errorf("Failed to create client: %w", err)
}
defer c.Close()
req := &credentialspb.GenerateAccessTokenRequest{
// See https://pkg.go.dev/cloud.google.com/go/iam/credentials/apiv1/credentialspb#GenerateAccessTokenRequest.
Name: fmt.Sprintf("projects/-/serviceAccounts/%s", adc.ClientEmail),
Scope: []string{defaultScope},
}
resp, err := c.GenerateAccessToken(ctx, req)
if err != nil {
return "", fmt.Errorf("Failed to generate access token: %w", err)
}
_ = resp
Cache.Set(cacheKey, resp.AccessToken, cache.DefaultExpiration)
return resp.AccessToken, nil
}

View File

@ -5,6 +5,7 @@ var ModelList = []string{
"SparkDesk-v1.1", "SparkDesk-v1.1",
"SparkDesk-v2.1", "SparkDesk-v2.1",
"SparkDesk-v3.1", "SparkDesk-v3.1",
"SparkDesk-v3.1-128K",
"SparkDesk-v3.5", "SparkDesk-v3.5",
"SparkDesk-v4.0", "SparkDesk-v4.0",
} }

View File

@ -272,9 +272,9 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl,
} }
func parseAPIVersionByModelName(modelName string) string { func parseAPIVersionByModelName(modelName string) string {
parts := strings.Split(modelName, "-") index := strings.IndexAny(modelName, "-")
if len(parts) == 2 { if index != -1 {
return parts[1] return modelName[index+1:]
} }
return "" return ""
} }
@ -288,6 +288,8 @@ func apiVersion2domain(apiVersion string) string {
return "generalv2" return "generalv2"
case "v3.1": case "v3.1":
return "generalv3" return "generalv3"
case "v3.1-128K":
return "pro-128k"
case "v3.5": case "v3.5":
return "generalv3.5" return "generalv3.5"
case "v4.0": case "v4.0":
@ -297,7 +299,14 @@ func apiVersion2domain(apiVersion string) string {
} }
func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) { func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) {
var authUrl string
domain := apiVersion2domain(apiVersion) domain := apiVersion2domain(apiVersion)
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) switch apiVersion {
case "v3.1-128K":
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/pro-128k", apiVersion), apiKey, apiSecret)
break
default:
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
}
return domain, authUrl return domain, authUrl
} }

View File

@ -17,6 +17,8 @@ const (
Cohere Cohere
Cloudflare Cloudflare
DeepL DeepL
VertexAI
Proxy
Dummy // this one is only for count, do not add any channel after this Dummy // this one is only for count, do not add any channel after this
) )

View File

@ -30,6 +30,14 @@ var ImageSizeRatios = map[string]map[string]float64{
"720x1280": 1, "720x1280": 1,
"1280x720": 1, "1280x720": 1,
}, },
"step-1x-medium": {
"256x256": 1,
"512x512": 1,
"768x768": 1,
"1024x1024": 1,
"1280x800": 1,
"800x1280": 1,
},
} }
var ImageGenerationAmounts = map[string][2]int{ var ImageGenerationAmounts = map[string][2]int{
@ -39,6 +47,7 @@ var ImageGenerationAmounts = map[string][2]int{
"ali-stable-diffusion-v1.5": {1, 4}, // Ali "ali-stable-diffusion-v1.5": {1, 4}, // Ali
"wanx-v1": {1, 4}, // Ali "wanx-v1": {1, 4}, // Ali
"cogview-3": {1, 1}, "cogview-3": {1, 1},
"step-1x-medium": {1, 1},
} }
var ImagePromptLengthLimitations = map[string]int{ var ImagePromptLengthLimitations = map[string]int{
@ -48,6 +57,7 @@ var ImagePromptLengthLimitations = map[string]int{
"ali-stable-diffusion-v1.5": 4000, "ali-stable-diffusion-v1.5": 4000,
"wanx-v1": 4000, "wanx-v1": 4000,
"cogview-3": 833, "cogview-3": 833,
"step-1x-medium": 4000,
} }
var ImageOriginModelName = map[string]string{ var ImageOriginModelName = map[string]string{

View File

@ -2,6 +2,7 @@ package ratio
import ( import (
"encoding/json" "encoding/json"
"fmt"
"strings" "strings"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
@ -33,7 +34,11 @@ var ModelRatio = map[string]float64{
"gpt-4-turbo": 5, // $0.01 / 1K tokens "gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-4o": 2.5, // $0.005 / 1K tokens "gpt-4o": 2.5, // $0.005 / 1K tokens
"chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
"gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens "gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
"gpt-3.5-turbo-0301": 0.75, "gpt-3.5-turbo-0301": 0.75,
@ -95,12 +100,11 @@ var ModelRatio = map[string]float64{
"bge-large-en": 0.002 * RMB, "bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB, "tao-8k": 0.002 * RMB,
// https://ai.google.dev/pricing // https://ai.google.dev/pricing
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-1.0-pro": 1,
"gemini-1.0-pro-vision-001": 1, "gemini-1.5-flash": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro": 1, "gemini-1.5-pro": 1,
"aqa": 1,
// https://open.bigmodel.cn/pricing // https://open.bigmodel.cn/pricing
"glm-4": 0.1 * RMB, "glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB, "glm-4v": 0.1 * RMB,
@ -124,6 +128,7 @@ var ModelRatio = map[string]float64{
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
@ -155,20 +160,32 @@ var ModelRatio = map[string]float64{
"mistral-large-latest": 8.0 / 1000 * USD, "mistral-large-latest": 8.0 / 1000 * USD,
"mistral-embed": 0.1 / 1000 * USD, "mistral-embed": 0.1 / 1000 * USD,
// https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed // https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed
"llama3-70b-8192": 0.59 / 1000 * USD, "gemma-7b-it": 0.07 / 1000000 * USD,
"mixtral-8x7b-32768": 0.27 / 1000 * USD, "mixtral-8x7b-32768": 0.24 / 1000000 * USD,
"llama3-8b-8192": 0.05 / 1000 * USD, "llama3-8b-8192": 0.05 / 1000000 * USD,
"gemma-7b-it": 0.1 / 1000 * USD, "llama3-70b-8192": 0.59 / 1000000 * USD,
"llama2-70b-4096": 0.64 / 1000 * USD, "gemma2-9b-it": 0.20 / 1000000 * USD,
"llama2-7b-2048": 0.1 / 1000 * USD, "llama-3.1-405b-reasoning": 0.89 / 1000000 * USD,
"llama-3.1-70b-versatile": 0.59 / 1000000 * USD,
"llama-3.1-8b-instant": 0.05 / 1000000 * USD,
"llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD,
"llama3-groq-8b-8192-tool-use-preview": 0.19 / 1000000 * USD,
// https://platform.lingyiwanwu.com/docs#-计费单元 // https://platform.lingyiwanwu.com/docs#-计费单元
"yi-34b-chat-0205": 2.5 / 1000 * RMB, "yi-34b-chat-0205": 2.5 / 1000 * RMB,
"yi-34b-chat-200k": 12.0 / 1000 * RMB, "yi-34b-chat-200k": 12.0 / 1000 * RMB,
"yi-vl-plus": 6.0 / 1000 * RMB, "yi-vl-plus": 6.0 / 1000 * RMB,
// stepfun todo // https://platform.stepfun.com/docs/pricing/details
"step-1v-32k": 0.024 * RMB, "step-1-8k": 0.005 / 1000 * RMB,
"step-1-32k": 0.024 * RMB, "step-1-32k": 0.015 / 1000 * RMB,
"step-1-200k": 0.15 * RMB, "step-1-128k": 0.040 / 1000 * RMB,
"step-1-256k": 0.095 / 1000 * RMB,
"step-1-flash": 0.001 / 1000 * RMB,
"step-2-16k": 0.038 / 1000 * RMB,
"step-1v-8k": 0.005 / 1000 * RMB,
"step-1v-32k": 0.015 / 1000 * RMB,
// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/
"llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens
"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens
// https://cohere.com/pricing // https://cohere.com/pricing
"command": 0.5, "command": 0.5,
"command-nightly": 0.5, "command-nightly": 0.5,
@ -185,10 +202,16 @@ var ModelRatio = map[string]float64{
"deepl-ja": 25.0 / 1000 * USD, "deepl-ja": 25.0 / 1000 * USD,
} }
var CompletionRatio = map[string]float64{} var CompletionRatio = map[string]float64{
// aws llama3
"llama3-8b-8192(33)": 0.0006 / 0.0003,
"llama3-70b-8192(33)": 0.0035 / 0.00265,
}
var DefaultModelRatio map[string]float64 var (
var DefaultCompletionRatio map[string]float64 DefaultModelRatio map[string]float64
DefaultCompletionRatio map[string]float64
)
func init() { func init() {
DefaultModelRatio = make(map[string]float64) DefaultModelRatio = make(map[string]float64)
@ -234,22 +257,28 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
return json.Unmarshal([]byte(jsonStr), &ModelRatio) return json.Unmarshal([]byte(jsonStr), &ModelRatio)
} }
func GetModelRatio(name string) float64 { func GetModelRatio(name string, channelType int) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet") name = strings.TrimSuffix(name, "-internet")
} }
if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet") name = strings.TrimSuffix(name, "-internet")
} }
ratio, ok := ModelRatio[name] model := fmt.Sprintf("%s(%d)", name, channelType)
if !ok { if ratio, ok := ModelRatio[model]; ok {
ratio, ok = DefaultModelRatio[name] return ratio
}
if ratio, ok := DefaultModelRatio[model]; ok {
return ratio
}
if ratio, ok := ModelRatio[name]; ok {
return ratio
}
if ratio, ok := DefaultModelRatio[name]; ok {
return ratio
} }
if !ok {
logger.SysError("model ratio not found: " + name) logger.SysError("model ratio not found: " + name)
return 30 return 30
}
return ratio
} }
func CompletionRatio2JSONString() string { func CompletionRatio2JSONString() string {
@ -265,7 +294,17 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
return json.Unmarshal([]byte(jsonStr), &CompletionRatio) return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
} }
func GetCompletionRatio(name string) float64 { func GetCompletionRatio(name string, channelType int) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
model := fmt.Sprintf("%s(%d)", name, channelType)
if ratio, ok := CompletionRatio[model]; ok {
return ratio
}
if ratio, ok := DefaultCompletionRatio[model]; ok {
return ratio
}
if ratio, ok := CompletionRatio[name]; ok { if ratio, ok := CompletionRatio[name]; ok {
return ratio return ratio
} }
@ -284,6 +323,9 @@ func GetCompletionRatio(name string) float64 {
return 4.0 / 3.0 return 4.0 / 3.0
} }
if strings.HasPrefix(name, "gpt-4") { if strings.HasPrefix(name, "gpt-4") {
if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" {
return 4
}
if strings.HasPrefix(name, "gpt-4-turbo") || if strings.HasPrefix(name, "gpt-4-turbo") ||
strings.HasPrefix(name, "gpt-4o") || strings.HasPrefix(name, "gpt-4o") ||
strings.HasSuffix(name, "preview") { strings.HasSuffix(name, "preview") {
@ -291,6 +333,9 @@ func GetCompletionRatio(name string) float64 {
} }
return 2 return 2
} }
if name == "chatgpt-4o-latest" {
return 3
}
if strings.HasPrefix(name, "claude-3") { if strings.HasPrefix(name, "claude-3") {
return 5 return 5
} }

View File

@ -42,5 +42,9 @@ const (
DeepL DeepL
TogetherAI TogetherAI
Doubao Doubao
Novita
VertextAI
Proxy
SiliconFlow
Dummy Dummy
) )

View File

@ -35,6 +35,10 @@ func ToAPIType(channelType int) int {
apiType = apitype.Cloudflare apiType = apitype.Cloudflare
case DeepL: case DeepL:
apiType = apitype.DeepL apiType = apitype.DeepL
case VertextAI:
apiType = apitype.VertexAI
case Proxy:
apiType = apitype.Proxy
} }
return apiType return apiType

View File

@ -42,6 +42,10 @@ var ChannelBaseURLs = []string{
"https://api-free.deepl.com", // 38 "https://api-free.deepl.com", // 38
"https://api.together.xyz", // 39 "https://api.together.xyz", // 39
"https://ark.cn-beijing.volces.com", // 40 "https://ark.cn-beijing.volces.com", // 40
"https://api.novita.ai/v3/openai", // 41
"", // 42
"", // 43
"https://api.siliconflow.cn", // 44
} }
func init() { func init() {

View File

@ -7,6 +7,10 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/client"
@ -21,9 +25,6 @@ import (
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
) )
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
@ -53,7 +54,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
} }
} }
modelRatio := billingratio.GetModelRatio(audioModel) modelRatio := billingratio.GetModelRatio(audioModel, channelType)
groupRatio := billingratio.GetGroupRatio(group) groupRatio := billingratio.GetGroupRatio(group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
var quota int64 var quota int64

View File

@ -4,6 +4,10 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"math"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
@ -16,9 +20,6 @@ import (
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"math"
"net/http"
"strings"
) )
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) { func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
@ -95,7 +96,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
return return
} }
var quota int64 var quota int64
completionRatio := billingratio.GetCompletionRatio(textRequest.Model) completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType)
promptTokens := usage.PromptTokens promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens completionTokens := usage.CompletionTokens
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))

View File

@ -6,6 +6,9 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
@ -17,8 +20,6 @@ import (
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
) )
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
@ -166,7 +167,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = bytes.NewBuffer(jsonStr) requestBody = bytes.NewBuffer(jsonStr)
} }
modelRatio := billingratio.GetModelRatio(imageModel) modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group) groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)

41
relay/controller/proxy.go Normal file
View File

@ -0,0 +1,41 @@
// Package controller is a package for handling the relay controller
package controller
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
// RelayProxyHelper is a helper function to proxy the request to the upstream service
func RelayProxyHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := meta.GetByContext(c)
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(meta)
resp, err := adaptor.DoRequest(c, meta, c.Request.Body)
if err != nil {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
// do response
_, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
return respErr
}
return nil
}

View File

@ -4,9 +4,13 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype" "github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing" "github.com/songquanpeng/one-api/relay/billing"
@ -14,8 +18,6 @@ import (
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
) )
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
@ -30,12 +32,11 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
meta.IsStream = textRequest.Stream meta.IsStream = textRequest.Stream
// map model name // map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model meta.OriginModelName = textRequest.Model
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping) textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model meta.ActualModelName = textRequest.Model
// get model ratio & group ratio // get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model) modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group) groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
// pre-consume quota // pre-consume quota
@ -54,31 +55,10 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
adaptor.Init(meta) adaptor.Init(meta)
// get request body // get request body
var requestBody io.Reader requestBody, err := getRequestBody(c, meta, textRequest, adaptor)
if meta.APIType == apitype.OpenAI {
// no need to convert request for openai
shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
} }
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
}
// do request // do request
resp, err := adaptor.DoRequest(c, meta, requestBody) resp, err := adaptor.DoRequest(c, meta, requestBody)
@ -102,3 +82,26 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
return nil return nil
} }
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
// no need to convert request for openai
return c.Request.Body, nil
}
// get request body
var requestBody io.Reader
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request failed: %s\n", err.Error())
return nil, err
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
return nil, err
}
logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
return requestBody, nil
}

View File

@ -18,12 +18,15 @@ type Meta struct {
UserId int UserId int
Group string Group string
ModelMapping map[string]string ModelMapping map[string]string
// BaseURL is the proxy url set in the channel config
BaseURL string BaseURL string
APIKey string APIKey string
APIType int APIType int
Config model.ChannelConfig Config model.ChannelConfig
IsStream bool IsStream bool
// OriginModelName is the model name from the raw user request
OriginModelName string OriginModelName string
// ActualModelName is the model name after mapping
ActualModelName string ActualModelName string
RequestURLPath string RequestURLPath string
PromptTokens int // only for DoResponse PromptTokens int // only for DoResponse

View File

@ -2,6 +2,14 @@ package model
type ResponseFormat struct { type ResponseFormat struct {
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
JsonSchema *JSONSchema `json:"json_schema,omitempty"`
}
type JSONSchema struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Schema map[string]interface{} `json:"schema,omitempty"`
Strict *bool `json:"strict,omitempty"`
} }
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
@ -13,6 +21,7 @@ type GeneralOpenAIRequest struct {
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"` Seed float64 `json:"seed,omitempty"`
Stop any `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
@ -28,6 +37,7 @@ type GeneralOpenAIRequest struct {
Dimensions int `json:"dimensions,omitempty"` Dimensions int `json:"dimensions,omitempty"`
Instruction string `json:"instruction,omitempty"` Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
} }
func (r GeneralOpenAIRequest) ParseInput() []string { func (r GeneralOpenAIRequest) ParseInput() []string {

View File

@ -11,4 +11,6 @@ const (
AudioSpeech AudioSpeech
AudioTranscription AudioTranscription
AudioTranslation AudioTranslation
// Proxy is a special relay mode for proxying requests to custom upstream
Proxy
) )

View File

@ -24,6 +24,8 @@ func GetByPath(path string) int {
relayMode = AudioTranscription relayMode = AudioTranscription
} else if strings.HasPrefix(path, "/v1/audio/translations") { } else if strings.HasPrefix(path, "/v1/audio/translations") {
relayMode = AudioTranslation relayMode = AudioTranslation
} else if strings.HasPrefix(path, "/v1/oneapi/proxy") {
relayMode = Proxy
} }
return relayMode return relayMode
} }

View File

@ -23,6 +23,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth)
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth)
apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth)

View File

@ -19,6 +19,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router := router.Group("/v1") relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute()) relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
{ {
relayV1Router.Any("/oneapi/proxy/:channelid/*target", controller.Relay)
relayV1Router.POST("/completions", controller.Relay) relayV1Router.POST("/completions", controller.Relay)
relayV1Router.POST("/chat/completions", controller.Relay) relayV1Router.POST("/chat/completions", controller.Relay)
relayV1Router.POST("/edits", controller.Relay) relayV1Router.POST("/edits", controller.Relay)

View File

@ -47,7 +47,7 @@ const PersonalSetting = () => {
const [countdown, setCountdown] = useState(30); const [countdown, setCountdown] = useState(30);
const [affLink, setAffLink] = useState(''); const [affLink, setAffLink] = useState('');
const [systemToken, setSystemToken] = useState(''); const [systemToken, setSystemToken] = useState('');
// const [models, setModels] = useState([]); const [models, setModels] = useState([]);
const [openTransfer, setOpenTransfer] = useState(false); const [openTransfer, setOpenTransfer] = useState(false);
const [transferAmount, setTransferAmount] = useState(0); const [transferAmount, setTransferAmount] = useState(0);
@ -72,7 +72,7 @@ const PersonalSetting = () => {
console.log(userState); console.log(userState);
} }
); );
// loadModels().then(); loadModels().then();
getAffLink().then(); getAffLink().then();
setTransferAmount(getQuotaPerUnit()); setTransferAmount(getQuotaPerUnit());
}, []); }, []);
@ -127,16 +127,16 @@ const PersonalSetting = () => {
} }
}; };
// const loadModels = async () => { const loadModels = async () => {
// let res = await API.get(`/api/user/models`); let res = await API.get(`/api/user/available_models`);
// const { success, message, data } = res.data; const { success, message, data } = res.data;
// if (success) { if (success) {
// setModels(data); setModels(data);
// console.log(data); console.log(data);
// } else { } else {
// showError(message); showError(message);
// } }
// }; };
const handleAffLinkClick = async (e) => { const handleAffLinkClick = async (e) => {
e.target.select(); e.target.select();
@ -344,7 +344,7 @@ const PersonalSetting = () => {
} }
> >
<Typography.Title heading={6}>调用信息</Typography.Title> <Typography.Title heading={6}>调用信息</Typography.Title>
{/* <Typography.Title heading={6}></Typography.Title> <p>可用模型可点击复制</p>
<div style={{ marginTop: 10 }}> <div style={{ marginTop: 10 }}>
<Space wrap> <Space wrap>
{models.map((model) => ( {models.map((model) => (
@ -355,7 +355,7 @@ const PersonalSetting = () => {
</Tag> </Tag>
))} ))}
</Space> </Space>
</div> */} </div>
</Card> </Card>
{/* <Card {/* <Card
footer={ footer={

View File

@ -11,12 +11,14 @@ import EditToken from '../pages/Token/EditToken';
const COPY_OPTIONS = [ const COPY_OPTIONS = [
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' }, { key: 'next', text: 'ChatGPT Next Web', value: 'next' },
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' } { key: 'opencat', text: 'OpenCat', value: 'opencat' },
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' },
]; ];
const OPEN_LINK_OPTIONS = [ const OPEN_LINK_OPTIONS = [
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' } { key: 'opencat', text: 'OpenCat', value: 'opencat' },
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' }
]; ];
function renderTimestamp(timestamp) { function renderTimestamp(timestamp) {
@ -60,7 +62,12 @@ const TokensTable = () => {
onOpenLink('next-mj'); onOpenLink('next-mj');
} }
}, },
{ node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' } { node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' },
{
node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => {
onOpenLink('lobechat');
}
}
]; ];
const columns = [ const columns = [
@ -177,6 +184,11 @@ const TokensTable = () => {
node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => { node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => {
onOpenLink('opencat', record.key); onOpenLink('opencat', record.key);
} }
},
{
node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => {
onOpenLink('lobechat');
}
} }
] ]
} }
@ -382,6 +394,9 @@ const TokensTable = () => {
case 'next-mj': case 'next-mj':
url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
break; break;
case 'lobechat':
url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}"/v1"}}}`;
break;
default: default:
if (!chatLink) { if (!chatLink) {
showError('管理员未设置聊天链接'); showError('管理员未设置聊天链接');

View File

@ -1,10 +1,13 @@
export const CHANNEL_OPTIONS = [ export const CHANNEL_OPTIONS = [
{ key: 1, text: 'OpenAI', value: 1, color: 'green' }, { key: 1, text: 'OpenAI', value: 1, color: 'green' },
{ key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' },
{ key: 33, text: 'AWS', value: 33, color: 'black' },
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
{ key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, { key: 24, text: 'Google Gemini', value: 24, color: 'orange' },
{ key: 28, text: 'Mistral AI', value: 28, color: 'orange' }, { key: 28, text: 'Mistral AI', value: 28, color: 'orange' },
{ key: 41, text: 'Novita', value: 41, color: 'purple' },
{ key: 40, text: '字节跳动豆包', value: 40, color: 'blue' },
{ key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
@ -17,6 +20,16 @@ export const CHANNEL_OPTIONS = [
{ key: 29, text: 'Groq', value: 29, color: 'orange' }, { key: 29, text: 'Groq', value: 29, color: 'orange' },
{ key: 30, text: 'Ollama', value: 30, color: 'black' }, { key: 30, text: 'Ollama', value: 30, color: 'black' },
{ key: 31, text: '零一万物', value: 31, color: 'green' }, { key: 31, text: '零一万物', value: 31, color: 'green' },
{ key: 32, text: '阶跃星辰', value: 32, color: 'blue' },
{ key: 34, text: 'Coze', value: 34, color: 'blue' },
{ key: 35, text: 'Cohere', value: 35, color: 'blue' },
{ key: 36, text: 'DeepSeek', value: 36, color: 'black' },
{ key: 37, text: 'Cloudflare', value: 37, color: 'orange' },
{ key: 38, text: 'DeepL', value: 38, color: 'black' },
{ key: 39, text: 'together.ai', value: 39, color: 'blue' },
{ key: 42, text: 'VertexAI', value: 42, color: 'blue' },
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' }, { key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' }, { key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },

View File

@ -78,7 +78,7 @@ const EditChannel = (props) => {
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break; break;
case 18: case 18:
localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']; localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.1-128K', 'SparkDesk-v3.5', 'SparkDesk-v4.0'];
break; break;
case 19: case 19:
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 5.4 KiB

After

Width:  |  Height:  |  Size: 4.3 KiB

View File

@ -0,0 +1,7 @@
<svg t="1723135116886" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg"
p-id="10969" width="200" height="200">
<path d="M512 960C265 960 64 759 64 512S265 64 512 64s448 201 448 448-201 448-448 448z m0-882.6c-239.7 0-434.6 195-434.6 434.6s195 434.6 434.6 434.6 434.6-195 434.6-434.6S751.7 77.4 512 77.4z"
p-id="10970" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="60"></path>
<path d="M197.7 512c0-78.3 31.6-98.8 87.2-98.8 56.2 0 87.2 20.5 87.2 98.8s-31 98.8-87.2 98.8c-55.7 0-87.2-20.5-87.2-98.8z m130.4 0c0-46.8-7.8-64.5-43.2-64.5-35.2 0-42.9 17.7-42.9 64.5 0 47.1 7.8 63.7 42.9 63.7 35.4 0 43.2-16.6 43.2-63.7zM409.7 415.9h42.1V608h-42.1V415.9zM653.9 512c0 74.2-37.1 96.1-93.6 96.1h-65.9V415.9h65.9c56.5 0 93.6 16.1 93.6 96.1z m-43.5 0c0-49.3-17.7-60.6-52.3-60.6h-21.6v120.7h21.6c35.4 0 52.3-13.3 52.3-60.1zM686.5 512c0-74.2 36.3-98.8 92.7-98.8 18.3 0 33.2 2.2 44.8 6.4v36.3c-11.9-4.2-26-6.6-42.1-6.6-34.6 0-49.8 15.5-49.8 62.6 0 50.1 15.2 62.6 49.3 62.6 15.8 0 30.2-2.2 44.8-7.5v36c-11.3 4.7-28.5 8-46.8 8-56.1-0.2-92.9-18.7-92.9-99z"
p-id="10971" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="20"></path>
</svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@ -22,7 +22,12 @@ const config = {
turnstile_site_key: '', turnstile_site_key: '',
version: '', version: '',
wechat_login: false, wechat_login: false,
wechat_qrcode: '' wechat_qrcode: '',
oidc: false,
oidc_client_id: '',
oidc_authorization_endpoint: '',
oidc_token_endpoint: '',
oidc_userinfo_endpoint: '',
} }
}; };

View File

@ -13,7 +13,7 @@ export const CHANNEL_OPTIONS = {
}, },
33: { 33: {
key: 33, key: 33,
text: 'AWS Claude', text: 'AWS',
value: 33, value: 33,
color: 'primary' color: 'primary'
}, },
@ -161,6 +161,30 @@ export const CHANNEL_OPTIONS = {
value: 39, value: 39,
color: 'primary' color: 'primary'
}, },
42: {
key: 42,
text: 'VertexAI',
value: 42,
color: 'primary'
},
43: {
key: 43,
text: 'Proxy',
value: 43,
color: 'primary'
},
44: {
key: 44,
text: 'SiliconFlow',
value: 44,
color: 'primary'
},
41: {
key: 41,
text: 'Novita',
value: 41,
color: 'purple'
},
8: { 8: {
key: 8, key: 8,
text: '自定义渠道', text: '自定义渠道',

View File

@ -70,6 +70,28 @@ const useLogin = () => {
} }
}; };
const oidcLogin = async (code, state) => {
try {
const res = await API.get(`/api/oauth/oidc?code=${code}&state=${state}`);
const { success, message, data } = res.data;
if (success) {
if (message === 'bind') {
showSuccess('绑定成功!');
navigate('/panel');
} else {
dispatch({ type: LOGIN, payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
navigate('/panel');
}
}
return { success, message };
} catch (err) {
// 请求失败,设置错误信息
return { success: false, message: '' };
}
}
const wechatLogin = async (code) => { const wechatLogin = async (code) => {
try { try {
const res = await API.get(`/api/oauth/wechat?code=${code}`); const res = await API.get(`/api/oauth/wechat?code=${code}`);
@ -94,7 +116,7 @@ const useLogin = () => {
navigate('/'); navigate('/');
}; };
return { login, logout, githubLogin, wechatLogin, larkLogin }; return { login, logout, githubLogin, wechatLogin, larkLogin,oidcLogin };
}; };
export default useLogin; export default useLogin;

View File

@ -9,6 +9,7 @@ const AuthLogin = Loadable(lazy(() => import('views/Authentication/Auth/Login'))
const AuthRegister = Loadable(lazy(() => import('views/Authentication/Auth/Register'))); const AuthRegister = Loadable(lazy(() => import('views/Authentication/Auth/Register')));
const GitHubOAuth = Loadable(lazy(() => import('views/Authentication/Auth/GitHubOAuth'))); const GitHubOAuth = Loadable(lazy(() => import('views/Authentication/Auth/GitHubOAuth')));
const LarkOAuth = Loadable(lazy(() => import('views/Authentication/Auth/LarkOAuth'))); const LarkOAuth = Loadable(lazy(() => import('views/Authentication/Auth/LarkOAuth')));
const OidcOAuth = Loadable(lazy(() => import('views/Authentication/Auth/OidcOAuth')));
const ForgetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ForgetPassword'))); const ForgetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ForgetPassword')));
const ResetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ResetPassword'))); const ResetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ResetPassword')));
const Home = Loadable(lazy(() => import('views/Home'))); const Home = Loadable(lazy(() => import('views/Home')));
@ -53,6 +54,10 @@ const OtherRoutes = {
path: '/oauth/lark', path: '/oauth/lark',
element: <LarkOAuth /> element: <LarkOAuth />
}, },
{
path: 'oauth/oidc',
element: <OidcOAuth />
},
{ {
path: '/404', path: '/404',
element: <NotFoundView /> element: <NotFoundView />

View File

@ -98,6 +98,21 @@ export async function onLarkOAuthClicked(lark_client_id) {
window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`); window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`);
} }
export async function onOidcClicked(auth_url, client_id, openInNewTab = false) {
const state = await getOAuthState();
if (!state) return;
const redirect_uri = `${window.location.origin}/oauth/oidc`;
const response_type = "code";
const scope = "openid profile email";
const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`;
if (openInNewTab) {
window.open(url);
} else
{
window.location.href = url;
}
}
export function isAdmin() { export function isAdmin() {
let user = localStorage.getItem('user'); let user = localStorage.getItem('user');
if (!user) return false; if (!user) return false;

Some files were not shown because too many files have changed in this diff Show More