Compare commits
320 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
fdd7bf41c0 | ||
|
29389ed44f | ||
|
88acc5a614 | ||
|
a21681096a | ||
|
32f90a79a8 | ||
|
99c8c77504 | ||
|
649ecbf29c | ||
|
3a27c90910 | ||
|
cba82404ae | ||
|
c9ac670ba1 | ||
|
15f815c23c | ||
|
89b63ca96f | ||
|
8cc54489b9 | ||
|
58bf60805e | ||
|
6714cf96d6 | ||
|
f9774698e9 | ||
|
2af6f6a166 | ||
|
04bb3ef392 | ||
|
b4bfa418a8 | ||
|
e7e99e558a | ||
|
402fcf7f79 | ||
|
36039e329e | ||
|
c936198ac8 | ||
|
296ab013b8 | ||
|
5f03c856b4 | ||
|
39383e5532 | ||
|
2a892c1937 | ||
|
adba54acd3 | ||
|
6209ff9ea9 | ||
|
1c44d7e1cd | ||
|
a3eefb7af0 | ||
|
b65bee46fb | ||
|
422a4e8ee5 | ||
|
cf9b5f0b92 | ||
|
65acb94f45 | ||
|
6ad169975f | ||
|
f636c50c84 | ||
|
720fe2dfeb | ||
|
e090e76c86 | ||
|
6a941748f8 | ||
|
46a0773580 | ||
|
ffdb0b0c81 | ||
|
efd30a40b3 | ||
|
d7a78f3397 | ||
|
273be55797 | ||
|
ec6ad24810 | ||
|
c4fe57c165 | ||
|
274fcf3d76 | ||
|
0fc07ea558 | ||
|
1ce1e529ee | ||
|
d936817de9 | ||
|
fecaece71b | ||
|
c135d74f13 | ||
|
d0369b114f | ||
|
b21b3b5b46 | ||
|
ae1cd29f94 | ||
|
f25aaf7752 | ||
|
b70a07e814 | ||
|
34cb147a74 | ||
|
8cc1ee6360 | ||
|
5a58426859 | ||
|
254b9777c0 | ||
|
114c44c6e7 | ||
|
a3c7e15aed | ||
|
3777517f64 | ||
|
9fc5f427dc | ||
|
864a467886 | ||
|
ed78b5340b | ||
|
fee69e7c20 | ||
|
9d23a44dbf | ||
|
6e4cfb20d5 | ||
|
ff196b75a7 | ||
|
279caf82dc | ||
|
b1520b308b | ||
|
ed717211aa | ||
|
6ccf3f3cfc | ||
|
f74577141c | ||
|
6aafb7a99e | ||
|
c1971870fa | ||
|
f83894c83f | ||
|
e9981fff36 | ||
|
98669d5d48 | ||
|
9321427c6e | ||
|
ceea4c6d4a | ||
|
b53e00a9b3 | ||
|
332c8db0b3 | ||
|
3be28da57b | ||
|
fa74ba0eaa | ||
|
a9211d66f6 | ||
|
07b2fd58d6 | ||
|
0acee9a065 | ||
|
f965469e8a | ||
|
03ea60532a | ||
|
2457d00afb | ||
|
91b80ae879 | ||
|
2720e1a358 | ||
|
71f4403fd5 | ||
|
1f76c80553 | ||
|
7e027d2bd0 | ||
|
30f373b623 | ||
|
1c2654320e | ||
|
6cffb116b7 | ||
|
a84c7b38b7 | ||
|
1bd14af47b | ||
|
6170b91d1c | ||
|
04b49aa0ec | ||
|
ef88497f25 | ||
|
007906216d | ||
|
e64e7707a0 | ||
|
ea210b6ed7 | ||
|
9026ec7510 | ||
|
c317872097 | ||
|
da0842272c | ||
|
0a650b85b4 | ||
|
24f026d18e | ||
|
cb33e8aad5 | ||
|
779b747e9e | ||
|
3d149fedf4 | ||
|
83517f687c | ||
|
e30ebda0fe | ||
|
d87c55f542 | ||
|
e5b3e37c46 | ||
|
8de489cf06 | ||
|
d14e4aa01b | ||
|
541182102e | ||
|
b2679cca65 | ||
|
8572fac7a2 | ||
|
a2a00dfbc3 | ||
|
129282f4a9 | ||
|
a873cbd392 | ||
|
35ba1da984 | ||
|
2369025842 | ||
|
f452bd481e | ||
|
ddee58df36 | ||
|
520a62e704 | ||
|
fc9a784950 | ||
|
1a0b039bcf | ||
|
7bf61f9165 | ||
|
a10232f43a | ||
|
af543ab8ec | ||
|
e086da05b1 | ||
|
3af4649b52 | ||
|
52c32c0b4a | ||
|
3fe2863ff7 | ||
|
acf8cb6248 | ||
|
572fc9ffb8 | ||
|
569c04acb0 | ||
|
961b4108e6 | ||
|
0b8ccb94eb | ||
|
f586ae0ad8 | ||
|
24ed170e7b | ||
|
f70506eac1 | ||
|
8f4d78e24d | ||
|
cd2707692f | ||
|
2ab7d25a80 | ||
|
f9d914873f | ||
|
880e12c855 | ||
|
0cb224e62e | ||
|
a44fb5d482 | ||
|
eec41849ec | ||
|
d4347e7a35 | ||
|
b50b43eb65 | ||
|
348adc2b02 | ||
|
dcf24b98dc | ||
|
af679e04f4 | ||
|
93cbca6a9f | ||
|
840ef80d94 | ||
|
9a2662af0d | ||
|
77f9e75654 | ||
|
5b41f57423 | ||
|
0bb7db0b44 | ||
|
4d61b9937b | ||
|
68605800af | ||
|
c49778c254 | ||
|
f02c7138ea | ||
|
ca3228855a | ||
|
f8cc63f00b | ||
|
0a37aa4cbd | ||
|
054b00b725 | ||
|
76569bb0b6 | ||
|
1994256bac | ||
|
1f80b0a39f | ||
|
f73f2e51df | ||
|
6f036bd0c9 | ||
|
fb90747c23 | ||
|
ed70881a58 | ||
|
8b9fa3d6e4 | ||
|
8b9813d63b | ||
|
dc7aaf2de5 | ||
|
065da8ef8c | ||
|
e3cfb1fa52 | ||
|
f89ae5ad58 | ||
|
06a3fc5421 | ||
|
a9c464ec5a | ||
|
3f3c13c98c | ||
|
2ba28c72cb | ||
|
5e81e19bc8 | ||
|
96d7a99312 | ||
|
24be9de098 | ||
|
5b349efff9 | ||
|
f76c46d648 | ||
|
cdfdeea3b4 | ||
|
56ddbb842a | ||
|
99f81a267c | ||
|
c243cd5535 | ||
|
e96b173abe | ||
|
4ae311e964 | ||
|
b14cb748d8 | ||
|
ade19ba4a2 | ||
|
4d86d021c4 | ||
|
7a44adb5a7 | ||
|
9821bc7281 | ||
|
08831881f1 | ||
|
0eb2272bb7 | ||
|
704ec1a827 | ||
|
1d7470d6ad | ||
|
1185303346 | ||
|
c212fcf8d7 | ||
|
c285e000cc | ||
|
d25ed4c009 | ||
|
7400885fbb | ||
|
11af81eb39 | ||
|
205aba694f | ||
|
8dac3afebc | ||
|
a07791bf93 | ||
|
4bb662c0e4 | ||
|
4998d58319 | ||
|
190203cf8f | ||
|
6325c8e0b4 | ||
|
b204f6d82b | ||
|
752639560f | ||
|
996f4d99dd | ||
|
ebfee3b46c | ||
|
3e2e805d61 | ||
|
3edf7247c4 | ||
|
0926b6206b | ||
|
7cd57f3125 | ||
|
66efabd5ae | ||
|
8ede66a896 | ||
|
b169173860 | ||
|
f33555ae78 | ||
|
c28ec10795 | ||
|
e3767cbb07 | ||
|
be9eb59fbb | ||
|
89e111ac69 | ||
|
2dcef85285 | ||
|
79d0cd378a | ||
|
e99150bdb9 | ||
|
a72e5fcc9e | ||
|
0710f8cd66 | ||
|
49cad7d4a5 | ||
|
a90161cf00 | ||
|
a45fc7d736 | ||
|
45940dcb12 | ||
|
969042b001 | ||
|
7e7369dbc4 | ||
|
e54e647170 | ||
|
358920c858 | ||
|
1ea598c773 | ||
|
796be42487 | ||
|
5b50eb94e5 | ||
|
71c61365eb | ||
|
b09f979b80 | ||
|
12440874b0 | ||
|
6ebc99460e | ||
|
27ad8bfb98 | ||
|
8388aa537f | ||
|
2346bf70af | ||
|
f05b403ca5 | ||
|
b33616df44 | ||
|
cf16f44970 | ||
|
bf2e26a48f | ||
|
4fb22ad4ce | ||
|
95cfb8e8c9 | ||
|
c6ace985c2 | ||
|
10a926b8f3 | ||
|
2df877a352 | ||
|
9d8967f7d3 | ||
|
b35f3523d3 | ||
|
82e916b5ff | ||
|
de18d6fe16 | ||
|
1d0b7fb5ae | ||
|
f9490bb72e | ||
|
76467285e8 | ||
|
df1fd9aa81 | ||
|
614c2e0442 | ||
|
eac6a0b9aa | ||
|
b747cdbc6f | ||
|
6b27d6659a | ||
|
dc5b781191 | ||
|
c880b4a9a3 | ||
|
565ea58e68 | ||
|
f141a37a9e | ||
|
5b78886ad3 | ||
|
87c7c4f0e6 | ||
|
4c4a873890 | ||
|
0664bdfda1 | ||
|
32387d9c20 | ||
|
bd888f2eb7 | ||
|
cece77e533 | ||
|
2a5468e23c | ||
|
d0e415893b | ||
|
6cf5ce9a7a | ||
|
f598b9df87 | ||
|
532c50d212 | ||
|
2acc2f5017 | ||
|
604ac56305 | ||
|
9383b638a6 | ||
|
28d512a675 | ||
|
de9a58ca0b | ||
|
1aa374ccfb | ||
|
d548a01c59 | ||
|
2cd1a78203 | ||
|
b9d3cb0c45 | ||
|
ea407f0054 | ||
|
26e2e646cb | ||
|
4f214c48c6 | ||
|
2d760d4a01 | ||
|
e2ed0399f0 | ||
|
eed9f5fdf0 |
3
.env.example
Normal file
3
.env.example
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
PORT=3000
|
||||||
|
DEBUG=false
|
||||||
|
HTTPS_PROXY=http://localhost:7890
|
47
.github/workflows/ci.yml
vendored
Normal file
47
.github/workflows/ci.yml
vendored
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
# This setup assumes that you run the unit tests with code coverage in the same
|
||||||
|
# workflow that will also print the coverage report as comment to the pull request.
|
||||||
|
# Therefore, you need to trigger this workflow when a pull request is (re)opened or
|
||||||
|
# when new code is pushed to the branch of the pull request. In addition, you also
|
||||||
|
# need to trigger this workflow when new code is pushed to the main branch because
|
||||||
|
# we need to upload the code coverage results as artifact for the main branch as
|
||||||
|
# well since it will be the baseline code coverage.
|
||||||
|
#
|
||||||
|
# We do not want to trigger the workflow for pushes to *any* branch because this
|
||||||
|
# would trigger our jobs twice on pull requests (once from "push" event and once
|
||||||
|
# from "pull_request->synchronize")
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [opened, reopened, synchronize]
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- 'main'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
unit_tests:
|
||||||
|
name: "Unit tests"
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: ^1.22
|
||||||
|
|
||||||
|
# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
|
||||||
|
# coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
|
||||||
|
# in the next step as well as the next job.
|
||||||
|
- name: Test
|
||||||
|
run: go test -cover -coverprofile=coverage.txt ./...
|
||||||
|
- uses: codecov/codecov-action@v4
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
|
||||||
|
commit_lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- uses: wagoid/commitlint-github-action@v6
|
54
.github/workflows/docker-image-amd64.yml
vendored
54
.github/workflows/docker-image-amd64.yml
vendored
@ -1,54 +0,0 @@
|
|||||||
name: Publish Docker image (amd64)
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- '*'
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
name:
|
|
||||||
description: 'reason'
|
|
||||||
required: false
|
|
||||||
jobs:
|
|
||||||
push_to_registries:
|
|
||||||
name: Push Docker image to multiple registries
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
packages: write
|
|
||||||
contents: read
|
|
||||||
steps:
|
|
||||||
- name: Check out the repo
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: Save version info
|
|
||||||
run: |
|
|
||||||
git describe --tags > VERSION
|
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
|
||||||
uses: docker/login-action@v2
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Log in to the Container registry
|
|
||||||
uses: docker/login-action@v2
|
|
||||||
with:
|
|
||||||
registry: ghcr.io
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v4
|
|
||||||
with:
|
|
||||||
images: |
|
|
||||||
justsong/one-api
|
|
||||||
ghcr.io/${{ github.repository }}
|
|
||||||
|
|
||||||
- name: Build and push Docker images
|
|
||||||
uses: docker/build-push-action@v3
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
push: true
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
@ -1,9 +1,9 @@
|
|||||||
name: Publish Docker image (amd64, English)
|
name: Publish Docker image (English)
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- 'v*.*.*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
name:
|
name:
|
||||||
@ -20,6 +20,13 @@ jobs:
|
|||||||
- name: Check out the repo
|
- name: Check out the repo
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Check repository URL
|
||||||
|
run: |
|
||||||
|
REPO_URL=$(git config --get remote.origin.url)
|
||||||
|
if [[ $REPO_URL == *"pro" ]]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Save version info
|
- name: Save version info
|
||||||
run: |
|
run: |
|
||||||
git describe --tags > VERSION
|
git describe --tags > VERSION
|
||||||
@ -27,6 +34,13 @@ jobs:
|
|||||||
- name: Translate
|
- name: Translate
|
||||||
run: |
|
run: |
|
||||||
python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json
|
python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json
|
||||||
|
|
||||||
|
- name: Set up QEMU
|
||||||
|
uses: docker/setup-qemu-action@v2
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v2
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
- name: Log in to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
@ -44,6 +58,7 @@ jobs:
|
|||||||
uses: docker/build-push-action@v3
|
uses: docker/build-push-action@v3
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
|
platforms: linux/amd64,linux/arm64
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
@ -1,10 +1,9 @@
|
|||||||
name: Publish Docker image (arm64)
|
name: Publish Docker image
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- 'v*.*.*'
|
||||||
- '!*-alpha*'
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
name:
|
name:
|
||||||
@ -21,6 +20,13 @@ jobs:
|
|||||||
- name: Check out the repo
|
- name: Check out the repo
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Check repository URL
|
||||||
|
run: |
|
||||||
|
REPO_URL=$(git config --get remote.origin.url)
|
||||||
|
if [[ $REPO_URL == *"pro" ]]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Save version info
|
- name: Save version info
|
||||||
run: |
|
run: |
|
||||||
git describe --tags > VERSION
|
git describe --tags > VERSION
|
12
.github/workflows/linux-release.yml
vendored
12
.github/workflows/linux-release.yml
vendored
@ -5,7 +5,7 @@ permissions:
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- 'v*.*.*'
|
||||||
- '!*-alpha*'
|
- '!*-alpha*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
@ -20,10 +20,16 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
- name: Check repository URL
|
||||||
|
run: |
|
||||||
|
REPO_URL=$(git config --get remote.origin.url)
|
||||||
|
if [[ $REPO_URL == *"pro" ]]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
- uses: actions/setup-node@v3
|
- uses: actions/setup-node@v3
|
||||||
with:
|
with:
|
||||||
node-version: 16
|
node-version: 16
|
||||||
- name: Build Frontend (theme default)
|
- name: Build Frontend
|
||||||
env:
|
env:
|
||||||
CI: ""
|
CI: ""
|
||||||
run: |
|
run: |
|
||||||
@ -38,7 +44,7 @@ jobs:
|
|||||||
- name: Build Backend (amd64)
|
- name: Build Backend (amd64)
|
||||||
run: |
|
run: |
|
||||||
go mod download
|
go mod download
|
||||||
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api
|
go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api
|
||||||
|
|
||||||
- name: Build Backend (arm64)
|
- name: Build Backend (arm64)
|
||||||
run: |
|
run: |
|
||||||
|
12
.github/workflows/macos-release.yml
vendored
12
.github/workflows/macos-release.yml
vendored
@ -5,7 +5,7 @@ permissions:
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- 'v*.*.*'
|
||||||
- '!*-alpha*'
|
- '!*-alpha*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
@ -20,10 +20,16 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
- name: Check repository URL
|
||||||
|
run: |
|
||||||
|
REPO_URL=$(git config --get remote.origin.url)
|
||||||
|
if [[ $REPO_URL == *"pro" ]]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
- uses: actions/setup-node@v3
|
- uses: actions/setup-node@v3
|
||||||
with:
|
with:
|
||||||
node-version: 16
|
node-version: 16
|
||||||
- name: Build Frontend (theme default)
|
- name: Build Frontend
|
||||||
env:
|
env:
|
||||||
CI: ""
|
CI: ""
|
||||||
run: |
|
run: |
|
||||||
@ -38,7 +44,7 @@ jobs:
|
|||||||
- name: Build Backend
|
- name: Build Backend
|
||||||
run: |
|
run: |
|
||||||
go mod download
|
go mod download
|
||||||
go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos
|
go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos
|
||||||
- name: Release
|
- name: Release
|
||||||
uses: softprops/action-gh-release@v1
|
uses: softprops/action-gh-release@v1
|
||||||
if: startsWith(github.ref, 'refs/tags/')
|
if: startsWith(github.ref, 'refs/tags/')
|
||||||
|
12
.github/workflows/windows-release.yml
vendored
12
.github/workflows/windows-release.yml
vendored
@ -5,7 +5,7 @@ permissions:
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- 'v*.*.*'
|
||||||
- '!*-alpha*'
|
- '!*-alpha*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
@ -23,10 +23,16 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
- name: Check repository URL
|
||||||
|
run: |
|
||||||
|
REPO_URL=$(git config --get remote.origin.url)
|
||||||
|
if [[ $REPO_URL == *"pro" ]]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
- uses: actions/setup-node@v3
|
- uses: actions/setup-node@v3
|
||||||
with:
|
with:
|
||||||
node-version: 16
|
node-version: 16
|
||||||
- name: Build Frontend (theme default)
|
- name: Build Frontend
|
||||||
env:
|
env:
|
||||||
CI: ""
|
CI: ""
|
||||||
run: |
|
run: |
|
||||||
@ -41,7 +47,7 @@ jobs:
|
|||||||
- name: Build Backend
|
- name: Build Backend
|
||||||
run: |
|
run: |
|
||||||
go mod download
|
go mod download
|
||||||
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe
|
go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe
|
||||||
- name: Release
|
- name: Release
|
||||||
uses: softprops/action-gh-release@v1
|
uses: softprops/action-gh-release@v1
|
||||||
if: startsWith(github.ref, 'refs/tags/')
|
if: startsWith(github.ref, 'refs/tags/')
|
||||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -7,3 +7,6 @@ build
|
|||||||
*.db-journal
|
*.db-journal
|
||||||
logs
|
logs
|
||||||
data
|
data
|
||||||
|
/web/node_modules
|
||||||
|
cmd.md
|
||||||
|
.env
|
12
Dockerfile
12
Dockerfile
@ -1,4 +1,4 @@
|
|||||||
FROM node:16 as builder
|
FROM --platform=$BUILDPLATFORM node:16 AS builder
|
||||||
|
|
||||||
WORKDIR /web
|
WORKDIR /web
|
||||||
COPY ./VERSION .
|
COPY ./VERSION .
|
||||||
@ -12,7 +12,13 @@ WORKDIR /web/berry
|
|||||||
RUN npm install
|
RUN npm install
|
||||||
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
|
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
|
||||||
|
|
||||||
FROM golang AS builder2
|
WORKDIR /web/air
|
||||||
|
RUN npm install
|
||||||
|
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
|
||||||
|
|
||||||
|
FROM golang:alpine AS builder2
|
||||||
|
|
||||||
|
RUN apk add --no-cache g++
|
||||||
|
|
||||||
ENV GO111MODULE=on \
|
ENV GO111MODULE=on \
|
||||||
CGO_ENABLED=1 \
|
CGO_ENABLED=1 \
|
||||||
@ -23,7 +29,7 @@ ADD go.mod go.sum ./
|
|||||||
RUN go mod download
|
RUN go mod download
|
||||||
COPY . .
|
COPY . .
|
||||||
COPY --from=builder /web/build ./web/build
|
COPY --from=builder /web/build ./web/build
|
||||||
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
|
RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
|
||||||
|
|
||||||
FROM alpine
|
FROM alpine
|
||||||
|
|
||||||
|
47
README.en.md
47
README.en.md
@ -134,12 +134,12 @@ The initial account username is `root` and password is `123456`.
|
|||||||
git clone https://github.com/songquanpeng/one-api.git
|
git clone https://github.com/songquanpeng/one-api.git
|
||||||
|
|
||||||
# Build the frontend
|
# Build the frontend
|
||||||
cd one-api/web
|
cd one-api/web/default
|
||||||
npm install
|
npm install
|
||||||
npm run build
|
npm run build
|
||||||
|
|
||||||
# Build the backend
|
# Build the backend
|
||||||
cd ..
|
cd ../..
|
||||||
go mod download
|
go mod download
|
||||||
go build -ldflags "-s -w" -o one-api
|
go build -ldflags "-s -w" -o one-api
|
||||||
```
|
```
|
||||||
@ -241,18 +241,45 @@ If the channel ID is not provided, load balancing will be used to distribute the
|
|||||||
+ Example: `SESSION_SECRET=random_string`
|
+ Example: `SESSION_SECRET=random_string`
|
||||||
3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0.
|
3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0.
|
||||||
+ Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
|
+ Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
|
||||||
4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address.
|
4. `LOG_SQL_DSN`: When set, a separate database will be used for the `logs` table; please use MySQL or PostgreSQL.
|
||||||
|
+ Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs`
|
||||||
|
5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address.
|
||||||
+ Example: `FRONTEND_BASE_URL=https://openai.justsong.cn`
|
+ Example: `FRONTEND_BASE_URL=https://openai.justsong.cn`
|
||||||
5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
|
6. 'MEMORY_CACHE_ENABLED': Enabling memory caching can cause a certain delay in updating user quotas, with optional values of 'true' and 'false'. If not set, it defaults to 'false'.
|
||||||
|
7. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
|
||||||
+ Example: `SYNC_FREQUENCY=60`
|
+ Example: `SYNC_FREQUENCY=60`
|
||||||
6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
|
8. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
|
||||||
+ Example: `NODE_TYPE=slave`
|
+ Example: `NODE_TYPE=slave`
|
||||||
7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
|
9. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
|
||||||
+ Example: `CHANNEL_UPDATE_FREQUENCY=1440`
|
+ Example: `CHANNEL_UPDATE_FREQUENCY=1440`
|
||||||
8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
|
10. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
|
||||||
+ Example: `CHANNEL_TEST_FREQUENCY=1440`
|
+ Example: `CHANNEL_TEST_FREQUENCY=1440`
|
||||||
9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
|
11. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
|
||||||
+ Example: `POLLING_INTERVAL=5`
|
+ Example: `POLLING_INTERVAL=5`
|
||||||
|
12. `BATCH_UPDATE_ENABLED`: Enabling batch database update aggregation can cause a certain delay in updating user quotas. The optional values are 'true' and 'false', but if not set, it defaults to 'false'.
|
||||||
|
+Example: ` BATCH_UPDATE_ENABLED=true`
|
||||||
|
+If you encounter an issue with too many database connections, you can try enabling this option.
|
||||||
|
13. `BATCH_UPDATE_INTERVAL=5`: The time interval for batch updating aggregates, measured in seconds, defaults to '5'.
|
||||||
|
+Example: ` BATCH_UPDATE_INTERVAL=5`
|
||||||
|
14. Request frequency limit:
|
||||||
|
+ `GLOBAL_API_RATE_LIMIT`: Global API rate limit (excluding relay requests), the maximum number of requests within three minutes per IP, default to 180.
|
||||||
|
+ `GLOBAL_WEL_RATE_LIMIT`: Global web speed limit, the maximum number of requests within three minutes per IP, default to 60.
|
||||||
|
15. Encoder cache settings:
|
||||||
|
+`TIKTOKEN_CACHE_DIR`: By default, when the program starts, it will download the encoding of some common word elements online, such as' gpt-3.5 turbo '. In some unstable network environments or offline situations, it may cause startup problems. This directory can be configured to cache data and can be migrated to an offline environment.
|
||||||
|
+`DATA_GYM_CACHE_DIR`: Currently, this configuration has the same function as' TIKTOKEN-CACHE-DIR ', but its priority is not as high as it.
|
||||||
|
16. `RELAY_TIMEOUT`: Relay timeout setting, measured in seconds, with no default timeout time set.
|
||||||
|
17. `RELAY_PROXY`: After setting up, use this proxy to request APIs.
|
||||||
|
18. `USER_CONTENT_REQUEST_TIMEOUT`: The timeout period for users to upload and download content, measured in seconds.
|
||||||
|
19. `USER_CONTENT_REQUEST_PROXY`: After setting up, use this agent to request content uploaded by users, such as images.
|
||||||
|
20. `SQLITE_BUSY_TIMEOUT`: SQLite lock wait timeout setting, measured in milliseconds, default to '3000'.
|
||||||
|
21. `GEMINI_SAFETY_SETTING`: Gemini's security settings are set to 'BLOCK-NONE' by default.
|
||||||
|
22. `GEMINI_VERSION`: The Gemini version used by the One API, which defaults to 'v1'.
|
||||||
|
23. `THE`: The system's theme setting, default to 'default', specific optional values refer to [here] (./web/README. md).
|
||||||
|
24. `ENABLE_METRIC`: Whether to disable channels based on request success rate, default not enabled, optional values are 'true' and 'false'.
|
||||||
|
25. `METRIC_QUEUE_SIZE`: Request success rate statistics queue size, default to '10'.
|
||||||
|
26. `METRIC_SUCCESS_RATE_THRESHOLD`: Request success rate threshold, default to '0.8'.
|
||||||
|
27. `INITIAL_ROOT_TOKEN`: If this value is set, a root user token with the value of the environment variable will be automatically created when the system starts for the first time.
|
||||||
|
28. `INITIAL_ROOT_ACCESS_TOKEN`: If this value is set, a system management token will be automatically created for the root user with a value of the environment variable when the system starts for the first time.
|
||||||
|
|
||||||
### Command Line Parameters
|
### Command Line Parameters
|
||||||
1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`.
|
1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`.
|
||||||
@ -285,7 +312,9 @@ If the channel ID is not provided, load balancing will be used to distribute the
|
|||||||
+ Double-check that your interface address and API Key are correct.
|
+ Double-check that your interface address and API Key are correct.
|
||||||
|
|
||||||
## Related Projects
|
## Related Projects
|
||||||
[FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM
|
* [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM
|
||||||
|
* [VChart](https://github.com/VisActor/VChart): More than just a cross-platform charting library, but also an expressive data storyteller.
|
||||||
|
* [VMind](https://github.com/VisActor/VMind): Not just automatic, but also fantastic. Open-source solution for intelligent visualization.
|
||||||
|
|
||||||
## Note
|
## Note
|
||||||
This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes.
|
This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes.
|
||||||
|
17
README.ja.md
17
README.ja.md
@ -135,12 +135,12 @@ sudo service nginx restart
|
|||||||
git clone https://github.com/songquanpeng/one-api.git
|
git clone https://github.com/songquanpeng/one-api.git
|
||||||
|
|
||||||
# フロントエンドのビルド
|
# フロントエンドのビルド
|
||||||
cd one-api/web
|
cd one-api/web/default
|
||||||
npm install
|
npm install
|
||||||
npm run build
|
npm run build
|
||||||
|
|
||||||
# バックエンドのビルド
|
# バックエンドのビルド
|
||||||
cd ..
|
cd ../..
|
||||||
go mod download
|
go mod download
|
||||||
go build -ldflags "-s -w" -o one-api
|
go build -ldflags "-s -w" -o one-api
|
||||||
```
|
```
|
||||||
@ -242,17 +242,18 @@ graph LR
|
|||||||
+ 例: `SESSION_SECRET=random_string`
|
+ 例: `SESSION_SECRET=random_string`
|
||||||
3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。
|
3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。
|
||||||
+ 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
|
+ 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
|
||||||
4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。
|
4. `LOG_SQL_DSN`: を設定すると、`logs`テーブルには独立したデータベースが使用されます。MySQLまたはPostgreSQLを使用してください。
|
||||||
|
5. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。
|
||||||
+ 例: `FRONTEND_BASE_URL=https://openai.justsong.cn`
|
+ 例: `FRONTEND_BASE_URL=https://openai.justsong.cn`
|
||||||
5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。
|
6. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。
|
||||||
+ 例: `SYNC_FREQUENCY=60`
|
+ 例: `SYNC_FREQUENCY=60`
|
||||||
6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。
|
7. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。
|
||||||
+ 例: `NODE_TYPE=slave`
|
+ 例: `NODE_TYPE=slave`
|
||||||
7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。
|
8. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。
|
||||||
+ 例: `CHANNEL_UPDATE_FREQUENCY=1440`
|
+ 例: `CHANNEL_UPDATE_FREQUENCY=1440`
|
||||||
8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。
|
9. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。
|
||||||
+ 例: `CHANNEL_TEST_FREQUENCY=1440`
|
+ 例: `CHANNEL_TEST_FREQUENCY=1440`
|
||||||
9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。
|
10. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。
|
||||||
+ 例: `POLLING_INTERVAL=5`
|
+ 例: `POLLING_INTERVAL=5`
|
||||||
|
|
||||||
### コマンドラインパラメータ
|
### コマンドラインパラメータ
|
||||||
|
84
README.md
84
README.md
@ -65,21 +65,38 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
|||||||
## 功能
|
## 功能
|
||||||
1. 支持多种大模型:
|
1. 支持多种大模型:
|
||||||
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
|
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
|
||||||
+ [x] [Anthropic Claude 系列模型](https://anthropic.com)
|
+ [x] [Anthropic Claude 系列模型](https://anthropic.com) (支持 AWS Claude)
|
||||||
+ [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
|
+ [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
|
||||||
|
+ [x] [Mistral 系列模型](https://mistral.ai/)
|
||||||
|
+ [x] [字节跳动豆包大模型](https://console.volcengine.com/ark/region:ark+cn-beijing/model)
|
||||||
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
|
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
|
||||||
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
|
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
|
||||||
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
|
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
|
||||||
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
|
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
|
||||||
+ [x] [360 智脑](https://ai.360.cn)
|
+ [x] [360 智脑](https://ai.360.cn)
|
||||||
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
|
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
|
||||||
|
+ [x] [Moonshot AI](https://platform.moonshot.cn/)
|
||||||
|
+ [x] [百川大模型](https://platform.baichuan-ai.com)
|
||||||
|
+ [x] [MINIMAX](https://api.minimax.chat/)
|
||||||
|
+ [x] [Groq](https://wow.groq.com/)
|
||||||
|
+ [x] [Ollama](https://github.com/ollama/ollama)
|
||||||
|
+ [x] [零一万物](https://platform.lingyiwanwu.com/)
|
||||||
|
+ [x] [阶跃星辰](https://platform.stepfun.com/)
|
||||||
|
+ [x] [Coze](https://www.coze.com/)
|
||||||
|
+ [x] [Cohere](https://cohere.com/)
|
||||||
|
+ [x] [DeepSeek](https://www.deepseek.com/)
|
||||||
|
+ [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/)
|
||||||
|
+ [x] [DeepL](https://www.deepl.com/)
|
||||||
|
+ [x] [together.ai](https://www.together.ai/)
|
||||||
|
+ [x] [novita.ai](https://www.novita.ai/)
|
||||||
|
+ [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud)
|
||||||
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
|
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
|
||||||
3. 支持通过**负载均衡**的方式访问多个渠道。
|
3. 支持通过**负载均衡**的方式访问多个渠道。
|
||||||
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
||||||
5. 支持**多机部署**,[详见此处](#多机部署)。
|
5. 支持**多机部署**,[详见此处](#多机部署)。
|
||||||
6. 支持**令牌管理**,设置令牌的过期时间和额度。
|
6. 支持**令牌管理**,设置令牌的过期时间、额度、允许的 IP 范围以及允许的模型访问。
|
||||||
7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
|
7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
|
||||||
8. 支持**通道管理**,批量创建通道。
|
8. 支持**渠道管理**,批量创建渠道。
|
||||||
9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
|
9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
|
||||||
10. 支持渠道**设置模型列表**。
|
10. 支持渠道**设置模型列表**。
|
||||||
11. 支持**查看额度明细**。
|
11. 支持**查看额度明细**。
|
||||||
@ -93,13 +110,15 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
|||||||
19. 支持丰富的**自定义**设置,
|
19. 支持丰富的**自定义**设置,
|
||||||
1. 支持自定义系统名称,logo 以及页脚。
|
1. 支持自定义系统名称,logo 以及页脚。
|
||||||
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
|
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
|
||||||
20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。
|
20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。。
|
||||||
21. 支持 Cloudflare Turnstile 用户校验。
|
21. 支持 Cloudflare Turnstile 用户校验。
|
||||||
22. 支持用户管理,支持**多种用户登录注册方式**:
|
22. 支持用户管理,支持**多种用户登录注册方式**:
|
||||||
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
|
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
|
||||||
|
+ 支持使用飞书进行授权登录。
|
||||||
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
||||||
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
||||||
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
|
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
|
||||||
|
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
|
||||||
|
|
||||||
## 部署
|
## 部署
|
||||||
### 基于 Docker 进行部署
|
### 基于 Docker 进行部署
|
||||||
@ -174,12 +193,12 @@ docker-compose ps
|
|||||||
git clone https://github.com/songquanpeng/one-api.git
|
git clone https://github.com/songquanpeng/one-api.git
|
||||||
|
|
||||||
# 构建前端
|
# 构建前端
|
||||||
cd one-api/web
|
cd one-api/web/default
|
||||||
npm install
|
npm install
|
||||||
npm run build
|
npm run build
|
||||||
|
|
||||||
# 构建后端
|
# 构建后端
|
||||||
cd ..
|
cd ../..
|
||||||
go mod download
|
go mod download
|
||||||
go build -ldflags "-s -w" -o one-api
|
go build -ldflags "-s -w" -o one-api
|
||||||
````
|
````
|
||||||
@ -233,9 +252,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
|||||||
#### QChatGPT - QQ机器人
|
#### QChatGPT - QQ机器人
|
||||||
项目主页:https://github.com/RockChinQ/QChatGPT
|
项目主页:https://github.com/RockChinQ/QChatGPT
|
||||||
|
|
||||||
根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
|
根据[文档](https://qchatgpt.rockchin.top)完成部署后,在 `data/provider.json`设置`requester.openai-chat-completions.base-url`为 One API 实例地址,并填写 API Key 到 `keys.openai` 组中,设置 `model` 为要使用的模型名称。
|
||||||
|
|
||||||
可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
|
运行期间可以通过`!model`命令查看、切换可用模型。
|
||||||
|
|
||||||
### 部署到第三方平台
|
### 部署到第三方平台
|
||||||
<details>
|
<details>
|
||||||
@ -323,6 +342,7 @@ graph LR
|
|||||||
不加的话将会使用负载均衡的方式使用多个渠道。
|
不加的话将会使用负载均衡的方式使用多个渠道。
|
||||||
|
|
||||||
### 环境变量
|
### 环境变量
|
||||||
|
> One API 支持从 `.env` 文件中读取环境变量,请参照 `.env.example` 文件,使用时请将其重命名为 `.env`。
|
||||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
||||||
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
||||||
+ 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。
|
+ 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。
|
||||||
@ -340,35 +360,45 @@ graph LR
|
|||||||
+ `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。
|
+ `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。
|
||||||
+ 如果报错 `Error 1040: Too many connections`,请适当减小该值。
|
+ 如果报错 `Error 1040: Too many connections`,请适当减小该值。
|
||||||
+ `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
|
+ `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
|
||||||
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
|
4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL。
|
||||||
|
5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
|
||||||
+ 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
|
+ 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
|
||||||
5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
||||||
+ 例子:`MEMORY_CACHE_ENABLED=true`
|
+ 例子:`MEMORY_CACHE_ENABLED=true`
|
||||||
6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。
|
7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。
|
||||||
+ 例子:`SYNC_FREQUENCY=60`
|
+ 例子:`SYNC_FREQUENCY=60`
|
||||||
7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
|
8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
|
||||||
+ 例子:`NODE_TYPE=slave`
|
+ 例子:`NODE_TYPE=slave`
|
||||||
8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
|
9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
|
||||||
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
|
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
|
||||||
9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
|
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
|
||||||
+例子:`CHANNEL_TEST_FREQUENCY=1440`
|
+例子:`CHANNEL_TEST_FREQUENCY=1440`
|
||||||
10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
|
11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
|
||||||
+ 例子:`POLLING_INTERVAL=5`
|
+ 例子:`POLLING_INTERVAL=5`
|
||||||
11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
||||||
+ 例子:`BATCH_UPDATE_ENABLED=true`
|
+ 例子:`BATCH_UPDATE_ENABLED=true`
|
||||||
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
|
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
|
||||||
12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
|
13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
|
||||||
+ 例子:`BATCH_UPDATE_INTERVAL=5`
|
+ 例子:`BATCH_UPDATE_INTERVAL=5`
|
||||||
13. 请求频率限制:
|
14. 请求频率限制:
|
||||||
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
|
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
|
||||||
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
|
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
|
||||||
14. 编码器缓存设置:
|
15. 编码器缓存设置:
|
||||||
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
|
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
|
||||||
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
|
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
|
||||||
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
|
16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
|
||||||
16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
|
17. `RELAY_PROXY`:设置后使用该代理来请求 API。
|
||||||
17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。
|
18. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。
|
||||||
18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
|
19. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。
|
||||||
|
20. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
|
||||||
|
21. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。
|
||||||
|
22. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。
|
||||||
|
23. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
|
||||||
|
24. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
|
||||||
|
25. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
|
||||||
|
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
|
||||||
|
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
|
||||||
|
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
|
||||||
|
|
||||||
### 命令行参数
|
### 命令行参数
|
||||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||||
@ -407,7 +437,7 @@ https://openai.justsong.cn
|
|||||||
+ 检查你的接口地址和 API Key 有没有填对。
|
+ 检查你的接口地址和 API Key 有没有填对。
|
||||||
+ 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。
|
+ 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。
|
||||||
6. 报错:`当前分组负载已饱和,请稍后再试`
|
6. 报错:`当前分组负载已饱和,请稍后再试`
|
||||||
+ 上游通道 429 了。
|
+ 上游渠道 429 了。
|
||||||
7. 升级之后我的数据会丢失吗?
|
7. 升级之后我的数据会丢失吗?
|
||||||
+ 如果使用 MySQL,不会。
|
+ 如果使用 MySQL,不会。
|
||||||
+ 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
|
+ 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
|
||||||
@ -415,12 +445,14 @@ https://openai.justsong.cn
|
|||||||
+ 一般情况下不需要,系统将在初始化的时候自动调整。
|
+ 一般情况下不需要,系统将在初始化的时候自动调整。
|
||||||
+ 如果需要的话,我会在更新日志中说明,并给出脚本。
|
+ 如果需要的话,我会在更新日志中说明,并给出脚本。
|
||||||
9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`?
|
9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`?
|
||||||
+ 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。
|
+ 这是检测到 ability 表里有些记录的渠道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的渠道。
|
||||||
+ 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。
|
+ 对于每一个渠道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该渠道支持该模型。
|
||||||
|
|
||||||
## 相关项目
|
## 相关项目
|
||||||
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
|
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
|
||||||
* [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用
|
* [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用
|
||||||
|
* [VChart](https://github.com/VisActor/VChart): 不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。
|
||||||
|
* [VMind](https://github.com/VisActor/VMind): 不仅自动,还很智能。开源智能可视化解决方案。
|
||||||
|
|
||||||
## 注意
|
## 注意
|
||||||
|
|
||||||
|
29
common/blacklist/main.go
Normal file
29
common/blacklist/main.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package blacklist
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var blackList sync.Map
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
blackList = sync.Map{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func userId2Key(id int) string {
|
||||||
|
return fmt.Sprintf("userid_%d", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BanUser(id int) {
|
||||||
|
blackList.Store(userId2Key(id), true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnbanUser(id int) {
|
||||||
|
blackList.Delete(userId2Key(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsUserBanned(id int) bool {
|
||||||
|
_, ok := blackList.Load(userId2Key(id))
|
||||||
|
return ok
|
||||||
|
}
|
60
common/client/init.go
Normal file
60
common/client/init.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var HTTPClient *http.Client
|
||||||
|
var ImpatientHTTPClient *http.Client
|
||||||
|
var UserContentRequestHTTPClient *http.Client
|
||||||
|
|
||||||
|
func Init() {
|
||||||
|
if config.UserContentRequestProxy != "" {
|
||||||
|
logger.SysLog(fmt.Sprintf("using %s as proxy to fetch user content", config.UserContentRequestProxy))
|
||||||
|
proxyURL, err := url.Parse(config.UserContentRequestProxy)
|
||||||
|
if err != nil {
|
||||||
|
logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy))
|
||||||
|
}
|
||||||
|
transport := &http.Transport{
|
||||||
|
Proxy: http.ProxyURL(proxyURL),
|
||||||
|
}
|
||||||
|
UserContentRequestHTTPClient = &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
Timeout: time.Second * time.Duration(config.UserContentRequestTimeout),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
UserContentRequestHTTPClient = &http.Client{}
|
||||||
|
}
|
||||||
|
var transport http.RoundTripper
|
||||||
|
if config.RelayProxy != "" {
|
||||||
|
logger.SysLog(fmt.Sprintf("using %s as api relay proxy", config.RelayProxy))
|
||||||
|
proxyURL, err := url.Parse(config.RelayProxy)
|
||||||
|
if err != nil {
|
||||||
|
logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy))
|
||||||
|
}
|
||||||
|
transport = &http.Transport{
|
||||||
|
Proxy: http.ProxyURL(proxyURL),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.RelayTimeout == 0 {
|
||||||
|
HTTPClient = &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
HTTPClient = &http.Client{
|
||||||
|
Timeout: time.Duration(config.RelayTimeout) * time.Second,
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ImpatientHTTPClient = &http.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
}
|
162
common/config/config.go
Normal file
162
common/config/config.go
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/songquanpeng/one-api/common/env"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
var SystemName = "One API"
|
||||||
|
var ServerAddress = "http://localhost:3000"
|
||||||
|
var Footer = ""
|
||||||
|
var Logo = ""
|
||||||
|
var TopUpLink = ""
|
||||||
|
var ChatLink = ""
|
||||||
|
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
||||||
|
var DisplayInCurrencyEnabled = true
|
||||||
|
var DisplayTokenStatEnabled = true
|
||||||
|
|
||||||
|
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
||||||
|
|
||||||
|
var SessionSecret = uuid.New().String()
|
||||||
|
|
||||||
|
var OptionMap map[string]string
|
||||||
|
var OptionMapRWMutex sync.RWMutex
|
||||||
|
|
||||||
|
var ItemsPerPage = 10
|
||||||
|
var MaxRecentItems = 100
|
||||||
|
|
||||||
|
var PasswordLoginEnabled = true
|
||||||
|
var PasswordRegisterEnabled = true
|
||||||
|
var EmailVerificationEnabled = false
|
||||||
|
var GitHubOAuthEnabled = false
|
||||||
|
var OidcEnabled = false
|
||||||
|
var WeChatAuthEnabled = false
|
||||||
|
var TurnstileCheckEnabled = false
|
||||||
|
var RegisterEnabled = true
|
||||||
|
|
||||||
|
var EmailDomainRestrictionEnabled = false
|
||||||
|
var EmailDomainWhitelist = []string{
|
||||||
|
"gmail.com",
|
||||||
|
"163.com",
|
||||||
|
"126.com",
|
||||||
|
"qq.com",
|
||||||
|
"outlook.com",
|
||||||
|
"hotmail.com",
|
||||||
|
"icloud.com",
|
||||||
|
"yahoo.com",
|
||||||
|
"foxmail.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
var DebugEnabled = strings.ToLower(os.Getenv("DEBUG")) == "true"
|
||||||
|
var DebugSQLEnabled = strings.ToLower(os.Getenv("DEBUG_SQL")) == "true"
|
||||||
|
var MemoryCacheEnabled = strings.ToLower(os.Getenv("MEMORY_CACHE_ENABLED")) == "true"
|
||||||
|
|
||||||
|
var LogConsumeEnabled = true
|
||||||
|
|
||||||
|
var SMTPServer = ""
|
||||||
|
var SMTPPort = 587
|
||||||
|
var SMTPAccount = ""
|
||||||
|
var SMTPFrom = ""
|
||||||
|
var SMTPToken = ""
|
||||||
|
|
||||||
|
var GitHubClientId = ""
|
||||||
|
var GitHubClientSecret = ""
|
||||||
|
|
||||||
|
var LarkClientId = ""
|
||||||
|
var LarkClientSecret = ""
|
||||||
|
|
||||||
|
var OidcClientId = ""
|
||||||
|
var OidcClientSecret = ""
|
||||||
|
var OidcWellKnown = ""
|
||||||
|
var OidcAuthorizationEndpoint = ""
|
||||||
|
var OidcTokenEndpoint = ""
|
||||||
|
var OidcUserinfoEndpoint = ""
|
||||||
|
|
||||||
|
var WeChatServerAddress = ""
|
||||||
|
var WeChatServerToken = ""
|
||||||
|
var WeChatAccountQRCodeImageURL = ""
|
||||||
|
|
||||||
|
var MessagePusherAddress = ""
|
||||||
|
var MessagePusherToken = ""
|
||||||
|
|
||||||
|
var TurnstileSiteKey = ""
|
||||||
|
var TurnstileSecretKey = ""
|
||||||
|
|
||||||
|
var QuotaForNewUser int64 = 0
|
||||||
|
var QuotaForInviter int64 = 0
|
||||||
|
var QuotaForInvitee int64 = 0
|
||||||
|
var ChannelDisableThreshold = 5.0
|
||||||
|
var AutomaticDisableChannelEnabled = false
|
||||||
|
var AutomaticEnableChannelEnabled = false
|
||||||
|
var QuotaRemindThreshold int64 = 1000
|
||||||
|
var PreConsumedQuota int64 = 500
|
||||||
|
var ApproximateTokenEnabled = false
|
||||||
|
var RetryTimes = 0
|
||||||
|
|
||||||
|
var RootUserEmail = ""
|
||||||
|
|
||||||
|
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||||
|
|
||||||
|
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||||
|
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||||
|
|
||||||
|
var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
|
||||||
|
|
||||||
|
var BatchUpdateEnabled = false
|
||||||
|
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)
|
||||||
|
|
||||||
|
var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second
|
||||||
|
|
||||||
|
var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||||
|
|
||||||
|
var Theme = env.String("THEME", "default")
|
||||||
|
var ValidThemes = map[string]bool{
|
||||||
|
"default": true,
|
||||||
|
"berry": true,
|
||||||
|
"air": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// All duration's unit is seconds
|
||||||
|
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||||
|
var (
|
||||||
|
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 240)
|
||||||
|
GlobalApiRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
|
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 120)
|
||||||
|
GlobalWebRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
|
UploadRateLimitNum = 10
|
||||||
|
UploadRateLimitDuration int64 = 60
|
||||||
|
|
||||||
|
DownloadRateLimitNum = 10
|
||||||
|
DownloadRateLimitDuration int64 = 60
|
||||||
|
|
||||||
|
CriticalRateLimitNum = 20
|
||||||
|
CriticalRateLimitDuration int64 = 20 * 60
|
||||||
|
)
|
||||||
|
|
||||||
|
var RateLimitKeyExpirationDuration = 20 * time.Minute
|
||||||
|
|
||||||
|
var EnableMetric = env.Bool("ENABLE_METRIC", false)
|
||||||
|
var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
|
||||||
|
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
|
||||||
|
var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
|
||||||
|
var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
|
||||||
|
|
||||||
|
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
|
||||||
|
|
||||||
|
var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN")
|
||||||
|
|
||||||
|
var GeminiVersion = env.String("GEMINI_VERSION", "v1")
|
||||||
|
|
||||||
|
var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
|
||||||
|
|
||||||
|
var RelayProxy = env.String("RELAY_PROXY", "")
|
||||||
|
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
|
||||||
|
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
|
@ -1,227 +1,6 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import "time"
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
var StartTime = time.Now().Unix() // unit: second
|
var StartTime = time.Now().Unix() // unit: second
|
||||||
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
|
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
|
||||||
var SystemName = "One API"
|
|
||||||
var ServerAddress = "http://localhost:3000"
|
|
||||||
var Footer = ""
|
|
||||||
var Logo = ""
|
|
||||||
var TopUpLink = ""
|
|
||||||
var ChatLink = ""
|
|
||||||
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
|
||||||
var DisplayInCurrencyEnabled = true
|
|
||||||
var DisplayTokenStatEnabled = true
|
|
||||||
|
|
||||||
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
|
||||||
|
|
||||||
var SessionSecret = uuid.New().String()
|
|
||||||
|
|
||||||
var OptionMap map[string]string
|
|
||||||
var OptionMapRWMutex sync.RWMutex
|
|
||||||
|
|
||||||
var ItemsPerPage = 10
|
|
||||||
var MaxRecentItems = 100
|
|
||||||
|
|
||||||
var PasswordLoginEnabled = true
|
|
||||||
var PasswordRegisterEnabled = true
|
|
||||||
var EmailVerificationEnabled = false
|
|
||||||
var GitHubOAuthEnabled = false
|
|
||||||
var WeChatAuthEnabled = false
|
|
||||||
var TurnstileCheckEnabled = false
|
|
||||||
var RegisterEnabled = true
|
|
||||||
|
|
||||||
var EmailDomainRestrictionEnabled = false
|
|
||||||
var EmailDomainWhitelist = []string{
|
|
||||||
"gmail.com",
|
|
||||||
"163.com",
|
|
||||||
"126.com",
|
|
||||||
"qq.com",
|
|
||||||
"outlook.com",
|
|
||||||
"hotmail.com",
|
|
||||||
"icloud.com",
|
|
||||||
"yahoo.com",
|
|
||||||
"foxmail.com",
|
|
||||||
}
|
|
||||||
|
|
||||||
var DebugEnabled = os.Getenv("DEBUG") == "true"
|
|
||||||
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
|
|
||||||
|
|
||||||
var LogConsumeEnabled = true
|
|
||||||
|
|
||||||
var SMTPServer = ""
|
|
||||||
var SMTPPort = 587
|
|
||||||
var SMTPAccount = ""
|
|
||||||
var SMTPFrom = ""
|
|
||||||
var SMTPToken = ""
|
|
||||||
|
|
||||||
var GitHubClientId = ""
|
|
||||||
var GitHubClientSecret = ""
|
|
||||||
|
|
||||||
var WeChatServerAddress = ""
|
|
||||||
var WeChatServerToken = ""
|
|
||||||
var WeChatAccountQRCodeImageURL = ""
|
|
||||||
|
|
||||||
var TurnstileSiteKey = ""
|
|
||||||
var TurnstileSecretKey = ""
|
|
||||||
|
|
||||||
var QuotaForNewUser = 0
|
|
||||||
var QuotaForInviter = 0
|
|
||||||
var QuotaForInvitee = 0
|
|
||||||
var ChannelDisableThreshold = 5.0
|
|
||||||
var AutomaticDisableChannelEnabled = false
|
|
||||||
var AutomaticEnableChannelEnabled = false
|
|
||||||
var QuotaRemindThreshold = 1000
|
|
||||||
var PreConsumedQuota = 500
|
|
||||||
var ApproximateTokenEnabled = false
|
|
||||||
var RetryTimes = 0
|
|
||||||
|
|
||||||
var RootUserEmail = ""
|
|
||||||
|
|
||||||
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
|
||||||
|
|
||||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
|
||||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
|
||||||
|
|
||||||
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
|
|
||||||
|
|
||||||
var BatchUpdateEnabled = false
|
|
||||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
|
||||||
|
|
||||||
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
|
||||||
|
|
||||||
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
|
||||||
|
|
||||||
var Theme = GetOrDefaultString("THEME", "default")
|
|
||||||
var ValidThemes = map[string]bool{
|
|
||||||
"default": true,
|
|
||||||
"berry": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
RequestIdKey = "X-Oneapi-Request-Id"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
RoleGuestUser = 0
|
|
||||||
RoleCommonUser = 1
|
|
||||||
RoleAdminUser = 10
|
|
||||||
RoleRootUser = 100
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
FileUploadPermission = RoleGuestUser
|
|
||||||
FileDownloadPermission = RoleGuestUser
|
|
||||||
ImageUploadPermission = RoleGuestUser
|
|
||||||
ImageDownloadPermission = RoleGuestUser
|
|
||||||
)
|
|
||||||
|
|
||||||
// All duration's unit is seconds
|
|
||||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
|
||||||
var (
|
|
||||||
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
|
||||||
GlobalApiRateLimitDuration int64 = 3 * 60
|
|
||||||
|
|
||||||
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
|
||||||
GlobalWebRateLimitDuration int64 = 3 * 60
|
|
||||||
|
|
||||||
UploadRateLimitNum = 10
|
|
||||||
UploadRateLimitDuration int64 = 60
|
|
||||||
|
|
||||||
DownloadRateLimitNum = 10
|
|
||||||
DownloadRateLimitDuration int64 = 60
|
|
||||||
|
|
||||||
CriticalRateLimitNum = 20
|
|
||||||
CriticalRateLimitDuration int64 = 20 * 60
|
|
||||||
)
|
|
||||||
|
|
||||||
var RateLimitKeyExpirationDuration = 20 * time.Minute
|
|
||||||
|
|
||||||
const (
|
|
||||||
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
|
|
||||||
UserStatusDisabled = 2 // also don't use 0
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
|
|
||||||
TokenStatusDisabled = 2 // also don't use 0
|
|
||||||
TokenStatusExpired = 3
|
|
||||||
TokenStatusExhausted = 4
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
|
|
||||||
RedemptionCodeStatusDisabled = 2 // also don't use 0
|
|
||||||
RedemptionCodeStatusUsed = 3 // also don't use 0
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ChannelStatusUnknown = 0
|
|
||||||
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
|
|
||||||
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
|
||||||
ChannelStatusAutoDisabled = 3
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ChannelTypeUnknown = 0
|
|
||||||
ChannelTypeOpenAI = 1
|
|
||||||
ChannelTypeAPI2D = 2
|
|
||||||
ChannelTypeAzure = 3
|
|
||||||
ChannelTypeCloseAI = 4
|
|
||||||
ChannelTypeOpenAISB = 5
|
|
||||||
ChannelTypeOpenAIMax = 6
|
|
||||||
ChannelTypeOhMyGPT = 7
|
|
||||||
ChannelTypeCustom = 8
|
|
||||||
ChannelTypeAILS = 9
|
|
||||||
ChannelTypeAIProxy = 10
|
|
||||||
ChannelTypePaLM = 11
|
|
||||||
ChannelTypeAPI2GPT = 12
|
|
||||||
ChannelTypeAIGC2D = 13
|
|
||||||
ChannelTypeAnthropic = 14
|
|
||||||
ChannelTypeBaidu = 15
|
|
||||||
ChannelTypeZhipu = 16
|
|
||||||
ChannelTypeAli = 17
|
|
||||||
ChannelTypeXunfei = 18
|
|
||||||
ChannelType360 = 19
|
|
||||||
ChannelTypeOpenRouter = 20
|
|
||||||
ChannelTypeAIProxyLibrary = 21
|
|
||||||
ChannelTypeFastGPT = 22
|
|
||||||
ChannelTypeTencent = 23
|
|
||||||
ChannelTypeGemini = 24
|
|
||||||
)
|
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
|
||||||
"", // 0
|
|
||||||
"https://api.openai.com", // 1
|
|
||||||
"https://oa.api2d.net", // 2
|
|
||||||
"", // 3
|
|
||||||
"https://api.closeai-proxy.xyz", // 4
|
|
||||||
"https://api.openai-sb.com", // 5
|
|
||||||
"https://api.openaimax.com", // 6
|
|
||||||
"https://api.ohmygpt.com", // 7
|
|
||||||
"", // 8
|
|
||||||
"https://api.caipacity.com", // 9
|
|
||||||
"https://api.aiproxy.io", // 10
|
|
||||||
"", // 11
|
|
||||||
"https://api.api2gpt.com", // 12
|
|
||||||
"https://api.aigc2d.com", // 13
|
|
||||||
"https://api.anthropic.com", // 14
|
|
||||||
"https://aip.baidubce.com", // 15
|
|
||||||
"https://open.bigmodel.cn", // 16
|
|
||||||
"https://dashscope.aliyuncs.com", // 17
|
|
||||||
"", // 18
|
|
||||||
"https://ai.360.cn", // 19
|
|
||||||
"https://openrouter.ai/api", // 20
|
|
||||||
"https://api.aiproxy.io", // 21
|
|
||||||
"https://fastgpt.run/api/openapi", // 22
|
|
||||||
"https://hunyuan.cloud.tencent.com", //23
|
|
||||||
"", //24
|
|
||||||
}
|
|
||||||
|
6
common/conv/any.go
Normal file
6
common/conv/any.go
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
package conv
|
||||||
|
|
||||||
|
func AsString(v any) string {
|
||||||
|
str, _ := v.(string)
|
||||||
|
return str
|
||||||
|
}
|
23
common/ctxkey/key.go
Normal file
23
common/ctxkey/key.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package ctxkey
|
||||||
|
|
||||||
|
const (
|
||||||
|
Config = "config"
|
||||||
|
Id = "id"
|
||||||
|
Username = "username"
|
||||||
|
Role = "role"
|
||||||
|
Status = "status"
|
||||||
|
Channel = "channel"
|
||||||
|
ChannelId = "channel_id"
|
||||||
|
SpecificChannelId = "specific_channel_id"
|
||||||
|
RequestModel = "request_model"
|
||||||
|
ConvertedRequest = "converted_request"
|
||||||
|
OriginalModel = "original_model"
|
||||||
|
Group = "group"
|
||||||
|
ModelMapping = "model_mapping"
|
||||||
|
ChannelName = "channel_name"
|
||||||
|
TokenId = "token_id"
|
||||||
|
TokenName = "token_name"
|
||||||
|
BaseURL = "base_url"
|
||||||
|
AvailableModels = "available_models"
|
||||||
|
KeyRequestBody = "key_request_body"
|
||||||
|
)
|
@ -1,7 +1,12 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/songquanpeng/one-api/common/env"
|
||||||
|
)
|
||||||
|
|
||||||
var UsingSQLite = false
|
var UsingSQLite = false
|
||||||
var UsingPostgreSQL = false
|
var UsingPostgreSQL = false
|
||||||
|
var UsingMySQL = false
|
||||||
|
|
||||||
var SQLitePath = "one-api.db"
|
var SQLitePath = "one-api.db"
|
||||||
var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000)
|
var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000)
|
||||||
|
@ -1,86 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/tls"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"net/smtp"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func SendEmail(subject string, receiver string, content string) error {
|
|
||||||
if SMTPFrom == "" { // for compatibility
|
|
||||||
SMTPFrom = SMTPAccount
|
|
||||||
}
|
|
||||||
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
|
|
||||||
|
|
||||||
// Extract domain from SMTPFrom
|
|
||||||
parts := strings.Split(SMTPFrom, "@")
|
|
||||||
var domain string
|
|
||||||
if len(parts) > 1 {
|
|
||||||
domain = parts[1]
|
|
||||||
}
|
|
||||||
// Generate a unique Message-ID
|
|
||||||
buf := make([]byte, 16)
|
|
||||||
_, err := rand.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
messageId := fmt.Sprintf("<%x@%s>", buf, domain)
|
|
||||||
|
|
||||||
mail := []byte(fmt.Sprintf("To: %s\r\n"+
|
|
||||||
"From: %s<%s>\r\n"+
|
|
||||||
"Subject: %s\r\n"+
|
|
||||||
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
|
|
||||||
"Date: %s\r\n"+
|
|
||||||
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
|
|
||||||
receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
|
|
||||||
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
|
|
||||||
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
|
|
||||||
to := strings.Split(receiver, ";")
|
|
||||||
|
|
||||||
if SMTPPort == 465 {
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
ServerName: SMTPServer,
|
|
||||||
}
|
|
||||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
client, err := smtp.NewClient(conn, SMTPServer)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
if err = client.Auth(auth); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err = client.Mail(SMTPFrom); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
receiverEmails := strings.Split(receiver, ";")
|
|
||||||
for _, receiver := range receiverEmails {
|
|
||||||
if err = client.Rcpt(receiver); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w, err := client.Data()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = w.Write(mail)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = w.Close()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
@ -15,10 +15,7 @@ type embedFileSystem struct {
|
|||||||
|
|
||||||
func (e embedFileSystem) Exists(prefix string, path string) bool {
|
func (e embedFileSystem) Exists(prefix string, path string) bool {
|
||||||
_, err := e.Open(path)
|
_, err := e.Open(path)
|
||||||
if err != nil {
|
return err == nil
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
|
func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
|
||||||
|
42
common/env/helper.go
vendored
Normal file
42
common/env/helper.go
vendored
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Bool(env string, defaultValue bool) bool {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return os.Getenv(env) == "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
func Int(env string, defaultValue int) int {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
num, err := strconv.Atoi(os.Getenv(env))
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return num
|
||||||
|
}
|
||||||
|
|
||||||
|
func Float64(env string, defaultValue float64) float64 {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
num, err := strconv.ParseFloat(os.Getenv(env), 64)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return num
|
||||||
|
}
|
||||||
|
|
||||||
|
func String(env string, defaultValue string) string {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return os.Getenv(env)
|
||||||
|
}
|
@ -4,30 +4,49 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
||||||
|
requestBody, _ := c.Get(ctxkey.KeyRequestBody)
|
||||||
|
if requestBody != nil {
|
||||||
|
return requestBody.([]byte), nil
|
||||||
|
}
|
||||||
requestBody, err := io.ReadAll(c.Request.Body)
|
requestBody, err := io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = c.Request.Body.Close()
|
_ = c.Request.Body.Close()
|
||||||
|
c.Set(ctxkey.KeyRequestBody, requestBody)
|
||||||
|
return requestBody.([]byte), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||||
|
requestBody, err := GetRequestBody(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
contentType := c.Request.Header.Get("Content-Type")
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
if strings.HasPrefix(contentType, "application/json") {
|
if strings.HasPrefix(contentType, "application/json") {
|
||||||
err = json.Unmarshal(requestBody, &v)
|
err = json.Unmarshal(requestBody, &v)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
} else {
|
} else {
|
||||||
// skip for now
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
// TODO: someday non json request have variant model, we will need to implementation this
|
err = c.ShouldBind(&v)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Reset request body
|
// Reset request body
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetEventStreamHeaders(c *gin.Context) {
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
}
|
||||||
|
139
common/helper/helper.go
Normal file
139
common/helper/helper.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
|
"html/template"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func OpenBrowser(url string) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux":
|
||||||
|
err = exec.Command("xdg-open", url).Start()
|
||||||
|
case "windows":
|
||||||
|
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||||
|
case "darwin":
|
||||||
|
err = exec.Command("open", url).Start()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetIp() (ip string) {
|
||||||
|
ips, err := net.InterfaceAddrs()
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, a := range ips {
|
||||||
|
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
|
||||||
|
if ipNet.IP.To4() != nil {
|
||||||
|
ip = ipNet.IP.String()
|
||||||
|
if strings.HasPrefix(ip, "10") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(ip, "172") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(ip, "192.168") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ip = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var sizeKB = 1024
|
||||||
|
var sizeMB = sizeKB * 1024
|
||||||
|
var sizeGB = sizeMB * 1024
|
||||||
|
|
||||||
|
func Bytes2Size(num int64) string {
|
||||||
|
numStr := ""
|
||||||
|
unit := "B"
|
||||||
|
if num/int64(sizeGB) > 1 {
|
||||||
|
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
|
||||||
|
unit = "GB"
|
||||||
|
} else if num/int64(sizeMB) > 1 {
|
||||||
|
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
|
||||||
|
unit = "MB"
|
||||||
|
} else if num/int64(sizeKB) > 1 {
|
||||||
|
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
|
||||||
|
unit = "KB"
|
||||||
|
} else {
|
||||||
|
numStr = fmt.Sprintf("%d", num)
|
||||||
|
}
|
||||||
|
return numStr + " " + unit
|
||||||
|
}
|
||||||
|
|
||||||
|
func Interface2String(inter interface{}) string {
|
||||||
|
switch inter := inter.(type) {
|
||||||
|
case string:
|
||||||
|
return inter
|
||||||
|
case int:
|
||||||
|
return fmt.Sprintf("%d", inter)
|
||||||
|
case float64:
|
||||||
|
return fmt.Sprintf("%f", inter)
|
||||||
|
}
|
||||||
|
return "Not Implemented"
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnescapeHTML(x string) interface{} {
|
||||||
|
return template.HTML(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IntMax(a int, b int) int {
|
||||||
|
if a >= b {
|
||||||
|
return a
|
||||||
|
} else {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenRequestID() string {
|
||||||
|
return GetTimeString() + random.GetRandomNumberString(8)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetResponseID(c *gin.Context) string {
|
||||||
|
logID := c.GetString(RequestIdKey)
|
||||||
|
return fmt.Sprintf("chatcmpl-%s", logID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Max(a int, b int) int {
|
||||||
|
if a >= b {
|
||||||
|
return a
|
||||||
|
} else {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssignOrDefault(value string, defaultValue string) string {
|
||||||
|
if len(value) != 0 {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
func MessageWithRequestId(message string, id string) string {
|
||||||
|
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func String2Int(str string) int {
|
||||||
|
num, err := strconv.Atoi(str)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return num
|
||||||
|
}
|
5
common/helper/key.go
Normal file
5
common/helper/key.go
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
const (
|
||||||
|
RequestIdKey = "X-Oneapi-Request-Id"
|
||||||
|
)
|
15
common/helper/time.go
Normal file
15
common/helper/time.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetTimestamp() int64 {
|
||||||
|
return time.Now().Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetTimeString() string {
|
||||||
|
now := time.Now()
|
||||||
|
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
|
||||||
|
}
|
@ -3,6 +3,7 @@ package image
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"github.com/songquanpeng/one-api/common/client"
|
||||||
"image"
|
"image"
|
||||||
_ "image/gif"
|
_ "image/gif"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
@ -19,7 +20,7 @@ import (
|
|||||||
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
|
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
|
||||||
|
|
||||||
func IsImageUrl(url string) (bool, error) {
|
func IsImageUrl(url string) (bool, error) {
|
||||||
resp, err := http.Head(url)
|
resp, err := client.UserContentRequestHTTPClient.Head(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@ -34,7 +35,7 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
|
|||||||
if !isImage {
|
if !isImage {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := http.Get(url)
|
resp, err := client.UserContentRequestHTTPClient.Get(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package image_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"github.com/songquanpeng/one-api/common/client"
|
||||||
"image"
|
"image"
|
||||||
_ "image/gif"
|
_ "image/gif"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
@ -12,7 +13,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
img "one-api/common/image"
|
img "github.com/songquanpeng/one-api/common/image"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
_ "golang.org/x/image/webp"
|
_ "golang.org/x/image/webp"
|
||||||
@ -44,6 +45,11 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
client.Init()
|
||||||
|
m.Run()
|
||||||
|
}
|
||||||
|
|
||||||
func TestDecode(t *testing.T) {
|
func TestDecode(t *testing.T) {
|
||||||
// Bytes read: varies sometimes
|
// Bytes read: varies sometimes
|
||||||
// jpeg: 1063892
|
// jpeg: 1063892
|
||||||
|
@ -3,6 +3,8 @@ package common
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -22,7 +24,7 @@ func printHelp() {
|
|||||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func Init() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if *PrintVersion {
|
if *PrintVersion {
|
||||||
@ -37,9 +39,9 @@ func init() {
|
|||||||
|
|
||||||
if os.Getenv("SESSION_SECRET") != "" {
|
if os.Getenv("SESSION_SECRET") != "" {
|
||||||
if os.Getenv("SESSION_SECRET") == "random_string" {
|
if os.Getenv("SESSION_SECRET") == "random_string" {
|
||||||
SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
|
logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
|
||||||
} else {
|
} else {
|
||||||
SessionSecret = os.Getenv("SESSION_SECRET")
|
config.SessionSecret = os.Getenv("SESSION_SECRET")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if os.Getenv("SQLITE_PATH") != "" {
|
if os.Getenv("SQLITE_PATH") != "" {
|
||||||
@ -57,5 +59,6 @@ func init() {
|
|||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
logger.LogDir = *LogDir
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
100
common/logger.go
100
common/logger.go
@ -1,100 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
loggerINFO = "INFO"
|
|
||||||
loggerWarn = "WARN"
|
|
||||||
loggerError = "ERR"
|
|
||||||
)
|
|
||||||
|
|
||||||
const maxLogCount = 1000000
|
|
||||||
|
|
||||||
var logCount int
|
|
||||||
var setupLogLock sync.Mutex
|
|
||||||
var setupLogWorking bool
|
|
||||||
|
|
||||||
func SetupLogger() {
|
|
||||||
if *LogDir != "" {
|
|
||||||
ok := setupLogLock.TryLock()
|
|
||||||
if !ok {
|
|
||||||
log.Println("setup log is already working")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
setupLogLock.Unlock()
|
|
||||||
setupLogWorking = false
|
|
||||||
}()
|
|
||||||
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
|
|
||||||
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("failed to open log file")
|
|
||||||
}
|
|
||||||
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
|
|
||||||
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func SysLog(s string) {
|
|
||||||
t := time.Now()
|
|
||||||
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func SysError(s string) {
|
|
||||||
t := time.Now()
|
|
||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogInfo(ctx context.Context, msg string) {
|
|
||||||
logHelper(ctx, loggerINFO, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogWarn(ctx context.Context, msg string) {
|
|
||||||
logHelper(ctx, loggerWarn, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogError(ctx context.Context, msg string) {
|
|
||||||
logHelper(ctx, loggerError, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func logHelper(ctx context.Context, level string, msg string) {
|
|
||||||
writer := gin.DefaultErrorWriter
|
|
||||||
if level == loggerINFO {
|
|
||||||
writer = gin.DefaultWriter
|
|
||||||
}
|
|
||||||
id := ctx.Value(RequestIdKey)
|
|
||||||
now := time.Now()
|
|
||||||
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
|
||||||
logCount++ // we don't need accurate count, so no lock here
|
|
||||||
if logCount > maxLogCount && !setupLogWorking {
|
|
||||||
logCount = 0
|
|
||||||
setupLogWorking = true
|
|
||||||
go func() {
|
|
||||||
SetupLogger()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func FatalLog(v ...any) {
|
|
||||||
t := time.Now()
|
|
||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogQuota(quota int) string {
|
|
||||||
if DisplayInCurrencyEnabled {
|
|
||||||
return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf("%d 点额度", quota)
|
|
||||||
}
|
|
||||||
}
|
|
3
common/logger/constants.go
Normal file
3
common/logger/constants.go
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
var LogDir string
|
116
common/logger/logger.go
Normal file
116
common/logger/logger.go
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
loggerDEBUG = "DEBUG"
|
||||||
|
loggerINFO = "INFO"
|
||||||
|
loggerWarn = "WARN"
|
||||||
|
loggerError = "ERR"
|
||||||
|
)
|
||||||
|
|
||||||
|
var setupLogOnce sync.Once
|
||||||
|
|
||||||
|
func SetupLogger() {
|
||||||
|
setupLogOnce.Do(func() {
|
||||||
|
if LogDir != "" {
|
||||||
|
var logPath string
|
||||||
|
if config.OnlyOneLogFile {
|
||||||
|
logPath = filepath.Join(LogDir, "oneapi.log")
|
||||||
|
} else {
|
||||||
|
logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
|
||||||
|
}
|
||||||
|
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("failed to open log file")
|
||||||
|
}
|
||||||
|
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
|
||||||
|
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func SysLog(s string) {
|
||||||
|
t := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SysLogf(format string, a ...any) {
|
||||||
|
SysLog(fmt.Sprintf(format, a...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func SysError(s string) {
|
||||||
|
t := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SysErrorf(format string, a ...any) {
|
||||||
|
SysError(fmt.Sprintf(format, a...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Debug(ctx context.Context, msg string) {
|
||||||
|
if config.DebugEnabled {
|
||||||
|
logHelper(ctx, loggerDEBUG, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Info(ctx context.Context, msg string) {
|
||||||
|
logHelper(ctx, loggerINFO, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Warn(ctx context.Context, msg string) {
|
||||||
|
logHelper(ctx, loggerWarn, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Error(ctx context.Context, msg string) {
|
||||||
|
logHelper(ctx, loggerError, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Debugf(ctx context.Context, format string, a ...any) {
|
||||||
|
Debug(ctx, fmt.Sprintf(format, a...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Infof(ctx context.Context, format string, a ...any) {
|
||||||
|
Info(ctx, fmt.Sprintf(format, a...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Warnf(ctx context.Context, format string, a ...any) {
|
||||||
|
Warn(ctx, fmt.Sprintf(format, a...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Errorf(ctx context.Context, format string, a ...any) {
|
||||||
|
Error(ctx, fmt.Sprintf(format, a...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func logHelper(ctx context.Context, level string, msg string) {
|
||||||
|
writer := gin.DefaultErrorWriter
|
||||||
|
if level == loggerINFO {
|
||||||
|
writer = gin.DefaultWriter
|
||||||
|
}
|
||||||
|
id := ctx.Value(helper.RequestIdKey)
|
||||||
|
if id == nil {
|
||||||
|
id = helper.GenRequestID()
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||||
|
SetupLogger()
|
||||||
|
}
|
||||||
|
|
||||||
|
func FatalLog(v ...any) {
|
||||||
|
t := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
105
common/message/email.go
Normal file
105
common/message/email.go
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"net"
|
||||||
|
"net/smtp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func shouldAuth() bool {
|
||||||
|
return config.SMTPAccount != "" || config.SMTPToken != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func SendEmail(subject string, receiver string, content string) error {
|
||||||
|
if receiver == "" {
|
||||||
|
return fmt.Errorf("receiver is empty")
|
||||||
|
}
|
||||||
|
if config.SMTPFrom == "" { // for compatibility
|
||||||
|
config.SMTPFrom = config.SMTPAccount
|
||||||
|
}
|
||||||
|
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
|
||||||
|
|
||||||
|
// Extract domain from SMTPFrom
|
||||||
|
parts := strings.Split(config.SMTPFrom, "@")
|
||||||
|
var domain string
|
||||||
|
if len(parts) > 1 {
|
||||||
|
domain = parts[1]
|
||||||
|
}
|
||||||
|
// Generate a unique Message-ID
|
||||||
|
buf := make([]byte, 16)
|
||||||
|
_, err := rand.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
messageId := fmt.Sprintf("<%x@%s>", buf, domain)
|
||||||
|
|
||||||
|
mail := []byte(fmt.Sprintf("To: %s\r\n"+
|
||||||
|
"From: %s<%s>\r\n"+
|
||||||
|
"Subject: %s\r\n"+
|
||||||
|
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
|
||||||
|
"Date: %s\r\n"+
|
||||||
|
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
|
||||||
|
receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
|
||||||
|
|
||||||
|
auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
|
||||||
|
addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
|
||||||
|
to := strings.Split(receiver, ";")
|
||||||
|
|
||||||
|
if config.SMTPPort == 465 || !shouldAuth() {
|
||||||
|
// need advanced client
|
||||||
|
var conn net.Conn
|
||||||
|
var err error
|
||||||
|
if config.SMTPPort == 465 {
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
ServerName: config.SMTPServer,
|
||||||
|
}
|
||||||
|
conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
|
||||||
|
} else {
|
||||||
|
conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
client, err := smtp.NewClient(conn, config.SMTPServer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
if shouldAuth() {
|
||||||
|
if err = client.Auth(auth); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err = client.Mail(config.SMTPFrom); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
receiverEmails := strings.Split(receiver, ";")
|
||||||
|
for _, receiver := range receiverEmails {
|
||||||
|
if err = client.Rcpt(receiver); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w, err := client.Data()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = w.Write(mail)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = w.Close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
22
common/message/main.go
Normal file
22
common/message/main.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ByAll = "all"
|
||||||
|
ByEmail = "email"
|
||||||
|
ByMessagePusher = "message_pusher"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Notify(by string, title string, description string, content string) error {
|
||||||
|
if by == ByEmail {
|
||||||
|
return SendEmail(title, config.RootUserEmail, content)
|
||||||
|
}
|
||||||
|
if by == ByMessagePusher {
|
||||||
|
return SendMessage(title, description, content)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unknown notify method: %s", by)
|
||||||
|
}
|
53
common/message/message-pusher.go
Normal file
53
common/message/message-pusher.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type request struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
Channel string `json:"channel"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type response struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func SendMessage(title string, description string, content string) error {
|
||||||
|
if config.MessagePusherAddress == "" {
|
||||||
|
return errors.New("message pusher address is not set")
|
||||||
|
}
|
||||||
|
req := request{
|
||||||
|
Title: title,
|
||||||
|
Description: description,
|
||||||
|
Content: content,
|
||||||
|
Token: config.MessagePusherToken,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resp, err := http.Post(config.MessagePusherAddress,
|
||||||
|
"application/json", bytes.NewBuffer(data))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var res response
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&res)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !res.Success {
|
||||||
|
return errors.New(res.Message)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -1,161 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var DalleSizeRatios = map[string]map[string]float64{
|
|
||||||
"dall-e-2": {
|
|
||||||
"256x256": 1,
|
|
||||||
"512x512": 1.125,
|
|
||||||
"1024x1024": 1.25,
|
|
||||||
},
|
|
||||||
"dall-e-3": {
|
|
||||||
"1024x1024": 1,
|
|
||||||
"1024x1792": 2,
|
|
||||||
"1792x1024": 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var DalleGenerationImageAmounts = map[string][2]int{
|
|
||||||
"dall-e-2": {1, 10},
|
|
||||||
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
|
|
||||||
}
|
|
||||||
|
|
||||||
var DalleImagePromptLengthLimitations = map[string]int{
|
|
||||||
"dall-e-2": 1000,
|
|
||||||
"dall-e-3": 4000,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModelRatio
|
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
|
||||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
|
|
||||||
// https://openai.com/pricing
|
|
||||||
// TODO: when a new api is enabled, check the pricing here
|
|
||||||
// 1 === $0.002 / 1K tokens
|
|
||||||
// 1 === ¥0.014 / 1k tokens
|
|
||||||
var ModelRatio = map[string]float64{
|
|
||||||
"gpt-4": 15,
|
|
||||||
"gpt-4-0314": 15,
|
|
||||||
"gpt-4-0613": 15,
|
|
||||||
"gpt-4-32k": 30,
|
|
||||||
"gpt-4-32k-0314": 30,
|
|
||||||
"gpt-4-32k-0613": 30,
|
|
||||||
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
|
||||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
|
||||||
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
|
|
||||||
"gpt-3.5-turbo-0301": 0.75,
|
|
||||||
"gpt-3.5-turbo-0613": 0.75,
|
|
||||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
|
||||||
"gpt-3.5-turbo-16k-0613": 1.5,
|
|
||||||
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
|
||||||
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
|
||||||
"davinci-002": 1, // $0.002 / 1K tokens
|
|
||||||
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
|
||||||
"text-ada-001": 0.2,
|
|
||||||
"text-babbage-001": 0.25,
|
|
||||||
"text-curie-001": 1,
|
|
||||||
"text-davinci-002": 10,
|
|
||||||
"text-davinci-003": 10,
|
|
||||||
"text-davinci-edit-001": 10,
|
|
||||||
"code-davinci-edit-001": 10,
|
|
||||||
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
|
|
||||||
"tts-1": 7.5, // $0.015 / 1K characters
|
|
||||||
"tts-1-1106": 7.5,
|
|
||||||
"tts-1-hd": 15, // $0.030 / 1K characters
|
|
||||||
"tts-1-hd-1106": 15,
|
|
||||||
"davinci": 10,
|
|
||||||
"curie": 10,
|
|
||||||
"babbage": 10,
|
|
||||||
"ada": 10,
|
|
||||||
"text-embedding-ada-002": 0.05,
|
|
||||||
"text-search-ada-doc-001": 10,
|
|
||||||
"text-moderation-stable": 0.1,
|
|
||||||
"text-moderation-latest": 0.1,
|
|
||||||
"dall-e-2": 8, // $0.016 - $0.020 / image
|
|
||||||
"dall-e-3": 20, // $0.040 - $0.120 / image
|
|
||||||
"claude-instant-1": 0.815, // $1.63 / 1M tokens
|
|
||||||
"claude-2": 5.51, // $11.02 / 1M tokens
|
|
||||||
"claude-2.0": 5.51, // $11.02 / 1M tokens
|
|
||||||
"claude-2.1": 5.51, // $11.02 / 1M tokens
|
|
||||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
|
||||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
|
||||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
|
||||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
|
||||||
"PaLM-2": 1,
|
|
||||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
|
||||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
|
||||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
|
||||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
|
||||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
|
||||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
|
||||||
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
|
|
||||||
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
|
|
||||||
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
|
|
||||||
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
|
|
||||||
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
|
||||||
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
|
||||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
|
||||||
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
|
||||||
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
|
||||||
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
|
||||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
|
||||||
}
|
|
||||||
|
|
||||||
func ModelRatio2JSONString() string {
|
|
||||||
jsonBytes, err := json.Marshal(ModelRatio)
|
|
||||||
if err != nil {
|
|
||||||
SysError("error marshalling model ratio: " + err.Error())
|
|
||||||
}
|
|
||||||
return string(jsonBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func UpdateModelRatioByJSONString(jsonStr string) error {
|
|
||||||
ModelRatio = make(map[string]float64)
|
|
||||||
return json.Unmarshal([]byte(jsonStr), &ModelRatio)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetModelRatio(name string) float64 {
|
|
||||||
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
|
|
||||||
name = strings.TrimSuffix(name, "-internet")
|
|
||||||
}
|
|
||||||
ratio, ok := ModelRatio[name]
|
|
||||||
if !ok {
|
|
||||||
SysError("model ratio not found: " + name)
|
|
||||||
return 30
|
|
||||||
}
|
|
||||||
return ratio
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetCompletionRatio(name string) float64 {
|
|
||||||
if strings.HasPrefix(name, "gpt-3.5") {
|
|
||||||
if strings.HasSuffix(name, "1106") {
|
|
||||||
return 2
|
|
||||||
}
|
|
||||||
if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
|
|
||||||
// TODO: clear this after 2023-12-11
|
|
||||||
now := time.Now()
|
|
||||||
// https://platform.openai.com/docs/models/continuous-model-upgrades
|
|
||||||
// if after 2023-12-11, use 2
|
|
||||||
if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
|
|
||||||
return 2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 1.333333
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(name, "gpt-4") {
|
|
||||||
if strings.HasSuffix(name, "preview") {
|
|
||||||
return 3
|
|
||||||
}
|
|
||||||
return 2
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(name, "claude-instant-1") {
|
|
||||||
return 3.38
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(name, "claude-2") {
|
|
||||||
return 2.965517
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
}
|
|
52
common/network/ip.go
Normal file
52
common/network/ip.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func splitSubnets(subnets string) []string {
|
||||||
|
res := strings.Split(subnets, ",")
|
||||||
|
for i := 0; i < len(res); i++ {
|
||||||
|
res[i] = strings.TrimSpace(res[i])
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidSubnet(subnet string) error {
|
||||||
|
_, _, err := net.ParseCIDR(subnet)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse subnet: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIpInSubnet(ctx context.Context, ip string, subnet string) bool {
|
||||||
|
_, ipNet, err := net.ParseCIDR(subnet)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf(ctx, "failed to parse subnet: %s", err.Error())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return ipNet.Contains(net.ParseIP(ip))
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsValidSubnets(subnets string) error {
|
||||||
|
for _, subnet := range splitSubnets(subnets) {
|
||||||
|
if err := isValidSubnet(subnet); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool {
|
||||||
|
for _, subnet := range splitSubnets(subnets) {
|
||||||
|
if isIpInSubnet(ctx, ip, subnet) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
19
common/network/ip_test.go
Normal file
19
common/network/ip_test.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsIpInSubnet(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ip1 := "192.168.0.5"
|
||||||
|
ip2 := "125.216.250.89"
|
||||||
|
subnet := "192.168.0.0/24"
|
||||||
|
Convey("TestIsIpInSubnet", t, func() {
|
||||||
|
So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue)
|
||||||
|
So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse)
|
||||||
|
})
|
||||||
|
}
|
61
common/random/main.go
Normal file
61
common/random/main.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package random
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetUUID() string {
|
||||||
|
code := uuid.New().String()
|
||||||
|
code = strings.Replace(code, "-", "", -1)
|
||||||
|
return code
|
||||||
|
}
|
||||||
|
|
||||||
|
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
|
const keyNumbers = "0123456789"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateKey() string {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
key := make([]byte, 48)
|
||||||
|
for i := 0; i < 16; i++ {
|
||||||
|
key[i] = keyChars[rand.Intn(len(keyChars))]
|
||||||
|
}
|
||||||
|
uuid_ := GetUUID()
|
||||||
|
for i := 0; i < 32; i++ {
|
||||||
|
c := uuid_[i]
|
||||||
|
if i%2 == 0 && c >= 'a' && c <= 'z' {
|
||||||
|
c = c - 'a' + 'A'
|
||||||
|
}
|
||||||
|
key[i+16] = c
|
||||||
|
}
|
||||||
|
return string(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetRandomString(length int) string {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
key := make([]byte, length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
key[i] = keyChars[rand.Intn(len(keyChars))]
|
||||||
|
}
|
||||||
|
return string(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetRandomNumberString(length int) string {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
key := make([]byte, length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
|
||||||
|
}
|
||||||
|
return string(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandRange returns a random number between min and max (max is not included)
|
||||||
|
func RandRange(min, max int) int {
|
||||||
|
return min + rand.Intn(max-min)
|
||||||
|
}
|
@ -3,6 +3,7 @@ package common
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -14,18 +15,18 @@ var RedisEnabled = true
|
|||||||
func InitRedisClient() (err error) {
|
func InitRedisClient() (err error) {
|
||||||
if os.Getenv("REDIS_CONN_STRING") == "" {
|
if os.Getenv("REDIS_CONN_STRING") == "" {
|
||||||
RedisEnabled = false
|
RedisEnabled = false
|
||||||
SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
|
logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if os.Getenv("SYNC_FREQUENCY") == "" {
|
if os.Getenv("SYNC_FREQUENCY") == "" {
|
||||||
RedisEnabled = false
|
RedisEnabled = false
|
||||||
SysLog("SYNC_FREQUENCY not set, Redis is disabled")
|
logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
SysLog("Redis is enabled")
|
logger.SysLog("Redis is enabled")
|
||||||
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
|
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
FatalLog("failed to parse Redis connection string: " + err.Error())
|
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
|
||||||
}
|
}
|
||||||
RDB = redis.NewClient(opt)
|
RDB = redis.NewClient(opt)
|
||||||
|
|
||||||
@ -34,7 +35,7 @@ func InitRedisClient() (err error) {
|
|||||||
|
|
||||||
_, err = RDB.Ping(ctx).Result()
|
_, err = RDB.Ping(ctx).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
FatalLog("Redis ping test failed: " + err.Error())
|
logger.FatalLog("Redis ping test failed: " + err.Error())
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -42,7 +43,7 @@ func InitRedisClient() (err error) {
|
|||||||
func ParseRedisOption() *redis.Options {
|
func ParseRedisOption() *redis.Options {
|
||||||
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
|
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
FatalLog("failed to parse Redis connection string: " + err.Error())
|
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
|
||||||
}
|
}
|
||||||
return opt
|
return opt
|
||||||
}
|
}
|
||||||
|
29
common/render/render.go
Normal file
29
common/render/render.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package render
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func StringData(c *gin.Context, str string) {
|
||||||
|
str = strings.TrimPrefix(str, "data: ")
|
||||||
|
str = strings.TrimSuffix(str, "\r")
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + str})
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func ObjectData(c *gin.Context, object interface{}) error {
|
||||||
|
jsonData, err := json.Marshal(object)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error marshalling object: %w", err)
|
||||||
|
}
|
||||||
|
StringData(c, string(jsonData))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Done(c *gin.Context) {
|
||||||
|
StringData(c, "[DONE]")
|
||||||
|
}
|
212
common/utils.go
212
common/utils.go
@ -2,215 +2,13 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/google/uuid"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"html/template"
|
|
||||||
"log"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"runtime"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func OpenBrowser(url string) {
|
func LogQuota(quota int64) string {
|
||||||
var err error
|
if config.DisplayInCurrencyEnabled {
|
||||||
|
return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit)
|
||||||
switch runtime.GOOS {
|
|
||||||
case "linux":
|
|
||||||
err = exec.Command("xdg-open", url).Start()
|
|
||||||
case "windows":
|
|
||||||
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
|
||||||
case "darwin":
|
|
||||||
err = exec.Command("open", url).Start()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetIp() (ip string) {
|
|
||||||
ips, err := net.InterfaceAddrs()
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, a := range ips {
|
|
||||||
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
|
|
||||||
if ipNet.IP.To4() != nil {
|
|
||||||
ip = ipNet.IP.String()
|
|
||||||
if strings.HasPrefix(ip, "10") {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(ip, "172") {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(ip, "192.168") {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ip = ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var sizeKB = 1024
|
|
||||||
var sizeMB = sizeKB * 1024
|
|
||||||
var sizeGB = sizeMB * 1024
|
|
||||||
|
|
||||||
func Bytes2Size(num int64) string {
|
|
||||||
numStr := ""
|
|
||||||
unit := "B"
|
|
||||||
if num/int64(sizeGB) > 1 {
|
|
||||||
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
|
|
||||||
unit = "GB"
|
|
||||||
} else if num/int64(sizeMB) > 1 {
|
|
||||||
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
|
|
||||||
unit = "MB"
|
|
||||||
} else if num/int64(sizeKB) > 1 {
|
|
||||||
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
|
|
||||||
unit = "KB"
|
|
||||||
} else {
|
} else {
|
||||||
numStr = fmt.Sprintf("%d", num)
|
return fmt.Sprintf("%d 点额度", quota)
|
||||||
}
|
|
||||||
return numStr + " " + unit
|
|
||||||
}
|
|
||||||
|
|
||||||
func Seconds2Time(num int) (time string) {
|
|
||||||
if num/31104000 > 0 {
|
|
||||||
time += strconv.Itoa(num/31104000) + " 年 "
|
|
||||||
num %= 31104000
|
|
||||||
}
|
|
||||||
if num/2592000 > 0 {
|
|
||||||
time += strconv.Itoa(num/2592000) + " 个月 "
|
|
||||||
num %= 2592000
|
|
||||||
}
|
|
||||||
if num/86400 > 0 {
|
|
||||||
time += strconv.Itoa(num/86400) + " 天 "
|
|
||||||
num %= 86400
|
|
||||||
}
|
|
||||||
if num/3600 > 0 {
|
|
||||||
time += strconv.Itoa(num/3600) + " 小时 "
|
|
||||||
num %= 3600
|
|
||||||
}
|
|
||||||
if num/60 > 0 {
|
|
||||||
time += strconv.Itoa(num/60) + " 分钟 "
|
|
||||||
num %= 60
|
|
||||||
}
|
|
||||||
time += strconv.Itoa(num) + " 秒"
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func Interface2String(inter interface{}) string {
|
|
||||||
switch inter.(type) {
|
|
||||||
case string:
|
|
||||||
return inter.(string)
|
|
||||||
case int:
|
|
||||||
return fmt.Sprintf("%d", inter.(int))
|
|
||||||
case float64:
|
|
||||||
return fmt.Sprintf("%f", inter.(float64))
|
|
||||||
}
|
|
||||||
return "Not Implemented"
|
|
||||||
}
|
|
||||||
|
|
||||||
func UnescapeHTML(x string) interface{} {
|
|
||||||
return template.HTML(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
func IntMax(a int, b int) int {
|
|
||||||
if a >= b {
|
|
||||||
return a
|
|
||||||
} else {
|
|
||||||
return b
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUUID() string {
|
|
||||||
code := uuid.New().String()
|
|
||||||
code = strings.Replace(code, "-", "", -1)
|
|
||||||
return code
|
|
||||||
}
|
|
||||||
|
|
||||||
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
}
|
|
||||||
|
|
||||||
func GenerateKey() string {
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
key := make([]byte, 48)
|
|
||||||
for i := 0; i < 16; i++ {
|
|
||||||
key[i] = keyChars[rand.Intn(len(keyChars))]
|
|
||||||
}
|
|
||||||
uuid_ := GetUUID()
|
|
||||||
for i := 0; i < 32; i++ {
|
|
||||||
c := uuid_[i]
|
|
||||||
if i%2 == 0 && c >= 'a' && c <= 'z' {
|
|
||||||
c = c - 'a' + 'A'
|
|
||||||
}
|
|
||||||
key[i+16] = c
|
|
||||||
}
|
|
||||||
return string(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetRandomString(length int) string {
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
key := make([]byte, length)
|
|
||||||
for i := 0; i < length; i++ {
|
|
||||||
key[i] = keyChars[rand.Intn(len(keyChars))]
|
|
||||||
}
|
|
||||||
return string(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetTimestamp() int64 {
|
|
||||||
return time.Now().Unix()
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetTimeString() string {
|
|
||||||
now := time.Now()
|
|
||||||
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Max(a int, b int) int {
|
|
||||||
if a >= b {
|
|
||||||
return a
|
|
||||||
} else {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetOrDefault(env string, defaultValue int) int {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
num, err := strconv.Atoi(os.Getenv(env))
|
|
||||||
if err != nil {
|
|
||||||
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return num
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetOrDefaultString(env string, defaultValue string) string {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return os.Getenv(env)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MessageWithRequestId(message string, id string) string {
|
|
||||||
return fmt.Sprintf("%s (request id: %s)", message, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func String2Int(str string) int {
|
|
||||||
num, err := strconv.Atoi(str)
|
|
||||||
if err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return num
|
|
||||||
}
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -7,9 +7,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
|
"github.com/songquanpeng/one-api/controller"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -30,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
|||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, errors.New("无效的参数")
|
return nil, errors.New("无效的参数")
|
||||||
}
|
}
|
||||||
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
|
values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code}
|
||||||
jsonData, err := json.Marshal(values)
|
jsonData, err := json.Marshal(values)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -46,7 +49,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
|||||||
}
|
}
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog(err.Error())
|
logger.SysLog(err.Error())
|
||||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
@ -62,7 +65,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
|||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||||
res2, err := client.Do(req)
|
res2, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog(err.Error())
|
logger.SysLog(err.Error())
|
||||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||||
}
|
}
|
||||||
defer res2.Body.Close()
|
defer res2.Body.Close()
|
||||||
@ -93,7 +96,7 @@ func GitHubOAuth(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !common.GitHubOAuthEnabled {
|
if !config.GitHubOAuthEnabled {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||||
@ -122,7 +125,7 @@ func GitHubOAuth(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if common.RegisterEnabled {
|
if config.RegisterEnabled {
|
||||||
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||||
if githubUser.Name != "" {
|
if githubUser.Name != "" {
|
||||||
user.DisplayName = githubUser.Name
|
user.DisplayName = githubUser.Name
|
||||||
@ -130,8 +133,8 @@ func GitHubOAuth(c *gin.Context) {
|
|||||||
user.DisplayName = "GitHub User"
|
user.DisplayName = "GitHub User"
|
||||||
}
|
}
|
||||||
user.Email = githubUser.Email
|
user.Email = githubUser.Email
|
||||||
user.Role = common.RoleCommonUser
|
user.Role = model.RoleCommonUser
|
||||||
user.Status = common.UserStatusEnabled
|
user.Status = model.UserStatusEnabled
|
||||||
|
|
||||||
if err := user.Insert(0); err != nil {
|
if err := user.Insert(0); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@ -149,18 +152,18 @@ func GitHubOAuth(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.Status != common.UserStatusEnabled {
|
if user.Status != model.UserStatusEnabled {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "用户已被封禁",
|
"message": "用户已被封禁",
|
||||||
"success": false,
|
"success": false,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
setupLogin(&user, c)
|
controller.SetupLogin(&user, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GitHubBind(c *gin.Context) {
|
func GitHubBind(c *gin.Context) {
|
||||||
if !common.GitHubOAuthEnabled {
|
if !config.GitHubOAuthEnabled {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||||
@ -216,7 +219,7 @@ func GitHubBind(c *gin.Context) {
|
|||||||
|
|
||||||
func GenerateOAuthCode(c *gin.Context) {
|
func GenerateOAuthCode(c *gin.Context) {
|
||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
state := common.GetRandomString(12)
|
state := random.GetRandomString(12)
|
||||||
session.Set("oauth_state", state)
|
session.Set("oauth_state", state)
|
||||||
err := session.Save()
|
err := session.Save()
|
||||||
if err != nil {
|
if err != nil {
|
200
controller/auth/lark.go
Normal file
200
controller/auth/lark.go
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/controller"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LarkOAuthResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LarkUser struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
OpenID string `json:"open_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLarkUserInfoByCode(code string) (*LarkUser, error) {
|
||||||
|
if code == "" {
|
||||||
|
return nil, errors.New("无效的参数")
|
||||||
|
}
|
||||||
|
values := map[string]string{
|
||||||
|
"client_id": config.LarkClientId,
|
||||||
|
"client_secret": config.LarkClientSecret,
|
||||||
|
"code": code,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress),
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(values)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
client := http.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysLog(err.Error())
|
||||||
|
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
var oAuthResponse LarkOAuthResponse
|
||||||
|
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||||
|
res2, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysLog(err.Error())
|
||||||
|
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
|
||||||
|
}
|
||||||
|
var larkUser LarkUser
|
||||||
|
err = json.NewDecoder(res2.Body).Decode(&larkUser)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &larkUser, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LarkOAuth(c *gin.Context) {
|
||||||
|
session := sessions.Default(c)
|
||||||
|
state := c.Query("state")
|
||||||
|
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "state is empty or not same",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
username := session.Get("username")
|
||||||
|
if username != nil {
|
||||||
|
LarkBind(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := c.Query("code")
|
||||||
|
larkUser, err := getLarkUserInfoByCode(code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user := model.User{
|
||||||
|
LarkId: larkUser.OpenID,
|
||||||
|
}
|
||||||
|
if model.IsLarkIdAlreadyTaken(user.LarkId) {
|
||||||
|
err := user.FillUserByLarkId()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if config.RegisterEnabled {
|
||||||
|
user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||||
|
if larkUser.Name != "" {
|
||||||
|
user.DisplayName = larkUser.Name
|
||||||
|
} else {
|
||||||
|
user.DisplayName = "Lark User"
|
||||||
|
}
|
||||||
|
user.Role = model.RoleCommonUser
|
||||||
|
user.Status = model.UserStatusEnabled
|
||||||
|
|
||||||
|
if err := user.Insert(0); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "管理员关闭了新用户注册",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Status != model.UserStatusEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "用户已被封禁",
|
||||||
|
"success": false,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
controller.SetupLogin(&user, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LarkBind(c *gin.Context) {
|
||||||
|
code := c.Query("code")
|
||||||
|
larkUser, err := getLarkUserInfoByCode(code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user := model.User{
|
||||||
|
LarkId: larkUser.OpenID,
|
||||||
|
}
|
||||||
|
if model.IsLarkIdAlreadyTaken(user.LarkId) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "该飞书账户已被绑定",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
session := sessions.Default(c)
|
||||||
|
id := session.Get("id")
|
||||||
|
// id := c.GetInt("id") // critical bug!
|
||||||
|
user.Id = id.(int)
|
||||||
|
err = user.FillUserById()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user.LarkId = larkUser.OpenID
|
||||||
|
err = user.Update(false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "bind",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
225
controller/auth/oidc.go
Normal file
225
controller/auth/oidc.go
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/controller"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OidcResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcUser struct {
|
||||||
|
OpenID string `json:"sub"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
PreferredUsername string `json:"preferred_username"`
|
||||||
|
Picture string `json:"picture"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||||
|
if code == "" {
|
||||||
|
return nil, errors.New("无效的参数")
|
||||||
|
}
|
||||||
|
values := map[string]string{
|
||||||
|
"client_id": config.OidcClientId,
|
||||||
|
"client_secret": config.OidcClientSecret,
|
||||||
|
"code": code,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress),
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(values)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
client := http.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysLog(err.Error())
|
||||||
|
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
var oidcResponse OidcResponse
|
||||||
|
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
||||||
|
res2, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysLog(err.Error())
|
||||||
|
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||||
|
}
|
||||||
|
var oidcUser OidcUser
|
||||||
|
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &oidcUser, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func OidcAuth(c *gin.Context) {
|
||||||
|
session := sessions.Default(c)
|
||||||
|
state := c.Query("state")
|
||||||
|
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "state is empty or not same",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
username := session.Get("username")
|
||||||
|
if username != nil {
|
||||||
|
OidcBind(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !config.OidcEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := c.Query("code")
|
||||||
|
oidcUser, err := getOidcUserInfoByCode(code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user := model.User{
|
||||||
|
OidcId: oidcUser.OpenID,
|
||||||
|
}
|
||||||
|
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||||
|
err := user.FillUserByOidcId()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if config.RegisterEnabled {
|
||||||
|
user.Email = oidcUser.Email
|
||||||
|
if oidcUser.PreferredUsername != "" {
|
||||||
|
user.Username = oidcUser.PreferredUsername
|
||||||
|
} else {
|
||||||
|
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||||
|
}
|
||||||
|
if oidcUser.Name != "" {
|
||||||
|
user.DisplayName = oidcUser.Name
|
||||||
|
} else {
|
||||||
|
user.DisplayName = "OIDC User"
|
||||||
|
}
|
||||||
|
err := user.Insert(0)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "管理员关闭了新用户注册",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Status != model.UserStatusEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "用户已被封禁",
|
||||||
|
"success": false,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
controller.SetupLogin(&user, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func OidcBind(c *gin.Context) {
|
||||||
|
if !config.OidcEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := c.Query("code")
|
||||||
|
oidcUser, err := getOidcUserInfoByCode(code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user := model.User{
|
||||||
|
OidcId: oidcUser.OpenID,
|
||||||
|
}
|
||||||
|
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "该 OIDC 账户已被绑定",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
session := sessions.Default(c)
|
||||||
|
id := session.Get("id")
|
||||||
|
// id := c.GetInt("id") // critical bug!
|
||||||
|
user.Id = id.(int)
|
||||||
|
err = user.FillUserById()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user.OidcId = oidcUser.OpenID
|
||||||
|
err = user.Update(false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "bind",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
@ -1,13 +1,15 @@
|
|||||||
package controller
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/controller"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -22,11 +24,11 @@ func getWeChatIdByCode(code string) (string, error) {
|
|||||||
if code == "" {
|
if code == "" {
|
||||||
return "", errors.New("无效的参数")
|
return "", errors.New("无效的参数")
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil)
|
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", common.WeChatServerToken)
|
req.Header.Set("Authorization", config.WeChatServerToken)
|
||||||
client := http.Client{
|
client := http.Client{
|
||||||
Timeout: 5 * time.Second,
|
Timeout: 5 * time.Second,
|
||||||
}
|
}
|
||||||
@ -50,7 +52,7 @@ func getWeChatIdByCode(code string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WeChatAuth(c *gin.Context) {
|
func WeChatAuth(c *gin.Context) {
|
||||||
if !common.WeChatAuthEnabled {
|
if !config.WeChatAuthEnabled {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "管理员未开启通过微信登录以及注册",
|
"message": "管理员未开启通过微信登录以及注册",
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -79,11 +81,11 @@ func WeChatAuth(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if common.RegisterEnabled {
|
if config.RegisterEnabled {
|
||||||
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
|
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||||
user.DisplayName = "WeChat User"
|
user.DisplayName = "WeChat User"
|
||||||
user.Role = common.RoleCommonUser
|
user.Role = model.RoleCommonUser
|
||||||
user.Status = common.UserStatusEnabled
|
user.Status = model.UserStatusEnabled
|
||||||
|
|
||||||
if err := user.Insert(0); err != nil {
|
if err := user.Insert(0); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@ -101,18 +103,18 @@ func WeChatAuth(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.Status != common.UserStatusEnabled {
|
if user.Status != model.UserStatusEnabled {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "用户已被封禁",
|
"message": "用户已被封禁",
|
||||||
"success": false,
|
"success": false,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
setupLogin(&user, c)
|
controller.SetupLogin(&user, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func WeChatBind(c *gin.Context) {
|
func WeChatBind(c *gin.Context) {
|
||||||
if !common.WeChatAuthEnabled {
|
if !config.WeChatAuthEnabled {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "管理员未开启通过微信登录以及注册",
|
"message": "管理员未开启通过微信登录以及注册",
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -135,7 +137,7 @@ func WeChatBind(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
id := c.GetInt("id")
|
id := c.GetInt(ctxkey.Id)
|
||||||
user := model.User{
|
user := model.User{
|
||||||
Id: id,
|
Id: id,
|
||||||
}
|
}
|
@ -2,44 +2,50 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"one-api/model"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetSubscription(c *gin.Context) {
|
func GetSubscription(c *gin.Context) {
|
||||||
var remainQuota int
|
var remainQuota int64
|
||||||
var usedQuota int
|
var usedQuota int64
|
||||||
var err error
|
var err error
|
||||||
var token *model.Token
|
var token *model.Token
|
||||||
var expiredTime int64
|
var expiredTime int64
|
||||||
if common.DisplayTokenStatEnabled {
|
if config.DisplayTokenStatEnabled {
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt(ctxkey.TokenId)
|
||||||
token, err = model.GetTokenById(tokenId)
|
token, err = model.GetTokenById(tokenId)
|
||||||
|
if err == nil {
|
||||||
expiredTime = token.ExpiredTime
|
expiredTime = token.ExpiredTime
|
||||||
remainQuota = token.RemainQuota
|
remainQuota = token.RemainQuota
|
||||||
usedQuota = token.UsedQuota
|
usedQuota = token.UsedQuota
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt(ctxkey.Id)
|
||||||
remainQuota, err = model.GetUserQuota(userId)
|
remainQuota, err = model.GetUserQuota(userId)
|
||||||
|
if err != nil {
|
||||||
usedQuota, err = model.GetUserUsedQuota(userId)
|
usedQuota, err = model.GetUserUsedQuota(userId)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if expiredTime <= 0 {
|
if expiredTime <= 0 {
|
||||||
expiredTime = 0
|
expiredTime = 0
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openAIError := OpenAIError{
|
Error := relaymodel.Error{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
Type: "upstream_error",
|
Type: "upstream_error",
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"error": openAIError,
|
"error": Error,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
quota := remainQuota + usedQuota
|
quota := remainQuota + usedQuota
|
||||||
amount := float64(quota)
|
amount := float64(quota)
|
||||||
if common.DisplayInCurrencyEnabled {
|
if config.DisplayInCurrencyEnabled {
|
||||||
amount /= common.QuotaPerUnit
|
amount /= config.QuotaPerUnit
|
||||||
}
|
}
|
||||||
if token != nil && token.UnlimitedQuota {
|
if token != nil && token.UnlimitedQuota {
|
||||||
amount = 100000000
|
amount = 100000000
|
||||||
@ -57,30 +63,30 @@ func GetSubscription(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetUsage(c *gin.Context) {
|
func GetUsage(c *gin.Context) {
|
||||||
var quota int
|
var quota int64
|
||||||
var err error
|
var err error
|
||||||
var token *model.Token
|
var token *model.Token
|
||||||
if common.DisplayTokenStatEnabled {
|
if config.DisplayTokenStatEnabled {
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt(ctxkey.TokenId)
|
||||||
token, err = model.GetTokenById(tokenId)
|
token, err = model.GetTokenById(tokenId)
|
||||||
quota = token.UsedQuota
|
quota = token.UsedQuota
|
||||||
} else {
|
} else {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt(ctxkey.Id)
|
||||||
quota, err = model.GetUserUsedQuota(userId)
|
quota, err = model.GetUserUsedQuota(userId)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openAIError := OpenAIError{
|
Error := relaymodel.Error{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
Type: "one_api_error",
|
Type: "one_api_error",
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"error": openAIError,
|
"error": Error,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
amount := float64(quota)
|
amount := float64(quota)
|
||||||
if common.DisplayInCurrencyEnabled {
|
if config.DisplayInCurrencyEnabled {
|
||||||
amount /= common.QuotaPerUnit
|
amount /= config.QuotaPerUnit
|
||||||
}
|
}
|
||||||
usage := OpenAIUsageResponse{
|
usage := OpenAIUsageResponse{
|
||||||
Object: "list",
|
Object: "list",
|
||||||
|
@ -4,10 +4,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common/client"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
|
"github.com/songquanpeng/one-api/monitor"
|
||||||
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -77,6 +81,26 @@ type APGC2DGPTUsageResponse struct {
|
|||||||
TotalUsed float64 `json:"total_used"`
|
TotalUsed float64 `json:"total_used"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SiliconFlowUsageResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Status bool `json:"status"`
|
||||||
|
Data struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Image string `json:"image"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
IsAdmin bool `json:"isAdmin"`
|
||||||
|
Balance string `json:"balance"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Introduction string `json:"introduction"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
ChargeBalance string `json:"chargeBalance"`
|
||||||
|
TotalBalance string `json:"totalBalance"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
// GetAuthHeader get auth header
|
// GetAuthHeader get auth header
|
||||||
func GetAuthHeader(token string) http.Header {
|
func GetAuthHeader(token string) http.Header {
|
||||||
h := http.Header{}
|
h := http.Header{}
|
||||||
@ -92,7 +116,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
|||||||
for k := range headers {
|
for k := range headers {
|
||||||
req.Header.Add(k, headers.Get(k))
|
req.Header.Add(k, headers.Get(k))
|
||||||
}
|
}
|
||||||
res, err := httpClient.Do(req)
|
res, err := client.HTTPClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -199,30 +223,54 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
|
|||||||
return response.TotalAvailable, nil
|
return response.TotalAvailable, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
|
||||||
|
url := "https://api.siliconflow.cn/v1/user/info"
|
||||||
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
response := SiliconFlowUsageResponse{}
|
||||||
|
err = json.Unmarshal(body, &response)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if response.Code != 20000 {
|
||||||
|
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
|
||||||
|
}
|
||||||
|
balance, err := strconv.ParseFloat(response.Data.Balance, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
channel.UpdateBalance(balance)
|
||||||
|
return balance, nil
|
||||||
|
}
|
||||||
|
|
||||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
baseURL := channeltype.ChannelBaseURLs[channel.Type]
|
||||||
if channel.GetBaseURL() == "" {
|
if channel.GetBaseURL() == "" {
|
||||||
channel.BaseURL = &baseURL
|
channel.BaseURL = &baseURL
|
||||||
}
|
}
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeOpenAI:
|
case channeltype.OpenAI:
|
||||||
if channel.GetBaseURL() != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
case common.ChannelTypeAzure:
|
case channeltype.Azure:
|
||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
case common.ChannelTypeCustom:
|
case channeltype.Custom:
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
case common.ChannelTypeCloseAI:
|
case channeltype.CloseAI:
|
||||||
return updateChannelCloseAIBalance(channel)
|
return updateChannelCloseAIBalance(channel)
|
||||||
case common.ChannelTypeOpenAISB:
|
case channeltype.OpenAISB:
|
||||||
return updateChannelOpenAISBBalance(channel)
|
return updateChannelOpenAISBBalance(channel)
|
||||||
case common.ChannelTypeAIProxy:
|
case channeltype.AIProxy:
|
||||||
return updateChannelAIProxyBalance(channel)
|
return updateChannelAIProxyBalance(channel)
|
||||||
case common.ChannelTypeAPI2GPT:
|
case channeltype.API2GPT:
|
||||||
return updateChannelAPI2GPTBalance(channel)
|
return updateChannelAPI2GPTBalance(channel)
|
||||||
case common.ChannelTypeAIGC2D:
|
case channeltype.AIGC2D:
|
||||||
return updateChannelAIGC2DBalance(channel)
|
return updateChannelAIGC2DBalance(channel)
|
||||||
|
case channeltype.SiliconFlow:
|
||||||
|
return updateChannelSiliconFlowBalance(channel)
|
||||||
default:
|
default:
|
||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
}
|
}
|
||||||
@ -292,16 +340,16 @@ func UpdateChannelBalance(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func updateAllChannelsBalance() error {
|
func updateAllChannelsBalance() error {
|
||||||
channels, err := model.GetAllChannels(0, 0, true)
|
channels, err := model.GetAllChannels(0, 0, "all")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
if channel.Status != common.ChannelStatusEnabled {
|
if channel.Status != model.ChannelStatusEnabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// TODO: support Azure
|
// TODO: support Azure
|
||||||
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
balance, err := updateChannelBalance(channel)
|
balance, err := updateChannelBalance(channel)
|
||||||
@ -310,24 +358,23 @@ func updateAllChannelsBalance() error {
|
|||||||
} else {
|
} else {
|
||||||
// err is nil & balance <= 0 means quota is used up
|
// err is nil & balance <= 0 means quota is used up
|
||||||
if balance <= 0 {
|
if balance <= 0 {
|
||||||
disableChannel(channel.Id, channel.Name, "余额不足")
|
monitor.DisableChannel(channel.Id, channel.Name, "余额不足")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(config.RequestInterval)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateAllChannelsBalance(c *gin.Context) {
|
func UpdateAllChannelsBalance(c *gin.Context) {
|
||||||
// TODO: make it async
|
//err := updateAllChannelsBalance()
|
||||||
err := updateAllChannelsBalance()
|
//if err != nil {
|
||||||
if err != nil {
|
// c.JSON(http.StatusOK, gin.H{
|
||||||
c.JSON(http.StatusOK, gin.H{
|
// "success": false,
|
||||||
"success": false,
|
// "message": err.Error(),
|
||||||
"message": err.Error(),
|
// })
|
||||||
})
|
// return
|
||||||
return
|
//}
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
@ -338,8 +385,8 @@ func UpdateAllChannelsBalance(c *gin.Context) {
|
|||||||
func AutomaticallyUpdateChannels(frequency int) {
|
func AutomaticallyUpdateChannels(frequency int) {
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||||
common.SysLog("updating all channels")
|
logger.SysLog("updating all channels")
|
||||||
_ = updateAllChannelsBalance()
|
_ = updateAllChannelsBalance()
|
||||||
common.SysLog("channels update done")
|
logger.SysLog("channels update done")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,96 +7,38 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"net/http/httptest"
|
||||||
"one-api/model"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/common/message"
|
||||||
|
"github.com/songquanpeng/one-api/middleware"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
|
"github.com/songquanpeng/one-api/monitor"
|
||||||
|
relay "github.com/songquanpeng/one-api/relay"
|
||||||
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
|
"github.com/songquanpeng/one-api/relay/controller"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
|
func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
|
||||||
switch channel.Type {
|
if model == "" {
|
||||||
case common.ChannelTypePaLM:
|
model = "gpt-3.5-turbo"
|
||||||
fallthrough
|
|
||||||
case common.ChannelTypeGemini:
|
|
||||||
fallthrough
|
|
||||||
case common.ChannelTypeAnthropic:
|
|
||||||
fallthrough
|
|
||||||
case common.ChannelTypeBaidu:
|
|
||||||
fallthrough
|
|
||||||
case common.ChannelTypeZhipu:
|
|
||||||
fallthrough
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
fallthrough
|
|
||||||
case common.ChannelType360:
|
|
||||||
fallthrough
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
request.Model = "gpt-35-turbo"
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
|
|
||||||
}
|
}
|
||||||
}()
|
testRequest := &relaymodel.GeneralOpenAIRequest{
|
||||||
default:
|
MaxTokens: 2,
|
||||||
request.Model = "gpt-3.5-turbo"
|
Model: model,
|
||||||
}
|
}
|
||||||
requestURL := common.ChannelBaseURLs[channel.Type]
|
testMessage := relaymodel.Message{
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
|
||||||
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
|
|
||||||
} else {
|
|
||||||
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
|
|
||||||
requestURL = baseURL
|
|
||||||
}
|
|
||||||
|
|
||||||
requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
return err, nil
|
|
||||||
}
|
|
||||||
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
|
|
||||||
if err != nil {
|
|
||||||
return err, nil
|
|
||||||
}
|
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
|
||||||
req.Header.Set("api-key", channel.Key)
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err, nil
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
var response TextResponse
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return err, nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(body, &response)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
|
|
||||||
}
|
|
||||||
if response.Usage.CompletionTokens == 0 {
|
|
||||||
if response.Error.Message == "" {
|
|
||||||
response.Error.Message = "补全 tokens 非预期返回 0"
|
|
||||||
}
|
|
||||||
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildTestRequest() *ChatRequest {
|
|
||||||
testRequest := &ChatRequest{
|
|
||||||
Model: "", // this will be set later
|
|
||||||
MaxTokens: 1,
|
|
||||||
}
|
|
||||||
testMessage := Message{
|
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: "hi",
|
Content: "hi",
|
||||||
}
|
}
|
||||||
@ -104,6 +46,78 @@ func buildTestRequest() *ChatRequest {
|
|||||||
return testRequest
|
return testRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = &http.Request{
|
||||||
|
Method: "POST",
|
||||||
|
URL: &url.URL{Path: "/v1/chat/completions"},
|
||||||
|
Body: nil,
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Set(ctxkey.Channel, channel.Type)
|
||||||
|
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
|
||||||
|
cfg, _ := channel.LoadConfig()
|
||||||
|
c.Set(ctxkey.Config, cfg)
|
||||||
|
middleware.SetupContextForSelectedChannel(c, channel, "")
|
||||||
|
meta := meta.GetByContext(c)
|
||||||
|
apiType := channeltype.ToAPIType(channel.Type)
|
||||||
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||||
|
}
|
||||||
|
adaptor.Init(meta)
|
||||||
|
modelName := request.Model
|
||||||
|
modelMap := channel.GetModelMapping()
|
||||||
|
if modelName == "" || !strings.Contains(channel.Models, modelName) {
|
||||||
|
modelNames := strings.Split(channel.Models, ",")
|
||||||
|
if len(modelNames) > 0 {
|
||||||
|
modelName = modelNames[0]
|
||||||
|
}
|
||||||
|
if modelMap != nil && modelMap[modelName] != "" {
|
||||||
|
modelName = modelMap[modelName]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
meta.OriginModelName, meta.ActualModelName = request.Model, modelName
|
||||||
|
request.Model = modelName
|
||||||
|
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
|
||||||
|
if err != nil {
|
||||||
|
return err, nil
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
|
if err != nil {
|
||||||
|
return err, nil
|
||||||
|
}
|
||||||
|
logger.SysLog(string(jsonData))
|
||||||
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
|
c.Request.Body = io.NopCloser(requestBody)
|
||||||
|
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return err, nil
|
||||||
|
}
|
||||||
|
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||||
|
err := controller.RelayErrorHandler(resp)
|
||||||
|
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
||||||
|
}
|
||||||
|
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||||
|
if respErr != nil {
|
||||||
|
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
|
||||||
|
}
|
||||||
|
if usage == nil {
|
||||||
|
return errors.New("usage is nil"), nil
|
||||||
|
}
|
||||||
|
result := w.Result()
|
||||||
|
// print result.Body
|
||||||
|
respBody, err := io.ReadAll(result.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err, nil
|
||||||
|
}
|
||||||
|
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestChannel(c *gin.Context) {
|
func TestChannel(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -121,11 +135,15 @@ func TestChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
testRequest := buildTestRequest()
|
model := c.Query("model")
|
||||||
|
testRequest := buildTestRequest(model)
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, _ = testChannel(channel, *testRequest)
|
err, _ = testChannel(channel, testRequest)
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
if err != nil {
|
||||||
|
milliseconds = 0
|
||||||
|
}
|
||||||
go channel.UpdateResponseTime(milliseconds)
|
go channel.UpdateResponseTime(milliseconds)
|
||||||
consumedTime := float64(milliseconds) / 1000.0
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -133,6 +151,7 @@ func TestChannel(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
"time": consumedTime,
|
"time": consumedTime,
|
||||||
|
"model": model,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -140,6 +159,7 @@ func TestChannel(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"time": consumedTime,
|
"time": consumedTime,
|
||||||
|
"model": model,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -147,35 +167,9 @@ func TestChannel(c *gin.Context) {
|
|||||||
var testAllChannelsLock sync.Mutex
|
var testAllChannelsLock sync.Mutex
|
||||||
var testAllChannelsRunning bool = false
|
var testAllChannelsRunning bool = false
|
||||||
|
|
||||||
func notifyRootUser(subject string, content string) {
|
func testChannels(notify bool, scope string) error {
|
||||||
if common.RootUserEmail == "" {
|
if config.RootUserEmail == "" {
|
||||||
common.RootUserEmail = model.GetRootUserEmail()
|
config.RootUserEmail = model.GetRootUserEmail()
|
||||||
}
|
|
||||||
err := common.SendEmail(subject, common.RootUserEmail, content)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// disable & notify
|
|
||||||
func disableChannel(channelId int, channelName string, reason string) {
|
|
||||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
|
||||||
notifyRootUser(subject, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
// enable & notify
|
|
||||||
func enableChannel(channelId int, channelName string) {
|
|
||||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
|
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
|
||||||
notifyRootUser(subject, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testAllChannels(notify bool) error {
|
|
||||||
if common.RootUserEmail == "" {
|
|
||||||
common.RootUserEmail = model.GetRootUserEmail()
|
|
||||||
}
|
}
|
||||||
testAllChannelsLock.Lock()
|
testAllChannelsLock.Lock()
|
||||||
if testAllChannelsRunning {
|
if testAllChannelsRunning {
|
||||||
@ -184,50 +178,58 @@ func testAllChannels(notify bool) error {
|
|||||||
}
|
}
|
||||||
testAllChannelsRunning = true
|
testAllChannelsRunning = true
|
||||||
testAllChannelsLock.Unlock()
|
testAllChannelsLock.Unlock()
|
||||||
channels, err := model.GetAllChannels(0, 0, true)
|
channels, err := model.GetAllChannels(0, 0, scope)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
testRequest := buildTestRequest()
|
var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
|
||||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
|
||||||
if disableThreshold == 0 {
|
if disableThreshold == 0 {
|
||||||
disableThreshold = 10000000 // a impossible value
|
disableThreshold = 10000000 // a impossible value
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
isChannelEnabled := channel.Status == model.ChannelStatusEnabled
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, openaiErr := testChannel(channel, *testRequest)
|
testRequest := buildTestRequest("")
|
||||||
|
err, openaiErr := testChannel(channel, testRequest)
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
if isChannelEnabled && milliseconds > disableThreshold {
|
if isChannelEnabled && milliseconds > disableThreshold {
|
||||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
|
||||||
disableChannel(channel.Id, channel.Name, err.Error())
|
if config.AutomaticDisableChannelEnabled {
|
||||||
|
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||||
|
} else {
|
||||||
|
_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error())
|
||||||
}
|
}
|
||||||
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) {
|
|
||||||
disableChannel(channel.Id, channel.Name, err.Error())
|
|
||||||
}
|
}
|
||||||
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
|
if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) {
|
||||||
enableChannel(channel.Id, channel.Name)
|
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||||
|
}
|
||||||
|
if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) {
|
||||||
|
monitor.EnableChannel(channel.Id, channel.Name)
|
||||||
}
|
}
|
||||||
channel.UpdateResponseTime(milliseconds)
|
channel.UpdateResponseTime(milliseconds)
|
||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(config.RequestInterval)
|
||||||
}
|
}
|
||||||
testAllChannelsLock.Lock()
|
testAllChannelsLock.Lock()
|
||||||
testAllChannelsRunning = false
|
testAllChannelsRunning = false
|
||||||
testAllChannelsLock.Unlock()
|
testAllChannelsLock.Unlock()
|
||||||
if notify {
|
if notify {
|
||||||
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
|
err := message.Notify(message.ByAll, "渠道测试完成", "", "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAllChannels(c *gin.Context) {
|
func TestChannels(c *gin.Context) {
|
||||||
err := testAllChannels(true)
|
scope := c.Query("scope")
|
||||||
|
if scope == "" {
|
||||||
|
scope = "all"
|
||||||
|
}
|
||||||
|
err := testChannels(true, scope)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -245,8 +247,8 @@ func TestAllChannels(c *gin.Context) {
|
|||||||
func AutomaticallyTestChannels(frequency int) {
|
func AutomaticallyTestChannels(frequency int) {
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||||
common.SysLog("testing all channels")
|
logger.SysLog("testing all channels")
|
||||||
_ = testAllChannels(false)
|
_ = testChannels(false, "all")
|
||||||
common.SysLog("channel test finished")
|
logger.SysLog("channel test finished")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,9 +2,10 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -14,7 +15,7 @@ func GetAllChannels(c *gin.Context) {
|
|||||||
if p < 0 {
|
if p < 0 {
|
||||||
p = 0
|
p = 0
|
||||||
}
|
}
|
||||||
channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
|
channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, "limited")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -83,7 +84,7 @@ func AddChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel.CreatedTime = common.GetTimestamp()
|
channel.CreatedTime = helper.GetTimestamp()
|
||||||
keys := strings.Split(channel.Key, "\n")
|
keys := strings.Split(channel.Key, "\n")
|
||||||
channels := make([]model.Channel, 0, len(keys))
|
channels := make([]model.Channel, 0, len(keys))
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
|
@ -2,13 +2,13 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetGroups(c *gin.Context) {
|
func GetGroups(c *gin.Context) {
|
||||||
groupNames := make([]string, 0)
|
groupNames := make([]string, 0)
|
||||||
for groupName, _ := range common.GroupRatio {
|
for groupName := range billingratio.GroupRatio {
|
||||||
groupNames = append(groupNames, groupName)
|
groupNames = append(groupNames, groupName)
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
@ -2,9 +2,10 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,7 +21,7 @@ func GetAllLogs(c *gin.Context) {
|
|||||||
tokenName := c.Query("token_name")
|
tokenName := c.Query("token_name")
|
||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||||
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
|
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -41,13 +42,13 @@ func GetUserLogs(c *gin.Context) {
|
|||||||
if p < 0 {
|
if p < 0 {
|
||||||
p = 0
|
p = 0
|
||||||
}
|
}
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt(ctxkey.Id)
|
||||||
logType, _ := strconv.Atoi(c.Query("type"))
|
logType, _ := strconv.Atoi(c.Query("type"))
|
||||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||||
tokenName := c.Query("token_name")
|
tokenName := c.Query("token_name")
|
||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
|
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -83,7 +84,7 @@ func SearchAllLogs(c *gin.Context) {
|
|||||||
|
|
||||||
func SearchUserLogs(c *gin.Context) {
|
func SearchUserLogs(c *gin.Context) {
|
||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt(ctxkey.Id)
|
||||||
logs, err := model.SearchUserLogs(userId, keyword)
|
logs, err := model.SearchUserLogs(userId, keyword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@ -122,7 +123,7 @@ func GetLogsStat(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetLogsSelfStat(c *gin.Context) {
|
func GetLogsSelfStat(c *gin.Context) {
|
||||||
username := c.GetString("username")
|
username := c.GetString(ctxkey.Username)
|
||||||
logType, _ := strconv.Atoi(c.Query("type"))
|
logType, _ := strconv.Atoi(c.Query("type"))
|
||||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||||
|
@ -3,9 +3,11 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/message"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -18,55 +20,62 @@ func GetStatus(c *gin.Context) {
|
|||||||
"data": gin.H{
|
"data": gin.H{
|
||||||
"version": common.Version,
|
"version": common.Version,
|
||||||
"start_time": common.StartTime,
|
"start_time": common.StartTime,
|
||||||
"email_verification": common.EmailVerificationEnabled,
|
"email_verification": config.EmailVerificationEnabled,
|
||||||
"github_oauth": common.GitHubOAuthEnabled,
|
"github_oauth": config.GitHubOAuthEnabled,
|
||||||
"github_client_id": common.GitHubClientId,
|
"github_client_id": config.GitHubClientId,
|
||||||
"system_name": common.SystemName,
|
"lark_client_id": config.LarkClientId,
|
||||||
"logo": common.Logo,
|
"system_name": config.SystemName,
|
||||||
"footer_html": common.Footer,
|
"logo": config.Logo,
|
||||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
"footer_html": config.Footer,
|
||||||
"wechat_login": common.WeChatAuthEnabled,
|
"wechat_qrcode": config.WeChatAccountQRCodeImageURL,
|
||||||
"server_address": common.ServerAddress,
|
"wechat_login": config.WeChatAuthEnabled,
|
||||||
"turnstile_check": common.TurnstileCheckEnabled,
|
"server_address": config.ServerAddress,
|
||||||
"turnstile_site_key": common.TurnstileSiteKey,
|
"turnstile_check": config.TurnstileCheckEnabled,
|
||||||
"top_up_link": common.TopUpLink,
|
"turnstile_site_key": config.TurnstileSiteKey,
|
||||||
"chat_link": common.ChatLink,
|
"top_up_link": config.TopUpLink,
|
||||||
"quota_per_unit": common.QuotaPerUnit,
|
"chat_link": config.ChatLink,
|
||||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
"quota_per_unit": config.QuotaPerUnit,
|
||||||
|
"display_in_currency": config.DisplayInCurrencyEnabled,
|
||||||
|
"oidc": config.OidcEnabled,
|
||||||
|
"oidc_client_id": config.OidcClientId,
|
||||||
|
"oidc_well_known": config.OidcWellKnown,
|
||||||
|
"oidc_authorization_endpoint": config.OidcAuthorizationEndpoint,
|
||||||
|
"oidc_token_endpoint": config.OidcTokenEndpoint,
|
||||||
|
"oidc_userinfo_endpoint": config.OidcUserinfoEndpoint,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetNotice(c *gin.Context) {
|
func GetNotice(c *gin.Context) {
|
||||||
common.OptionMapRWMutex.RLock()
|
config.OptionMapRWMutex.RLock()
|
||||||
defer common.OptionMapRWMutex.RUnlock()
|
defer config.OptionMapRWMutex.RUnlock()
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": common.OptionMap["Notice"],
|
"data": config.OptionMap["Notice"],
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAbout(c *gin.Context) {
|
func GetAbout(c *gin.Context) {
|
||||||
common.OptionMapRWMutex.RLock()
|
config.OptionMapRWMutex.RLock()
|
||||||
defer common.OptionMapRWMutex.RUnlock()
|
defer config.OptionMapRWMutex.RUnlock()
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": common.OptionMap["About"],
|
"data": config.OptionMap["About"],
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetHomePageContent(c *gin.Context) {
|
func GetHomePageContent(c *gin.Context) {
|
||||||
common.OptionMapRWMutex.RLock()
|
config.OptionMapRWMutex.RLock()
|
||||||
defer common.OptionMapRWMutex.RUnlock()
|
defer config.OptionMapRWMutex.RUnlock()
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": common.OptionMap["HomePageContent"],
|
"data": config.OptionMap["HomePageContent"],
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -80,9 +89,9 @@ func SendEmailVerification(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if common.EmailDomainRestrictionEnabled {
|
if config.EmailDomainRestrictionEnabled {
|
||||||
allowed := false
|
allowed := false
|
||||||
for _, domain := range common.EmailDomainWhitelist {
|
for _, domain := range config.EmailDomainWhitelist {
|
||||||
if strings.HasSuffix(email, "@"+domain) {
|
if strings.HasSuffix(email, "@"+domain) {
|
||||||
allowed = true
|
allowed = true
|
||||||
break
|
break
|
||||||
@ -105,11 +114,11 @@ func SendEmailVerification(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
code := common.GenerateVerificationCode(6)
|
code := common.GenerateVerificationCode(6)
|
||||||
common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
|
common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
|
||||||
subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName)
|
subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName)
|
||||||
content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+
|
content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+
|
||||||
"<p>您的验证码为: <strong>%s</strong></p>"+
|
"<p>您的验证码为: <strong>%s</strong></p>"+
|
||||||
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
|
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, code, common.VerificationValidMinutes)
|
||||||
err := common.SendEmail(subject, email, content)
|
err := message.SendEmail(subject, email, content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -142,13 +151,13 @@ func SendPasswordResetEmail(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
code := common.GenerateVerificationCode(0)
|
code := common.GenerateVerificationCode(0)
|
||||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
|
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code)
|
||||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
subject := fmt.Sprintf("%s密码重置", config.SystemName)
|
||||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||||
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
|
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
|
||||||
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, link, link, common.VerificationValidMinutes)
|
||||||
err := common.SendEmail(subject, email, content)
|
err := message.SendEmail(subject, email, content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
|
@ -2,8 +2,17 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
|
relay "github.com/songquanpeng/one-api/relay"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/apitype"
|
||||||
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/models/list
|
// https://platform.openai.com/docs/api-reference/models/list
|
||||||
@ -33,8 +42,9 @@ type OpenAIModels struct {
|
|||||||
Parent *string `json:"parent"`
|
Parent *string `json:"parent"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var openAIModels []OpenAIModels
|
var models []OpenAIModels
|
||||||
var openAIModelsMap map[string]OpenAIModels
|
var modelsMap map[string]OpenAIModels
|
||||||
|
var channelId2Models map[int][]string
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
var permission []OpenAIModelPermission
|
var permission []OpenAIModelPermission
|
||||||
@ -53,574 +63,151 @@ func init() {
|
|||||||
IsBlocking: false,
|
IsBlocking: false,
|
||||||
})
|
})
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
openAIModels = []OpenAIModels{
|
for i := 0; i < apitype.Dummy; i++ {
|
||||||
{
|
if i == apitype.AIProxyLibrary {
|
||||||
Id: "dall-e-2",
|
continue
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "dall-e-2",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "dall-e-3",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "dall-e-3",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "whisper-1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "whisper-1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "tts-1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "tts-1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "tts-1-1106",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "tts-1-1106",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "tts-1-hd",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "tts-1-hd",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "tts-1-hd-1106",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "tts-1-hd-1106",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo-0301",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo-0301",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo-0613",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo-0613",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo-16k",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo-16k",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo-16k-0613",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo-16k-0613",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo-1106",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1699593571,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo-1106",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo-instruct",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo-instruct",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-0314",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-0314",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-0613",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-0613",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-32k",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-32k",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-32k-0314",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-32k-0314",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-32k-0613",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-32k-0613",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-1106-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1699593571,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-1106-preview",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-vision-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1699593571,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-vision-preview",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-embedding-ada-002",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-embedding-ada-002",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-davinci-003",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-davinci-003",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-davinci-002",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-davinci-002",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-curie-001",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-curie-001",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-babbage-001",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-babbage-001",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-ada-001",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-ada-001",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-moderation-latest",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-moderation-latest",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-moderation-stable",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-moderation-stable",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-davinci-edit-001",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-davinci-edit-001",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "code-davinci-edit-001",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "code-davinci-edit-001",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "davinci-002",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "davinci-002",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "babbage-002",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "babbage-002",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "claude-instant-1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "anthropic",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "claude-instant-1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "claude-2",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "anthropic",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "claude-2",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "claude-2.1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "anthropic",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "claude-2.1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "claude-2.0",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "anthropic",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "claude-2.0",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "ERNIE-Bot",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "baidu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "ERNIE-Bot",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "ERNIE-Bot-turbo",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "baidu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "ERNIE-Bot-turbo",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "ERNIE-Bot-4",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "baidu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "ERNIE-Bot-4",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "Embedding-V1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "baidu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "Embedding-V1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "PaLM-2",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "google palm",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "PaLM-2",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gemini-pro",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "google gemini",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gemini-pro",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gemini-pro-vision",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "google gemini",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gemini-pro-vision",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "chatglm_turbo",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "zhipu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "chatglm_turbo",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "chatglm_pro",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "zhipu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "chatglm_pro",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "chatglm_std",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "zhipu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "chatglm_std",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "chatglm_lite",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "zhipu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "chatglm_lite",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "qwen-turbo",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "ali",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "qwen-turbo",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "qwen-plus",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "ali",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "qwen-plus",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "qwen-max",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "ali",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "qwen-max",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "qwen-max-longcontext",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "ali",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "qwen-max-longcontext",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-embedding-v1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "ali",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-embedding-v1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "SparkDesk",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "xunfei",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "SparkDesk",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "360GPT_S2_V9",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "360",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "360GPT_S2_V9",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "embedding-bert-512-v1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "360",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "embedding-bert-512-v1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "embedding_s1_v1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "360",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "embedding_s1_v1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "semantic_similarity_s1_v1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "360",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "semantic_similarity_s1_v1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "hunyuan",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "tencent",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "hunyuan",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
openAIModelsMap = make(map[string]OpenAIModels)
|
adaptor := relay.GetAdaptor(i)
|
||||||
for _, model := range openAIModels {
|
channelName := adaptor.GetChannelName()
|
||||||
openAIModelsMap[model.Id] = model
|
modelNames := adaptor.GetModelList()
|
||||||
|
for _, modelName := range modelNames {
|
||||||
|
models = append(models, OpenAIModels{
|
||||||
|
Id: modelName,
|
||||||
|
Object: "model",
|
||||||
|
Created: 1626777600,
|
||||||
|
OwnedBy: channelName,
|
||||||
|
Permission: permission,
|
||||||
|
Root: modelName,
|
||||||
|
Parent: nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, channelType := range openai.CompatibleChannels {
|
||||||
|
if channelType == channeltype.Azure {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
|
||||||
|
for _, modelName := range channelModelList {
|
||||||
|
models = append(models, OpenAIModels{
|
||||||
|
Id: modelName,
|
||||||
|
Object: "model",
|
||||||
|
Created: 1626777600,
|
||||||
|
OwnedBy: channelName,
|
||||||
|
Permission: permission,
|
||||||
|
Root: modelName,
|
||||||
|
Parent: nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelsMap = make(map[string]OpenAIModels)
|
||||||
|
for _, model := range models {
|
||||||
|
modelsMap[model.Id] = model
|
||||||
|
}
|
||||||
|
channelId2Models = make(map[int][]string)
|
||||||
|
for i := 1; i < channeltype.Dummy; i++ {
|
||||||
|
adaptor := relay.GetAdaptor(channeltype.ToAPIType(i))
|
||||||
|
meta := &meta.Meta{
|
||||||
|
ChannelType: i,
|
||||||
|
}
|
||||||
|
adaptor.Init(meta)
|
||||||
|
channelId2Models[i] = adaptor.GetModelList()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListModels(c *gin.Context) {
|
func DashboardListModels(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": channelId2Models,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListAllModels(c *gin.Context) {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": openAIModels,
|
"data": models,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListModels(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
var availableModels []string
|
||||||
|
if c.GetString(ctxkey.AvailableModels) != "" {
|
||||||
|
availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",")
|
||||||
|
} else {
|
||||||
|
userId := c.GetInt(ctxkey.Id)
|
||||||
|
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||||
|
availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
|
||||||
|
}
|
||||||
|
modelSet := make(map[string]bool)
|
||||||
|
for _, availableModel := range availableModels {
|
||||||
|
modelSet[availableModel] = true
|
||||||
|
}
|
||||||
|
availableOpenAIModels := make([]OpenAIModels, 0)
|
||||||
|
for _, model := range models {
|
||||||
|
if _, ok := modelSet[model.Id]; ok {
|
||||||
|
modelSet[model.Id] = false
|
||||||
|
availableOpenAIModels = append(availableOpenAIModels, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for modelName, ok := range modelSet {
|
||||||
|
if ok {
|
||||||
|
availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
|
||||||
|
Id: modelName,
|
||||||
|
Object: "model",
|
||||||
|
Created: 1626777600,
|
||||||
|
OwnedBy: "custom",
|
||||||
|
Root: modelName,
|
||||||
|
Parent: nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"object": "list",
|
||||||
|
"data": availableOpenAIModels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func RetrieveModel(c *gin.Context) {
|
func RetrieveModel(c *gin.Context) {
|
||||||
modelId := c.Param("model")
|
modelId := c.Param("model")
|
||||||
if model, ok := openAIModelsMap[modelId]; ok {
|
if model, ok := modelsMap[modelId]; ok {
|
||||||
c.JSON(200, model)
|
c.JSON(200, model)
|
||||||
} else {
|
} else {
|
||||||
openAIError := OpenAIError{
|
Error := relaymodel.Error{
|
||||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||||
Type: "invalid_request_error",
|
Type: "invalid_request_error",
|
||||||
Param: "model",
|
Param: "model",
|
||||||
Code: "model_not_found",
|
Code: "model_not_found",
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"error": openAIError,
|
"error": Error,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetUserAvailableModels(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
id := c.GetInt(ctxkey.Id)
|
||||||
|
userGroup, err := model.CacheGetUserGroup(id)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
models, err := model.CacheGetGroupModels(ctx, userGroup)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": models,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -2,9 +2,10 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -12,17 +13,17 @@ import (
|
|||||||
|
|
||||||
func GetOptions(c *gin.Context) {
|
func GetOptions(c *gin.Context) {
|
||||||
var options []*model.Option
|
var options []*model.Option
|
||||||
common.OptionMapRWMutex.Lock()
|
config.OptionMapRWMutex.Lock()
|
||||||
for k, v := range common.OptionMap {
|
for k, v := range config.OptionMap {
|
||||||
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
|
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
options = append(options, &model.Option{
|
options = append(options, &model.Option{
|
||||||
Key: k,
|
Key: k,
|
||||||
Value: common.Interface2String(v),
|
Value: helper.Interface2String(v),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
common.OptionMapRWMutex.Unlock()
|
config.OptionMapRWMutex.Unlock()
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
@ -43,7 +44,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
switch option.Key {
|
switch option.Key {
|
||||||
case "Theme":
|
case "Theme":
|
||||||
if !common.ValidThemes[option.Value] {
|
if !config.ValidThemes[option.Value] {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无效的主题",
|
"message": "无效的主题",
|
||||||
@ -51,7 +52,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "GitHubOAuthEnabled":
|
case "GitHubOAuthEnabled":
|
||||||
if option.Value == "true" && common.GitHubClientId == "" {
|
if option.Value == "true" && config.GitHubClientId == "" {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
|
"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
|
||||||
@ -59,7 +60,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "EmailDomainRestrictionEnabled":
|
case "EmailDomainRestrictionEnabled":
|
||||||
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
|
if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
|
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
|
||||||
@ -67,7 +68,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "WeChatAuthEnabled":
|
case "WeChatAuthEnabled":
|
||||||
if option.Value == "true" && common.WeChatServerAddress == "" {
|
if option.Value == "true" && config.WeChatServerAddress == "" {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无法启用微信登录,请先填入微信登录相关配置信息!",
|
"message": "无法启用微信登录,请先填入微信登录相关配置信息!",
|
||||||
@ -75,7 +76,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "TurnstileCheckEnabled":
|
case "TurnstileCheckEnabled":
|
||||||
if option.Value == "true" && common.TurnstileSiteKey == "" {
|
if option.Value == "true" && config.TurnstileSiteKey == "" {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
|
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
|
||||||
|
@ -2,9 +2,12 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -13,7 +16,7 @@ func GetAllRedemptions(c *gin.Context) {
|
|||||||
if p < 0 {
|
if p < 0 {
|
||||||
p = 0
|
p = 0
|
||||||
}
|
}
|
||||||
redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage)
|
redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -105,12 +108,12 @@ func AddRedemption(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
var keys []string
|
var keys []string
|
||||||
for i := 0; i < redemption.Count; i++ {
|
for i := 0; i < redemption.Count; i++ {
|
||||||
key := common.GetUUID()
|
key := random.GetUUID()
|
||||||
cleanRedemption := model.Redemption{
|
cleanRedemption := model.Redemption{
|
||||||
UserId: c.GetInt("id"),
|
UserId: c.GetInt(ctxkey.Id),
|
||||||
Name: redemption.Name,
|
Name: redemption.Name,
|
||||||
Key: key,
|
Key: key,
|
||||||
CreatedTime: common.GetTimestamp(),
|
CreatedTime: helper.GetTimestamp(),
|
||||||
Quota: redemption.Quota,
|
Quota: redemption.Quota,
|
||||||
}
|
}
|
||||||
err = cleanRedemption.Insert()
|
err = cleanRedemption.Insert()
|
||||||
|
@ -1,220 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
|
|
||||||
|
|
||||||
type AIProxyLibraryRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Query string `json:"query"`
|
|
||||||
LibraryId string `json:"libraryId"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AIProxyLibraryError struct {
|
|
||||||
ErrCode int `json:"errCode"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AIProxyLibraryDocument struct {
|
|
||||||
Title string `json:"title"`
|
|
||||||
URL string `json:"url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AIProxyLibraryResponse struct {
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Answer string `json:"answer"`
|
|
||||||
Documents []AIProxyLibraryDocument `json:"documents"`
|
|
||||||
AIProxyLibraryError
|
|
||||||
}
|
|
||||||
|
|
||||||
type AIProxyLibraryStreamResponse struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
Finish bool `json:"finish"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Documents []AIProxyLibraryDocument `json:"documents"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
|
|
||||||
query := ""
|
|
||||||
if len(request.Messages) != 0 {
|
|
||||||
query = request.Messages[len(request.Messages)-1].StringContent()
|
|
||||||
}
|
|
||||||
return &AIProxyLibraryRequest{
|
|
||||||
Model: request.Model,
|
|
||||||
Stream: request.Stream,
|
|
||||||
Query: query,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
|
|
||||||
if len(documents) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
content := "\n\n参考文档:\n"
|
|
||||||
for i, document := range documents {
|
|
||||||
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
|
|
||||||
}
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
|
|
||||||
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: content,
|
|
||||||
},
|
|
||||||
FinishReason: "stop",
|
|
||||||
}
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Id: common.GetUUID(),
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
|
|
||||||
choice.FinishReason = &stopFinishReason
|
|
||||||
return &ChatCompletionsStreamResponse{
|
|
||||||
Id: common.GetUUID(),
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = response.Content
|
|
||||||
return &ChatCompletionsStreamResponse{
|
|
||||||
Id: common.GetUUID(),
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: response.Model,
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var usage Usage
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 5 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if data[:5] != "data:" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = data[5:]
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
var documents []AIProxyLibraryDocument
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if len(AIProxyLibraryResponse.Documents) != 0 {
|
|
||||||
documents = AIProxyLibraryResponse.Documents
|
|
||||||
}
|
|
||||||
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
response := documentsAIProxyLibrary(documents)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var AIProxyLibraryResponse AIProxyLibraryResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if AIProxyLibraryResponse.ErrCode != 0 {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: AIProxyLibraryResponse.Message,
|
|
||||||
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
|
|
||||||
Code: AIProxyLibraryResponse.ErrCode,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
@ -1,322 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
|
||||||
|
|
||||||
type AliMessage struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
Role string `json:"role"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliInput struct {
|
|
||||||
//Prompt string `json:"prompt"`
|
|
||||||
Messages []AliMessage `json:"messages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliParameters struct {
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
TopK int `json:"top_k,omitempty"`
|
|
||||||
Seed uint64 `json:"seed,omitempty"`
|
|
||||||
EnableSearch bool `json:"enable_search,omitempty"`
|
|
||||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliChatRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input AliInput `json:"input"`
|
|
||||||
Parameters AliParameters `json:"parameters,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliEmbeddingRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input struct {
|
|
||||||
Texts []string `json:"texts"`
|
|
||||||
} `json:"input"`
|
|
||||||
Parameters *struct {
|
|
||||||
TextType string `json:"text_type,omitempty"`
|
|
||||||
} `json:"parameters,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliEmbedding struct {
|
|
||||||
Embedding []float64 `json:"embedding"`
|
|
||||||
TextIndex int `json:"text_index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliEmbeddingResponse struct {
|
|
||||||
Output struct {
|
|
||||||
Embeddings []AliEmbedding `json:"embeddings"`
|
|
||||||
} `json:"output"`
|
|
||||||
Usage AliUsage `json:"usage"`
|
|
||||||
AliError
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliError struct {
|
|
||||||
Code string `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
RequestId string `json:"request_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliUsage struct {
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliOutput struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
FinishReason string `json:"finish_reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliChatResponse struct {
|
|
||||||
Output AliOutput `json:"output"`
|
|
||||||
Usage AliUsage `json:"usage"`
|
|
||||||
AliError
|
|
||||||
}
|
|
||||||
|
|
||||||
const AliEnableSearchModelSuffix = "-internet"
|
|
||||||
|
|
||||||
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
|
||||||
messages := make([]AliMessage, 0, len(request.Messages))
|
|
||||||
for i := 0; i < len(request.Messages); i++ {
|
|
||||||
message := request.Messages[i]
|
|
||||||
messages = append(messages, AliMessage{
|
|
||||||
Content: message.StringContent(),
|
|
||||||
Role: strings.ToLower(message.Role),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
enableSearch := false
|
|
||||||
aliModel := request.Model
|
|
||||||
if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) {
|
|
||||||
enableSearch = true
|
|
||||||
aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix)
|
|
||||||
}
|
|
||||||
return &AliChatRequest{
|
|
||||||
Model: aliModel,
|
|
||||||
Input: AliInput{
|
|
||||||
Messages: messages,
|
|
||||||
},
|
|
||||||
Parameters: AliParameters{
|
|
||||||
EnableSearch: enableSearch,
|
|
||||||
IncrementalOutput: request.Stream,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
|
|
||||||
return &AliEmbeddingRequest{
|
|
||||||
Model: "text-embedding-v1",
|
|
||||||
Input: struct {
|
|
||||||
Texts []string `json:"texts"`
|
|
||||||
}{
|
|
||||||
Texts: request.ParseInput(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var aliResponse AliEmbeddingResponse
|
|
||||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if aliResponse.Code != "" {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: aliResponse.Message,
|
|
||||||
Type: aliResponse.Code,
|
|
||||||
Param: aliResponse.RequestId,
|
|
||||||
Code: aliResponse.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
|
|
||||||
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
|
||||||
Object: "list",
|
|
||||||
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
|
||||||
Model: "text-embedding-v1",
|
|
||||||
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, item := range response.Output.Embeddings {
|
|
||||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
|
||||||
Object: `embedding`,
|
|
||||||
Index: item.TextIndex,
|
|
||||||
Embedding: item.Embedding,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return &openAIEmbeddingResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: response.Output.Text,
|
|
||||||
},
|
|
||||||
FinishReason: response.Output.FinishReason,
|
|
||||||
}
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Id: response.RequestId,
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
|
||||||
Usage: Usage{
|
|
||||||
PromptTokens: response.Usage.InputTokens,
|
|
||||||
CompletionTokens: response.Usage.OutputTokens,
|
|
||||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = aliResponse.Output.Text
|
|
||||||
if aliResponse.Output.FinishReason != "null" {
|
|
||||||
finishReason := aliResponse.Output.FinishReason
|
|
||||||
choice.FinishReason = &finishReason
|
|
||||||
}
|
|
||||||
response := ChatCompletionsStreamResponse{
|
|
||||||
Id: aliResponse.RequestId,
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "qwen",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var usage Usage
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 5 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if data[:5] != "data:" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = data[5:]
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
//lastResponseText := ""
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
var aliResponse AliChatResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if aliResponse.Usage.OutputTokens != 0 {
|
|
||||||
usage.PromptTokens = aliResponse.Usage.InputTokens
|
|
||||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
|
||||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
|
||||||
}
|
|
||||||
response := streamResponseAli2OpenAI(&aliResponse)
|
|
||||||
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
|
||||||
//lastResponseText = aliResponse.Output.Text
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var aliResponse AliChatResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &aliResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if aliResponse.Code != "" {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: aliResponse.Message,
|
|
||||||
Type: aliResponse.Code,
|
|
||||||
Param: aliResponse.RequestId,
|
|
||||||
Code: aliResponse.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
|
||||||
fullTextResponse.Model = "qwen"
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
@ -1,262 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
||||||
audioModel := "whisper-1"
|
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
channelType := c.GetInt("channel")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
group := c.GetString("group")
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
|
|
||||||
var ttsRequest TextToSpeechRequest
|
|
||||||
if relayMode == RelayModeAudioSpeech {
|
|
||||||
// Read JSON
|
|
||||||
err := common.UnmarshalBodyReusable(c, &ttsRequest)
|
|
||||||
// Check if JSON is valid
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "invalid_json", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
audioModel = ttsRequest.Model
|
|
||||||
// Check if text is too long 4096
|
|
||||||
if len(ttsRequest.Input) > 4096 {
|
|
||||||
return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
modelRatio := common.GetModelRatio(audioModel)
|
|
||||||
groupRatio := common.GetGroupRatio(group)
|
|
||||||
ratio := modelRatio * groupRatio
|
|
||||||
var quota int
|
|
||||||
var preConsumedQuota int
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeAudioSpeech:
|
|
||||||
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
|
|
||||||
quota = preConsumedQuota
|
|
||||||
default:
|
|
||||||
preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
|
|
||||||
}
|
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if user quota is enough
|
|
||||||
if userQuota-preConsumedQuota < 0 {
|
|
||||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if userQuota > 100*preConsumedQuota {
|
|
||||||
// in this case, we do not pre-consume quota
|
|
||||||
// because the user has enough quota
|
|
||||||
preConsumedQuota = 0
|
|
||||||
}
|
|
||||||
if preConsumedQuota > 0 {
|
|
||||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// map model name
|
|
||||||
modelMapping := c.GetString("model_mapping")
|
|
||||||
if modelMapping != "" {
|
|
||||||
modelMap := make(map[string]string)
|
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap[audioModel] != "" {
|
|
||||||
audioModel = modelMap[audioModel]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
|
||||||
requestURL := c.Request.URL.String()
|
|
||||||
if c.GetString("base_url") != "" {
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
|
||||||
|
|
||||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
|
||||||
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
|
||||||
apiVersion := GetAPIVersion(c)
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
requestBody := &bytes.Buffer{}
|
|
||||||
_, err = io.Copy(requestBody, c.Request.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
|
|
||||||
responseFormat := c.DefaultPostForm("response_format", "json")
|
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
req.Header.Set("api-key", apiKey)
|
|
||||||
req.ContentLength = c.Request.ContentLength
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = req.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = c.Request.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if relayMode != RelayModeAudioSpeech {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
var openAIErr TextResponse
|
|
||||||
if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
|
|
||||||
if openAIErr.Error.Message != "" {
|
|
||||||
return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var text string
|
|
||||||
switch responseFormat {
|
|
||||||
case "json":
|
|
||||||
text, err = getTextFromJSON(responseBody)
|
|
||||||
case "text":
|
|
||||||
text, err = getTextFromText(responseBody)
|
|
||||||
case "srt":
|
|
||||||
text, err = getTextFromSRT(responseBody)
|
|
||||||
case "verbose_json":
|
|
||||||
text, err = getTextFromVerboseJSON(responseBody)
|
|
||||||
case "vtt":
|
|
||||||
text, err = getTextFromVTT(responseBody)
|
|
||||||
default:
|
|
||||||
return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
quota = countTokenText(text, audioModel)
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if preConsumedQuota > 0 {
|
|
||||||
// we need to roll back the pre-consumed quota
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
go func() {
|
|
||||||
// negative means add quota back for token & user
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}(c.Request.Context())
|
|
||||||
}
|
|
||||||
return relayErrorHandler(resp)
|
|
||||||
}
|
|
||||||
quotaDelta := quota - preConsumedQuota
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
|
||||||
}(c.Request.Context())
|
|
||||||
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTextFromVTT(body []byte) (string, error) {
|
|
||||||
return getTextFromSRT(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTextFromVerboseJSON(body []byte) (string, error) {
|
|
||||||
var whisperResponse WhisperVerboseJSONResponse
|
|
||||||
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
|
||||||
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
|
||||||
}
|
|
||||||
return whisperResponse.Text, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTextFromSRT(body []byte) (string, error) {
|
|
||||||
scanner := bufio.NewScanner(strings.NewReader(string(body)))
|
|
||||||
var builder strings.Builder
|
|
||||||
var textLine bool
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
if textLine {
|
|
||||||
builder.WriteString(line)
|
|
||||||
textLine = false
|
|
||||||
continue
|
|
||||||
} else if strings.Contains(line, "-->") {
|
|
||||||
textLine = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return builder.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTextFromText(body []byte) (string, error) {
|
|
||||||
return strings.TrimSuffix(string(body), "\n"), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTextFromJSON(body []byte) (string, error) {
|
|
||||||
var whisperResponse WhisperJSONResponse
|
|
||||||
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
|
||||||
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
|
||||||
}
|
|
||||||
return whisperResponse.Text, nil
|
|
||||||
}
|
|
@ -1,360 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
|
||||||
|
|
||||||
type BaiduTokenResponse struct {
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduChatRequest struct {
|
|
||||||
Messages []BaiduMessage `json:"messages"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
UserId string `json:"user_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduError struct {
|
|
||||||
ErrorCode int `json:"error_code"`
|
|
||||||
ErrorMsg string `json:"error_msg"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduChatResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Result string `json:"result"`
|
|
||||||
IsTruncated bool `json:"is_truncated"`
|
|
||||||
NeedClearHistory bool `json:"need_clear_history"`
|
|
||||||
Usage Usage `json:"usage"`
|
|
||||||
BaiduError
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduChatStreamResponse struct {
|
|
||||||
BaiduChatResponse
|
|
||||||
SentenceId int `json:"sentence_id"`
|
|
||||||
IsEnd bool `json:"is_end"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduEmbeddingRequest struct {
|
|
||||||
Input []string `json:"input"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduEmbeddingData struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Embedding []float64 `json:"embedding"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduEmbeddingResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Data []BaiduEmbeddingData `json:"data"`
|
|
||||||
Usage Usage `json:"usage"`
|
|
||||||
BaiduError
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduAccessToken struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
ErrorDescription string `json:"error_description,omitempty"`
|
|
||||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
|
||||||
ExpiresAt time.Time `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var baiduTokenStore sync.Map
|
|
||||||
|
|
||||||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
|
||||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
|
||||||
for _, message := range request.Messages {
|
|
||||||
if message.Role == "system" {
|
|
||||||
messages = append(messages, BaiduMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
messages = append(messages, BaiduMessage{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "Okay",
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
messages = append(messages, BaiduMessage{
|
|
||||||
Role: message.Role,
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &BaiduChatRequest{
|
|
||||||
Messages: messages,
|
|
||||||
Stream: request.Stream,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: response.Result,
|
|
||||||
},
|
|
||||||
FinishReason: "stop",
|
|
||||||
}
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Id: response.Id,
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: response.Created,
|
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
|
||||||
Usage: response.Usage,
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = baiduResponse.Result
|
|
||||||
if baiduResponse.IsEnd {
|
|
||||||
choice.FinishReason = &stopFinishReason
|
|
||||||
}
|
|
||||||
response := ChatCompletionsStreamResponse{
|
|
||||||
Id: baiduResponse.Id,
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: baiduResponse.Created,
|
|
||||||
Model: "ernie-bot",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
|
||||||
return &BaiduEmbeddingRequest{
|
|
||||||
Input: request.ParseInput(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
|
|
||||||
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
|
||||||
Object: "list",
|
|
||||||
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
|
||||||
Model: "baidu-embedding",
|
|
||||||
Usage: response.Usage,
|
|
||||||
}
|
|
||||||
for _, item := range response.Data {
|
|
||||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
|
||||||
Object: item.Object,
|
|
||||||
Index: item.Index,
|
|
||||||
Embedding: item.Embedding,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return &openAIEmbeddingResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var usage Usage
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 6 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = data[6:]
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
var baiduResponse BaiduChatStreamResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if baiduResponse.Usage.TotalTokens != 0 {
|
|
||||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
|
||||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
|
||||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
|
||||||
}
|
|
||||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var baiduResponse BaiduChatResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if baiduResponse.ErrorMsg != "" {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: baiduResponse.ErrorMsg,
|
|
||||||
Type: "baidu_error",
|
|
||||||
Param: "",
|
|
||||||
Code: baiduResponse.ErrorCode,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
|
||||||
fullTextResponse.Model = "ernie-bot"
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var baiduResponse BaiduEmbeddingResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if baiduResponse.ErrorMsg != "" {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: baiduResponse.ErrorMsg,
|
|
||||||
Type: "baidu_error",
|
|
||||||
Param: "",
|
|
||||||
Code: baiduResponse.ErrorCode,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBaiduAccessToken(apiKey string) (string, error) {
|
|
||||||
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
|
||||||
var accessToken BaiduAccessToken
|
|
||||||
if accessToken, ok = val.(BaiduAccessToken); ok {
|
|
||||||
// soon this will expire
|
|
||||||
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
|
||||||
go func() {
|
|
||||||
_, _ = getBaiduAccessTokenHelper(apiKey)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
return accessToken.AccessToken, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
accessToken, err := getBaiduAccessTokenHelper(apiKey)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if accessToken == nil {
|
|
||||||
return "", errors.New("getBaiduAccessToken return a nil token")
|
|
||||||
}
|
|
||||||
return (*accessToken).AccessToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
|
||||||
parts := strings.Split(apiKey, "|")
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return nil, errors.New("invalid baidu apikey")
|
|
||||||
}
|
|
||||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
|
||||||
parts[0], parts[1]), nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Add("Content-Type", "application/json")
|
|
||||||
req.Header.Add("Accept", "application/json")
|
|
||||||
res, err := impatientHTTPClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
|
|
||||||
var accessToken BaiduAccessToken
|
|
||||||
err = json.NewDecoder(res.Body).Decode(&accessToken)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if accessToken.Error != "" {
|
|
||||||
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
|
||||||
}
|
|
||||||
if accessToken.AccessToken == "" {
|
|
||||||
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
|
||||||
}
|
|
||||||
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
|
||||||
baiduTokenStore.Store(apiKey, accessToken)
|
|
||||||
return &accessToken, nil
|
|
||||||
}
|
|
@ -1,223 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ClaudeMetadata struct {
|
|
||||||
UserId string `json:"user_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
TopK int `json:"top_k,omitempty"`
|
|
||||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeError struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeResponse struct {
|
|
||||||
Completion string `json:"completion"`
|
|
||||||
StopReason string `json:"stop_reason"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Error ClaudeError `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func stopReasonClaude2OpenAI(reason string) string {
|
|
||||||
switch reason {
|
|
||||||
case "stop_sequence":
|
|
||||||
return "stop"
|
|
||||||
case "max_tokens":
|
|
||||||
return "length"
|
|
||||||
default:
|
|
||||||
return reason
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
|
|
||||||
claudeRequest := ClaudeRequest{
|
|
||||||
Model: textRequest.Model,
|
|
||||||
Prompt: "",
|
|
||||||
MaxTokensToSample: textRequest.MaxTokens,
|
|
||||||
StopSequences: nil,
|
|
||||||
Temperature: textRequest.Temperature,
|
|
||||||
TopP: textRequest.TopP,
|
|
||||||
Stream: textRequest.Stream,
|
|
||||||
}
|
|
||||||
if claudeRequest.MaxTokensToSample == 0 {
|
|
||||||
claudeRequest.MaxTokensToSample = 1000000
|
|
||||||
}
|
|
||||||
prompt := ""
|
|
||||||
for _, message := range textRequest.Messages {
|
|
||||||
if message.Role == "user" {
|
|
||||||
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
|
|
||||||
} else if message.Role == "assistant" {
|
|
||||||
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
|
|
||||||
} else if message.Role == "system" {
|
|
||||||
if prompt == "" {
|
|
||||||
prompt = message.StringContent()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
prompt += "\n\nAssistant:"
|
|
||||||
claudeRequest.Prompt = prompt
|
|
||||||
return &claudeRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = claudeResponse.Completion
|
|
||||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
|
||||||
if finishReason != "null" {
|
|
||||||
choice.FinishReason = &finishReason
|
|
||||||
}
|
|
||||||
var response ChatCompletionsStreamResponse
|
|
||||||
response.Object = "chat.completion.chunk"
|
|
||||||
response.Model = claudeResponse.Model
|
|
||||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
|
|
||||||
Name: nil,
|
|
||||||
},
|
|
||||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
|
||||||
}
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
|
||||||
responseText := ""
|
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
|
||||||
createdTime := common.GetTimestamp()
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
|
|
||||||
return i + 4, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if !strings.HasPrefix(data, "event: completion") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
// some implementations may add \r at the end of data
|
|
||||||
data = strings.TrimSuffix(data, "\r")
|
|
||||||
var claudeResponse ClaudeResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
responseText += claudeResponse.Completion
|
|
||||||
response := streamResponseClaude2OpenAI(&claudeResponse)
|
|
||||||
response.Id = responseId
|
|
||||||
response.Created = createdTime
|
|
||||||
jsonStr, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
|
||||||
}
|
|
||||||
return nil, responseText
|
|
||||||
}
|
|
||||||
|
|
||||||
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var claudeResponse ClaudeResponse
|
|
||||||
err = json.Unmarshal(responseBody, &claudeResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if claudeResponse.Error.Type != "" {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: claudeResponse.Error.Message,
|
|
||||||
Type: claudeResponse.Error.Type,
|
|
||||||
Param: "",
|
|
||||||
Code: claudeResponse.Error.Type,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
|
|
||||||
fullTextResponse.Model = model
|
|
||||||
completionTokens := countTokenText(claudeResponse.Completion, model)
|
|
||||||
usage := Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
CompletionTokens: completionTokens,
|
|
||||||
TotalTokens: promptTokens + completionTokens,
|
|
||||||
}
|
|
||||||
fullTextResponse.Usage = usage
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
@ -1,337 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/common/image"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
|
|
||||||
|
|
||||||
const (
|
|
||||||
GeminiVisionMaxImageNum = 16
|
|
||||||
)
|
|
||||||
|
|
||||||
type GeminiChatRequest struct {
|
|
||||||
Contents []GeminiChatContent `json:"contents"`
|
|
||||||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
|
||||||
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
|
|
||||||
Tools []GeminiChatTools `json:"tools,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiInlineData struct {
|
|
||||||
MimeType string `json:"mimeType"`
|
|
||||||
Data string `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiPart struct {
|
|
||||||
Text string `json:"text,omitempty"`
|
|
||||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatContent struct {
|
|
||||||
Role string `json:"role,omitempty"`
|
|
||||||
Parts []GeminiPart `json:"parts"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatSafetySettings struct {
|
|
||||||
Category string `json:"category"`
|
|
||||||
Threshold string `json:"threshold"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatTools struct {
|
|
||||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatGenerationConfig struct {
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"topP,omitempty"`
|
|
||||||
TopK float64 `json:"topK,omitempty"`
|
|
||||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
|
||||||
CandidateCount int `json:"candidateCount,omitempty"`
|
|
||||||
StopSequences []string `json:"stopSequences,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
|
||||||
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
|
|
||||||
geminiRequest := GeminiChatRequest{
|
|
||||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
|
||||||
SafetySettings: []GeminiChatSafetySettings{
|
|
||||||
{
|
|
||||||
Category: "HARM_CATEGORY_HARASSMENT",
|
|
||||||
Threshold: common.GeminiSafetySetting,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Category: "HARM_CATEGORY_HATE_SPEECH",
|
|
||||||
Threshold: common.GeminiSafetySetting,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
|
||||||
Threshold: common.GeminiSafetySetting,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
|
||||||
Threshold: common.GeminiSafetySetting,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
GenerationConfig: GeminiChatGenerationConfig{
|
|
||||||
Temperature: textRequest.Temperature,
|
|
||||||
TopP: textRequest.TopP,
|
|
||||||
MaxOutputTokens: textRequest.MaxTokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if textRequest.Functions != nil {
|
|
||||||
geminiRequest.Tools = []GeminiChatTools{
|
|
||||||
{
|
|
||||||
FunctionDeclarations: textRequest.Functions,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
shouldAddDummyModelMessage := false
|
|
||||||
for _, message := range textRequest.Messages {
|
|
||||||
content := GeminiChatContent{
|
|
||||||
Role: message.Role,
|
|
||||||
Parts: []GeminiPart{
|
|
||||||
{
|
|
||||||
Text: message.StringContent(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
openaiContent := message.ParseContent()
|
|
||||||
var parts []GeminiPart
|
|
||||||
imageNum := 0
|
|
||||||
for _, part := range openaiContent {
|
|
||||||
if part.Type == ContentTypeText {
|
|
||||||
parts = append(parts, GeminiPart{
|
|
||||||
Text: part.Text,
|
|
||||||
})
|
|
||||||
} else if part.Type == ContentTypeImageURL {
|
|
||||||
imageNum += 1
|
|
||||||
if imageNum > GeminiVisionMaxImageNum {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
|
|
||||||
parts = append(parts, GeminiPart{
|
|
||||||
InlineData: &GeminiInlineData{
|
|
||||||
MimeType: mimeType,
|
|
||||||
Data: data,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
content.Parts = parts
|
|
||||||
|
|
||||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
|
||||||
if content.Role == "assistant" {
|
|
||||||
content.Role = "model"
|
|
||||||
}
|
|
||||||
// Converting system prompt to prompt from user for the same reason
|
|
||||||
if content.Role == "system" {
|
|
||||||
content.Role = "user"
|
|
||||||
shouldAddDummyModelMessage = true
|
|
||||||
}
|
|
||||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
|
||||||
|
|
||||||
// If a system message is the last message, we need to add a dummy model message to make gemini happy
|
|
||||||
if shouldAddDummyModelMessage {
|
|
||||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
|
||||||
Role: "model",
|
|
||||||
Parts: []GeminiPart{
|
|
||||||
{
|
|
||||||
Text: "Okay",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
shouldAddDummyModelMessage = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &geminiRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatResponse struct {
|
|
||||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
|
||||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *GeminiChatResponse) GetResponseText() string {
|
|
||||||
if g == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
|
|
||||||
return g.Candidates[0].Content.Parts[0].Text
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatCandidate struct {
|
|
||||||
Content GeminiChatContent `json:"content"`
|
|
||||||
FinishReason string `json:"finishReason"`
|
|
||||||
Index int64 `json:"index"`
|
|
||||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatSafetyRating struct {
|
|
||||||
Category string `json:"category"`
|
|
||||||
Probability string `json:"probability"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatPromptFeedback struct {
|
|
||||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
|
|
||||||
}
|
|
||||||
for i, candidate := range response.Candidates {
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: i,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "",
|
|
||||||
},
|
|
||||||
FinishReason: stopFinishReason,
|
|
||||||
}
|
|
||||||
if len(candidate.Content.Parts) > 0 {
|
|
||||||
choice.Message.Content = candidate.Content.Parts[0].Text
|
|
||||||
}
|
|
||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
|
||||||
choice.FinishReason = &stopFinishReason
|
|
||||||
var response ChatCompletionsStreamResponse
|
|
||||||
response.Object = "chat.completion.chunk"
|
|
||||||
response.Model = "gemini"
|
|
||||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
|
||||||
responseText := ""
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
data = strings.TrimSpace(data)
|
|
||||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = strings.TrimPrefix(data, "\"text\": \"")
|
|
||||||
data = strings.TrimSuffix(data, "\"")
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
// this is used to prevent annoying \ related format bug
|
|
||||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
|
||||||
type dummyStruct struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
var dummy dummyStruct
|
|
||||||
err := json.Unmarshal([]byte(data), &dummy)
|
|
||||||
responseText += dummy.Content
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = dummy.Content
|
|
||||||
response := ChatCompletionsStreamResponse{
|
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "gemini-pro",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
|
||||||
}
|
|
||||||
return nil, responseText
|
|
||||||
}
|
|
||||||
|
|
||||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var geminiResponse GeminiChatResponse
|
|
||||||
err = json.Unmarshal(responseBody, &geminiResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if len(geminiResponse.Candidates) == 0 {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: "No candidates returned",
|
|
||||||
Type: "server_error",
|
|
||||||
Param: "",
|
|
||||||
Code: 500,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
|
||||||
fullTextResponse.Model = model
|
|
||||||
completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
|
|
||||||
usage := Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
CompletionTokens: completionTokens,
|
|
||||||
TotalTokens: promptTokens + completionTokens,
|
|
||||||
}
|
|
||||||
fullTextResponse.Usage = usage
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
@ -1,222 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func isWithinRange(element string, value int) bool {
|
|
||||||
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
min := common.DalleGenerationImageAmounts[element][0]
|
|
||||||
max := common.DalleGenerationImageAmounts[element][1]
|
|
||||||
|
|
||||||
return value >= min && value <= max
|
|
||||||
}
|
|
||||||
|
|
||||||
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
||||||
imageModel := "dall-e-2"
|
|
||||||
imageSize := "1024x1024"
|
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
channelType := c.GetInt("channel")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
group := c.GetString("group")
|
|
||||||
|
|
||||||
var imageRequest ImageRequest
|
|
||||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
if imageRequest.N == 0 {
|
|
||||||
imageRequest.N = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size validation
|
|
||||||
if imageRequest.Size != "" {
|
|
||||||
imageSize = imageRequest.Size
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model validation
|
|
||||||
if imageRequest.Model != "" {
|
|
||||||
imageModel = imageRequest.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize]
|
|
||||||
|
|
||||||
// Check if model is supported
|
|
||||||
if hasValidSize {
|
|
||||||
if imageRequest.Quality == "hd" && imageModel == "dall-e-3" {
|
|
||||||
if imageSize == "1024x1024" {
|
|
||||||
imageCostRatio *= 2
|
|
||||||
} else {
|
|
||||||
imageCostRatio *= 1.5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prompt validation
|
|
||||||
if imageRequest.Prompt == "" {
|
|
||||||
return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check prompt length
|
|
||||||
if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
|
|
||||||
return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Number of generated images validation
|
|
||||||
if isWithinRange(imageModel, imageRequest.N) == false {
|
|
||||||
// channel not azure
|
|
||||||
if channelType != common.ChannelTypeAzure {
|
|
||||||
return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// map model name
|
|
||||||
modelMapping := c.GetString("model_mapping")
|
|
||||||
isModelMapped := false
|
|
||||||
if modelMapping != "" {
|
|
||||||
modelMap := make(map[string]string)
|
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap[imageModel] != "" {
|
|
||||||
imageModel = modelMap[imageModel]
|
|
||||||
isModelMapped = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
|
||||||
requestURL := c.Request.URL.String()
|
|
||||||
if c.GetString("base_url") != "" {
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
|
||||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
|
||||||
if channelType == common.ChannelTypeAzure {
|
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
|
||||||
apiVersion := GetAPIVersion(c)
|
|
||||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
var requestBody io.Reader
|
|
||||||
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
|
|
||||||
jsonStr, err := json.Marshal(imageRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
} else {
|
|
||||||
requestBody = c.Request.Body
|
|
||||||
}
|
|
||||||
|
|
||||||
modelRatio := common.GetModelRatio(imageModel)
|
|
||||||
groupRatio := common.GetGroupRatio(group)
|
|
||||||
ratio := modelRatio * groupRatio
|
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
|
||||||
|
|
||||||
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
|
|
||||||
|
|
||||||
if userQuota-quota < 0 {
|
|
||||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
token := c.Request.Header.Get("Authorization")
|
|
||||||
if channelType == common.ChannelTypeAzure { // Azure authentication
|
|
||||||
token = strings.TrimPrefix(token, "Bearer ")
|
|
||||||
req.Header.Set("api-key", token)
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Authorization", token)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = req.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = c.Request.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
var textResponse ImageResponse
|
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
|
||||||
}
|
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error update user quota cache: " + err.Error())
|
|
||||||
}
|
|
||||||
if quota != 0 {
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
|
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
|
||||||
}
|
|
||||||
}(c.Request.Context())
|
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,143 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
|
|
||||||
responseText := ""
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 6 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dataChan <- data
|
|
||||||
data = data[6:]
|
|
||||||
if !strings.HasPrefix(data, "[DONE]") {
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeChatCompletions:
|
|
||||||
var streamResponse ChatCompletionsStreamResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
continue // just ignore the error
|
|
||||||
}
|
|
||||||
for _, choice := range streamResponse.Choices {
|
|
||||||
responseText += choice.Delta.Content
|
|
||||||
}
|
|
||||||
case RelayModeCompletions:
|
|
||||||
var streamResponse CompletionsStreamResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, choice := range streamResponse.Choices {
|
|
||||||
responseText += choice.Text
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
if strings.HasPrefix(data, "data: [DONE]") {
|
|
||||||
data = data[:12]
|
|
||||||
}
|
|
||||||
// some implementations may add \r at the end of data
|
|
||||||
data = strings.TrimSuffix(data, "\r")
|
|
||||||
c.Render(-1, common.CustomEvent{Data: data})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
|
||||||
}
|
|
||||||
return nil, responseText
|
|
||||||
}
|
|
||||||
|
|
||||||
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var textResponse TextResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if textResponse.Error.Type != "" {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: textResponse.Error,
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
// Reset response body
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
|
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
||||||
// So the httpClient will be confused by the response.
|
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if textResponse.Usage.TotalTokens == 0 {
|
|
||||||
completionTokens := 0
|
|
||||||
for _, choice := range textResponse.Choices {
|
|
||||||
completionTokens += countTokenText(choice.Message.StringContent(), model)
|
|
||||||
}
|
|
||||||
textResponse.Usage = Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
CompletionTokens: completionTokens,
|
|
||||||
TotalTokens: promptTokens + completionTokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, &textResponse.Usage
|
|
||||||
}
|
|
@ -1,206 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
|
||||||
|
|
||||||
type PaLMChatMessage struct {
|
|
||||||
Author string `json:"author"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMFilter struct {
|
|
||||||
Reason string `json:"reason"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMPrompt struct {
|
|
||||||
Messages []PaLMChatMessage `json:"messages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMChatRequest struct {
|
|
||||||
Prompt PaLMPrompt `json:"prompt"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
CandidateCount int `json:"candidateCount,omitempty"`
|
|
||||||
TopP float64 `json:"topP,omitempty"`
|
|
||||||
TopK int `json:"topK,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMError struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMChatResponse struct {
|
|
||||||
Candidates []PaLMChatMessage `json:"candidates"`
|
|
||||||
Messages []Message `json:"messages"`
|
|
||||||
Filters []PaLMFilter `json:"filters"`
|
|
||||||
Error PaLMError `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
|
||||||
palmRequest := PaLMChatRequest{
|
|
||||||
Prompt: PaLMPrompt{
|
|
||||||
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
|
||||||
},
|
|
||||||
Temperature: textRequest.Temperature,
|
|
||||||
CandidateCount: textRequest.N,
|
|
||||||
TopP: textRequest.TopP,
|
|
||||||
TopK: textRequest.MaxTokens,
|
|
||||||
}
|
|
||||||
for _, message := range textRequest.Messages {
|
|
||||||
palmMessage := PaLMChatMessage{
|
|
||||||
Content: message.StringContent(),
|
|
||||||
}
|
|
||||||
if message.Role == "user" {
|
|
||||||
palmMessage.Author = "0"
|
|
||||||
} else {
|
|
||||||
palmMessage.Author = "1"
|
|
||||||
}
|
|
||||||
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
|
|
||||||
}
|
|
||||||
return &palmRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
|
|
||||||
}
|
|
||||||
for i, candidate := range response.Candidates {
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: i,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: candidate.Content,
|
|
||||||
},
|
|
||||||
FinishReason: "stop",
|
|
||||||
}
|
|
||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
if len(palmResponse.Candidates) > 0 {
|
|
||||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
|
||||||
}
|
|
||||||
choice.FinishReason = &stopFinishReason
|
|
||||||
var response ChatCompletionsStreamResponse
|
|
||||||
response.Object = "chat.completion.chunk"
|
|
||||||
response.Model = "palm2"
|
|
||||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
|
||||||
responseText := ""
|
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
|
||||||
createdTime := common.GetTimestamp()
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error reading stream response: " + err.Error())
|
|
||||||
stopChan <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error closing stream response: " + err.Error())
|
|
||||||
stopChan <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var palmResponse PaLMChatResponse
|
|
||||||
err = json.Unmarshal(responseBody, &palmResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
stopChan <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
|
|
||||||
fullTextResponse.Id = responseId
|
|
||||||
fullTextResponse.Created = createdTime
|
|
||||||
if len(palmResponse.Candidates) > 0 {
|
|
||||||
responseText = palmResponse.Candidates[0].Content
|
|
||||||
}
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
stopChan <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
dataChan <- string(jsonResponse)
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + data})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
|
||||||
}
|
|
||||||
return nil, responseText
|
|
||||||
}
|
|
||||||
|
|
||||||
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var palmResponse PaLMChatResponse
|
|
||||||
err = json.Unmarshal(responseBody, &palmResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: palmResponse.Error.Message,
|
|
||||||
Type: palmResponse.Error.Status,
|
|
||||||
Param: "",
|
|
||||||
Code: palmResponse.Error.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
|
||||||
fullTextResponse.Model = model
|
|
||||||
completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
|
|
||||||
usage := Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
CompletionTokens: completionTokens,
|
|
||||||
TotalTokens: promptTokens + completionTokens,
|
|
||||||
}
|
|
||||||
fullTextResponse.Usage = usage
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
@ -1,288 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha1"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://cloud.tencent.com/document/product/1729/97732
|
|
||||||
|
|
||||||
type TencentMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentChatRequest struct {
|
|
||||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
|
||||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
|
||||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
|
||||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
|
||||||
Timestamp int64 `json:"timestamp"`
|
|
||||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
|
||||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
|
||||||
Expired int64 `json:"expired"`
|
|
||||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
|
||||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
|
||||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
|
||||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
|
||||||
Temperature float64 `json:"temperature"`
|
|
||||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
|
||||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
|
||||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
|
||||||
TopP float64 `json:"top_p"`
|
|
||||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
|
||||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
|
||||||
Stream int `json:"stream"`
|
|
||||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
|
||||||
// 输入 content 总数最大支持 3000 token。
|
|
||||||
Messages []TencentMessage `json:"messages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentError struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentUsage struct {
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentResponseChoices struct {
|
|
||||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
|
||||||
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
|
||||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentChatResponse struct {
|
|
||||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
|
||||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
|
||||||
Id string `json:"id,omitempty"` // 会话 id
|
|
||||||
Usage Usage `json:"usage,omitempty"` // token 数量
|
|
||||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
|
||||||
Note string `json:"note,omitempty"` // 注释
|
|
||||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
|
||||||
messages := make([]TencentMessage, 0, len(request.Messages))
|
|
||||||
for i := 0; i < len(request.Messages); i++ {
|
|
||||||
message := request.Messages[i]
|
|
||||||
if message.Role == "system" {
|
|
||||||
messages = append(messages, TencentMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
messages = append(messages, TencentMessage{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "Okay",
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
messages = append(messages, TencentMessage{
|
|
||||||
Content: message.StringContent(),
|
|
||||||
Role: message.Role,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
stream := 0
|
|
||||||
if request.Stream {
|
|
||||||
stream = 1
|
|
||||||
}
|
|
||||||
return &TencentChatRequest{
|
|
||||||
Timestamp: common.GetTimestamp(),
|
|
||||||
Expired: common.GetTimestamp() + 24*60*60,
|
|
||||||
QueryID: common.GetUUID(),
|
|
||||||
Temperature: request.Temperature,
|
|
||||||
TopP: request.TopP,
|
|
||||||
Stream: stream,
|
|
||||||
Messages: messages,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Usage: response.Usage,
|
|
||||||
}
|
|
||||||
if len(response.Choices) > 0 {
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: response.Choices[0].Messages.Content,
|
|
||||||
},
|
|
||||||
FinishReason: response.Choices[0].FinishReason,
|
|
||||||
}
|
|
||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
|
|
||||||
response := ChatCompletionsStreamResponse{
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "tencent-hunyuan",
|
|
||||||
}
|
|
||||||
if len(TencentResponse.Choices) > 0 {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
|
||||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
|
||||||
choice.FinishReason = &stopFinishReason
|
|
||||||
}
|
|
||||||
response.Choices = append(response.Choices, choice)
|
|
||||||
}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
|
||||||
var responseText string
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 5 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if data[:5] != "data:" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = data[5:]
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
var TencentResponse TencentChatResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &TencentResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
|
||||||
if len(response.Choices) != 0 {
|
|
||||||
responseText += response.Choices[0].Delta.Content
|
|
||||||
}
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
|
||||||
}
|
|
||||||
return nil, responseText
|
|
||||||
}
|
|
||||||
|
|
||||||
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var TencentResponse TencentChatResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &TencentResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if TencentResponse.Error.Code != 0 {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: TencentResponse.Error.Message,
|
|
||||||
Code: TencentResponse.Error.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
|
||||||
fullTextResponse.Model = "hunyuan"
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
|
||||||
parts := strings.Split(config, "|")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
err = errors.New("invalid tencent config")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
appId, err = strconv.ParseInt(parts[0], 10, 64)
|
|
||||||
secretId = parts[1]
|
|
||||||
secretKey = parts[2]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTencentSign(req TencentChatRequest, secretKey string) string {
|
|
||||||
params := make([]string, 0)
|
|
||||||
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
|
||||||
params = append(params, "secret_id="+req.SecretId)
|
|
||||||
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
|
||||||
params = append(params, "query_id="+req.QueryID)
|
|
||||||
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
|
||||||
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
|
||||||
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
|
||||||
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
|
||||||
|
|
||||||
var messageStr string
|
|
||||||
for _, msg := range req.Messages {
|
|
||||||
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
|
||||||
}
|
|
||||||
messageStr = strings.TrimSuffix(messageStr, ",")
|
|
||||||
params = append(params, "messages=["+messageStr+"]")
|
|
||||||
|
|
||||||
sort.Sort(sort.StringSlice(params))
|
|
||||||
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
|
||||||
mac := hmac.New(sha1.New, []byte(secretKey))
|
|
||||||
signURL := url
|
|
||||||
mac.Write([]byte(signURL))
|
|
||||||
sign := mac.Sum([]byte(nil))
|
|
||||||
return base64.StdEncoding.EncodeToString(sign)
|
|
||||||
}
|
|
@ -1,689 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
APITypeOpenAI = iota
|
|
||||||
APITypeClaude
|
|
||||||
APITypePaLM
|
|
||||||
APITypeBaidu
|
|
||||||
APITypeZhipu
|
|
||||||
APITypeAli
|
|
||||||
APITypeXunfei
|
|
||||||
APITypeAIProxyLibrary
|
|
||||||
APITypeTencent
|
|
||||||
APITypeGemini
|
|
||||||
)
|
|
||||||
|
|
||||||
var httpClient *http.Client
|
|
||||||
var impatientHTTPClient *http.Client
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
if common.RelayTimeout == 0 {
|
|
||||||
httpClient = &http.Client{}
|
|
||||||
} else {
|
|
||||||
httpClient = &http.Client{
|
|
||||||
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impatientHTTPClient = &http.Client{
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
||||||
channelType := c.GetInt("channel")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
group := c.GetString("group")
|
|
||||||
var textRequest GeneralOpenAIRequest
|
|
||||||
err := common.UnmarshalBodyReusable(c, &textRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
|
|
||||||
return errorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
if relayMode == RelayModeModerations && textRequest.Model == "" {
|
|
||||||
textRequest.Model = "text-moderation-latest"
|
|
||||||
}
|
|
||||||
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
|
|
||||||
textRequest.Model = c.Param("model")
|
|
||||||
}
|
|
||||||
// request validation
|
|
||||||
if textRequest.Model == "" {
|
|
||||||
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeCompletions:
|
|
||||||
if textRequest.Prompt == "" {
|
|
||||||
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
case RelayModeChatCompletions:
|
|
||||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
|
||||||
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
case RelayModeModerations:
|
|
||||||
if textRequest.Input == "" {
|
|
||||||
return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
case RelayModeEdits:
|
|
||||||
if textRequest.Instruction == "" {
|
|
||||||
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// map model name
|
|
||||||
modelMapping := c.GetString("model_mapping")
|
|
||||||
isModelMapped := false
|
|
||||||
if modelMapping != "" && modelMapping != "{}" {
|
|
||||||
modelMap := make(map[string]string)
|
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap[textRequest.Model] != "" {
|
|
||||||
textRequest.Model = modelMap[textRequest.Model]
|
|
||||||
isModelMapped = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
apiType := APITypeOpenAI
|
|
||||||
switch channelType {
|
|
||||||
case common.ChannelTypeAnthropic:
|
|
||||||
apiType = APITypeClaude
|
|
||||||
case common.ChannelTypeBaidu:
|
|
||||||
apiType = APITypeBaidu
|
|
||||||
case common.ChannelTypePaLM:
|
|
||||||
apiType = APITypePaLM
|
|
||||||
case common.ChannelTypeZhipu:
|
|
||||||
apiType = APITypeZhipu
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
apiType = APITypeAli
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
apiType = APITypeXunfei
|
|
||||||
case common.ChannelTypeAIProxyLibrary:
|
|
||||||
apiType = APITypeAIProxyLibrary
|
|
||||||
case common.ChannelTypeTencent:
|
|
||||||
apiType = APITypeTencent
|
|
||||||
case common.ChannelTypeGemini:
|
|
||||||
apiType = APITypeGemini
|
|
||||||
}
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
|
||||||
requestURL := c.Request.URL.String()
|
|
||||||
if c.GetString("base_url") != "" {
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
|
||||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
|
||||||
switch apiType {
|
|
||||||
case APITypeOpenAI:
|
|
||||||
if channelType == common.ChannelTypeAzure {
|
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
|
||||||
apiVersion := GetAPIVersion(c)
|
|
||||||
requestURL := strings.Split(requestURL, "?")[0]
|
|
||||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
|
||||||
model_ := textRequest.Model
|
|
||||||
model_ = strings.Replace(model_, ".", "", -1)
|
|
||||||
// https://github.com/songquanpeng/one-api/issues/67
|
|
||||||
model_ = strings.TrimSuffix(model_, "-0301")
|
|
||||||
model_ = strings.TrimSuffix(model_, "-0314")
|
|
||||||
model_ = strings.TrimSuffix(model_, "-0613")
|
|
||||||
|
|
||||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
|
||||||
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
|
|
||||||
}
|
|
||||||
case APITypeClaude:
|
|
||||||
fullRequestURL = "https://api.anthropic.com/v1/complete"
|
|
||||||
if baseURL != "" {
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
|
|
||||||
}
|
|
||||||
case APITypeBaidu:
|
|
||||||
switch textRequest.Model {
|
|
||||||
case "ERNIE-Bot":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
|
||||||
case "ERNIE-Bot-turbo":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
|
||||||
case "ERNIE-Bot-4":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
|
||||||
case "BLOOMZ-7B":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
|
||||||
case "Embedding-V1":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
|
||||||
}
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
var err error
|
|
||||||
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
|
|
||||||
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
fullRequestURL += "?access_token=" + apiKey
|
|
||||||
case APITypePaLM:
|
|
||||||
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
|
|
||||||
if baseURL != "" {
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
|
|
||||||
}
|
|
||||||
case APITypeGemini:
|
|
||||||
requestBaseURL := "https://generativelanguage.googleapis.com"
|
|
||||||
if baseURL != "" {
|
|
||||||
requestBaseURL = baseURL
|
|
||||||
}
|
|
||||||
version := "v1"
|
|
||||||
if c.GetString("api_version") != "" {
|
|
||||||
version = c.GetString("api_version")
|
|
||||||
}
|
|
||||||
action := "generateContent"
|
|
||||||
if textRequest.Stream {
|
|
||||||
action = "streamGenerateContent"
|
|
||||||
}
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
|
|
||||||
case APITypeZhipu:
|
|
||||||
method := "invoke"
|
|
||||||
if textRequest.Stream {
|
|
||||||
method = "sse-invoke"
|
|
||||||
}
|
|
||||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
|
||||||
case APITypeAli:
|
|
||||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
|
||||||
if relayMode == RelayModeEmbeddings {
|
|
||||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
|
||||||
}
|
|
||||||
case APITypeTencent:
|
|
||||||
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
|
|
||||||
case APITypeAIProxyLibrary:
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
|
|
||||||
}
|
|
||||||
var promptTokens int
|
|
||||||
var completionTokens int
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeChatCompletions:
|
|
||||||
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
|
|
||||||
case RelayModeCompletions:
|
|
||||||
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
|
|
||||||
case RelayModeModerations:
|
|
||||||
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
|
|
||||||
}
|
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
|
||||||
if textRequest.MaxTokens != 0 {
|
|
||||||
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
|
||||||
}
|
|
||||||
modelRatio := common.GetModelRatio(textRequest.Model)
|
|
||||||
groupRatio := common.GetGroupRatio(group)
|
|
||||||
ratio := modelRatio * groupRatio
|
|
||||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if userQuota-preConsumedQuota < 0 {
|
|
||||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if userQuota > 100*preConsumedQuota {
|
|
||||||
// in this case, we do not pre-consume quota
|
|
||||||
// because the user has enough quota
|
|
||||||
preConsumedQuota = 0
|
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
|
||||||
}
|
|
||||||
if preConsumedQuota > 0 {
|
|
||||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var requestBody io.Reader
|
|
||||||
if isModelMapped {
|
|
||||||
jsonStr, err := json.Marshal(textRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
} else {
|
|
||||||
requestBody = c.Request.Body
|
|
||||||
}
|
|
||||||
switch apiType {
|
|
||||||
case APITypeClaude:
|
|
||||||
claudeRequest := requestOpenAI2Claude(textRequest)
|
|
||||||
jsonStr, err := json.Marshal(claudeRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeBaidu:
|
|
||||||
var jsonData []byte
|
|
||||||
var err error
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
|
|
||||||
jsonData, err = json.Marshal(baiduEmbeddingRequest)
|
|
||||||
default:
|
|
||||||
baiduRequest := requestOpenAI2Baidu(textRequest)
|
|
||||||
jsonData, err = json.Marshal(baiduRequest)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
|
||||||
case APITypePaLM:
|
|
||||||
palmRequest := requestOpenAI2PaLM(textRequest)
|
|
||||||
jsonStr, err := json.Marshal(palmRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeGemini:
|
|
||||||
geminiChatRequest := requestOpenAI2Gemini(textRequest)
|
|
||||||
jsonStr, err := json.Marshal(geminiChatRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeZhipu:
|
|
||||||
zhipuRequest := requestOpenAI2Zhipu(textRequest)
|
|
||||||
jsonStr, err := json.Marshal(zhipuRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeAli:
|
|
||||||
var jsonStr []byte
|
|
||||||
var err error
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
|
|
||||||
jsonStr, err = json.Marshal(aliEmbeddingRequest)
|
|
||||||
default:
|
|
||||||
aliRequest := requestOpenAI2Ali(textRequest)
|
|
||||||
jsonStr, err = json.Marshal(aliRequest)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeTencent:
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
tencentRequest := requestOpenAI2Tencent(textRequest)
|
|
||||||
tencentRequest.AppId = appId
|
|
||||||
tencentRequest.SecretId = secretId
|
|
||||||
jsonStr, err := json.Marshal(tencentRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
sign := getTencentSign(*tencentRequest, secretKey)
|
|
||||||
c.Request.Header.Set("Authorization", sign)
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeAIProxyLibrary:
|
|
||||||
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
|
|
||||||
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
|
|
||||||
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
var req *http.Request
|
|
||||||
var resp *http.Response
|
|
||||||
isStream := textRequest.Stream
|
|
||||||
|
|
||||||
if apiType != APITypeXunfei { // cause xunfei use websocket
|
|
||||||
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
switch apiType {
|
|
||||||
case APITypeOpenAI:
|
|
||||||
if channelType == common.ChannelTypeAzure {
|
|
||||||
req.Header.Set("api-key", apiKey)
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
|
||||||
if channelType == common.ChannelTypeOpenRouter {
|
|
||||||
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
|
||||||
req.Header.Set("X-Title", "One API")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case APITypeClaude:
|
|
||||||
req.Header.Set("x-api-key", apiKey)
|
|
||||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
|
||||||
if anthropicVersion == "" {
|
|
||||||
anthropicVersion = "2023-06-01"
|
|
||||||
}
|
|
||||||
req.Header.Set("anthropic-version", anthropicVersion)
|
|
||||||
case APITypeZhipu:
|
|
||||||
token := getZhipuToken(apiKey)
|
|
||||||
req.Header.Set("Authorization", token)
|
|
||||||
case APITypeAli:
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
if textRequest.Stream {
|
|
||||||
req.Header.Set("X-DashScope-SSE", "enable")
|
|
||||||
}
|
|
||||||
if c.GetString("plugin") != "" {
|
|
||||||
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
|
||||||
}
|
|
||||||
case APITypeTencent:
|
|
||||||
req.Header.Set("Authorization", apiKey)
|
|
||||||
case APITypePaLM:
|
|
||||||
req.Header.Set("x-goog-api-key", apiKey)
|
|
||||||
case APITypeGemini:
|
|
||||||
req.Header.Set("x-goog-api-key", apiKey)
|
|
||||||
default:
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
||||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
|
||||||
req.Header.Set("Accept", "text/event-stream")
|
|
||||||
}
|
|
||||||
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
|
|
||||||
resp, err = httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = req.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = c.Request.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if preConsumedQuota != 0 {
|
|
||||||
go func(ctx context.Context) {
|
|
||||||
// return pre-consumed quota
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
|
||||||
}
|
|
||||||
}(c.Request.Context())
|
|
||||||
}
|
|
||||||
return relayErrorHandler(resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var textResponse TextResponse
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
// c.Writer.Flush()
|
|
||||||
go func() {
|
|
||||||
quota := 0
|
|
||||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
|
||||||
promptTokens = textResponse.Usage.PromptTokens
|
|
||||||
completionTokens = textResponse.Usage.CompletionTokens
|
|
||||||
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
|
||||||
if ratio != 0 && quota <= 0 {
|
|
||||||
quota = 1
|
|
||||||
}
|
|
||||||
totalTokens := promptTokens + completionTokens
|
|
||||||
if totalTokens == 0 {
|
|
||||||
// in this case, must be some error happened
|
|
||||||
// we cannot just return, because we may have to return the pre-consumed quota
|
|
||||||
quota = 0
|
|
||||||
}
|
|
||||||
quotaDelta := quota - preConsumedQuota
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
|
||||||
}
|
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
|
||||||
}
|
|
||||||
if quota != 0 {
|
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
|
||||||
}
|
|
||||||
|
|
||||||
}()
|
|
||||||
}(c.Request.Context())
|
|
||||||
switch apiType {
|
|
||||||
case APITypeOpenAI:
|
|
||||||
if isStream {
|
|
||||||
err, responseText := openaiStreamHandler(c, resp, relayMode)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeClaude:
|
|
||||||
if isStream {
|
|
||||||
err, responseText := claudeStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeBaidu:
|
|
||||||
if isStream {
|
|
||||||
err, usage := baiduStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
var err *OpenAIErrorWithStatusCode
|
|
||||||
var usage *Usage
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
err, usage = baiduEmbeddingHandler(c, resp)
|
|
||||||
default:
|
|
||||||
err, usage = baiduHandler(c, resp)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypePaLM:
|
|
||||||
if textRequest.Stream { // PaLM2 API does not support stream
|
|
||||||
err, responseText := palmStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeGemini:
|
|
||||||
if textRequest.Stream {
|
|
||||||
err, responseText := geminiChatStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeZhipu:
|
|
||||||
if isStream {
|
|
||||||
err, usage := zhipuStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
// zhipu's API does not return prompt tokens & completion tokens
|
|
||||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := zhipuHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
// zhipu's API does not return prompt tokens & completion tokens
|
|
||||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeAli:
|
|
||||||
if isStream {
|
|
||||||
err, usage := aliStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
var err *OpenAIErrorWithStatusCode
|
|
||||||
var usage *Usage
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
err, usage = aliEmbeddingHandler(c, resp)
|
|
||||||
default:
|
|
||||||
err, usage = aliHandler(c, resp)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeXunfei:
|
|
||||||
auth := c.Request.Header.Get("Authorization")
|
|
||||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
|
||||||
splits := strings.Split(auth, "|")
|
|
||||||
if len(splits) != 3 {
|
|
||||||
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
var err *OpenAIErrorWithStatusCode
|
|
||||||
var usage *Usage
|
|
||||||
if isStream {
|
|
||||||
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
|
||||||
} else {
|
|
||||||
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
case APITypeAIProxyLibrary:
|
|
||||||
if isStream {
|
|
||||||
err, usage := aiProxyLibraryStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := aiProxyLibraryHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeTencent:
|
|
||||||
if isStream {
|
|
||||||
err, responseText := tencentStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := tencentHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,385 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/common/image"
|
|
||||||
"one-api/model"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/pkoukk/tiktoken-go"
|
|
||||||
)
|
|
||||||
|
|
||||||
var stopFinishReason = "stop"
|
|
||||||
|
|
||||||
// tokenEncoderMap won't grow after initialization
|
|
||||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
|
||||||
var defaultTokenEncoder *tiktoken.Tiktoken
|
|
||||||
|
|
||||||
func InitTokenEncoders() {
|
|
||||||
common.SysLog("initializing token encoders")
|
|
||||||
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
|
||||||
if err != nil {
|
|
||||||
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
|
||||||
}
|
|
||||||
defaultTokenEncoder = gpt35TokenEncoder
|
|
||||||
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
|
|
||||||
if err != nil {
|
|
||||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
|
||||||
}
|
|
||||||
for model, _ := range common.ModelRatio {
|
|
||||||
if strings.HasPrefix(model, "gpt-3.5") {
|
|
||||||
tokenEncoderMap[model] = gpt35TokenEncoder
|
|
||||||
} else if strings.HasPrefix(model, "gpt-4") {
|
|
||||||
tokenEncoderMap[model] = gpt4TokenEncoder
|
|
||||||
} else {
|
|
||||||
tokenEncoderMap[model] = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
common.SysLog("token encoders initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
|
||||||
tokenEncoder, ok := tokenEncoderMap[model]
|
|
||||||
if ok && tokenEncoder != nil {
|
|
||||||
return tokenEncoder
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
|
||||||
tokenEncoder = defaultTokenEncoder
|
|
||||||
}
|
|
||||||
tokenEncoderMap[model] = tokenEncoder
|
|
||||||
return tokenEncoder
|
|
||||||
}
|
|
||||||
return defaultTokenEncoder
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|
||||||
if common.ApproximateTokenEnabled {
|
|
||||||
return int(float64(len(text)) * 0.38)
|
|
||||||
}
|
|
||||||
return len(tokenEncoder.Encode(text, nil, nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
func countTokenMessages(messages []Message, model string) int {
|
|
||||||
tokenEncoder := getTokenEncoder(model)
|
|
||||||
// Reference:
|
|
||||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
||||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
|
||||||
//
|
|
||||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
|
||||||
var tokensPerMessage int
|
|
||||||
var tokensPerName int
|
|
||||||
if model == "gpt-3.5-turbo-0301" {
|
|
||||||
tokensPerMessage = 4
|
|
||||||
tokensPerName = -1 // If there's a name, the role is omitted
|
|
||||||
} else {
|
|
||||||
tokensPerMessage = 3
|
|
||||||
tokensPerName = 1
|
|
||||||
}
|
|
||||||
tokenNum := 0
|
|
||||||
for _, message := range messages {
|
|
||||||
tokenNum += tokensPerMessage
|
|
||||||
switch v := message.Content.(type) {
|
|
||||||
case string:
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, v)
|
|
||||||
case []any:
|
|
||||||
for _, it := range v {
|
|
||||||
m := it.(map[string]any)
|
|
||||||
switch m["type"] {
|
|
||||||
case "text":
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, m["text"].(string))
|
|
||||||
case "image_url":
|
|
||||||
imageUrl, ok := m["image_url"].(map[string]any)
|
|
||||||
if ok {
|
|
||||||
url := imageUrl["url"].(string)
|
|
||||||
detail := ""
|
|
||||||
if imageUrl["detail"] != nil {
|
|
||||||
detail = imageUrl["detail"].(string)
|
|
||||||
}
|
|
||||||
imageTokens, err := countImageTokens(url, detail)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error counting image tokens: " + err.Error())
|
|
||||||
} else {
|
|
||||||
tokenNum += imageTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
|
||||||
if message.Name != nil {
|
|
||||||
tokenNum += tokensPerName
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
|
||||||
return tokenNum
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
lowDetailCost = 85
|
|
||||||
highDetailCostPerTile = 170
|
|
||||||
additionalCost = 85
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://platform.openai.com/docs/guides/vision/calculating-costs
|
|
||||||
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
||||||
func countImageTokens(url string, detail string) (_ int, err error) {
|
|
||||||
var fetchSize = true
|
|
||||||
var width, height int
|
|
||||||
// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
|
|
||||||
// detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting.
|
|
||||||
// According to the official guide, "low" disable the high-res model,
|
|
||||||
// and only receive low-res 512px x 512px version of the image, indicating
|
|
||||||
// that image is treated as low-res when size is smaller than 512px x 512px,
|
|
||||||
// then we can assume that image size larger than 512px x 512px is treated
|
|
||||||
// as high-res. Then we have the following logic:
|
|
||||||
// if detail == "" || detail == "auto" {
|
|
||||||
// width, height, err = image.GetImageSize(url)
|
|
||||||
// if err != nil {
|
|
||||||
// return 0, err
|
|
||||||
// }
|
|
||||||
// fetchSize = false
|
|
||||||
// // not sure if this is correct
|
|
||||||
// if width > 512 || height > 512 {
|
|
||||||
// detail = "high"
|
|
||||||
// } else {
|
|
||||||
// detail = "low"
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// However, in my test, it seems to be always the same as "high".
|
|
||||||
// The following image, which is 125x50, is still treated as high-res, taken
|
|
||||||
// 255 tokens in the response of non-stream chat completion api.
|
|
||||||
// https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg
|
|
||||||
if detail == "" || detail == "auto" {
|
|
||||||
// assume by test, not sure if this is correct
|
|
||||||
detail = "high"
|
|
||||||
}
|
|
||||||
switch detail {
|
|
||||||
case "low":
|
|
||||||
return lowDetailCost, nil
|
|
||||||
case "high":
|
|
||||||
if fetchSize {
|
|
||||||
width, height, err = image.GetImageSize(url)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if width > 2048 || height > 2048 { // max(width, height) > 2048
|
|
||||||
ratio := float64(2048) / math.Max(float64(width), float64(height))
|
|
||||||
width = int(float64(width) * ratio)
|
|
||||||
height = int(float64(height) * ratio)
|
|
||||||
}
|
|
||||||
if width > 768 && height > 768 { // min(width, height) > 768
|
|
||||||
ratio := float64(768) / math.Min(float64(width), float64(height))
|
|
||||||
width = int(float64(width) * ratio)
|
|
||||||
height = int(float64(height) * ratio)
|
|
||||||
}
|
|
||||||
numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512))
|
|
||||||
result := numSquares*highDetailCostPerTile + additionalCost
|
|
||||||
return result, nil
|
|
||||||
default:
|
|
||||||
return 0, errors.New("invalid detail option")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func countTokenInput(input any, model string) int {
|
|
||||||
switch v := input.(type) {
|
|
||||||
case string:
|
|
||||||
return countTokenText(v, model)
|
|
||||||
case []string:
|
|
||||||
text := ""
|
|
||||||
for _, s := range v {
|
|
||||||
text += s
|
|
||||||
}
|
|
||||||
return countTokenText(text, model)
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func countTokenText(text string, model string) int {
|
|
||||||
tokenEncoder := getTokenEncoder(model)
|
|
||||||
return getTokenNum(tokenEncoder, text)
|
|
||||||
}
|
|
||||||
|
|
||||||
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
|
|
||||||
openAIError := OpenAIError{
|
|
||||||
Message: err.Error(),
|
|
||||||
Type: "one_api_error",
|
|
||||||
Code: code,
|
|
||||||
}
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: openAIError,
|
|
||||||
StatusCode: statusCode,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
|
|
||||||
if !common.AutomaticDisableChannelEnabled {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if statusCode == http.StatusUnauthorized {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldEnableChannel(err error, openAIErr *OpenAIError) bool {
|
|
||||||
if !common.AutomaticEnableChannelEnabled {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if openAIErr != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func setEventStreamHeaders(c *gin.Context) {
|
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
|
||||||
c.Writer.Header().Set("Connection", "keep-alive")
|
|
||||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
|
||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeneralErrorResponse struct {
|
|
||||||
Error OpenAIError `json:"error"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Msg string `json:"msg"`
|
|
||||||
Err string `json:"err"`
|
|
||||||
ErrorMsg string `json:"error_msg"`
|
|
||||||
Header struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
} `json:"header"`
|
|
||||||
Response struct {
|
|
||||||
Error struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
} `json:"error"`
|
|
||||||
} `json:"response"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e GeneralErrorResponse) ToMessage() string {
|
|
||||||
if e.Error.Message != "" {
|
|
||||||
return e.Error.Message
|
|
||||||
}
|
|
||||||
if e.Message != "" {
|
|
||||||
return e.Message
|
|
||||||
}
|
|
||||||
if e.Msg != "" {
|
|
||||||
return e.Msg
|
|
||||||
}
|
|
||||||
if e.Err != "" {
|
|
||||||
return e.Err
|
|
||||||
}
|
|
||||||
if e.ErrorMsg != "" {
|
|
||||||
return e.ErrorMsg
|
|
||||||
}
|
|
||||||
if e.Header.Message != "" {
|
|
||||||
return e.Header.Message
|
|
||||||
}
|
|
||||||
if e.Response.Error.Message != "" {
|
|
||||||
return e.Response.Error.Message
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
|
|
||||||
openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: "",
|
|
||||||
Type: "upstream_error",
|
|
||||||
Code: "bad_response_status_code",
|
|
||||||
Param: strconv.Itoa(resp.StatusCode),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var errResponse GeneralErrorResponse
|
|
||||||
err = json.Unmarshal(responseBody, &errResponse)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if errResponse.Error.Message != "" {
|
|
||||||
// OpenAI format error, so we override the default one
|
|
||||||
openAIErrorWithStatusCode.OpenAIError = errResponse.Error
|
|
||||||
} else {
|
|
||||||
openAIErrorWithStatusCode.OpenAIError.Message = errResponse.ToMessage()
|
|
||||||
}
|
|
||||||
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
|
|
||||||
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
||||||
|
|
||||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
|
||||||
switch channelType {
|
|
||||||
case common.ChannelTypeOpenAI:
|
|
||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fullRequestURL
|
|
||||||
}
|
|
||||||
|
|
||||||
func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
|
||||||
// quotaDelta is remaining quota to be consumed
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
|
||||||
}
|
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error update user quota cache: " + err.Error())
|
|
||||||
}
|
|
||||||
// totalQuota is total quota consumed
|
|
||||||
if totalQuota != 0 {
|
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
|
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
|
||||||
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
|
||||||
}
|
|
||||||
if totalQuota <= 0 {
|
|
||||||
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAPIVersion(c *gin.Context) string {
|
|
||||||
query := c.Request.URL.Query()
|
|
||||||
apiVersion := query.Get("api-version")
|
|
||||||
if apiVersion == "" {
|
|
||||||
apiVersion = c.GetString("api_version")
|
|
||||||
}
|
|
||||||
return apiVersion
|
|
||||||
}
|
|
@ -1,312 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"one-api/common"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://console.xfyun.cn/services/cbm
|
|
||||||
// https://www.xfyun.cn/doc/spark/Web.html
|
|
||||||
|
|
||||||
type XunfeiMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XunfeiChatRequest struct {
|
|
||||||
Header struct {
|
|
||||||
AppId string `json:"app_id"`
|
|
||||||
} `json:"header"`
|
|
||||||
Parameter struct {
|
|
||||||
Chat struct {
|
|
||||||
Domain string `json:"domain,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopK int `json:"top_k,omitempty"`
|
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
|
||||||
Auditing bool `json:"auditing,omitempty"`
|
|
||||||
} `json:"chat"`
|
|
||||||
} `json:"parameter"`
|
|
||||||
Payload struct {
|
|
||||||
Message struct {
|
|
||||||
Text []XunfeiMessage `json:"text"`
|
|
||||||
} `json:"message"`
|
|
||||||
} `json:"payload"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XunfeiChatResponseTextItem struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
Role string `json:"role"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XunfeiChatResponse struct {
|
|
||||||
Header struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Sid string `json:"sid"`
|
|
||||||
Status int `json:"status"`
|
|
||||||
} `json:"header"`
|
|
||||||
Payload struct {
|
|
||||||
Choices struct {
|
|
||||||
Status int `json:"status"`
|
|
||||||
Seq int `json:"seq"`
|
|
||||||
Text []XunfeiChatResponseTextItem `json:"text"`
|
|
||||||
} `json:"choices"`
|
|
||||||
Usage struct {
|
|
||||||
//Text struct {
|
|
||||||
// QuestionTokens string `json:"question_tokens"`
|
|
||||||
// PromptTokens string `json:"prompt_tokens"`
|
|
||||||
// CompletionTokens string `json:"completion_tokens"`
|
|
||||||
// TotalTokens string `json:"total_tokens"`
|
|
||||||
//} `json:"text"`
|
|
||||||
Text Usage `json:"text"`
|
|
||||||
} `json:"usage"`
|
|
||||||
} `json:"payload"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
|
|
||||||
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
|
||||||
for _, message := range request.Messages {
|
|
||||||
if message.Role == "system" {
|
|
||||||
messages = append(messages, XunfeiMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
messages = append(messages, XunfeiMessage{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "Okay",
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
messages = append(messages, XunfeiMessage{
|
|
||||||
Role: message.Role,
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
xunfeiRequest := XunfeiChatRequest{}
|
|
||||||
xunfeiRequest.Header.AppId = xunfeiAppId
|
|
||||||
xunfeiRequest.Parameter.Chat.Domain = domain
|
|
||||||
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
|
||||||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
|
||||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
|
||||||
xunfeiRequest.Payload.Message.Text = messages
|
|
||||||
return &xunfeiRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
|
|
||||||
if len(response.Payload.Choices.Text) == 0 {
|
|
||||||
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
|
||||||
{
|
|
||||||
Content: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: response.Payload.Choices.Text[0].Content,
|
|
||||||
},
|
|
||||||
FinishReason: stopFinishReason,
|
|
||||||
}
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
|
||||||
Usage: response.Payload.Usage.Text,
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
|
|
||||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
|
||||||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
|
||||||
{
|
|
||||||
Content: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
|
||||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
|
||||||
choice.FinishReason = &stopFinishReason
|
|
||||||
}
|
|
||||||
response := ChatCompletionsStreamResponse{
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "SparkDesk",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
|
||||||
HmacWithShaToBase64 := func(algorithm, data, key string) string {
|
|
||||||
mac := hmac.New(sha256.New, []byte(key))
|
|
||||||
mac.Write([]byte(data))
|
|
||||||
encodeData := mac.Sum(nil)
|
|
||||||
return base64.StdEncoding.EncodeToString(encodeData)
|
|
||||||
}
|
|
||||||
ul, err := url.Parse(hostUrl)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
}
|
|
||||||
date := time.Now().UTC().Format(time.RFC1123)
|
|
||||||
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
|
||||||
sign := strings.Join(signString, "\n")
|
|
||||||
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
|
|
||||||
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
|
||||||
"hmac-sha256", "host date request-line", sha)
|
|
||||||
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
|
||||||
v := url.Values{}
|
|
||||||
v.Add("host", ul.Host)
|
|
||||||
v.Add("date", date)
|
|
||||||
v.Add("authorization", authorization)
|
|
||||||
callUrl := hostUrl + "?" + v.Encode()
|
|
||||||
return callUrl
|
|
||||||
}
|
|
||||||
|
|
||||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
|
||||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
var usage Usage
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case xunfeiResponse := <-dataChan:
|
|
||||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
|
||||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
|
||||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
|
||||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
|
||||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var usage Usage
|
|
||||||
var content string
|
|
||||||
var xunfeiResponse XunfeiChatResponse
|
|
||||||
stop := false
|
|
||||||
for !stop {
|
|
||||||
select {
|
|
||||||
case xunfeiResponse = <-dataChan:
|
|
||||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
|
||||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
|
||||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
|
||||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
|
||||||
case stop = <-stopChan:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
|
||||||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
|
||||||
{
|
|
||||||
Content: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
|
||||||
|
|
||||||
response := responseXunfei2OpenAI(&xunfeiResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
|
||||||
d := websocket.Dialer{
|
|
||||||
HandshakeTimeout: 5 * time.Second,
|
|
||||||
}
|
|
||||||
conn, resp, err := d.Dial(authUrl, nil)
|
|
||||||
if err != nil || resp.StatusCode != 101 {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
|
||||||
err = conn.WriteJSON(data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
dataChan := make(chan XunfeiChatResponse)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
_, msg, err := conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error reading stream response: " + err.Error())
|
|
||||||
break
|
|
||||||
}
|
|
||||||
var response XunfeiChatResponse
|
|
||||||
err = json.Unmarshal(msg, &response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
break
|
|
||||||
}
|
|
||||||
dataChan <- response
|
|
||||||
if response.Payload.Choices.Status == 2 {
|
|
||||||
err := conn.Close()
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error closing websocket connection: " + err.Error())
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
|
|
||||||
return dataChan, stopChan, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
|
|
||||||
query := c.Request.URL.Query()
|
|
||||||
apiVersion := query.Get("api-version")
|
|
||||||
if apiVersion == "" {
|
|
||||||
apiVersion = c.GetString("api_version")
|
|
||||||
}
|
|
||||||
if apiVersion == "" {
|
|
||||||
apiVersion = "v1.1"
|
|
||||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
|
||||||
}
|
|
||||||
domain := "general"
|
|
||||||
if apiVersion != "v1.1" {
|
|
||||||
domain += strings.Split(apiVersion, ".")[0]
|
|
||||||
}
|
|
||||||
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
|
||||||
return domain, authUrl
|
|
||||||
}
|
|
@ -1,302 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://open.bigmodel.cn/doc/api#chatglm_std
|
|
||||||
// chatglm_std, chatglm_lite
|
|
||||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
|
||||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
|
||||||
|
|
||||||
type ZhipuMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuRequest struct {
|
|
||||||
Prompt []ZhipuMessage `json:"prompt"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
RequestId string `json:"request_id,omitempty"`
|
|
||||||
Incremental bool `json:"incremental,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuResponseData struct {
|
|
||||||
TaskId string `json:"task_id"`
|
|
||||||
RequestId string `json:"request_id"`
|
|
||||||
TaskStatus string `json:"task_status"`
|
|
||||||
Choices []ZhipuMessage `json:"choices"`
|
|
||||||
Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuResponse struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Msg string `json:"msg"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Data ZhipuResponseData `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuStreamMetaResponse struct {
|
|
||||||
RequestId string `json:"request_id"`
|
|
||||||
TaskId string `json:"task_id"`
|
|
||||||
TaskStatus string `json:"task_status"`
|
|
||||||
Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type zhipuTokenData struct {
|
|
||||||
Token string
|
|
||||||
ExpiryTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
var zhipuTokens sync.Map
|
|
||||||
var expSeconds int64 = 24 * 3600
|
|
||||||
|
|
||||||
func getZhipuToken(apikey string) string {
|
|
||||||
data, ok := zhipuTokens.Load(apikey)
|
|
||||||
if ok {
|
|
||||||
tokenData := data.(zhipuTokenData)
|
|
||||||
if time.Now().Before(tokenData.ExpiryTime) {
|
|
||||||
return tokenData.Token
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
split := strings.Split(apikey, ".")
|
|
||||||
if len(split) != 2 {
|
|
||||||
common.SysError("invalid zhipu key: " + apikey)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
id := split[0]
|
|
||||||
secret := split[1]
|
|
||||||
|
|
||||||
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
|
|
||||||
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
|
|
||||||
|
|
||||||
timestamp := time.Now().UnixNano() / 1e6
|
|
||||||
|
|
||||||
payload := jwt.MapClaims{
|
|
||||||
"api_key": id,
|
|
||||||
"exp": expMillis,
|
|
||||||
"timestamp": timestamp,
|
|
||||||
}
|
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
|
||||||
|
|
||||||
token.Header["alg"] = "HS256"
|
|
||||||
token.Header["sign_type"] = "SIGN"
|
|
||||||
|
|
||||||
tokenString, err := token.SignedString([]byte(secret))
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
zhipuTokens.Store(apikey, zhipuTokenData{
|
|
||||||
Token: tokenString,
|
|
||||||
ExpiryTime: expiryTime,
|
|
||||||
})
|
|
||||||
|
|
||||||
return tokenString
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
|
||||||
messages := make([]ZhipuMessage, 0, len(request.Messages))
|
|
||||||
for _, message := range request.Messages {
|
|
||||||
if message.Role == "system" {
|
|
||||||
messages = append(messages, ZhipuMessage{
|
|
||||||
Role: "system",
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
messages = append(messages, ZhipuMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: "Okay",
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
messages = append(messages, ZhipuMessage{
|
|
||||||
Role: message.Role,
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &ZhipuRequest{
|
|
||||||
Prompt: messages,
|
|
||||||
Temperature: request.Temperature,
|
|
||||||
TopP: request.TopP,
|
|
||||||
Incremental: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Id: response.Data.TaskId,
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
|
|
||||||
Usage: response.Data.Usage,
|
|
||||||
}
|
|
||||||
for i, choice := range response.Data.Choices {
|
|
||||||
openaiChoice := OpenAITextResponseChoice{
|
|
||||||
Index: i,
|
|
||||||
Message: Message{
|
|
||||||
Role: choice.Role,
|
|
||||||
Content: strings.Trim(choice.Content, "\""),
|
|
||||||
},
|
|
||||||
FinishReason: "",
|
|
||||||
}
|
|
||||||
if i == len(response.Data.Choices)-1 {
|
|
||||||
openaiChoice.FinishReason = "stop"
|
|
||||||
}
|
|
||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = zhipuResponse
|
|
||||||
response := ChatCompletionsStreamResponse{
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "chatglm",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = ""
|
|
||||||
choice.FinishReason = &stopFinishReason
|
|
||||||
response := ChatCompletionsStreamResponse{
|
|
||||||
Id: zhipuResponse.RequestId,
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "chatglm",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &response, &zhipuResponse.Usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var usage *Usage
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
|
|
||||||
return i + 2, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
metaChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
lines := strings.Split(data, "\n")
|
|
||||||
for i, line := range lines {
|
|
||||||
if len(line) < 5 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if line[:5] == "data:" {
|
|
||||||
dataChan <- line[5:]
|
|
||||||
if i != len(lines)-1 {
|
|
||||||
dataChan <- "\n"
|
|
||||||
}
|
|
||||||
} else if line[:5] == "meta:" {
|
|
||||||
metaChan <- line[5:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
response := streamResponseZhipu2OpenAI(data)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case data := <-metaChan:
|
|
||||||
var zhipuResponse ZhipuStreamMetaResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
usage = zhipuUsage
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var zhipuResponse ZhipuResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if !zhipuResponse.Success {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: zhipuResponse.Msg,
|
|
||||||
Type: "zhipu_error",
|
|
||||||
Param: "",
|
|
||||||
Code: zhipuResponse.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
|
||||||
fullTextResponse.Model = "chatglm"
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
@ -1,384 +1,139 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
type Message struct {
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
Role string `json:"role"`
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
Content any `json:"content"`
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
Name *string `json:"name,omitempty"`
|
"github.com/songquanpeng/one-api/middleware"
|
||||||
}
|
dbmodel "github.com/songquanpeng/one-api/model"
|
||||||
|
"github.com/songquanpeng/one-api/monitor"
|
||||||
type ImageURL struct {
|
"github.com/songquanpeng/one-api/relay/controller"
|
||||||
Url string `json:"url,omitempty"`
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
Detail string `json:"detail,omitempty"`
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
}
|
|
||||||
|
|
||||||
type TextContent struct {
|
|
||||||
Type string `json:"type,omitempty"`
|
|
||||||
Text string `json:"text,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ImageContent struct {
|
|
||||||
Type string `json:"type,omitempty"`
|
|
||||||
ImageURL *ImageURL `json:"image_url,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
ContentTypeText = "text"
|
|
||||||
ContentTypeImageURL = "image_url"
|
|
||||||
)
|
|
||||||
|
|
||||||
type OpenAIMessageContent struct {
|
|
||||||
Type string `json:"type,omitempty"`
|
|
||||||
Text string `json:"text"`
|
|
||||||
ImageURL *ImageURL `json:"image_url,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Message) IsStringContent() bool {
|
|
||||||
_, ok := m.Content.(string)
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Message) StringContent() string {
|
|
||||||
content, ok := m.Content.(string)
|
|
||||||
if ok {
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
contentList, ok := m.Content.([]any)
|
|
||||||
if ok {
|
|
||||||
var contentStr string
|
|
||||||
for _, contentItem := range contentList {
|
|
||||||
contentMap, ok := contentItem.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if contentMap["type"] == ContentTypeText {
|
|
||||||
if subStr, ok := contentMap["text"].(string); ok {
|
|
||||||
contentStr += subStr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return contentStr
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Message) ParseContent() []OpenAIMessageContent {
|
|
||||||
var contentList []OpenAIMessageContent
|
|
||||||
content, ok := m.Content.(string)
|
|
||||||
if ok {
|
|
||||||
contentList = append(contentList, OpenAIMessageContent{
|
|
||||||
Type: ContentTypeText,
|
|
||||||
Text: content,
|
|
||||||
})
|
|
||||||
return contentList
|
|
||||||
}
|
|
||||||
anyList, ok := m.Content.([]any)
|
|
||||||
if ok {
|
|
||||||
for _, contentItem := range anyList {
|
|
||||||
contentMap, ok := contentItem.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
switch contentMap["type"] {
|
|
||||||
case ContentTypeText:
|
|
||||||
if subStr, ok := contentMap["text"].(string); ok {
|
|
||||||
contentList = append(contentList, OpenAIMessageContent{
|
|
||||||
Type: ContentTypeText,
|
|
||||||
Text: subStr,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
case ContentTypeImageURL:
|
|
||||||
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
|
|
||||||
contentList = append(contentList, OpenAIMessageContent{
|
|
||||||
Type: ContentTypeImageURL,
|
|
||||||
ImageURL: &ImageURL{
|
|
||||||
Url: subObj["url"].(string),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return contentList
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
RelayModeUnknown = iota
|
|
||||||
RelayModeChatCompletions
|
|
||||||
RelayModeCompletions
|
|
||||||
RelayModeEmbeddings
|
|
||||||
RelayModeModerations
|
|
||||||
RelayModeImagesGenerations
|
|
||||||
RelayModeEdits
|
|
||||||
RelayModeAudioSpeech
|
|
||||||
RelayModeAudioTranscription
|
|
||||||
RelayModeAudioTranslation
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/chat
|
// https://platform.openai.com/docs/api-reference/chat
|
||||||
|
|
||||||
type ResponseFormat struct {
|
func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
|
||||||
Type string `json:"type,omitempty"`
|
var err *model.ErrorWithStatusCode
|
||||||
|
switch relayMode {
|
||||||
|
case relaymode.ImagesGenerations:
|
||||||
|
err = controller.RelayImageHelper(c, relayMode)
|
||||||
|
case relaymode.AudioSpeech:
|
||||||
|
fallthrough
|
||||||
|
case relaymode.AudioTranslation:
|
||||||
|
fallthrough
|
||||||
|
case relaymode.AudioTranscription:
|
||||||
|
err = controller.RelayAudioHelper(c, relayMode)
|
||||||
|
case relaymode.Proxy:
|
||||||
|
err = controller.RelayProxyHelper(c, relayMode)
|
||||||
|
default:
|
||||||
|
err = controller.RelayTextHelper(c)
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
type GeneralOpenAIRequest struct {
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Messages []Message `json:"messages,omitempty"`
|
|
||||||
Prompt any `json:"prompt,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
N int `json:"n,omitempty"`
|
|
||||||
Input any `json:"input,omitempty"`
|
|
||||||
Instruction string `json:"instruction,omitempty"`
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
Functions any `json:"functions,omitempty"`
|
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
|
||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
|
||||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
|
||||||
Seed float64 `json:"seed,omitempty"`
|
|
||||||
Tools any `json:"tools,omitempty"`
|
|
||||||
ToolChoice any `json:"tool_choice,omitempty"`
|
|
||||||
User string `json:"user,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
|
||||||
if r.Input == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var input []string
|
|
||||||
switch r.Input.(type) {
|
|
||||||
case string:
|
|
||||||
input = []string{r.Input.(string)}
|
|
||||||
case []any:
|
|
||||||
input = make([]string, 0, len(r.Input.([]any)))
|
|
||||||
for _, item := range r.Input.([]any) {
|
|
||||||
if str, ok := item.(string); ok {
|
|
||||||
input = append(input, str)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return input
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Messages []Message `json:"messages"`
|
|
||||||
MaxTokens int `json:"max_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TextRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Messages []Message `json:"messages"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
MaxTokens int `json:"max_tokens"`
|
|
||||||
//Stream bool `json:"stream"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
|
|
||||||
type ImageRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt" binding:"required"`
|
|
||||||
N int `json:"n,omitempty"`
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
Quality string `json:"quality,omitempty"`
|
|
||||||
ResponseFormat string `json:"response_format,omitempty"`
|
|
||||||
Style string `json:"style,omitempty"`
|
|
||||||
User string `json:"user,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type WhisperJSONResponse struct {
|
|
||||||
Text string `json:"text,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type WhisperVerboseJSONResponse struct {
|
|
||||||
Task string `json:"task,omitempty"`
|
|
||||||
Language string `json:"language,omitempty"`
|
|
||||||
Duration float64 `json:"duration,omitempty"`
|
|
||||||
Text string `json:"text,omitempty"`
|
|
||||||
Segments []Segment `json:"segments,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Segment struct {
|
|
||||||
Id int `json:"id"`
|
|
||||||
Seek int `json:"seek"`
|
|
||||||
Start float64 `json:"start"`
|
|
||||||
End float64 `json:"end"`
|
|
||||||
Text string `json:"text"`
|
|
||||||
Tokens []int `json:"tokens"`
|
|
||||||
Temperature float64 `json:"temperature"`
|
|
||||||
AvgLogprob float64 `json:"avg_logprob"`
|
|
||||||
CompressionRatio float64 `json:"compression_ratio"`
|
|
||||||
NoSpeechProb float64 `json:"no_speech_prob"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TextToSpeechRequest struct {
|
|
||||||
Model string `json:"model" binding:"required"`
|
|
||||||
Input string `json:"input" binding:"required"`
|
|
||||||
Voice string `json:"voice" binding:"required"`
|
|
||||||
Speed float64 `json:"speed"`
|
|
||||||
ResponseFormat string `json:"response_format"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Usage struct {
|
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIError struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Param string `json:"param"`
|
|
||||||
Code any `json:"code"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIErrorWithStatusCode struct {
|
|
||||||
OpenAIError
|
|
||||||
StatusCode int `json:"status_code"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TextResponse struct {
|
|
||||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
|
||||||
Usage `json:"usage"`
|
|
||||||
Error OpenAIError `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAITextResponseChoice struct {
|
|
||||||
Index int `json:"index"`
|
|
||||||
Message `json:"message"`
|
|
||||||
FinishReason string `json:"finish_reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAITextResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
|
||||||
Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIEmbeddingResponseItem struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
Embedding []float64 `json:"embedding"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIEmbeddingResponse struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Data []OpenAIEmbeddingResponseItem `json:"data"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ImageResponse struct {
|
|
||||||
Created int `json:"created"`
|
|
||||||
Data []struct {
|
|
||||||
Url string `json:"url"`
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionsStreamResponseChoice struct {
|
|
||||||
Delta struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
} `json:"delta"`
|
|
||||||
FinishReason *string `json:"finish_reason,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionsStreamResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CompletionsStreamResponse struct {
|
|
||||||
Choices []struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
FinishReason string `json:"finish_reason"`
|
|
||||||
} `json:"choices"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
relayMode := RelayModeUnknown
|
ctx := c.Request.Context()
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
relayMode := relaymode.GetByPath(c.Request.URL.Path)
|
||||||
relayMode = RelayModeChatCompletions
|
if config.DebugEnabled {
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
relayMode = RelayModeCompletions
|
logger.Debugf(ctx, "request body: %s", string(requestBody))
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
|
||||||
relayMode = RelayModeEmbeddings
|
|
||||||
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
|
||||||
relayMode = RelayModeEmbeddings
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
|
||||||
relayMode = RelayModeModerations
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
|
||||||
relayMode = RelayModeImagesGenerations
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
|
||||||
relayMode = RelayModeEdits
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
|
||||||
relayMode = RelayModeAudioSpeech
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
|
||||||
relayMode = RelayModeAudioTranscription
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
|
||||||
relayMode = RelayModeAudioTranslation
|
|
||||||
}
|
}
|
||||||
var err *OpenAIErrorWithStatusCode
|
channelId := c.GetInt(ctxkey.ChannelId)
|
||||||
switch relayMode {
|
userId := c.GetInt(ctxkey.Id)
|
||||||
case RelayModeImagesGenerations:
|
bizErr := relayHelper(c, relayMode)
|
||||||
err = relayImageHelper(c, relayMode)
|
if bizErr == nil {
|
||||||
case RelayModeAudioSpeech:
|
monitor.Emit(channelId, true)
|
||||||
fallthrough
|
return
|
||||||
case RelayModeAudioTranslation:
|
|
||||||
fallthrough
|
|
||||||
case RelayModeAudioTranscription:
|
|
||||||
err = relayAudioHelper(c, relayMode)
|
|
||||||
default:
|
|
||||||
err = relayTextHelper(c, relayMode)
|
|
||||||
}
|
}
|
||||||
|
lastFailedChannelId := channelId
|
||||||
|
channelName := c.GetString(ctxkey.ChannelName)
|
||||||
|
group := c.GetString(ctxkey.Group)
|
||||||
|
originalModel := c.GetString(ctxkey.OriginalModel)
|
||||||
|
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
|
||||||
|
requestId := c.GetString(helper.RequestIdKey)
|
||||||
|
retryTimes := config.RetryTimes
|
||||||
|
if !shouldRetry(c, bizErr.StatusCode) {
|
||||||
|
logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode)
|
||||||
|
retryTimes = 0
|
||||||
|
}
|
||||||
|
for i := retryTimes; i > 0; i-- {
|
||||||
|
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %+v", err)
|
||||||
retryTimesStr := c.Query("retry")
|
break
|
||||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
|
||||||
if retryTimesStr == "" {
|
|
||||||
retryTimes = common.RetryTimes
|
|
||||||
}
|
}
|
||||||
if retryTimes > 0 {
|
logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i)
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
if channel.Id == lastFailedChannelId {
|
||||||
} else {
|
continue
|
||||||
if err.StatusCode == http.StatusTooManyRequests {
|
|
||||||
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
|
||||||
}
|
}
|
||||||
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
c.JSON(err.StatusCode, gin.H{
|
requestBody, err := common.GetRequestBody(c)
|
||||||
"error": err.OpenAIError,
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
bizErr = relayHelper(c, relayMode)
|
||||||
|
if bizErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
channelId := c.GetInt(ctxkey.ChannelId)
|
||||||
|
lastFailedChannelId = channelId
|
||||||
|
channelName := c.GetString(ctxkey.ChannelName)
|
||||||
|
// BUG: bizErr is in race condition
|
||||||
|
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
|
||||||
|
}
|
||||||
|
if bizErr != nil {
|
||||||
|
if bizErr.StatusCode == http.StatusTooManyRequests {
|
||||||
|
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||||
|
}
|
||||||
|
|
||||||
|
// BUG: bizErr is in race condition
|
||||||
|
bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
|
||||||
|
c.JSON(bizErr.StatusCode, gin.H{
|
||||||
|
"error": bizErr.Error,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
|
||||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
|
||||||
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
channelName := c.GetString("channel_name")
|
|
||||||
disableChannel(channelId, channelName, err.Message)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldRetry(c *gin.Context, statusCode int) bool {
|
||||||
|
if _, ok := c.Get(ctxkey.SpecificChannelId); ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if statusCode == http.StatusTooManyRequests {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if statusCode/100 == 5 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if statusCode == http.StatusBadRequest {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if statusCode/100 == 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
|
||||||
|
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
|
||||||
|
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||||
|
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
||||||
|
monitor.DisableChannel(channelId, channelName, err.Message)
|
||||||
|
} else {
|
||||||
|
monitor.Emit(channelId, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayNotImplemented(c *gin.Context) {
|
func RelayNotImplemented(c *gin.Context) {
|
||||||
err := OpenAIError{
|
err := model.Error{
|
||||||
Message: "API not implemented",
|
Message: "API not implemented",
|
||||||
Type: "one_api_error",
|
Type: "one_api_error",
|
||||||
Param: "",
|
Param: "",
|
||||||
@ -390,7 +145,7 @@ func RelayNotImplemented(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RelayNotFound(c *gin.Context) {
|
func RelayNotFound(c *gin.Context) {
|
||||||
err := OpenAIError{
|
err := model.Error{
|
||||||
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
||||||
Type: "invalid_request_error",
|
Type: "invalid_request_error",
|
||||||
Param: "",
|
Param: "",
|
||||||
|
@ -1,20 +1,28 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/common/network"
|
||||||
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAllTokens(c *gin.Context) {
|
func GetAllTokens(c *gin.Context) {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt(ctxkey.Id)
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
p, _ := strconv.Atoi(c.Query("p"))
|
||||||
if p < 0 {
|
if p < 0 {
|
||||||
p = 0
|
p = 0
|
||||||
}
|
}
|
||||||
tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage)
|
|
||||||
|
order := c.Query("order")
|
||||||
|
tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage, order)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -31,7 +39,7 @@ func GetAllTokens(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SearchTokens(c *gin.Context) {
|
func SearchTokens(c *gin.Context) {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt(ctxkey.Id)
|
||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
tokens, err := model.SearchUserTokens(userId, keyword)
|
tokens, err := model.SearchUserTokens(userId, keyword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -51,7 +59,7 @@ func SearchTokens(c *gin.Context) {
|
|||||||
|
|
||||||
func GetToken(c *gin.Context) {
|
func GetToken(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt(ctxkey.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -76,8 +84,8 @@ func GetToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetTokenStatus(c *gin.Context) {
|
func GetTokenStatus(c *gin.Context) {
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt(ctxkey.TokenId)
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt(ctxkey.Id)
|
||||||
token, err := model.GetTokenByIds(tokenId, userId)
|
token, err := model.GetTokenByIds(tokenId, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@ -99,6 +107,19 @@ func GetTokenStatus(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateToken(c *gin.Context, token model.Token) error {
|
||||||
|
if len(token.Name) > 30 {
|
||||||
|
return fmt.Errorf("令牌名称过长")
|
||||||
|
}
|
||||||
|
if token.Subnet != nil && *token.Subnet != "" {
|
||||||
|
err := network.IsValidSubnets(*token.Subnet)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("无效的网段:%s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func AddToken(c *gin.Context) {
|
func AddToken(c *gin.Context) {
|
||||||
token := model.Token{}
|
token := model.Token{}
|
||||||
err := c.ShouldBindJSON(&token)
|
err := c.ShouldBindJSON(&token)
|
||||||
@ -109,22 +130,26 @@ func AddToken(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(token.Name) > 30 {
|
err = validateToken(c, token)
|
||||||
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "令牌名称过长",
|
"message": fmt.Sprintf("参数错误:%s", err.Error()),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanToken := model.Token{
|
cleanToken := model.Token{
|
||||||
UserId: c.GetInt("id"),
|
UserId: c.GetInt(ctxkey.Id),
|
||||||
Name: token.Name,
|
Name: token.Name,
|
||||||
Key: common.GenerateKey(),
|
Key: random.GenerateKey(),
|
||||||
CreatedTime: common.GetTimestamp(),
|
CreatedTime: helper.GetTimestamp(),
|
||||||
AccessedTime: common.GetTimestamp(),
|
AccessedTime: helper.GetTimestamp(),
|
||||||
ExpiredTime: token.ExpiredTime,
|
ExpiredTime: token.ExpiredTime,
|
||||||
RemainQuota: token.RemainQuota,
|
RemainQuota: token.RemainQuota,
|
||||||
UnlimitedQuota: token.UnlimitedQuota,
|
UnlimitedQuota: token.UnlimitedQuota,
|
||||||
|
Models: token.Models,
|
||||||
|
Subnet: token.Subnet,
|
||||||
}
|
}
|
||||||
err = cleanToken.Insert()
|
err = cleanToken.Insert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -137,13 +162,14 @@ func AddToken(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
|
"data": cleanToken,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteToken(c *gin.Context) {
|
func DeleteToken(c *gin.Context) {
|
||||||
id, _ := strconv.Atoi(c.Param("id"))
|
id, _ := strconv.Atoi(c.Param("id"))
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt(ctxkey.Id)
|
||||||
err := model.DeleteTokenById(id, userId)
|
err := model.DeleteTokenById(id, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@ -160,7 +186,7 @@ func DeleteToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateToken(c *gin.Context) {
|
func UpdateToken(c *gin.Context) {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt(ctxkey.Id)
|
||||||
statusOnly := c.Query("status_only")
|
statusOnly := c.Query("status_only")
|
||||||
token := model.Token{}
|
token := model.Token{}
|
||||||
err := c.ShouldBindJSON(&token)
|
err := c.ShouldBindJSON(&token)
|
||||||
@ -171,10 +197,11 @@ func UpdateToken(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(token.Name) > 30 {
|
err = validateToken(c, token)
|
||||||
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "令牌名称过长",
|
"message": fmt.Sprintf("参数错误:%s", err.Error()),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -186,15 +213,15 @@ func UpdateToken(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if token.Status == common.TokenStatusEnabled {
|
if token.Status == model.TokenStatusEnabled {
|
||||||
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
|
if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
|
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
|
if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
|
"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
|
||||||
@ -210,6 +237,8 @@ func UpdateToken(c *gin.Context) {
|
|||||||
cleanToken.ExpiredTime = token.ExpiredTime
|
cleanToken.ExpiredTime = token.ExpiredTime
|
||||||
cleanToken.RemainQuota = token.RemainQuota
|
cleanToken.RemainQuota = token.RemainQuota
|
||||||
cleanToken.UnlimitedQuota = token.UnlimitedQuota
|
cleanToken.UnlimitedQuota = token.UnlimitedQuota
|
||||||
|
cleanToken.Models = token.Models
|
||||||
|
cleanToken.Subnet = token.Subnet
|
||||||
}
|
}
|
||||||
err = cleanToken.Update()
|
err = cleanToken.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -3,9 +3,12 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -19,7 +22,7 @@ type LoginRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Login(c *gin.Context) {
|
func Login(c *gin.Context) {
|
||||||
if !common.PasswordLoginEnabled {
|
if !config.PasswordLoginEnabled {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "管理员关闭了密码登录",
|
"message": "管理员关闭了密码登录",
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -56,11 +59,11 @@ func Login(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
setupLogin(&user, c)
|
SetupLogin(&user, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setup session & cookies and then return user info
|
// setup session & cookies and then return user info
|
||||||
func setupLogin(user *model.User, c *gin.Context) {
|
func SetupLogin(user *model.User, c *gin.Context) {
|
||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
session.Set("id", user.Id)
|
session.Set("id", user.Id)
|
||||||
session.Set("username", user.Username)
|
session.Set("username", user.Username)
|
||||||
@ -106,14 +109,14 @@ func Logout(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Register(c *gin.Context) {
|
func Register(c *gin.Context) {
|
||||||
if !common.RegisterEnabled {
|
if !config.RegisterEnabled {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "管理员关闭了新用户注册",
|
"message": "管理员关闭了新用户注册",
|
||||||
"success": false,
|
"success": false,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !common.PasswordRegisterEnabled {
|
if !config.PasswordRegisterEnabled {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
|
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -136,7 +139,7 @@ func Register(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if common.EmailVerificationEnabled {
|
if config.EmailVerificationEnabled {
|
||||||
if user.Email == "" || user.VerificationCode == "" {
|
if user.Email == "" || user.VerificationCode == "" {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -160,7 +163,7 @@ func Register(c *gin.Context) {
|
|||||||
DisplayName: user.Username,
|
DisplayName: user.Username,
|
||||||
InviterId: inviterId,
|
InviterId: inviterId,
|
||||||
}
|
}
|
||||||
if common.EmailVerificationEnabled {
|
if config.EmailVerificationEnabled {
|
||||||
cleanUser.Email = user.Email
|
cleanUser.Email = user.Email
|
||||||
}
|
}
|
||||||
if err := cleanUser.Insert(inviterId); err != nil {
|
if err := cleanUser.Insert(inviterId); err != nil {
|
||||||
@ -170,6 +173,7 @@ func Register(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
@ -182,7 +186,10 @@ func GetAllUsers(c *gin.Context) {
|
|||||||
if p < 0 {
|
if p < 0 {
|
||||||
p = 0
|
p = 0
|
||||||
}
|
}
|
||||||
users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage)
|
|
||||||
|
order := c.DefaultQuery("order", "")
|
||||||
|
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -190,12 +197,12 @@ func GetAllUsers(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": users,
|
"data": users,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchUsers(c *gin.Context) {
|
func SearchUsers(c *gin.Context) {
|
||||||
@ -233,8 +240,8 @@ func GetUser(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
myRole := c.GetInt("role")
|
myRole := c.GetInt(ctxkey.Role)
|
||||||
if myRole <= user.Role && myRole != common.RoleRootUser {
|
if myRole <= user.Role && myRole != model.RoleRootUser {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无权获取同级或更高等级用户的信息",
|
"message": "无权获取同级或更高等级用户的信息",
|
||||||
@ -250,7 +257,7 @@ func GetUser(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetUserDashboard(c *gin.Context) {
|
func GetUserDashboard(c *gin.Context) {
|
||||||
id := c.GetInt("id")
|
id := c.GetInt(ctxkey.Id)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix()
|
startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix()
|
||||||
endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix()
|
endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix()
|
||||||
@ -273,7 +280,7 @@ func GetUserDashboard(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GenerateAccessToken(c *gin.Context) {
|
func GenerateAccessToken(c *gin.Context) {
|
||||||
id := c.GetInt("id")
|
id := c.GetInt(ctxkey.Id)
|
||||||
user, err := model.GetUserById(id, true)
|
user, err := model.GetUserById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@ -282,7 +289,7 @@ func GenerateAccessToken(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.AccessToken = common.GetUUID()
|
user.AccessToken = random.GetUUID()
|
||||||
|
|
||||||
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
|
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@ -309,7 +316,7 @@ func GenerateAccessToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetAffCode(c *gin.Context) {
|
func GetAffCode(c *gin.Context) {
|
||||||
id := c.GetInt("id")
|
id := c.GetInt(ctxkey.Id)
|
||||||
user, err := model.GetUserById(id, true)
|
user, err := model.GetUserById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@ -319,7 +326,7 @@ func GetAffCode(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if user.AffCode == "" {
|
if user.AffCode == "" {
|
||||||
user.AffCode = common.GetRandomString(4)
|
user.AffCode = random.GetRandomString(4)
|
||||||
if err := user.Update(false); err != nil {
|
if err := user.Update(false); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -337,7 +344,7 @@ func GetAffCode(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetSelf(c *gin.Context) {
|
func GetSelf(c *gin.Context) {
|
||||||
id := c.GetInt("id")
|
id := c.GetInt(ctxkey.Id)
|
||||||
user, err := model.GetUserById(id, false)
|
user, err := model.GetUserById(id, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@ -382,15 +389,15 @@ func UpdateUser(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
myRole := c.GetInt("role")
|
myRole := c.GetInt(ctxkey.Role)
|
||||||
if myRole <= originUser.Role && myRole != common.RoleRootUser {
|
if myRole <= originUser.Role && myRole != model.RoleRootUser {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无权更新同权限等级或更高权限等级的用户信息",
|
"message": "无权更新同权限等级或更高权限等级的用户信息",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if myRole <= updatedUser.Role && myRole != common.RoleRootUser {
|
if myRole <= updatedUser.Role && myRole != model.RoleRootUser {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
|
"message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
|
||||||
@ -440,7 +447,7 @@ func UpdateSelf(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cleanUser := model.User{
|
cleanUser := model.User{
|
||||||
Id: c.GetInt("id"),
|
Id: c.GetInt(ctxkey.Id),
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Password: user.Password,
|
Password: user.Password,
|
||||||
DisplayName: user.DisplayName,
|
DisplayName: user.DisplayName,
|
||||||
@ -504,7 +511,7 @@ func DeleteSelf(c *gin.Context) {
|
|||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
user, _ := model.GetUserById(id, false)
|
user, _ := model.GetUserById(id, false)
|
||||||
|
|
||||||
if user.Role == common.RoleRootUser {
|
if user.Role == model.RoleRootUser {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "不能删除超级管理员账户",
|
"message": "不能删除超级管理员账户",
|
||||||
@ -606,7 +613,7 @@ func ManageUser(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
myRole := c.GetInt("role")
|
myRole := c.GetInt("role")
|
||||||
if myRole <= user.Role && myRole != common.RoleRootUser {
|
if myRole <= user.Role && myRole != model.RoleRootUser {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无权更新同权限等级或更高权限等级的用户信息",
|
"message": "无权更新同权限等级或更高权限等级的用户信息",
|
||||||
@ -615,8 +622,8 @@ func ManageUser(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
switch req.Action {
|
switch req.Action {
|
||||||
case "disable":
|
case "disable":
|
||||||
user.Status = common.UserStatusDisabled
|
user.Status = model.UserStatusDisabled
|
||||||
if user.Role == common.RoleRootUser {
|
if user.Role == model.RoleRootUser {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无法禁用超级管理员用户",
|
"message": "无法禁用超级管理员用户",
|
||||||
@ -624,9 +631,9 @@ func ManageUser(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "enable":
|
case "enable":
|
||||||
user.Status = common.UserStatusEnabled
|
user.Status = model.UserStatusEnabled
|
||||||
case "delete":
|
case "delete":
|
||||||
if user.Role == common.RoleRootUser {
|
if user.Role == model.RoleRootUser {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无法删除超级管理员用户",
|
"message": "无法删除超级管理员用户",
|
||||||
@ -641,37 +648,37 @@ func ManageUser(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "promote":
|
case "promote":
|
||||||
if myRole != common.RoleRootUser {
|
if myRole != model.RoleRootUser {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "普通管理员用户无法提升其他用户为管理员",
|
"message": "普通管理员用户无法提升其他用户为管理员",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if user.Role >= common.RoleAdminUser {
|
if user.Role >= model.RoleAdminUser {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "该用户已经是管理员",
|
"message": "该用户已经是管理员",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.Role = common.RoleAdminUser
|
user.Role = model.RoleAdminUser
|
||||||
case "demote":
|
case "demote":
|
||||||
if user.Role == common.RoleRootUser {
|
if user.Role == model.RoleRootUser {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无法降级超级管理员用户",
|
"message": "无法降级超级管理员用户",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if user.Role == common.RoleCommonUser {
|
if user.Role == model.RoleCommonUser {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "该用户已经是普通用户",
|
"message": "该用户已经是普通用户",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.Role = common.RoleCommonUser
|
user.Role = model.RoleCommonUser
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.Update(false); err != nil {
|
if err := user.Update(false); err != nil {
|
||||||
@ -725,8 +732,8 @@ func EmailBind(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if user.Role == common.RoleRootUser {
|
if user.Role == model.RoleRootUser {
|
||||||
common.RootUserEmail = email
|
config.RootUserEmail = email
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
@ -765,3 +772,38 @@ func TopUp(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type adminTopUpRequest struct {
|
||||||
|
UserId int `json:"user_id"`
|
||||||
|
Quota int `json:"quota"`
|
||||||
|
Remark string `json:"remark"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func AdminTopUp(c *gin.Context) {
|
||||||
|
req := adminTopUpRequest{}
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = model.IncreaseUserQuota(req.UserId, int64(req.Quota))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Remark == "" {
|
||||||
|
req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
|
||||||
|
}
|
||||||
|
model.RecordTopupLog(req.UserId, req.Remark, req.Quota)
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -2,7 +2,7 @@ version: '3.4'
|
|||||||
|
|
||||||
services:
|
services:
|
||||||
one-api:
|
one-api:
|
||||||
image: justsong/one-api:latest
|
image: "${REGISTRY:-docker.io}/justsong/one-api:latest"
|
||||||
container_name: one-api
|
container_name: one-api
|
||||||
restart: always
|
restart: always
|
||||||
command: --log-dir /app/logs
|
command: --log-dir /app/logs
|
||||||
@ -29,12 +29,12 @@ services:
|
|||||||
retries: 3
|
retries: 3
|
||||||
|
|
||||||
redis:
|
redis:
|
||||||
image: redis:latest
|
image: "${REGISTRY:-docker.io}/redis:latest"
|
||||||
container_name: redis
|
container_name: redis
|
||||||
restart: always
|
restart: always
|
||||||
|
|
||||||
db:
|
db:
|
||||||
image: mysql:8.2.0
|
image: "${REGISTRY:-docker.io}/mysql:8.2.0"
|
||||||
restart: always
|
restart: always
|
||||||
container_name: mysql
|
container_name: mysql
|
||||||
volumes:
|
volumes:
|
||||||
|
53
docs/API.md
Normal file
53
docs/API.md
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# 使用 API 操控 & 扩展 One API
|
||||||
|
> 欢迎提交 PR 在此放上你的拓展项目。
|
||||||
|
|
||||||
|
例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。
|
||||||
|
|
||||||
|
又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。
|
||||||
|
|
||||||
|
## 鉴权
|
||||||
|
One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API:
|
||||||
|

|
||||||
|
|
||||||
|
## 请求格式与响应格式
|
||||||
|
One API 使用 JSON 格式进行请求和响应。
|
||||||
|
|
||||||
|
对于响应体,一般格式如下:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"message": "请求信息",
|
||||||
|
"success": true,
|
||||||
|
"data": {}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## API 列表
|
||||||
|
> 当前 API 列表不全,请自行通过浏览器抓取前端请求
|
||||||
|
|
||||||
|
如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。
|
||||||
|
|
||||||
|
### 获取当前登录用户信息
|
||||||
|
**GET** `/api/user/self`
|
||||||
|
|
||||||
|
### 为给定用户充值额度
|
||||||
|
**POST** `/api/topup`
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"user_id": 1,
|
||||||
|
"quota": 100000,
|
||||||
|
"remark": "充值 100000 额度"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 其他
|
||||||
|
### 充值链接上的附加参数
|
||||||
|
One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如:
|
||||||
|
`https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837`
|
||||||
|
|
||||||
|
你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。
|
||||||
|
|
||||||
|
注意,不是所有主题都支持该功能,欢迎 PR 补齐。
|
128
go.mod
128
go.mod
@ -1,65 +1,111 @@
|
|||||||
module one-api
|
module github.com/songquanpeng/one-api
|
||||||
|
|
||||||
// +heroku goVersion go1.18
|
// +heroku goVersion go1.18
|
||||||
go 1.18
|
go 1.20
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/gin-contrib/cors v1.4.0
|
cloud.google.com/go/iam v1.1.10
|
||||||
github.com/gin-contrib/gzip v0.0.6
|
github.com/aws/aws-sdk-go-v2 v1.27.0
|
||||||
github.com/gin-contrib/sessions v0.0.5
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.15
|
||||||
github.com/gin-contrib/static v0.0.1
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3
|
||||||
github.com/gin-gonic/gin v1.9.1
|
github.com/gin-contrib/cors v1.7.2
|
||||||
github.com/go-playground/validator/v10 v10.14.0
|
github.com/gin-contrib/gzip v1.0.1
|
||||||
|
github.com/gin-contrib/sessions v1.0.1
|
||||||
|
github.com/gin-contrib/static v1.1.2
|
||||||
|
github.com/gin-gonic/gin v1.10.0
|
||||||
|
github.com/go-playground/validator/v10 v10.20.0
|
||||||
github.com/go-redis/redis/v8 v8.11.5
|
github.com/go-redis/redis/v8 v8.11.5
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||||
github.com/google/uuid v1.3.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/gorilla/websocket v1.5.0
|
github.com/gorilla/websocket v1.5.1
|
||||||
github.com/pkoukk/tiktoken-go v0.1.5
|
github.com/jinzhu/copier v0.4.0
|
||||||
github.com/stretchr/testify v1.8.3
|
github.com/joho/godotenv v1.5.1
|
||||||
golang.org/x/crypto v0.17.0
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
golang.org/x/image v0.14.0
|
github.com/pkg/errors v0.9.1
|
||||||
gorm.io/driver/mysql v1.4.3
|
github.com/pkoukk/tiktoken-go v0.1.7
|
||||||
gorm.io/driver/postgres v1.5.2
|
github.com/smartystreets/goconvey v1.8.1
|
||||||
gorm.io/driver/sqlite v1.4.3
|
github.com/stretchr/testify v1.9.0
|
||||||
gorm.io/gorm v1.25.0
|
golang.org/x/crypto v0.24.0
|
||||||
|
golang.org/x/image v0.18.0
|
||||||
|
google.golang.org/api v0.187.0
|
||||||
|
gorm.io/driver/mysql v1.5.6
|
||||||
|
gorm.io/driver/postgres v1.5.7
|
||||||
|
gorm.io/driver/sqlite v1.5.5
|
||||||
|
gorm.io/gorm v1.25.10
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/bytedance/sonic v1.9.1 // indirect
|
cloud.google.com/go/auth v0.6.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||||
|
filippo.io/edwards25519 v1.1.0 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 // indirect
|
||||||
|
github.com/aws/smithy-go v1.20.2 // indirect
|
||||||
|
github.com/bytedance/sonic v1.11.6 // indirect
|
||||||
|
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
|
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||||
|
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
github.com/dlclark/regexp2 v1.11.0 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
|
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
|
github.com/go-logr/logr v1.4.1 // indirect
|
||||||
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-sql-driver/mysql v1.6.0 // indirect
|
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.3 // indirect
|
||||||
github.com/gorilla/context v1.1.1 // indirect
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
github.com/golang/protobuf v1.5.4 // indirect
|
||||||
github.com/gorilla/sessions v1.2.1 // indirect
|
github.com/google/s2a-go v0.1.7 // indirect
|
||||||
|
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||||
|
github.com/googleapis/gax-go/v2 v2.12.5 // indirect
|
||||||
|
github.com/gopherjs/gopherjs v1.17.2 // indirect
|
||||||
|
github.com/gorilla/context v1.1.2 // indirect
|
||||||
|
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||||
|
github.com/gorilla/sessions v1.2.2 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
|
||||||
github.com/jackc/pgx/v5 v5.3.1 // indirect
|
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
||||||
|
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
github.com/jtolds/gls v4.20.0+incompatible // indirect
|
||||||
github.com/leodido/go-urn v1.2.4 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
github.com/kr/text v0.2.0 // indirect
|
||||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/smarty/assertions v1.15.0 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.3.0 // indirect
|
go.opencensus.io v0.24.0 // indirect
|
||||||
golang.org/x/net v0.17.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect
|
||||||
golang.org/x/sys v0.15.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||||
golang.org/x/text v0.14.0 // indirect
|
go.opentelemetry.io/otel v1.24.0 // indirect
|
||||||
google.golang.org/protobuf v1.30.0 // indirect
|
go.opentelemetry.io/otel/metric v1.24.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/trace v1.24.0 // indirect
|
||||||
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
|
golang.org/x/net v0.26.0 // indirect
|
||||||
|
golang.org/x/oauth2 v0.21.0 // indirect
|
||||||
|
golang.org/x/sync v0.7.0 // indirect
|
||||||
|
golang.org/x/sys v0.21.0 // indirect
|
||||||
|
golang.org/x/text v0.16.0 // indirect
|
||||||
|
golang.org/x/time v0.5.0 // indirect
|
||||||
|
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
|
||||||
|
google.golang.org/grpc v1.64.1 // indirect
|
||||||
|
google.golang.org/protobuf v1.34.2 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
395
go.sum
395
go.sum
@ -1,206 +1,317 @@
|
|||||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
cloud.google.com/go/auth v0.6.1 h1:T0Zw1XM5c1GlpN2HYr2s+m3vr1p2wy+8VN+Z1FKxW38=
|
||||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
cloud.google.com/go/auth v0.6.1/go.mod h1:eFHG7zDzbXHKmjJddFG/rBlcGp6t25SwRUiEQSlO4x4=
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
|
cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
|
||||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
cloud.google.com/go/iam v1.1.10 h1:ZSAr64oEhQSClwBL670MsJAW5/RLiC6kfw3Bqmd5ZDI=
|
||||||
|
cloud.google.com/go/iam v1.1.10/go.mod h1:iEgMq62sg8zx446GCaijmA2Miwg5o3UbO+nI47WHJps=
|
||||||
|
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||||
|
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||||
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.15 h1:YDexlvDRCA8ems2T5IP1xkMtOZ1uLJOCJdTr0igs5zo=
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.15/go.mod h1:vxHggqW6hFNaeNC0WyXS3VdyjcV0a4KMUY4dKJ96buU=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 h1:Fihjyd6DeNjcawBEGLH9dkIEUi6AdhucDKPE9nJ4QiY=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3/go.mod h1:opvUj3ismqSCxYc+m4WIjPL0ewZGtvp0ess7cKvBPOQ=
|
||||||
|
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
|
||||||
|
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
||||||
|
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||||
|
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||||
|
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||||
|
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||||
|
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||||
|
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||||
|
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||||
|
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||||
|
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
||||||
|
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
|
||||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||||
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
|
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||||
github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs=
|
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||||
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
|
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk=
|
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||||
github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfmF+0jE=
|
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||||
github.com/gin-contrib/sessions v0.0.5/go.mod h1:vYAuaUPqie3WUSsft6HUlCjlwwoJQs97miaG2+7neKY=
|
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||||
|
github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw=
|
||||||
|
github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
|
||||||
|
github.com/gin-contrib/gzip v1.0.1 h1:HQ8ENHODeLY7a4g1Au/46Z92bdGFl74OhxcZble9WJE=
|
||||||
|
github.com/gin-contrib/gzip v1.0.1/go.mod h1:njt428fdUNRvjuJf16tZMYZ2Yl+WQB53X5wmhDwXvC4=
|
||||||
|
github.com/gin-contrib/sessions v1.0.1 h1:3hsJyNs7v7N8OtelFmYXFrulAf6zSR7nW/putcPEHxI=
|
||||||
|
github.com/gin-contrib/sessions v1.0.1/go.mod h1:ouxSFM24/OgIud5MJYQJLpy6AwxQ5EYO9yLhbtObGkM=
|
||||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||||
github.com/gin-contrib/static v0.0.1 h1:JVxuvHPuUfkoul12N7dtQw7KRn/pSMq7Ue1Va9Swm1U=
|
github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0NglqmlZ4=
|
||||||
github.com/gin-contrib/static v0.0.1/go.mod h1:CSxeF+wep05e0kCOsqWdAWbSszmc31zTIbD8TvWl7Hs=
|
github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw=
|
||||||
github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
|
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||||
github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
|
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
|
||||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
|
|
||||||
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
|
|
||||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||||
github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
|
|
||||||
github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA=
|
|
||||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||||
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
|
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
|
||||||
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
|
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||||
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
|
|
||||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
|
||||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||||
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||||
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
|
||||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
|
||||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
|
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||||
|
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
|
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
|
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
|
||||||
|
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
|
||||||
|
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
|
||||||
|
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
|
||||||
|
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
|
||||||
|
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
|
||||||
|
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
||||||
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
|
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||||
|
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||||
|
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||||
|
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||||
|
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
|
||||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
|
||||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
|
github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
|
||||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
|
||||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBYGmXdxA=
|
||||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E=
|
||||||
|
github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
|
||||||
|
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
|
||||||
|
github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o=
|
||||||
|
github.com/gorilla/context v1.1.2/go.mod h1:KDPwT9i/MeWHiLl90fuTgrt4/wPcv75vFAZLaOOcbxM=
|
||||||
|
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||||
|
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||||
|
github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY=
|
||||||
|
github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||||
|
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
|
||||||
|
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
|
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
|
||||||
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
|
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
|
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
|
||||||
|
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
|
||||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
||||||
|
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
||||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
|
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||||
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
||||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
|
||||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
|
||||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
|
||||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
|
||||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
|
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||||
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
|
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||||
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
|
||||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
|
||||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
|
||||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
|
||||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
|
||||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||||
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
|
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
|
||||||
|
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
|
||||||
|
github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=
|
||||||
|
github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
|
||||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
|
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||||
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
|
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
|
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
||||||
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
||||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg=
|
||||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
|
||||||
|
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
|
||||||
|
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
|
||||||
|
go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI=
|
||||||
|
go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco=
|
||||||
|
go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI=
|
||||||
|
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
|
||||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
|
||||||
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
|
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
||||||
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
|
||||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
|
||||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
|
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
|
||||||
|
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||||
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
|
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
|
||||||
|
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||||
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||||
|
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
||||||
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||||
|
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||||
|
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||||
|
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
google.golang.org/api v0.187.0 h1:Mxs7VATVC2v7CY+7Xwm4ndkX71hpElcvx0D1Ji/p1eo=
|
||||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
google.golang.org/api v0.187.0/go.mod h1:KIHlTc4x7N7gKKuVsdmfBXN13yEEWXWFURWY6SBp2gk=
|
||||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||||
|
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||||
|
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
||||||
|
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
|
||||||
|
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 h1:MuYw1wJzT+ZkybKfaOXKp5hJiZDn2iHaXRw0mRYdHSc=
|
||||||
|
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4/go.mod h1:px9SlOOZBg1wM1zdnr8jEL4CNGUBZ+ZKYtNPApNQc4c=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d h1:k3zyW3BYYR30e8v3x0bTDdE9vpYFjZHK+HcyqkrppWk=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
|
||||||
|
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||||
|
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
|
||||||
|
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
|
||||||
|
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
|
||||||
|
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
|
||||||
|
google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA=
|
||||||
|
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
|
||||||
|
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||||
|
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||||
|
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||||
|
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
|
||||||
|
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||||
|
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||||
|
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||||
|
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||||
|
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
|
||||||
|
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||||
|
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
|
||||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
|
||||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
|
||||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
|
||||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k=
|
gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8=
|
||||||
gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
|
gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
|
||||||
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
|
gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM=
|
||||||
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
|
gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA=
|
||||||
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
|
gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E=
|
||||||
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
|
gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE=
|
||||||
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
|
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||||
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
|
gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=
|
||||||
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
|
gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||||
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||||
|
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||||
|
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
|
||||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||||
|
35
i18n/en.json
35
i18n/en.json
@ -8,12 +8,12 @@
|
|||||||
"确认删除": "Confirm Delete",
|
"确认删除": "Confirm Delete",
|
||||||
"确认绑定": "Confirm Binding",
|
"确认绑定": "Confirm Binding",
|
||||||
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.",
|
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.",
|
||||||
"\"通道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"",
|
"\"渠道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"",
|
||||||
"通道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s",
|
"渠道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s",
|
||||||
"测试已在运行中": "Test is already running",
|
"测试已在运行中": "Test is already running",
|
||||||
"响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs",
|
"响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs",
|
||||||
"通道测试完成": "Channel test completed",
|
"渠道测试完成": "Channel test completed",
|
||||||
"通道测试完成,如果没有收到禁用通知,说明所有通道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal",
|
"渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal",
|
||||||
"无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!",
|
"无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!",
|
||||||
"返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!",
|
"返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!",
|
||||||
"管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub",
|
"管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub",
|
||||||
@ -119,11 +119,11 @@
|
|||||||
" 个月 ": " M ",
|
" 个月 ": " M ",
|
||||||
" 年 ": " y ",
|
" 年 ": " y ",
|
||||||
"未测试": "Not tested",
|
"未测试": "Not tested",
|
||||||
"通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
|
"渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
|
||||||
"已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
|
"已成功开始测试所有渠道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
|
||||||
"已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
|
"已成功开始测试所有已启用渠道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
|
||||||
"通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
|
"渠道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
|
||||||
"已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!",
|
"已更新完毕所有已启用渠道余额!": "The balance of all enabled channels has been updated!",
|
||||||
"搜索渠道的 ID,名称和密钥 ...": "Search for channel ID, name and key ...",
|
"搜索渠道的 ID,名称和密钥 ...": "Search for channel ID, name and key ...",
|
||||||
"名称": "Name",
|
"名称": "Name",
|
||||||
"分组": "Group",
|
"分组": "Group",
|
||||||
@ -141,9 +141,9 @@
|
|||||||
"启用": "Enable",
|
"启用": "Enable",
|
||||||
"编辑": "Edit",
|
"编辑": "Edit",
|
||||||
"添加新的渠道": "Add a new channel",
|
"添加新的渠道": "Add a new channel",
|
||||||
"测试所有通道": "Test all channels",
|
"测试所有渠道": "Test all channels",
|
||||||
"测试所有已启用通道": "Test all enabled channels",
|
"测试所有已启用渠道": "Test all enabled channels",
|
||||||
"更新所有已启用通道余额": "Update the balance of all enabled channels",
|
"更新所有已启用渠道余额": "Update the balance of all enabled channels",
|
||||||
"刷新": "Refresh",
|
"刷新": "Refresh",
|
||||||
"处理中...": "Processing...",
|
"处理中...": "Processing...",
|
||||||
"绑定成功!": "Binding succeeded!",
|
"绑定成功!": "Binding succeeded!",
|
||||||
@ -207,11 +207,11 @@
|
|||||||
"监控设置": "Monitoring Settings",
|
"监控设置": "Monitoring Settings",
|
||||||
"最长响应时间": "Longest Response Time",
|
"最长响应时间": "Longest Response Time",
|
||||||
"单位秒": "Unit in seconds",
|
"单位秒": "Unit in seconds",
|
||||||
"当运行通道全部测试时": "When all operating channels are tested",
|
"当运行渠道全部测试时": "When all operating channels are tested",
|
||||||
"超过此时间将自动禁用通道": "Channels will be automatically disabled if this time is exceeded",
|
"超过此时间将自动禁用渠道": "Channels will be automatically disabled if this time is exceeded",
|
||||||
"额度提醒阈值": "Quota reminder threshold",
|
"额度提醒阈值": "Quota reminder threshold",
|
||||||
"低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this",
|
"低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this",
|
||||||
"失败时自动禁用通道": "Automatically disable the channel when it fails",
|
"失败时自动禁用渠道": "Automatically disable the channel when it fails",
|
||||||
"保存监控设置": "Save Monitoring Settings",
|
"保存监控设置": "Save Monitoring Settings",
|
||||||
"额度设置": "Quota Settings",
|
"额度设置": "Quota Settings",
|
||||||
"新用户初始额度": "Initial quota for new users",
|
"新用户初始额度": "Initial quota for new users",
|
||||||
@ -405,7 +405,7 @@
|
|||||||
"镜像": "Mirror",
|
"镜像": "Mirror",
|
||||||
"请输入镜像站地址,格式为:https://domain.com,可不填,不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used",
|
"请输入镜像站地址,格式为:https://domain.com,可不填,不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used",
|
||||||
"模型": "Model",
|
"模型": "Model",
|
||||||
"请选择该通道所支持的模型": "Please select the model supported by the channel",
|
"请选择该渠道所支持的模型": "Please select the model supported by the channel",
|
||||||
"填入基础模型": "Fill in the basic model",
|
"填入基础模型": "Fill in the basic model",
|
||||||
"填入所有模型": "Fill in all models",
|
"填入所有模型": "Fill in all models",
|
||||||
"清除所有模型": "Clear all models",
|
"清除所有模型": "Clear all models",
|
||||||
@ -456,6 +456,7 @@
|
|||||||
"已绑定的邮箱账户": "Email Account Bound",
|
"已绑定的邮箱账户": "Email Account Bound",
|
||||||
"用户信息更新成功!": "User information updated successfully!",
|
"用户信息更新成功!": "User information updated successfully!",
|
||||||
"模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f",
|
"模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f",
|
||||||
|
"模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f": "model rate %.2f, group rate %.2f, completion rate %.2f",
|
||||||
"使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})",
|
"使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})",
|
||||||
"用户名称": "User Name",
|
"用户名称": "User Name",
|
||||||
"令牌名称": "Token Name",
|
"令牌名称": "Token Name",
|
||||||
@ -514,7 +515,7 @@
|
|||||||
"请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
|
"请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
|
||||||
"Homepage URL 填": "Fill in the Homepage URL",
|
"Homepage URL 填": "Fill in the Homepage URL",
|
||||||
"Authorization callback URL 填": "Fill in the Authorization callback URL",
|
"Authorization callback URL 填": "Fill in the Authorization callback URL",
|
||||||
"请为通道命名": "Please name the channel",
|
"请为渠道命名": "Please name the channel",
|
||||||
"此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:",
|
"此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:",
|
||||||
"模型重定向": "Model redirection",
|
"模型重定向": "Model redirection",
|
||||||
"请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel",
|
"请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel",
|
||||||
|
80
main.go
80
main.go
@ -6,11 +6,16 @@ import (
|
|||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-contrib/sessions/cookie"
|
"github.com/gin-contrib/sessions/cookie"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
_ "github.com/joho/godotenv/autoload"
|
||||||
"one-api/controller"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"one-api/middleware"
|
"github.com/songquanpeng/one-api/common/client"
|
||||||
"one-api/model"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"one-api/router"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/controller"
|
||||||
|
"github.com/songquanpeng/one-api/middleware"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/router"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
@ -19,68 +24,72 @@ import (
|
|||||||
var buildFS embed.FS
|
var buildFS embed.FS
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
common.SetupLogger()
|
common.Init()
|
||||||
common.SysLog(fmt.Sprintf("One API %s started", common.Version))
|
logger.SetupLogger()
|
||||||
if os.Getenv("GIN_MODE") != "debug" {
|
logger.SysLogf("One API %s started", common.Version)
|
||||||
|
|
||||||
|
if os.Getenv("GIN_MODE") != gin.DebugMode {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
if common.DebugEnabled {
|
if config.DebugEnabled {
|
||||||
common.SysLog("running in debug mode")
|
logger.SysLog("running in debug mode")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize SQL Database
|
// Initialize SQL Database
|
||||||
err := model.InitDB()
|
model.InitDB()
|
||||||
|
model.InitLogDB()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
err = model.CreateRootAccountIfNeed()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.FatalLog("failed to initialize database: " + err.Error())
|
logger.FatalLog("database init error: " + err.Error())
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
err := model.CloseDB()
|
err := model.CloseDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.FatalLog("failed to close database: " + err.Error())
|
logger.FatalLog("failed to close database: " + err.Error())
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Initialize Redis
|
// Initialize Redis
|
||||||
err = common.InitRedisClient()
|
err = common.InitRedisClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.FatalLog("failed to initialize Redis: " + err.Error())
|
logger.FatalLog("failed to initialize Redis: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize options
|
// Initialize options
|
||||||
model.InitOptionMap()
|
model.InitOptionMap()
|
||||||
common.SysLog(fmt.Sprintf("using theme %s", common.Theme))
|
logger.SysLog(fmt.Sprintf("using theme %s", config.Theme))
|
||||||
if common.RedisEnabled {
|
if common.RedisEnabled {
|
||||||
// for compatibility with old versions
|
// for compatibility with old versions
|
||||||
common.MemoryCacheEnabled = true
|
config.MemoryCacheEnabled = true
|
||||||
}
|
}
|
||||||
if common.MemoryCacheEnabled {
|
if config.MemoryCacheEnabled {
|
||||||
common.SysLog("memory cache enabled")
|
logger.SysLog("memory cache enabled")
|
||||||
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
logger.SysLog(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency))
|
||||||
model.InitChannelCache()
|
model.InitChannelCache()
|
||||||
}
|
}
|
||||||
if common.MemoryCacheEnabled {
|
if config.MemoryCacheEnabled {
|
||||||
go model.SyncOptions(common.SyncFrequency)
|
go model.SyncOptions(config.SyncFrequency)
|
||||||
go model.SyncChannelCache(common.SyncFrequency)
|
go model.SyncChannelCache(config.SyncFrequency)
|
||||||
}
|
|
||||||
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
|
|
||||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
|
|
||||||
if err != nil {
|
|
||||||
common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
|
|
||||||
}
|
|
||||||
go controller.AutomaticallyUpdateChannels(frequency)
|
|
||||||
}
|
}
|
||||||
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
|
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
|
||||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
|
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
|
logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
|
||||||
}
|
}
|
||||||
go controller.AutomaticallyTestChannels(frequency)
|
go controller.AutomaticallyTestChannels(frequency)
|
||||||
}
|
}
|
||||||
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
||||||
common.BatchUpdateEnabled = true
|
config.BatchUpdateEnabled = true
|
||||||
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s")
|
||||||
model.InitBatchUpdater()
|
model.InitBatchUpdater()
|
||||||
}
|
}
|
||||||
controller.InitTokenEncoders()
|
if config.EnableMetric {
|
||||||
|
logger.SysLog("metric enabled, will disable channel if too much request failed")
|
||||||
|
}
|
||||||
|
openai.InitTokenEncoders()
|
||||||
|
client.Init()
|
||||||
|
|
||||||
// Initialize HTTP server
|
// Initialize HTTP server
|
||||||
server := gin.New()
|
server := gin.New()
|
||||||
@ -90,7 +99,7 @@ func main() {
|
|||||||
server.Use(middleware.RequestId())
|
server.Use(middleware.RequestId())
|
||||||
middleware.SetUpLogger(server)
|
middleware.SetUpLogger(server)
|
||||||
// Initialize session store
|
// Initialize session store
|
||||||
store := cookie.NewStore([]byte(common.SessionSecret))
|
store := cookie.NewStore([]byte(config.SessionSecret))
|
||||||
server.Use(sessions.Sessions("session", store))
|
server.Use(sessions.Sessions("session", store))
|
||||||
|
|
||||||
router.SetRouter(server, buildFS)
|
router.SetRouter(server, buildFS)
|
||||||
@ -98,8 +107,9 @@ func main() {
|
|||||||
if port == "" {
|
if port == "" {
|
||||||
port = strconv.Itoa(*common.Port)
|
port = strconv.Itoa(*common.Port)
|
||||||
}
|
}
|
||||||
|
logger.SysLogf("server started on http://localhost:%s", port)
|
||||||
err = server.Run(":" + port)
|
err = server.Run(":" + port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.FatalLog("failed to start HTTP server: " + err.Error())
|
logger.FatalLog("failed to start HTTP server: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/blacklist"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/common/network"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -42,11 +45,14 @@ func authHelper(c *gin.Context, minRole int) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if status.(int) == common.UserStatusDisabled {
|
if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "用户已被封禁",
|
"message": "用户已被封禁",
|
||||||
})
|
})
|
||||||
|
session := sessions.Default(c)
|
||||||
|
session.Clear()
|
||||||
|
_ = session.Save()
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -66,24 +72,25 @@ func authHelper(c *gin.Context, minRole int) {
|
|||||||
|
|
||||||
func UserAuth() func(c *gin.Context) {
|
func UserAuth() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
authHelper(c, common.RoleCommonUser)
|
authHelper(c, model.RoleCommonUser)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func AdminAuth() func(c *gin.Context) {
|
func AdminAuth() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
authHelper(c, common.RoleAdminUser)
|
authHelper(c, model.RoleAdminUser)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RootAuth() func(c *gin.Context) {
|
func RootAuth() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
authHelper(c, common.RoleRootUser)
|
authHelper(c, model.RoleRootUser)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TokenAuth() func(c *gin.Context) {
|
func TokenAuth() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
key := c.Request.Header.Get("Authorization")
|
key := c.Request.Header.Get("Authorization")
|
||||||
key = strings.TrimPrefix(key, "Bearer ")
|
key = strings.TrimPrefix(key, "Bearer ")
|
||||||
key = strings.TrimPrefix(key, "sk-")
|
key = strings.TrimPrefix(key, "sk-")
|
||||||
@ -94,26 +101,67 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if token.Subnet != nil && *token.Subnet != "" {
|
||||||
|
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
|
||||||
|
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !userEnabled {
|
if !userEnabled || blacklist.IsUserBanned(token.UserId) {
|
||||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("id", token.UserId)
|
requestModel, err := getRequestModel(c)
|
||||||
c.Set("token_id", token.Id)
|
if err != nil && shouldCheckModel(c) {
|
||||||
c.Set("token_name", token.Name)
|
abortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(ctxkey.RequestModel, requestModel)
|
||||||
|
if token.Models != nil && *token.Models != "" {
|
||||||
|
c.Set(ctxkey.AvailableModels, *token.Models)
|
||||||
|
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
|
||||||
|
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Set(ctxkey.Id, token.UserId)
|
||||||
|
c.Set(ctxkey.TokenId, token.Id)
|
||||||
|
c.Set(ctxkey.TokenName, token.Name)
|
||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
c.Set("channelId", parts[1])
|
c.Set(ctxkey.SpecificChannelId, parts[1])
|
||||||
} else {
|
} else {
|
||||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set channel id for proxy relay
|
||||||
|
if channelId := c.Param("channelid"); channelId != "" {
|
||||||
|
c.Set(ctxkey.SpecificChannelId, channelId)
|
||||||
|
}
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldCheckModel(c *gin.Context) bool {
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/images") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -2,26 +2,27 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ModelRequest struct {
|
type ModelRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model" form:"model"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func Distribute() func(c *gin.Context) {
|
func Distribute() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt(ctxkey.Id)
|
||||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||||
c.Set("group", userGroup)
|
c.Set(ctxkey.Group, userGroup)
|
||||||
|
var requestModel string
|
||||||
var channel *model.Channel
|
var channel *model.Channel
|
||||||
channelId, ok := c.Get("channelId")
|
channelId, ok := c.Get(ctxkey.SpecificChannelId)
|
||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -33,67 +34,62 @@ func Distribute() func(c *gin.Context) {
|
|||||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel.Status != common.ChannelStatusEnabled {
|
if channel.Status != model.ChannelStatusEnabled {
|
||||||
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Select a channel for the user
|
requestModel = c.GetString(ctxkey.RequestModel)
|
||||||
var modelRequest ModelRequest
|
var err error
|
||||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel)
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
|
||||||
if modelRequest.Model == "" {
|
|
||||||
modelRequest.Model = "text-moderation-stable"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
|
||||||
if modelRequest.Model == "" {
|
|
||||||
modelRequest.Model = c.Param("model")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
|
||||||
if modelRequest.Model == "" {
|
|
||||||
modelRequest.Model = "dall-e-2"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
|
||||||
if modelRequest.Model == "" {
|
|
||||||
modelRequest.Model = "whisper-1"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
|
||||||
if err != nil {
|
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
message = "数据库一致性已被破坏,请联系管理员"
|
||||||
}
|
}
|
||||||
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.Set("channel", channel.Type)
|
SetupContextForSelectedChannel(c, channel, requestModel)
|
||||||
c.Set("channel_id", channel.Id)
|
|
||||||
c.Set("channel_name", channel.Name)
|
|
||||||
c.Set("model_mapping", channel.GetModelMapping())
|
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
|
||||||
switch channel.Type {
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeGemini:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeAIProxyLibrary:
|
|
||||||
c.Set("library_id", channel.Other)
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
c.Set("plugin", channel.Other)
|
|
||||||
}
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
||||||
|
c.Set(ctxkey.Channel, channel.Type)
|
||||||
|
c.Set(ctxkey.ChannelId, channel.Id)
|
||||||
|
c.Set(ctxkey.ChannelName, channel.Name)
|
||||||
|
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
|
||||||
|
c.Set(ctxkey.OriginalModel, modelName) // for retry
|
||||||
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
|
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
|
||||||
|
cfg, _ := channel.LoadConfig()
|
||||||
|
// this is for backward compatibility
|
||||||
|
if channel.Other != nil {
|
||||||
|
switch channel.Type {
|
||||||
|
case channeltype.Azure:
|
||||||
|
if cfg.APIVersion == "" {
|
||||||
|
cfg.APIVersion = *channel.Other
|
||||||
|
}
|
||||||
|
case channeltype.Xunfei:
|
||||||
|
if cfg.APIVersion == "" {
|
||||||
|
cfg.APIVersion = *channel.Other
|
||||||
|
}
|
||||||
|
case channeltype.Gemini:
|
||||||
|
if cfg.APIVersion == "" {
|
||||||
|
cfg.APIVersion = *channel.Other
|
||||||
|
}
|
||||||
|
case channeltype.AIProxyLibrary:
|
||||||
|
if cfg.LibraryID == "" {
|
||||||
|
cfg.LibraryID = *channel.Other
|
||||||
|
}
|
||||||
|
case channeltype.Ali:
|
||||||
|
if cfg.Plugin == "" {
|
||||||
|
cfg.Plugin = *channel.Other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Set(ctxkey.Config, cfg)
|
||||||
|
}
|
||||||
|
@ -3,14 +3,14 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetUpLogger(server *gin.Engine) {
|
func SetUpLogger(server *gin.Engine) {
|
||||||
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||||
var requestID string
|
var requestID string
|
||||||
if param.Keys != nil {
|
if param.Keys != nil {
|
||||||
requestID = param.Keys[common.RequestIdKey].(string)
|
requestID = param.Keys[helper.RequestIdKey].(string)
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
|
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
|
||||||
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||||
|
@ -3,10 +3,12 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
var timeFormat = "2006-01-02T15:04:05.000Z"
|
var timeFormat = "2006-01-02T15:04:05.000Z"
|
||||||
@ -26,7 +28,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
|
|||||||
}
|
}
|
||||||
if listLength < int64(maxRequestNum) {
|
if listLength < int64(maxRequestNum) {
|
||||||
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
||||||
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
|
||||||
} else {
|
} else {
|
||||||
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
|
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
|
||||||
oldTime, err := time.Parse(timeFormat, oldTimeStr)
|
oldTime, err := time.Parse(timeFormat, oldTimeStr)
|
||||||
@ -47,14 +49,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
|
|||||||
// time.Since will return negative number!
|
// time.Since will return negative number!
|
||||||
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
|
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
|
||||||
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
|
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
|
||||||
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
|
||||||
c.Status(http.StatusTooManyRequests)
|
c.Status(http.StatusTooManyRequests)
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
||||||
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
|
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
|
||||||
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -69,13 +71,18 @@ func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
|
func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
|
||||||
|
if maxRequestNum == 0 {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
if common.RedisEnabled {
|
if common.RedisEnabled {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
redisRateLimiter(c, maxRequestNum, duration, mark)
|
redisRateLimiter(c, maxRequestNum, duration, mark)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// It's safe to call multi times.
|
// It's safe to call multi times.
|
||||||
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
|
inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration)
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
memoryRateLimiter(c, maxRequestNum, duration, mark)
|
memoryRateLimiter(c, maxRequestNum, duration, mark)
|
||||||
}
|
}
|
||||||
@ -83,21 +90,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GlobalWebRateLimit() func(c *gin.Context) {
|
func GlobalWebRateLimit() func(c *gin.Context) {
|
||||||
return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW")
|
return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW")
|
||||||
}
|
}
|
||||||
|
|
||||||
func GlobalAPIRateLimit() func(c *gin.Context) {
|
func GlobalAPIRateLimit() func(c *gin.Context) {
|
||||||
return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA")
|
return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA")
|
||||||
}
|
}
|
||||||
|
|
||||||
func CriticalRateLimit() func(c *gin.Context) {
|
func CriticalRateLimit() func(c *gin.Context) {
|
||||||
return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT")
|
return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT")
|
||||||
}
|
}
|
||||||
|
|
||||||
func DownloadRateLimit() func(c *gin.Context) {
|
func DownloadRateLimit() func(c *gin.Context) {
|
||||||
return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW")
|
return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW")
|
||||||
}
|
}
|
||||||
|
|
||||||
func UploadRateLimit() func(c *gin.Context) {
|
func UploadRateLimit() func(c *gin.Context) {
|
||||||
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
|
return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP")
|
||||||
}
|
}
|
||||||
|
@ -3,8 +3,9 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -12,11 +13,15 @@ func RelayPanicRecover() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
common.SysError(fmt.Sprintf("panic detected: %v", err))
|
ctx := c.Request.Context()
|
||||||
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err))
|
||||||
|
logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
||||||
|
logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path))
|
||||||
|
body, _ := common.GetRequestBody(c)
|
||||||
|
logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body)))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
|
"message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err),
|
||||||
"type": "one_api_panic",
|
"type": "one_api_panic",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -3,16 +3,16 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RequestId() func(c *gin.Context) {
|
func RequestId() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
id := common.GetTimeString() + common.GetRandomString(8)
|
id := helper.GenRequestID()
|
||||||
c.Set(common.RequestIdKey, id)
|
c.Set(helper.RequestIdKey, id)
|
||||||
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
|
ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id)
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
c.Header(common.RequestIdKey, id)
|
c.Header(helper.RequestIdKey, id)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,9 +4,10 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type turnstileCheckResponse struct {
|
type turnstileCheckResponse struct {
|
||||||
@ -15,7 +16,7 @@ type turnstileCheckResponse struct {
|
|||||||
|
|
||||||
func TurnstileCheck() gin.HandlerFunc {
|
func TurnstileCheck() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
if common.TurnstileCheckEnabled {
|
if config.TurnstileCheckEnabled {
|
||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
turnstileChecked := session.Get("turnstile")
|
turnstileChecked := session.Get("turnstile")
|
||||||
if turnstileChecked != nil {
|
if turnstileChecked != nil {
|
||||||
@ -32,12 +33,12 @@ func TurnstileCheck() gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
|
rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
|
||||||
"secret": {common.TurnstileSecretKey},
|
"secret": {config.TurnstileSecretKey},
|
||||||
"response": {response},
|
"response": {response},
|
||||||
"remoteip": {c.ClientIP()},
|
"remoteip": {c.ClientIP()},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(err.Error())
|
logger.SysError(err.Error())
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc {
|
|||||||
var res turnstileCheckResponse
|
var res turnstileCheckResponse
|
||||||
err = json.NewDecoder(rawRes.Body).Decode(&res)
|
err = json.NewDecoder(rawRes.Body).Decode(&res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(err.Error())
|
logger.SysError(err.Error())
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
|
@ -1,17 +1,60 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||||
c.JSON(statusCode, gin.H{
|
c.JSON(statusCode, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
"message": helper.MessageWithRequestId(message, c.GetString(helper.RequestIdKey)),
|
||||||
"type": "one_api_error",
|
"type": "one_api_error",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
c.Abort()
|
c.Abort()
|
||||||
common.LogError(c.Request.Context(), message)
|
logger.Error(c.Request.Context(), message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRequestModel(c *gin.Context) (string, error) {
|
||||||
|
var modelRequest ModelRequest
|
||||||
|
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
|
if modelRequest.Model == "" {
|
||||||
|
modelRequest.Model = "text-moderation-stable"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||||
|
if modelRequest.Model == "" {
|
||||||
|
modelRequest.Model = c.Param("model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
|
if modelRequest.Model == "" {
|
||||||
|
modelRequest.Model = "dall-e-2"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||||
|
if modelRequest.Model == "" {
|
||||||
|
modelRequest.Model = "whisper-1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modelRequest.Model, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isModelInList(modelName string, models string) bool {
|
||||||
|
modelList := strings.Split(models, ",")
|
||||||
|
for _, model := range modelList {
|
||||||
|
if modelName == model {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"context"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -13,7 +16,7 @@ type Ability struct {
|
|||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
|
||||||
ability := Ability{}
|
ability := Ability{}
|
||||||
groupCol := "`group`"
|
groupCol := "`group`"
|
||||||
trueVal := "1"
|
trueVal := "1"
|
||||||
@ -23,8 +26,13 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error = nil
|
var err error = nil
|
||||||
|
var channelQuery *gorm.DB
|
||||||
|
if ignoreFirstPriority {
|
||||||
|
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||||
|
} else {
|
||||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||||
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||||
|
}
|
||||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||||
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
||||||
} else {
|
} else {
|
||||||
@ -49,7 +57,7 @@ func (channel *Channel) AddAbilities() error {
|
|||||||
Group: group,
|
Group: group,
|
||||||
Model: model,
|
Model: model,
|
||||||
ChannelId: channel.Id,
|
ChannelId: channel.Id,
|
||||||
Enabled: channel.Status == common.ChannelStatusEnabled,
|
Enabled: channel.Status == ChannelStatusEnabled,
|
||||||
Priority: channel.Priority,
|
Priority: channel.Priority,
|
||||||
}
|
}
|
||||||
abilities = append(abilities, ability)
|
abilities = append(abilities, ability)
|
||||||
@ -82,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error {
|
|||||||
func UpdateAbilityStatus(channelId int, status bool) error {
|
func UpdateAbilityStatus(channelId int, status bool) error {
|
||||||
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
|
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetGroupModels(ctx context.Context, group string) ([]string, error) {
|
||||||
|
groupCol := "`group`"
|
||||||
|
trueVal := "1"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
groupCol = `"group"`
|
||||||
|
trueVal = "true"
|
||||||
|
}
|
||||||
|
var models []string
|
||||||
|
err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sort.Strings(models)
|
||||||
|
return models, err
|
||||||
|
}
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"one-api/common"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -14,10 +18,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
TokenCacheSeconds = common.SyncFrequency
|
TokenCacheSeconds = config.SyncFrequency
|
||||||
UserId2GroupCacheSeconds = common.SyncFrequency
|
UserId2GroupCacheSeconds = config.SyncFrequency
|
||||||
UserId2QuotaCacheSeconds = common.SyncFrequency
|
UserId2QuotaCacheSeconds = config.SyncFrequency
|
||||||
UserId2StatusCacheSeconds = common.SyncFrequency
|
UserId2StatusCacheSeconds = config.SyncFrequency
|
||||||
|
GroupModelsCacheSeconds = config.SyncFrequency
|
||||||
)
|
)
|
||||||
|
|
||||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||||
@ -42,7 +47,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
|
|||||||
}
|
}
|
||||||
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
|
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Redis set token error: " + err.Error())
|
logger.SysError("Redis set token error: " + err.Error())
|
||||||
}
|
}
|
||||||
return &token, nil
|
return &token, nil
|
||||||
}
|
}
|
||||||
@ -62,37 +67,48 @@ func CacheGetUserGroup(id int) (group string, err error) {
|
|||||||
}
|
}
|
||||||
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Redis set user group error: " + err.Error())
|
logger.SysError("Redis set user group error: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return group, err
|
return group, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheGetUserQuota(id int) (quota int, err error) {
|
func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) {
|
||||||
if !common.RedisEnabled {
|
|
||||||
return GetUserQuota(id)
|
|
||||||
}
|
|
||||||
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
|
|
||||||
if err != nil {
|
|
||||||
quota, err = GetUserQuota(id)
|
quota, err = GetUserQuota(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Redis set user quota error: " + err.Error())
|
logger.Error(ctx, "Redis set user quota error: "+err.Error())
|
||||||
}
|
}
|
||||||
return quota, err
|
return
|
||||||
}
|
|
||||||
quota, err = strconv.Atoi(quotaString)
|
|
||||||
return quota, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheUpdateUserQuota(id int) error {
|
func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) {
|
||||||
|
if !common.RedisEnabled {
|
||||||
|
return GetUserQuota(id)
|
||||||
|
}
|
||||||
|
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
|
||||||
|
if err != nil {
|
||||||
|
return fetchAndUpdateUserQuota(ctx, id)
|
||||||
|
}
|
||||||
|
quota, err = strconv.ParseInt(quotaString, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db
|
||||||
|
logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id)
|
||||||
|
return fetchAndUpdateUserQuota(ctx, id)
|
||||||
|
}
|
||||||
|
return quota, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CacheUpdateUserQuota(ctx context.Context, id int) error {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
quota, err := GetUserQuota(id)
|
quota, err := CacheGetUserQuota(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -100,7 +116,7 @@ func CacheUpdateUserQuota(id int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheDecreaseUserQuota(id int, quota int) error {
|
func CacheDecreaseUserQuota(id int, quota int64) error {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -127,18 +143,37 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
|||||||
}
|
}
|
||||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Redis set user enabled error: " + err.Error())
|
logger.SysError("Redis set user enabled error: " + err.Error())
|
||||||
}
|
}
|
||||||
return userEnabled, err
|
return userEnabled, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
|
||||||
|
if !common.RedisEnabled {
|
||||||
|
return GetGroupModels(ctx, group)
|
||||||
|
}
|
||||||
|
modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group))
|
||||||
|
if err == nil {
|
||||||
|
return strings.Split(modelsStr, ","), nil
|
||||||
|
}
|
||||||
|
models, err := GetGroupModels(ctx, group)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError("Redis set group models error: " + err.Error())
|
||||||
|
}
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|
||||||
var group2model2channels map[string]map[string][]*Channel
|
var group2model2channels map[string]map[string][]*Channel
|
||||||
var channelSyncLock sync.RWMutex
|
var channelSyncLock sync.RWMutex
|
||||||
|
|
||||||
func InitChannelCache() {
|
func InitChannelCache() {
|
||||||
newChannelId2channel := make(map[int]*Channel)
|
newChannelId2channel := make(map[int]*Channel)
|
||||||
var channels []*Channel
|
var channels []*Channel
|
||||||
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
|
DB.Where("status = ?", ChannelStatusEnabled).Find(&channels)
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
newChannelId2channel[channel.Id] = channel
|
newChannelId2channel[channel.Id] = channel
|
||||||
}
|
}
|
||||||
@ -178,20 +213,20 @@ func InitChannelCache() {
|
|||||||
channelSyncLock.Lock()
|
channelSyncLock.Lock()
|
||||||
group2model2channels = newGroup2model2channels
|
group2model2channels = newGroup2model2channels
|
||||||
channelSyncLock.Unlock()
|
channelSyncLock.Unlock()
|
||||||
common.SysLog("channels synced from database")
|
logger.SysLog("channels synced from database")
|
||||||
}
|
}
|
||||||
|
|
||||||
func SyncChannelCache(frequency int) {
|
func SyncChannelCache(frequency int) {
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Second)
|
time.Sleep(time.Duration(frequency) * time.Second)
|
||||||
common.SysLog("syncing channels from database")
|
logger.SysLog("syncing channels from database")
|
||||||
InitChannelCache()
|
InitChannelCache()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
|
||||||
if !common.MemoryCacheEnabled {
|
if !config.MemoryCacheEnabled {
|
||||||
return GetRandomSatisfiedChannel(group, model)
|
return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority)
|
||||||
}
|
}
|
||||||
channelSyncLock.RLock()
|
channelSyncLock.RLock()
|
||||||
defer channelSyncLock.RUnlock()
|
defer channelSyncLock.RUnlock()
|
||||||
@ -211,5 +246,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
idx := rand.Intn(endIdx)
|
idx := rand.Intn(endIdx)
|
||||||
|
if ignoreFirstPriority {
|
||||||
|
if endIdx < len(channels) { // which means there are more than one priority
|
||||||
|
idx = random.RandRange(endIdx, len(channels))
|
||||||
|
}
|
||||||
|
}
|
||||||
return channels[idx], nil
|
return channels[idx], nil
|
||||||
}
|
}
|
||||||
|
@ -1,14 +1,26 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ChannelStatusUnknown = 0
|
||||||
|
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||||
|
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
||||||
|
ChannelStatusAutoDisabled = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
type Channel struct {
|
type Channel struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Type int `json:"type" gorm:"default:0"`
|
Type int `json:"type" gorm:"default:0"`
|
||||||
Key string `json:"key" gorm:"not null;index"`
|
Key string `json:"key" gorm:"type:text"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
Name string `json:"name" gorm:"index"`
|
Name string `json:"name" gorm:"index"`
|
||||||
Weight *uint `json:"weight" gorm:"default:0"`
|
Weight *uint `json:"weight" gorm:"default:0"`
|
||||||
@ -16,7 +28,7 @@ type Channel struct {
|
|||||||
TestTime int64 `json:"test_time" gorm:"bigint"`
|
TestTime int64 `json:"test_time" gorm:"bigint"`
|
||||||
ResponseTime int `json:"response_time"` // in milliseconds
|
ResponseTime int `json:"response_time"` // in milliseconds
|
||||||
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
|
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
|
||||||
Other string `json:"other"`
|
Other *string `json:"other"` // DEPRECATED: please save config to field Config
|
||||||
Balance float64 `json:"balance"` // in USD
|
Balance float64 `json:"balance"` // in USD
|
||||||
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
||||||
Models string `json:"models"`
|
Models string `json:"models"`
|
||||||
@ -24,25 +36,37 @@ type Channel struct {
|
|||||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||||
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
|
Config string `json:"config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
type ChannelConfig struct {
|
||||||
|
Region string `json:"region,omitempty"`
|
||||||
|
SK string `json:"sk,omitempty"`
|
||||||
|
AK string `json:"ak,omitempty"`
|
||||||
|
UserID string `json:"user_id,omitempty"`
|
||||||
|
APIVersion string `json:"api_version,omitempty"`
|
||||||
|
LibraryID string `json:"library_id,omitempty"`
|
||||||
|
Plugin string `json:"plugin,omitempty"`
|
||||||
|
VertexAIProjectID string `json:"vertex_ai_project_id,omitempty"`
|
||||||
|
VertexAIADC string `json:"vertex_ai_adc,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
|
||||||
var channels []*Channel
|
var channels []*Channel
|
||||||
var err error
|
var err error
|
||||||
if selectAll {
|
switch scope {
|
||||||
|
case "all":
|
||||||
err = DB.Order("id desc").Find(&channels).Error
|
err = DB.Order("id desc").Find(&channels).Error
|
||||||
} else {
|
case "disabled":
|
||||||
|
err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error
|
||||||
|
default:
|
||||||
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||||||
}
|
}
|
||||||
return channels, err
|
return channels, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchChannels(keyword string) (channels []*Channel, err error) {
|
func SearchChannels(keyword string) (channels []*Channel, err error) {
|
||||||
keyCol := "`key`"
|
err = DB.Omit("key").Where("id = ? or name LIKE ?", helper.String2Int(keyword), keyword+"%").Find(&channels).Error
|
||||||
if common.UsingPostgreSQL {
|
|
||||||
keyCol = `"key"`
|
|
||||||
}
|
|
||||||
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
|
|
||||||
return channels, err
|
return channels, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,11 +110,17 @@ func (channel *Channel) GetBaseURL() string {
|
|||||||
return *channel.BaseURL
|
return *channel.BaseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) GetModelMapping() string {
|
func (channel *Channel) GetModelMapping() map[string]string {
|
||||||
if channel.ModelMapping == nil {
|
if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" {
|
||||||
return ""
|
return nil
|
||||||
}
|
}
|
||||||
return *channel.ModelMapping
|
modelMapping := make(map[string]string)
|
||||||
|
err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return modelMapping
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) Insert() error {
|
func (channel *Channel) Insert() error {
|
||||||
@ -116,21 +146,21 @@ func (channel *Channel) Update() error {
|
|||||||
|
|
||||||
func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
||||||
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
|
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
|
||||||
TestTime: common.GetTimestamp(),
|
TestTime: helper.GetTimestamp(),
|
||||||
ResponseTime: int(responseTime),
|
ResponseTime: int(responseTime),
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update response time: " + err.Error())
|
logger.SysError("failed to update response time: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) UpdateBalance(balance float64) {
|
func (channel *Channel) UpdateBalance(balance float64) {
|
||||||
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
|
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
|
||||||
BalanceUpdatedTime: common.GetTimestamp(),
|
BalanceUpdatedTime: helper.GetTimestamp(),
|
||||||
Balance: balance,
|
Balance: balance,
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update balance: " + err.Error())
|
logger.SysError("failed to update balance: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,29 +174,41 @@ func (channel *Channel) Delete() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateChannelStatusById(id int, status int) {
|
func (channel *Channel) LoadConfig() (ChannelConfig, error) {
|
||||||
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
var cfg ChannelConfig
|
||||||
|
if channel.Config == "" {
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
err := json.Unmarshal([]byte(channel.Config), &cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update ability status: " + err.Error())
|
return cfg, err
|
||||||
|
}
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateChannelStatusById(id int, status int) {
|
||||||
|
err := UpdateAbilityStatus(id, status == ChannelStatusEnabled)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError("failed to update ability status: " + err.Error())
|
||||||
}
|
}
|
||||||
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
|
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update channel status: " + err.Error())
|
logger.SysError("failed to update channel status: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateChannelUsedQuota(id int, quota int) {
|
func UpdateChannelUsedQuota(id int, quota int64) {
|
||||||
if common.BatchUpdateEnabled {
|
if config.BatchUpdateEnabled {
|
||||||
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
|
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
updateChannelUsedQuota(id, quota)
|
updateChannelUsedQuota(id, quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateChannelUsedQuota(id int, quota int) {
|
func updateChannelUsedQuota(id int, quota int64) {
|
||||||
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update channel used quota: " + err.Error())
|
logger.SysError("failed to update channel used quota: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,6 +218,6 @@ func DeleteChannelByStatus(status int64) (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DeleteDisabledChannel() (int64, error) {
|
func DeleteDisabledChannel() (int64, error) {
|
||||||
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
|
result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{})
|
||||||
return result.RowsAffected, result.Error
|
return result.RowsAffected, result.Error
|
||||||
}
|
}
|
||||||
|
72
model/log.go
72
model/log.go
@ -3,8 +3,11 @@ package model
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -32,52 +35,67 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func RecordLog(userId int, logType int, content string) {
|
func RecordLog(userId int, logType int, content string) {
|
||||||
if logType == LogTypeConsume && !common.LogConsumeEnabled {
|
if logType == LogTypeConsume && !config.LogConsumeEnabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log := &Log{
|
log := &Log{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Username: GetUsernameById(userId),
|
Username: GetUsernameById(userId),
|
||||||
CreatedAt: common.GetTimestamp(),
|
CreatedAt: helper.GetTimestamp(),
|
||||||
Type: logType,
|
Type: logType,
|
||||||
Content: content,
|
Content: content,
|
||||||
}
|
}
|
||||||
err := DB.Create(log).Error
|
err := LOG_DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to record log: " + err.Error())
|
logger.SysError("failed to record log: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
func RecordTopupLog(userId int, content string, quota int) {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
log := &Log{
|
||||||
if !common.LogConsumeEnabled {
|
UserId: userId,
|
||||||
|
Username: GetUsernameById(userId),
|
||||||
|
CreatedAt: helper.GetTimestamp(),
|
||||||
|
Type: LogTypeTopup,
|
||||||
|
Content: content,
|
||||||
|
Quota: quota,
|
||||||
|
}
|
||||||
|
err := LOG_DB.Create(log).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError("failed to record log: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
|
||||||
|
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||||
|
if !config.LogConsumeEnabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log := &Log{
|
log := &Log{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Username: GetUsernameById(userId),
|
Username: GetUsernameById(userId),
|
||||||
CreatedAt: common.GetTimestamp(),
|
CreatedAt: helper.GetTimestamp(),
|
||||||
Type: LogTypeConsume,
|
Type: LogTypeConsume,
|
||||||
Content: content,
|
Content: content,
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TokenName: tokenName,
|
TokenName: tokenName,
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
Quota: quota,
|
Quota: int(quota),
|
||||||
ChannelId: channelId,
|
ChannelId: channelId,
|
||||||
}
|
}
|
||||||
err := DB.Create(log).Error
|
err := LOG_DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "failed to record log: "+err.Error())
|
logger.Error(ctx, "failed to record log: "+err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
|
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
|
||||||
var tx *gorm.DB
|
var tx *gorm.DB
|
||||||
if logType == LogTypeUnknown {
|
if logType == LogTypeUnknown {
|
||||||
tx = DB
|
tx = LOG_DB
|
||||||
} else {
|
} else {
|
||||||
tx = DB.Where("type = ?", logType)
|
tx = LOG_DB.Where("type = ?", logType)
|
||||||
}
|
}
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
tx = tx.Where("model_name = ?", modelName)
|
tx = tx.Where("model_name = ?", modelName)
|
||||||
@ -104,9 +122,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
|||||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
|
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
|
||||||
var tx *gorm.DB
|
var tx *gorm.DB
|
||||||
if logType == LogTypeUnknown {
|
if logType == LogTypeUnknown {
|
||||||
tx = DB.Where("user_id = ?", userId)
|
tx = LOG_DB.Where("user_id = ?", userId)
|
||||||
} else {
|
} else {
|
||||||
tx = DB.Where("user_id = ? and type = ?", userId, logType)
|
tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType)
|
||||||
}
|
}
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
tx = tx.Where("model_name = ?", modelName)
|
tx = tx.Where("model_name = ?", modelName)
|
||||||
@ -125,17 +143,21 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SearchAllLogs(keyword string) (logs []*Log, err error) {
|
func SearchAllLogs(keyword string) (logs []*Log, err error) {
|
||||||
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
|
err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
|
||||||
return logs, err
|
return logs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
|
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
|
||||||
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
|
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
|
||||||
return logs, err
|
return logs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
|
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
|
||||||
tx := DB.Table("logs").Select("ifnull(sum(quota),0)")
|
ifnull := "ifnull"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
ifnull = "COALESCE"
|
||||||
|
}
|
||||||
|
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull))
|
||||||
if username != "" {
|
if username != "" {
|
||||||
tx = tx.Where("username = ?", username)
|
tx = tx.Where("username = ?", username)
|
||||||
}
|
}
|
||||||
@ -159,7 +181,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
|
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
|
||||||
tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
|
ifnull := "ifnull"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
ifnull = "COALESCE"
|
||||||
|
}
|
||||||
|
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull))
|
||||||
if username != "" {
|
if username != "" {
|
||||||
tx = tx.Where("username = ?", username)
|
tx = tx.Where("username = ?", username)
|
||||||
}
|
}
|
||||||
@ -180,7 +206,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DeleteOldLog(targetTimestamp int64) (int64, error) {
|
func DeleteOldLog(targetTimestamp int64) (int64, error) {
|
||||||
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
|
result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
|
||||||
return result.RowsAffected, result.Error
|
return result.RowsAffected, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -204,7 +230,7 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis
|
|||||||
groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
|
groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DB.Raw(`
|
err = LOG_DB.Raw(`
|
||||||
SELECT `+groupSelect+`,
|
SELECT `+groupSelect+`,
|
||||||
model_name, count(1) as request_count,
|
model_name, count(1) as request_count,
|
||||||
sum(quota) as quota,
|
sum(quota) as quota,
|
||||||
|
245
model/main.go
245
model/main.go
@ -1,48 +1,87 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/env"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
"gorm.io/driver/mysql"
|
"gorm.io/driver/mysql"
|
||||||
"gorm.io/driver/postgres"
|
"gorm.io/driver/postgres"
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var DB *gorm.DB
|
var DB *gorm.DB
|
||||||
|
var LOG_DB *gorm.DB
|
||||||
|
|
||||||
func createRootAccountIfNeed() error {
|
func CreateRootAccountIfNeed() error {
|
||||||
var user User
|
var user User
|
||||||
//if user.Status != common.UserStatusEnabled {
|
//if user.Status != util.UserStatusEnabled {
|
||||||
if err := DB.First(&user).Error; err != nil {
|
if err := DB.First(&user).Error; err != nil {
|
||||||
common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
|
logger.SysLog("no user exists, creating a root user for you: username is root, password is 123456")
|
||||||
hashedPassword, err := common.Password2Hash("123456")
|
hashedPassword, err := common.Password2Hash("123456")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
accessToken := random.GetUUID()
|
||||||
|
if config.InitialRootAccessToken != "" {
|
||||||
|
accessToken = config.InitialRootAccessToken
|
||||||
|
}
|
||||||
rootUser := User{
|
rootUser := User{
|
||||||
Username: "root",
|
Username: "root",
|
||||||
Password: hashedPassword,
|
Password: hashedPassword,
|
||||||
Role: common.RoleRootUser,
|
Role: RoleRootUser,
|
||||||
Status: common.UserStatusEnabled,
|
Status: UserStatusEnabled,
|
||||||
DisplayName: "Root User",
|
DisplayName: "Root User",
|
||||||
AccessToken: common.GetUUID(),
|
AccessToken: accessToken,
|
||||||
Quota: 100000000,
|
Quota: 500000000000000,
|
||||||
}
|
}
|
||||||
DB.Create(&rootUser)
|
DB.Create(&rootUser)
|
||||||
|
if config.InitialRootToken != "" {
|
||||||
|
logger.SysLog("creating initial root token as requested")
|
||||||
|
token := Token{
|
||||||
|
Id: 1,
|
||||||
|
UserId: rootUser.Id,
|
||||||
|
Key: config.InitialRootToken,
|
||||||
|
Status: TokenStatusEnabled,
|
||||||
|
Name: "Initial Root Token",
|
||||||
|
CreatedTime: helper.GetTimestamp(),
|
||||||
|
AccessedTime: helper.GetTimestamp(),
|
||||||
|
ExpiredTime: -1,
|
||||||
|
RemainQuota: 500000000000000,
|
||||||
|
UnlimitedQuota: true,
|
||||||
|
}
|
||||||
|
DB.Create(&token)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func chooseDB() (*gorm.DB, error) {
|
func chooseDB(envName string) (*gorm.DB, error) {
|
||||||
if os.Getenv("SQL_DSN") != "" {
|
dsn := os.Getenv(envName)
|
||||||
dsn := os.Getenv("SQL_DSN")
|
|
||||||
if strings.HasPrefix(dsn, "postgres://") {
|
switch {
|
||||||
|
case strings.HasPrefix(dsn, "postgres://"):
|
||||||
// Use PostgreSQL
|
// Use PostgreSQL
|
||||||
common.SysLog("using PostgreSQL as database")
|
return openPostgreSQL(dsn)
|
||||||
|
case dsn != "":
|
||||||
|
// Use MySQL
|
||||||
|
return openMySQL(dsn)
|
||||||
|
default:
|
||||||
|
// Use SQLite
|
||||||
|
return openSQLite()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func openPostgreSQL(dsn string) (*gorm.DB, error) {
|
||||||
|
logger.SysLog("using PostgreSQL as database")
|
||||||
common.UsingPostgreSQL = true
|
common.UsingPostgreSQL = true
|
||||||
return gorm.Open(postgres.New(postgres.Config{
|
return gorm.Open(postgres.New(postgres.Config{
|
||||||
DSN: dsn,
|
DSN: dsn,
|
||||||
@ -51,82 +90,148 @@ func chooseDB() (*gorm.DB, error) {
|
|||||||
PrepareStmt: true, // precompile SQL
|
PrepareStmt: true, // precompile SQL
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
// Use MySQL
|
|
||||||
common.SysLog("using MySQL as database")
|
func openMySQL(dsn string) (*gorm.DB, error) {
|
||||||
|
logger.SysLog("using MySQL as database")
|
||||||
|
common.UsingMySQL = true
|
||||||
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||||
PrepareStmt: true, // precompile SQL
|
PrepareStmt: true, // precompile SQL
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
// Use SQLite
|
|
||||||
common.SysLog("SQL_DSN not set, using SQLite as database")
|
func openSQLite() (*gorm.DB, error) {
|
||||||
|
logger.SysLog("SQL_DSN not set, using SQLite as database")
|
||||||
common.UsingSQLite = true
|
common.UsingSQLite = true
|
||||||
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
|
dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout)
|
||||||
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
|
return gorm.Open(sqlite.Open(dsn), &gorm.Config{
|
||||||
PrepareStmt: true, // precompile SQL
|
PrepareStmt: true, // precompile SQL
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitDB() (err error) {
|
func InitDB() {
|
||||||
db, err := chooseDB()
|
var err error
|
||||||
if err == nil {
|
DB, err = chooseDB("SQL_DSN")
|
||||||
if common.DebugEnabled {
|
|
||||||
db = db.Debug()
|
|
||||||
}
|
|
||||||
DB = db
|
|
||||||
sqlDB, err := DB.DB()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.FatalLog("failed to initialize database: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB := setDBConns(DB)
|
||||||
|
|
||||||
|
if !config.IsMasterNode {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.UsingMySQL {
|
||||||
|
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.SysLog("database migration started")
|
||||||
|
if err = migrateDB(); err != nil {
|
||||||
|
logger.FatalLog("failed to migrate database: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.SysLog("database migrated")
|
||||||
|
}
|
||||||
|
|
||||||
|
func migrateDB() error {
|
||||||
|
var err error
|
||||||
|
if err = DB.AutoMigrate(&Channel{}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = DB.AutoMigrate(&Token{}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = DB.AutoMigrate(&User{}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = DB.AutoMigrate(&Option{}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = DB.AutoMigrate(&Redemption{}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = DB.AutoMigrate(&Ability{}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = DB.AutoMigrate(&Log{}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = DB.AutoMigrate(&Channel{}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
|
||||||
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
|
||||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
|
|
||||||
|
|
||||||
if !common.IsMasterNode {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
common.SysLog("database migration started")
|
|
||||||
err = db.AutoMigrate(&Channel{})
|
func InitLogDB() {
|
||||||
if err != nil {
|
if os.Getenv("LOG_SQL_DSN") == "" {
|
||||||
return err
|
LOG_DB = DB
|
||||||
}
|
return
|
||||||
err = db.AutoMigrate(&Token{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = db.AutoMigrate(&User{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = db.AutoMigrate(&Option{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = db.AutoMigrate(&Redemption{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = db.AutoMigrate(&Ability{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = db.AutoMigrate(&Log{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
common.SysLog("database migrated")
|
|
||||||
err = createRootAccountIfNeed()
|
|
||||||
return err
|
|
||||||
} else {
|
|
||||||
common.FatalLog(err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CloseDB() error {
|
logger.SysLog("using secondary database for table logs")
|
||||||
sqlDB, err := DB.DB()
|
var err error
|
||||||
|
LOG_DB, err = chooseDB("LOG_SQL_DSN")
|
||||||
|
if err != nil {
|
||||||
|
logger.FatalLog("failed to initialize secondary database: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
setDBConns(LOG_DB)
|
||||||
|
|
||||||
|
if !config.IsMasterNode {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.SysLog("secondary database migration started")
|
||||||
|
err = migrateLOGDB()
|
||||||
|
if err != nil {
|
||||||
|
logger.FatalLog("failed to migrate secondary database: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.SysLog("secondary database migrated")
|
||||||
|
}
|
||||||
|
|
||||||
|
func migrateLOGDB() error {
|
||||||
|
var err error
|
||||||
|
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setDBConns(db *gorm.DB) *sql.DB {
|
||||||
|
if config.DebugSQLEnabled {
|
||||||
|
db = db.Debug()
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
logger.FatalLog("failed to connect database: " + err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
|
||||||
|
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
|
||||||
|
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
|
||||||
|
return sqlDB
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeDB(db *gorm.DB) error {
|
||||||
|
sqlDB, err := db.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = sqlDB.Close()
|
err = sqlDB.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CloseDB() error {
|
||||||
|
if LOG_DB != DB {
|
||||||
|
err := closeDB(LOG_DB)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return closeDB(DB)
|
||||||
|
}
|
||||||
|
250
model/option.go
250
model/option.go
@ -1,7 +1,9 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -20,69 +22,72 @@ func AllOption() ([]*Option, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func InitOptionMap() {
|
func InitOptionMap() {
|
||||||
common.OptionMapRWMutex.Lock()
|
config.OptionMapRWMutex.Lock()
|
||||||
common.OptionMap = make(map[string]string)
|
config.OptionMap = make(map[string]string)
|
||||||
common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
|
config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled)
|
||||||
common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
|
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
|
||||||
common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
|
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
|
||||||
common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission)
|
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
|
||||||
common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled)
|
config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled)
|
||||||
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
|
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
|
||||||
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
|
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
|
||||||
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
|
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
|
||||||
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
|
config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled)
|
||||||
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
|
config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled)
|
||||||
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
|
config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled)
|
||||||
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
|
config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled)
|
||||||
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
|
config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled)
|
||||||
common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled)
|
config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled)
|
||||||
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
|
config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64)
|
||||||
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
|
config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled)
|
||||||
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
|
config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",")
|
||||||
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
|
config.OptionMap["SMTPServer"] = ""
|
||||||
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
|
config.OptionMap["SMTPFrom"] = ""
|
||||||
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
|
config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort)
|
||||||
common.OptionMap["SMTPServer"] = ""
|
config.OptionMap["SMTPAccount"] = ""
|
||||||
common.OptionMap["SMTPFrom"] = ""
|
config.OptionMap["SMTPToken"] = ""
|
||||||
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
|
config.OptionMap["Notice"] = ""
|
||||||
common.OptionMap["SMTPAccount"] = ""
|
config.OptionMap["About"] = ""
|
||||||
common.OptionMap["SMTPToken"] = ""
|
config.OptionMap["HomePageContent"] = ""
|
||||||
common.OptionMap["Notice"] = ""
|
config.OptionMap["Footer"] = config.Footer
|
||||||
common.OptionMap["About"] = ""
|
config.OptionMap["SystemName"] = config.SystemName
|
||||||
common.OptionMap["HomePageContent"] = ""
|
config.OptionMap["Logo"] = config.Logo
|
||||||
common.OptionMap["Footer"] = common.Footer
|
config.OptionMap["ServerAddress"] = ""
|
||||||
common.OptionMap["SystemName"] = common.SystemName
|
config.OptionMap["GitHubClientId"] = ""
|
||||||
common.OptionMap["Logo"] = common.Logo
|
config.OptionMap["GitHubClientSecret"] = ""
|
||||||
common.OptionMap["ServerAddress"] = ""
|
config.OptionMap["WeChatServerAddress"] = ""
|
||||||
common.OptionMap["GitHubClientId"] = ""
|
config.OptionMap["WeChatServerToken"] = ""
|
||||||
common.OptionMap["GitHubClientSecret"] = ""
|
config.OptionMap["WeChatAccountQRCodeImageURL"] = ""
|
||||||
common.OptionMap["WeChatServerAddress"] = ""
|
config.OptionMap["MessagePusherAddress"] = ""
|
||||||
common.OptionMap["WeChatServerToken"] = ""
|
config.OptionMap["MessagePusherToken"] = ""
|
||||||
common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
|
config.OptionMap["TurnstileSiteKey"] = ""
|
||||||
common.OptionMap["TurnstileSiteKey"] = ""
|
config.OptionMap["TurnstileSecretKey"] = ""
|
||||||
common.OptionMap["TurnstileSecretKey"] = ""
|
config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10)
|
||||||
common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
|
config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10)
|
||||||
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
|
config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
|
||||||
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
|
config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
|
||||||
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
|
config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
|
||||||
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
|
config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString()
|
||||||
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString()
|
||||||
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
|
config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString()
|
||||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
config.OptionMap["TopUpLink"] = config.TopUpLink
|
||||||
common.OptionMap["ChatLink"] = common.ChatLink
|
config.OptionMap["ChatLink"] = config.ChatLink
|
||||||
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
|
config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
|
||||||
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes)
|
||||||
common.OptionMap["Theme"] = common.Theme
|
config.OptionMap["Theme"] = config.Theme
|
||||||
common.OptionMapRWMutex.Unlock()
|
config.OptionMapRWMutex.Unlock()
|
||||||
loadOptionsFromDatabase()
|
loadOptionsFromDatabase()
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadOptionsFromDatabase() {
|
func loadOptionsFromDatabase() {
|
||||||
options, _ := AllOption()
|
options, _ := AllOption()
|
||||||
for _, option := range options {
|
for _, option := range options {
|
||||||
|
if option.Key == "ModelRatio" {
|
||||||
|
option.Value = billingratio.AddNewMissingRatio(option.Value)
|
||||||
|
}
|
||||||
err := updateOptionMap(option.Key, option.Value)
|
err := updateOptionMap(option.Key, option.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update option map: " + err.Error())
|
logger.SysError("failed to update option map: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -90,7 +95,7 @@ func loadOptionsFromDatabase() {
|
|||||||
func SyncOptions(frequency int) {
|
func SyncOptions(frequency int) {
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Second)
|
time.Sleep(time.Duration(frequency) * time.Second)
|
||||||
common.SysLog("syncing options from database")
|
logger.SysLog("syncing options from database")
|
||||||
loadOptionsFromDatabase()
|
loadOptionsFromDatabase()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -112,117 +117,128 @@ func UpdateOption(key string, value string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func updateOptionMap(key string, value string) (err error) {
|
func updateOptionMap(key string, value string) (err error) {
|
||||||
common.OptionMapRWMutex.Lock()
|
config.OptionMapRWMutex.Lock()
|
||||||
defer common.OptionMapRWMutex.Unlock()
|
defer config.OptionMapRWMutex.Unlock()
|
||||||
common.OptionMap[key] = value
|
config.OptionMap[key] = value
|
||||||
if strings.HasSuffix(key, "Permission") {
|
|
||||||
intValue, _ := strconv.Atoi(value)
|
|
||||||
switch key {
|
|
||||||
case "FileUploadPermission":
|
|
||||||
common.FileUploadPermission = intValue
|
|
||||||
case "FileDownloadPermission":
|
|
||||||
common.FileDownloadPermission = intValue
|
|
||||||
case "ImageUploadPermission":
|
|
||||||
common.ImageUploadPermission = intValue
|
|
||||||
case "ImageDownloadPermission":
|
|
||||||
common.ImageDownloadPermission = intValue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(key, "Enabled") {
|
if strings.HasSuffix(key, "Enabled") {
|
||||||
boolValue := value == "true"
|
boolValue := value == "true"
|
||||||
switch key {
|
switch key {
|
||||||
case "PasswordRegisterEnabled":
|
case "PasswordRegisterEnabled":
|
||||||
common.PasswordRegisterEnabled = boolValue
|
config.PasswordRegisterEnabled = boolValue
|
||||||
case "PasswordLoginEnabled":
|
case "PasswordLoginEnabled":
|
||||||
common.PasswordLoginEnabled = boolValue
|
config.PasswordLoginEnabled = boolValue
|
||||||
case "EmailVerificationEnabled":
|
case "EmailVerificationEnabled":
|
||||||
common.EmailVerificationEnabled = boolValue
|
config.EmailVerificationEnabled = boolValue
|
||||||
case "GitHubOAuthEnabled":
|
case "GitHubOAuthEnabled":
|
||||||
common.GitHubOAuthEnabled = boolValue
|
config.GitHubOAuthEnabled = boolValue
|
||||||
|
case "OidcEnabled":
|
||||||
|
config.OidcEnabled = boolValue
|
||||||
case "WeChatAuthEnabled":
|
case "WeChatAuthEnabled":
|
||||||
common.WeChatAuthEnabled = boolValue
|
config.WeChatAuthEnabled = boolValue
|
||||||
case "TurnstileCheckEnabled":
|
case "TurnstileCheckEnabled":
|
||||||
common.TurnstileCheckEnabled = boolValue
|
config.TurnstileCheckEnabled = boolValue
|
||||||
case "RegisterEnabled":
|
case "RegisterEnabled":
|
||||||
common.RegisterEnabled = boolValue
|
config.RegisterEnabled = boolValue
|
||||||
case "EmailDomainRestrictionEnabled":
|
case "EmailDomainRestrictionEnabled":
|
||||||
common.EmailDomainRestrictionEnabled = boolValue
|
config.EmailDomainRestrictionEnabled = boolValue
|
||||||
case "AutomaticDisableChannelEnabled":
|
case "AutomaticDisableChannelEnabled":
|
||||||
common.AutomaticDisableChannelEnabled = boolValue
|
config.AutomaticDisableChannelEnabled = boolValue
|
||||||
case "AutomaticEnableChannelEnabled":
|
case "AutomaticEnableChannelEnabled":
|
||||||
common.AutomaticEnableChannelEnabled = boolValue
|
config.AutomaticEnableChannelEnabled = boolValue
|
||||||
case "ApproximateTokenEnabled":
|
case "ApproximateTokenEnabled":
|
||||||
common.ApproximateTokenEnabled = boolValue
|
config.ApproximateTokenEnabled = boolValue
|
||||||
case "LogConsumeEnabled":
|
case "LogConsumeEnabled":
|
||||||
common.LogConsumeEnabled = boolValue
|
config.LogConsumeEnabled = boolValue
|
||||||
case "DisplayInCurrencyEnabled":
|
case "DisplayInCurrencyEnabled":
|
||||||
common.DisplayInCurrencyEnabled = boolValue
|
config.DisplayInCurrencyEnabled = boolValue
|
||||||
case "DisplayTokenStatEnabled":
|
case "DisplayTokenStatEnabled":
|
||||||
common.DisplayTokenStatEnabled = boolValue
|
config.DisplayTokenStatEnabled = boolValue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch key {
|
switch key {
|
||||||
case "EmailDomainWhitelist":
|
case "EmailDomainWhitelist":
|
||||||
common.EmailDomainWhitelist = strings.Split(value, ",")
|
config.EmailDomainWhitelist = strings.Split(value, ",")
|
||||||
case "SMTPServer":
|
case "SMTPServer":
|
||||||
common.SMTPServer = value
|
config.SMTPServer = value
|
||||||
case "SMTPPort":
|
case "SMTPPort":
|
||||||
intValue, _ := strconv.Atoi(value)
|
intValue, _ := strconv.Atoi(value)
|
||||||
common.SMTPPort = intValue
|
config.SMTPPort = intValue
|
||||||
case "SMTPAccount":
|
case "SMTPAccount":
|
||||||
common.SMTPAccount = value
|
config.SMTPAccount = value
|
||||||
case "SMTPFrom":
|
case "SMTPFrom":
|
||||||
common.SMTPFrom = value
|
config.SMTPFrom = value
|
||||||
case "SMTPToken":
|
case "SMTPToken":
|
||||||
common.SMTPToken = value
|
config.SMTPToken = value
|
||||||
case "ServerAddress":
|
case "ServerAddress":
|
||||||
common.ServerAddress = value
|
config.ServerAddress = value
|
||||||
case "GitHubClientId":
|
case "GitHubClientId":
|
||||||
common.GitHubClientId = value
|
config.GitHubClientId = value
|
||||||
case "GitHubClientSecret":
|
case "GitHubClientSecret":
|
||||||
common.GitHubClientSecret = value
|
config.GitHubClientSecret = value
|
||||||
|
case "LarkClientId":
|
||||||
|
config.LarkClientId = value
|
||||||
|
case "LarkClientSecret":
|
||||||
|
config.LarkClientSecret = value
|
||||||
|
case "OidcClientId":
|
||||||
|
config.OidcClientId = value
|
||||||
|
case "OidcClientSecret":
|
||||||
|
config.OidcClientSecret = value
|
||||||
|
case "OidcWellKnown":
|
||||||
|
config.OidcWellKnown = value
|
||||||
|
case "OidcAuthorizationEndpoint":
|
||||||
|
config.OidcAuthorizationEndpoint = value
|
||||||
|
case "OidcTokenEndpoint":
|
||||||
|
config.OidcTokenEndpoint = value
|
||||||
|
case "OidcUserinfoEndpoint":
|
||||||
|
config.OidcUserinfoEndpoint = value
|
||||||
case "Footer":
|
case "Footer":
|
||||||
common.Footer = value
|
config.Footer = value
|
||||||
case "SystemName":
|
case "SystemName":
|
||||||
common.SystemName = value
|
config.SystemName = value
|
||||||
case "Logo":
|
case "Logo":
|
||||||
common.Logo = value
|
config.Logo = value
|
||||||
case "WeChatServerAddress":
|
case "WeChatServerAddress":
|
||||||
common.WeChatServerAddress = value
|
config.WeChatServerAddress = value
|
||||||
case "WeChatServerToken":
|
case "WeChatServerToken":
|
||||||
common.WeChatServerToken = value
|
config.WeChatServerToken = value
|
||||||
case "WeChatAccountQRCodeImageURL":
|
case "WeChatAccountQRCodeImageURL":
|
||||||
common.WeChatAccountQRCodeImageURL = value
|
config.WeChatAccountQRCodeImageURL = value
|
||||||
|
case "MessagePusherAddress":
|
||||||
|
config.MessagePusherAddress = value
|
||||||
|
case "MessagePusherToken":
|
||||||
|
config.MessagePusherToken = value
|
||||||
case "TurnstileSiteKey":
|
case "TurnstileSiteKey":
|
||||||
common.TurnstileSiteKey = value
|
config.TurnstileSiteKey = value
|
||||||
case "TurnstileSecretKey":
|
case "TurnstileSecretKey":
|
||||||
common.TurnstileSecretKey = value
|
config.TurnstileSecretKey = value
|
||||||
case "QuotaForNewUser":
|
case "QuotaForNewUser":
|
||||||
common.QuotaForNewUser, _ = strconv.Atoi(value)
|
config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64)
|
||||||
case "QuotaForInviter":
|
case "QuotaForInviter":
|
||||||
common.QuotaForInviter, _ = strconv.Atoi(value)
|
config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64)
|
||||||
case "QuotaForInvitee":
|
case "QuotaForInvitee":
|
||||||
common.QuotaForInvitee, _ = strconv.Atoi(value)
|
config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64)
|
||||||
case "QuotaRemindThreshold":
|
case "QuotaRemindThreshold":
|
||||||
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
|
config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64)
|
||||||
case "PreConsumedQuota":
|
case "PreConsumedQuota":
|
||||||
common.PreConsumedQuota, _ = strconv.Atoi(value)
|
config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64)
|
||||||
case "RetryTimes":
|
case "RetryTimes":
|
||||||
common.RetryTimes, _ = strconv.Atoi(value)
|
config.RetryTimes, _ = strconv.Atoi(value)
|
||||||
case "ModelRatio":
|
case "ModelRatio":
|
||||||
err = common.UpdateModelRatioByJSONString(value)
|
err = billingratio.UpdateModelRatioByJSONString(value)
|
||||||
case "GroupRatio":
|
case "GroupRatio":
|
||||||
err = common.UpdateGroupRatioByJSONString(value)
|
err = billingratio.UpdateGroupRatioByJSONString(value)
|
||||||
|
case "CompletionRatio":
|
||||||
|
err = billingratio.UpdateCompletionRatioByJSONString(value)
|
||||||
case "TopUpLink":
|
case "TopUpLink":
|
||||||
common.TopUpLink = value
|
config.TopUpLink = value
|
||||||
case "ChatLink":
|
case "ChatLink":
|
||||||
common.ChatLink = value
|
config.ChatLink = value
|
||||||
case "ChannelDisableThreshold":
|
case "ChannelDisableThreshold":
|
||||||
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
|
config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
|
||||||
case "QuotaPerUnit":
|
case "QuotaPerUnit":
|
||||||
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
|
config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
|
||||||
case "Theme":
|
case "Theme":
|
||||||
common.Theme = value
|
config.Theme = value
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -3,8 +3,15 @@ package model
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||||
|
RedemptionCodeStatusDisabled = 2 // also don't use 0
|
||||||
|
RedemptionCodeStatusUsed = 3 // also don't use 0
|
||||||
)
|
)
|
||||||
|
|
||||||
type Redemption struct {
|
type Redemption struct {
|
||||||
@ -13,7 +20,7 @@ type Redemption struct {
|
|||||||
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
|
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
Name string `json:"name" gorm:"index"`
|
Name string `json:"name" gorm:"index"`
|
||||||
Quota int `json:"quota" gorm:"default:100"`
|
Quota int64 `json:"quota" gorm:"bigint;default:100"`
|
||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
|
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
|
||||||
Count int `json:"count" gorm:"-:all"` // only for api request
|
Count int `json:"count" gorm:"-:all"` // only for api request
|
||||||
@ -41,7 +48,7 @@ func GetRedemptionById(id int) (*Redemption, error) {
|
|||||||
return &redemption, err
|
return &redemption, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Redeem(key string, userId int) (quota int, err error) {
|
func Redeem(key string, userId int) (quota int64, err error) {
|
||||||
if key == "" {
|
if key == "" {
|
||||||
return 0, errors.New("未提供兑换码")
|
return 0, errors.New("未提供兑换码")
|
||||||
}
|
}
|
||||||
@ -60,15 +67,15 @@ func Redeem(key string, userId int) (quota int, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("无效的兑换码")
|
return errors.New("无效的兑换码")
|
||||||
}
|
}
|
||||||
if redemption.Status != common.RedemptionCodeStatusEnabled {
|
if redemption.Status != RedemptionCodeStatusEnabled {
|
||||||
return errors.New("该兑换码已被使用")
|
return errors.New("该兑换码已被使用")
|
||||||
}
|
}
|
||||||
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
|
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
redemption.RedeemedTime = common.GetTimestamp()
|
redemption.RedeemedTime = helper.GetTimestamp()
|
||||||
redemption.Status = common.RedemptionCodeStatusUsed
|
redemption.Status = RedemptionCodeStatusUsed
|
||||||
err = tx.Save(redemption).Error
|
err = tx.Save(redemption).Error
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
116
model/token.go
116
model/token.go
@ -3,8 +3,19 @@ package model
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/common/message"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||||
|
TokenStatusDisabled = 2 // also don't use 0
|
||||||
|
TokenStatusExpired = 3
|
||||||
|
TokenStatusExhausted = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
type Token struct {
|
type Token struct {
|
||||||
@ -16,15 +27,28 @@ type Token struct {
|
|||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||||
RemainQuota int `json:"remain_quota" gorm:"default:0"`
|
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
|
||||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
|
||||||
|
Models *string `json:"models" gorm:"type:text"` // allowed models
|
||||||
|
Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
|
func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
|
||||||
var tokens []*Token
|
var tokens []*Token
|
||||||
var err error
|
var err error
|
||||||
err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error
|
query := DB.Where("user_id = ?", userId)
|
||||||
|
|
||||||
|
switch order {
|
||||||
|
case "remain_quota":
|
||||||
|
query = query.Order("unlimited_quota desc, remain_quota desc")
|
||||||
|
case "used_quota":
|
||||||
|
query = query.Order("used_quota desc")
|
||||||
|
default:
|
||||||
|
query = query.Order("id desc")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = query.Limit(num).Offset(startIdx).Find(&tokens).Error
|
||||||
return tokens, err
|
return tokens, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,26 +63,26 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
}
|
}
|
||||||
token, err = CacheGetTokenByKey(key)
|
token, err = CacheGetTokenByKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("CacheGetTokenByKey failed: " + err.Error())
|
logger.SysError("CacheGetTokenByKey failed: " + err.Error())
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, errors.New("无效的令牌")
|
return nil, errors.New("无效的令牌")
|
||||||
}
|
}
|
||||||
return nil, errors.New("令牌验证失败")
|
return nil, errors.New("令牌验证失败")
|
||||||
}
|
}
|
||||||
if token.Status == common.TokenStatusExhausted {
|
if token.Status == TokenStatusExhausted {
|
||||||
return nil, errors.New("该令牌额度已用尽")
|
return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id)
|
||||||
} else if token.Status == common.TokenStatusExpired {
|
} else if token.Status == TokenStatusExpired {
|
||||||
return nil, errors.New("该令牌已过期")
|
return nil, errors.New("该令牌已过期")
|
||||||
}
|
}
|
||||||
if token.Status != common.TokenStatusEnabled {
|
if token.Status != TokenStatusEnabled {
|
||||||
return nil, errors.New("该令牌状态不可用")
|
return nil, errors.New("该令牌状态不可用")
|
||||||
}
|
}
|
||||||
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
|
if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
token.Status = common.TokenStatusExpired
|
token.Status = TokenStatusExpired
|
||||||
err := token.SelectUpdate()
|
err := token.SelectUpdate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update token status" + err.Error())
|
logger.SysError("failed to update token status" + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("该令牌已过期")
|
return nil, errors.New("该令牌已过期")
|
||||||
@ -66,10 +90,10 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
|
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
// in this case, we can make sure the token is exhausted
|
// in this case, we can make sure the token is exhausted
|
||||||
token.Status = common.TokenStatusExhausted
|
token.Status = TokenStatusExhausted
|
||||||
err := token.SelectUpdate()
|
err := token.SelectUpdate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update token status" + err.Error())
|
logger.SysError("failed to update token status" + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("该令牌额度已用尽")
|
return nil, errors.New("该令牌额度已用尽")
|
||||||
@ -97,30 +121,40 @@ func GetTokenById(id int) (*Token, error) {
|
|||||||
return &token, err
|
return &token, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (token *Token) Insert() error {
|
func (t *Token) Insert() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Create(token).Error
|
err = DB.Create(t).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||||
func (token *Token) Update() error {
|
func (t *Token) Update() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
|
err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (token *Token) SelectUpdate() error {
|
func (t *Token) SelectUpdate() error {
|
||||||
// This can update zero values
|
// This can update zero values
|
||||||
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
|
return DB.Model(t).Select("accessed_time", "status").Updates(t).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (token *Token) Delete() error {
|
func (t *Token) Delete() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Delete(token).Error
|
err = DB.Delete(t).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Token) GetModels() string {
|
||||||
|
if t == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if t.Models == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *t.Models
|
||||||
|
}
|
||||||
|
|
||||||
func DeleteTokenById(id int, userId int) (err error) {
|
func DeleteTokenById(id int, userId int) (err error) {
|
||||||
// Why we need userId here? In case user want to delete other's token.
|
// Why we need userId here? In case user want to delete other's token.
|
||||||
if id == 0 || userId == 0 {
|
if id == 0 || userId == 0 {
|
||||||
@ -134,51 +168,51 @@ func DeleteTokenById(id int, userId int) (err error) {
|
|||||||
return token.Delete()
|
return token.Delete()
|
||||||
}
|
}
|
||||||
|
|
||||||
func IncreaseTokenQuota(id int, quota int) (err error) {
|
func IncreaseTokenQuota(id int, quota int64) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
if common.BatchUpdateEnabled {
|
if config.BatchUpdateEnabled {
|
||||||
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
|
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return increaseTokenQuota(id, quota)
|
return increaseTokenQuota(id, quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
func increaseTokenQuota(id int, quota int) (err error) {
|
func increaseTokenQuota(id int, quota int64) (err error) {
|
||||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
||||||
"used_quota": gorm.Expr("used_quota - ?", quota),
|
"used_quota": gorm.Expr("used_quota - ?", quota),
|
||||||
"accessed_time": common.GetTimestamp(),
|
"accessed_time": helper.GetTimestamp(),
|
||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecreaseTokenQuota(id int, quota int) (err error) {
|
func DecreaseTokenQuota(id int, quota int64) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
if common.BatchUpdateEnabled {
|
if config.BatchUpdateEnabled {
|
||||||
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
|
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return decreaseTokenQuota(id, quota)
|
return decreaseTokenQuota(id, quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
func decreaseTokenQuota(id int, quota int) (err error) {
|
func decreaseTokenQuota(id int, quota int64) (err error) {
|
||||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
||||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
"accessed_time": common.GetTimestamp(),
|
"accessed_time": helper.GetTimestamp(),
|
||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
|
func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
@ -196,24 +230,24 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
|
|||||||
if userQuota < quota {
|
if userQuota < quota {
|
||||||
return errors.New("用户额度不足")
|
return errors.New("用户额度不足")
|
||||||
}
|
}
|
||||||
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold
|
quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold
|
||||||
noMoreQuota := userQuota-quota <= 0
|
noMoreQuota := userQuota-quota <= 0
|
||||||
if quotaTooLow || noMoreQuota {
|
if quotaTooLow || noMoreQuota {
|
||||||
go func() {
|
go func() {
|
||||||
email, err := GetUserEmail(token.UserId)
|
email, err := GetUserEmail(token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to fetch user email: " + err.Error())
|
logger.SysError("failed to fetch user email: " + err.Error())
|
||||||
}
|
}
|
||||||
prompt := "您的额度即将用尽"
|
prompt := "您的额度即将用尽"
|
||||||
if noMoreQuota {
|
if noMoreQuota {
|
||||||
prompt = "您的额度已用尽"
|
prompt = "您的额度已用尽"
|
||||||
}
|
}
|
||||||
if email != "" {
|
if email != "" {
|
||||||
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
|
topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress)
|
||||||
err = common.SendEmail(prompt, email,
|
err = message.SendEmail(prompt, email,
|
||||||
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
|
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to send email" + err.Error())
|
logger.SysError("failed to send email" + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -228,16 +262,16 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func PostConsumeTokenQuota(tokenId int, quota int) (err error) {
|
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
|
||||||
token, err := GetTokenById(tokenId)
|
token, err := GetTokenById(tokenId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if quota > 0 {
|
if quota > 0 {
|
||||||
err = DecreaseUserQuota(token.UserId, quota)
|
err = DecreaseUserQuota(token.UserId, quota)
|
||||||
} else {
|
} else {
|
||||||
err = IncreaseUserQuota(token.UserId, -quota)
|
err = IncreaseUserQuota(token.UserId, -quota)
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !token.UnlimitedQuota {
|
if !token.UnlimitedQuota {
|
||||||
if quota > 0 {
|
if quota > 0 {
|
||||||
err = DecreaseTokenQuota(tokenId, quota)
|
err = DecreaseTokenQuota(tokenId, quota)
|
||||||
|
157
model/user.go
157
model/user.go
@ -3,11 +3,29 @@ package model
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/blacklist"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RoleGuestUser = 0
|
||||||
|
RoleCommonUser = 1
|
||||||
|
RoleAdminUser = 10
|
||||||
|
RoleRootUser = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||||
|
UserStatusDisabled = 2 // also don't use 0
|
||||||
|
UserStatusDeleted = 3
|
||||||
|
)
|
||||||
|
|
||||||
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
|
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
|
||||||
// Otherwise, the sensitive information will be saved on local storage in plain text!
|
// Otherwise, the sensitive information will be saved on local storage in plain text!
|
||||||
type User struct {
|
type User struct {
|
||||||
@ -15,15 +33,17 @@ type User struct {
|
|||||||
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
|
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
|
||||||
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
|
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
|
||||||
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
|
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
|
||||||
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
|
Role int `json:"role" gorm:"type:int;default:1"` // admin, util
|
||||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||||
|
LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
|
||||||
|
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
|
||||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||||
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
||||||
Quota int `json:"quota" gorm:"type:int;default:0"`
|
Quota int64 `json:"quota" gorm:"bigint;default:0"`
|
||||||
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
|
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0;column:used_quota"` // used quota
|
||||||
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
|
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
|
||||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||||
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
|
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
|
||||||
@ -36,8 +56,21 @@ func GetMaxUserId() int {
|
|||||||
return user.Id
|
return user.Id
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllUsers(startIdx int, num int) (users []*User, err error) {
|
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
|
||||||
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
|
query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted)
|
||||||
|
|
||||||
|
switch order {
|
||||||
|
case "quota":
|
||||||
|
query = query.Order("quota desc")
|
||||||
|
case "used_quota":
|
||||||
|
query = query.Order("used_quota desc")
|
||||||
|
case "request_count":
|
||||||
|
query = query.Order("request_count desc")
|
||||||
|
default:
|
||||||
|
query = query.Order("id desc")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = query.Find(&users).Error
|
||||||
return users, err
|
return users, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,26 +122,42 @@ func (user *User) Insert(inviterId int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
user.Quota = common.QuotaForNewUser
|
user.Quota = config.QuotaForNewUser
|
||||||
user.AccessToken = common.GetUUID()
|
user.AccessToken = random.GetUUID()
|
||||||
user.AffCode = common.GetRandomString(4)
|
user.AffCode = random.GetRandomString(4)
|
||||||
result := DB.Create(user)
|
result := DB.Create(user)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
if common.QuotaForNewUser > 0 {
|
if config.QuotaForNewUser > 0 {
|
||||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
|
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
|
||||||
}
|
}
|
||||||
if inviterId != 0 {
|
if inviterId != 0 {
|
||||||
if common.QuotaForInvitee > 0 {
|
if config.QuotaForInvitee > 0 {
|
||||||
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
|
_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
|
||||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
|
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
|
||||||
}
|
}
|
||||||
if common.QuotaForInviter > 0 {
|
if config.QuotaForInviter > 0 {
|
||||||
_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
|
_ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
|
||||||
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
|
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// create default token
|
||||||
|
cleanToken := Token{
|
||||||
|
UserId: user.Id,
|
||||||
|
Name: "default",
|
||||||
|
Key: random.GenerateKey(),
|
||||||
|
CreatedTime: helper.GetTimestamp(),
|
||||||
|
AccessedTime: helper.GetTimestamp(),
|
||||||
|
ExpiredTime: -1,
|
||||||
|
RemainQuota: -1,
|
||||||
|
UnlimitedQuota: true,
|
||||||
|
}
|
||||||
|
result.Error = cleanToken.Insert()
|
||||||
|
if result.Error != nil {
|
||||||
|
// do not block
|
||||||
|
logger.SysError(fmt.Sprintf("create default token for user %d failed: %s", user.Id, result.Error.Error()))
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,6 +169,11 @@ func (user *User) Update(updatePassword bool) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if user.Status == UserStatusDisabled {
|
||||||
|
blacklist.BanUser(user.Id)
|
||||||
|
} else if user.Status == UserStatusEnabled {
|
||||||
|
blacklist.UnbanUser(user.Id)
|
||||||
|
}
|
||||||
err = DB.Model(user).Updates(user).Error
|
err = DB.Model(user).Updates(user).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -128,7 +182,10 @@ func (user *User) Delete() error {
|
|||||||
if user.Id == 0 {
|
if user.Id == 0 {
|
||||||
return errors.New("id 为空!")
|
return errors.New("id 为空!")
|
||||||
}
|
}
|
||||||
err := DB.Delete(user).Error
|
blacklist.BanUser(user.Id)
|
||||||
|
user.Username = fmt.Sprintf("deleted_%s", random.GetUUID())
|
||||||
|
user.Status = UserStatusDeleted
|
||||||
|
err := DB.Model(user).Updates(user).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,7 +208,7 @@ func (user *User) ValidateAndFill() (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
okay := common.ValidatePasswordAndHash(password, user.Password)
|
okay := common.ValidatePasswordAndHash(password, user.Password)
|
||||||
if !okay || user.Status != common.UserStatusEnabled {
|
if !okay || user.Status != UserStatusEnabled {
|
||||||
return errors.New("用户名或密码错误,或用户已被封禁")
|
return errors.New("用户名或密码错误,或用户已被封禁")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -181,6 +238,22 @@ func (user *User) FillUserByGitHubId() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (user *User) FillUserByLarkId() error {
|
||||||
|
if user.LarkId == "" {
|
||||||
|
return errors.New("lark id 为空!")
|
||||||
|
}
|
||||||
|
DB.Where(User{LarkId: user.LarkId}).First(user)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *User) FillUserByOidcId() error {
|
||||||
|
if user.OidcId == "" {
|
||||||
|
return errors.New("oidc id 为空!")
|
||||||
|
}
|
||||||
|
DB.Where(User{OidcId: user.OidcId}).First(user)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (user *User) FillUserByWeChatId() error {
|
func (user *User) FillUserByWeChatId() error {
|
||||||
if user.WeChatId == "" {
|
if user.WeChatId == "" {
|
||||||
return errors.New("WeChat id 为空!")
|
return errors.New("WeChat id 为空!")
|
||||||
@ -209,6 +282,14 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
|
|||||||
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsLarkIdAlreadyTaken(githubId string) bool {
|
||||||
|
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsOidcIdAlreadyTaken(oidcId string) bool {
|
||||||
|
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
|
||||||
|
}
|
||||||
|
|
||||||
func IsUsernameAlreadyTaken(username string) bool {
|
func IsUsernameAlreadyTaken(username string) bool {
|
||||||
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
||||||
}
|
}
|
||||||
@ -232,10 +313,10 @@ func IsAdmin(userId int) bool {
|
|||||||
var user User
|
var user User
|
||||||
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
|
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("no such user " + err.Error())
|
logger.SysError("no such user " + err.Error())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return user.Role >= common.RoleAdminUser
|
return user.Role >= RoleAdminUser
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsUserEnabled(userId int) (bool, error) {
|
func IsUserEnabled(userId int) (bool, error) {
|
||||||
@ -247,7 +328,7 @@ func IsUserEnabled(userId int) (bool, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
return user.Status == common.UserStatusEnabled, nil
|
return user.Status == UserStatusEnabled, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateAccessToken(token string) (user *User) {
|
func ValidateAccessToken(token string) (user *User) {
|
||||||
@ -262,12 +343,12 @@ func ValidateAccessToken(token string) (user *User) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserQuota(id int) (quota int, err error) {
|
func GetUserQuota(id int) (quota int64, err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
||||||
return quota, err
|
return quota, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserUsedQuota(id int) (quota int, err error) {
|
func GetUserUsedQuota(id int) (quota int64, err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error
|
||||||
return quota, err
|
return quota, err
|
||||||
}
|
}
|
||||||
@ -287,45 +368,45 @@ func GetUserGroup(id int) (group string, err error) {
|
|||||||
return group, err
|
return group, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func IncreaseUserQuota(id int, quota int) (err error) {
|
func IncreaseUserQuota(id int, quota int64) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
if common.BatchUpdateEnabled {
|
if config.BatchUpdateEnabled {
|
||||||
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
|
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return increaseUserQuota(id, quota)
|
return increaseUserQuota(id, quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
func increaseUserQuota(id int, quota int) (err error) {
|
func increaseUserQuota(id int, quota int64) (err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecreaseUserQuota(id int, quota int) (err error) {
|
func DecreaseUserQuota(id int, quota int64) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
if common.BatchUpdateEnabled {
|
if config.BatchUpdateEnabled {
|
||||||
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
|
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return decreaseUserQuota(id, quota)
|
return decreaseUserQuota(id, quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
func decreaseUserQuota(id int, quota int) (err error) {
|
func decreaseUserQuota(id int, quota int64) (err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRootUserEmail() (email string) {
|
func GetRootUserEmail() (email string) {
|
||||||
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
|
DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email)
|
||||||
return email
|
return email
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) {
|
||||||
if common.BatchUpdateEnabled {
|
if config.BatchUpdateEnabled {
|
||||||
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
|
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
|
||||||
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
|
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
|
||||||
return
|
return
|
||||||
@ -333,7 +414,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
|||||||
updateUserUsedQuotaAndRequestCount(id, quota, 1)
|
updateUserUsedQuotaAndRequestCount(id, quota, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) {
|
||||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
@ -341,25 +422,25 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
|||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update user used quota and request count: " + err.Error())
|
logger.SysError("failed to update user used quota and request count: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateUserUsedQuota(id int, quota int) {
|
func updateUserUsedQuota(id int, quota int64) {
|
||||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update user used quota: " + err.Error())
|
logger.SysError("failed to update user used quota: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateUserRequestCount(id int, count int) {
|
func updateUserRequestCount(id int, count int) {
|
||||||
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
|
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update user request count: " + err.Error())
|
logger.SysError("failed to update user request count: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -15,12 +16,12 @@ const (
|
|||||||
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
|
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
|
||||||
)
|
)
|
||||||
|
|
||||||
var batchUpdateStores []map[int]int
|
var batchUpdateStores []map[int]int64
|
||||||
var batchUpdateLocks []sync.Mutex
|
var batchUpdateLocks []sync.Mutex
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
batchUpdateStores = append(batchUpdateStores, make(map[int]int))
|
batchUpdateStores = append(batchUpdateStores, make(map[int]int64))
|
||||||
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
|
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -28,13 +29,13 @@ func init() {
|
|||||||
func InitBatchUpdater() {
|
func InitBatchUpdater() {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
|
time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second)
|
||||||
batchUpdate()
|
batchUpdate()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func addNewRecord(type_ int, id int, value int) {
|
func addNewRecord(type_ int, id int, value int64) {
|
||||||
batchUpdateLocks[type_].Lock()
|
batchUpdateLocks[type_].Lock()
|
||||||
defer batchUpdateLocks[type_].Unlock()
|
defer batchUpdateLocks[type_].Unlock()
|
||||||
if _, ok := batchUpdateStores[type_][id]; !ok {
|
if _, ok := batchUpdateStores[type_][id]; !ok {
|
||||||
@ -45,11 +46,11 @@ func addNewRecord(type_ int, id int, value int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func batchUpdate() {
|
func batchUpdate() {
|
||||||
common.SysLog("batch update started")
|
logger.SysLog("batch update started")
|
||||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
batchUpdateLocks[i].Lock()
|
batchUpdateLocks[i].Lock()
|
||||||
store := batchUpdateStores[i]
|
store := batchUpdateStores[i]
|
||||||
batchUpdateStores[i] = make(map[int]int)
|
batchUpdateStores[i] = make(map[int]int64)
|
||||||
batchUpdateLocks[i].Unlock()
|
batchUpdateLocks[i].Unlock()
|
||||||
// TODO: maybe we can combine updates with same key?
|
// TODO: maybe we can combine updates with same key?
|
||||||
for key, value := range store {
|
for key, value := range store {
|
||||||
@ -57,21 +58,21 @@ func batchUpdate() {
|
|||||||
case BatchUpdateTypeUserQuota:
|
case BatchUpdateTypeUserQuota:
|
||||||
err := increaseUserQuota(key, value)
|
err := increaseUserQuota(key, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to batch update user quota: " + err.Error())
|
logger.SysError("failed to batch update user quota: " + err.Error())
|
||||||
}
|
}
|
||||||
case BatchUpdateTypeTokenQuota:
|
case BatchUpdateTypeTokenQuota:
|
||||||
err := increaseTokenQuota(key, value)
|
err := increaseTokenQuota(key, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to batch update token quota: " + err.Error())
|
logger.SysError("failed to batch update token quota: " + err.Error())
|
||||||
}
|
}
|
||||||
case BatchUpdateTypeUsedQuota:
|
case BatchUpdateTypeUsedQuota:
|
||||||
updateUserUsedQuota(key, value)
|
updateUserUsedQuota(key, value)
|
||||||
case BatchUpdateTypeRequestCount:
|
case BatchUpdateTypeRequestCount:
|
||||||
updateUserRequestCount(key, value)
|
updateUserRequestCount(key, int(value))
|
||||||
case BatchUpdateTypeChannelUsedQuota:
|
case BatchUpdateTypeChannelUsedQuota:
|
||||||
updateChannelUsedQuota(key, value)
|
updateChannelUsedQuota(key, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
common.SysLog("batch update finished")
|
logger.SysLog("batch update finished")
|
||||||
}
|
}
|
||||||
|
54
monitor/channel.go
Normal file
54
monitor/channel.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
package monitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/common/message"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func notifyRootUser(subject string, content string) {
|
||||||
|
if config.MessagePusherAddress != "" {
|
||||||
|
err := message.SendMessage(subject, content, content)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError(fmt.Sprintf("failed to send message: %s", err.Error()))
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if config.RootUserEmail == "" {
|
||||||
|
config.RootUserEmail = model.GetRootUserEmail()
|
||||||
|
}
|
||||||
|
err := message.SendEmail(subject, config.RootUserEmail, content)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableChannel disable & notify
|
||||||
|
func DisableChannel(channelId int, channelName string, reason string) {
|
||||||
|
model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled)
|
||||||
|
logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason))
|
||||||
|
subject := fmt.Sprintf("渠道「%s」(#%d)已被禁用", channelName, channelId)
|
||||||
|
content := fmt.Sprintf("渠道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||||
|
notifyRootUser(subject, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MetricDisableChannel(channelId int, successRate float64) {
|
||||||
|
model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled)
|
||||||
|
logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100))
|
||||||
|
subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId)
|
||||||
|
content := fmt.Sprintf("该渠道(#%d)在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
|
||||||
|
channelId, config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100)
|
||||||
|
notifyRootUser(subject, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableChannel enable & notify
|
||||||
|
func EnableChannel(channelId int, channelName string) {
|
||||||
|
model.UpdateChannelStatusById(channelId, model.ChannelStatusEnabled)
|
||||||
|
logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId))
|
||||||
|
subject := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId)
|
||||||
|
content := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId)
|
||||||
|
notifyRootUser(subject, content)
|
||||||
|
}
|
55
monitor/manage.go
Normal file
55
monitor/manage.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package monitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ShouldDisableChannel(err *model.Error, statusCode int) bool {
|
||||||
|
if !config.AutomaticDisableChannelEnabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if statusCode == http.StatusUnauthorized {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch err.Type {
|
||||||
|
case "insufficient_quota", "authentication_error", "permission_error", "forbidden":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
lowerMessage := strings.ToLower(err.Message)
|
||||||
|
if strings.Contains(lowerMessage, "your access was terminated") ||
|
||||||
|
strings.Contains(lowerMessage, "violation of our policies") ||
|
||||||
|
strings.Contains(lowerMessage, "your credit balance is too low") ||
|
||||||
|
strings.Contains(lowerMessage, "organization has been disabled") ||
|
||||||
|
strings.Contains(lowerMessage, "credit") ||
|
||||||
|
strings.Contains(lowerMessage, "balance") ||
|
||||||
|
strings.Contains(lowerMessage, "permission denied") ||
|
||||||
|
strings.Contains(lowerMessage, "organization has been restricted") || // groq
|
||||||
|
strings.Contains(lowerMessage, "已欠费") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func ShouldEnableChannel(err error, openAIErr *model.Error) bool {
|
||||||
|
if !config.AutomaticEnableChannelEnabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if openAIErr != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user