Merge remote-tracking branch 'origin/upstream/main' into patch/images-edits

This commit is contained in:
Laisky.Cai 2024-07-13 13:42:49 +00:00
commit 02da017791
65 changed files with 1957 additions and 1200 deletions

View File

@ -36,23 +36,9 @@ jobs:
# in the next step as well as the next job. # in the next step as well as the next job.
- name: Test - name: Test
run: go test -cover -coverprofile=coverage.txt ./... run: go test -cover -coverprofile=coverage.txt ./...
- uses: codecov/codecov-action@v4
- name: Archive code coverage results
uses: actions/upload-artifact@v4
with: with:
name: code-coverage token: ${{ secrets.CODECOV_TOKEN }}
path: coverage.txt # Make sure to use the same file name you chose for the "-coverprofile" in the "Test" step
code_coverage:
name: "Code coverage report"
if: github.event_name == 'pull_request' # Do not run when workflow is triggered by push to main branch
runs-on: ubuntu-latest
needs: unit_tests # Depends on the artifact uploaded by the "unit_tests" job
steps:
- uses: fgrosse/go-coverage-report@v1.0.2 # Consider using a Git revision for maximum security
with:
coverage-artifact-name: "code-coverage" # can be omitted if you used this default value
coverage-file-name: "coverage.txt" # can be omitted if you used this default value
commit_lint: commit_lint:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -16,7 +16,9 @@ WORKDIR /web/air
RUN npm install RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang AS builder2 FROM golang:alpine AS builder2
RUN apk add --no-cache g++
ENV GO111MODULE=on \ ENV GO111MODULE=on \
CGO_ENABLED=1 \ CGO_ENABLED=1 \
@ -27,7 +29,7 @@ ADD go.mod go.sum ./
RUN go mod download RUN go mod download
COPY . . COPY . .
COPY --from=builder /web/build ./web/build COPY --from=builder /web/build ./web/build
RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine FROM alpine

View File

@ -245,16 +245,41 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs` + Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs`
5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. 5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address.
+ Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn`
6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. 6. 'MEMORY_CACHE_ENABLED': Enabling memory caching can cause a certain delay in updating user quotas, with optional values of 'true' and 'false'. If not set, it defaults to 'false'.
7. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
+ Example: `SYNC_FREQUENCY=60` + Example: `SYNC_FREQUENCY=60`
7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. 8. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
+ Example: `NODE_TYPE=slave` + Example: `NODE_TYPE=slave`
8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. 9. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
+ Example: `CHANNEL_UPDATE_FREQUENCY=1440` + Example: `CHANNEL_UPDATE_FREQUENCY=1440`
9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. 10. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
+ Example: `CHANNEL_TEST_FREQUENCY=1440` + Example: `CHANNEL_TEST_FREQUENCY=1440`
10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. 11. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
+ Example: `POLLING_INTERVAL=5` + Example: `POLLING_INTERVAL=5`
12. `BATCH_UPDATE_ENABLED`: Enabling batch database update aggregation can cause a certain delay in updating user quotas. The optional values are 'true' and 'false', but if not set, it defaults to 'false'.
+Example: ` BATCH_UPDATE_ENABLED=true`
+If you encounter an issue with too many database connections, you can try enabling this option.
13. `BATCH_UPDATE_INTERVAL=5`: The time interval for batch updating aggregates, measured in seconds, defaults to '5'.
+Example: ` BATCH_UPDATE_INTERVAL=5`
14. Request frequency limit:
+ `GLOBAL_API_RATE_LIMIT`: Global API rate limit (excluding relay requests), the maximum number of requests within three minutes per IP, default to 180.
+ `GLOBAL_WEL_RATE_LIMIT`: Global web speed limit, the maximum number of requests within three minutes per IP, default to 60.
15. Encoder cache settings:
+`TIKTOKEN_CACHE_DIR`: By default, when the program starts, it will download the encoding of some common word elements online, such as' gpt-3.5 turbo '. In some unstable network environments or offline situations, it may cause startup problems. This directory can be configured to cache data and can be migrated to an offline environment.
+`DATA_GYM_CACHE_DIR`: Currently, this configuration has the same function as' TIKTOKEN-CACHE-DIR ', but its priority is not as high as it.
16. `RELAY_TIMEOUT`: Relay timeout setting, measured in seconds, with no default timeout time set.
17. `RELAY_PROXY`: After setting up, use this proxy to request APIs.
18. `USER_CONTENT_REQUEST_TIMEOUT`: The timeout period for users to upload and download content, measured in seconds.
19. `USER_CONTENT_REQUEST_PROXY`: After setting up, use this agent to request content uploaded by users, such as images.
20. `SQLITE_BUSY_TIMEOUT`: SQLite lock wait timeout setting, measured in milliseconds, default to '3000'.
21. `GEMINI_SAFETY_SETTING`: Gemini's security settings are set to 'BLOCK-NONE' by default.
22. `GEMINI_VERSION`: The Gemini version used by the One API, which defaults to 'v1'.
23. `THE`: The system's theme setting, default to 'default', specific optional values refer to [here] (./web/README. md).
24. `ENABLE_METRIC`: Whether to disable channels based on request success rate, default not enabled, optional values are 'true' and 'false'.
25. `METRIC_QUEUE_SIZE`: Request success rate statistics queue size, default to '10'.
26. `METRIC_SUCCESS_RATE_THRESHOLD`: Request success rate threshold, default to '0.8'.
27. `INITIAL_ROOT_TOKEN`: If this value is set, a root user token with the value of the environment variable will be automatically created when the system starts for the first time.
28. `INITIAL_ROOT_ACCESS_TOKEN`: If this value is set, a system management token will be automatically created for the root user with a value of the environment variable when the system starts for the first time.
### Command Line Parameters ### Command Line Parameters
1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`. 1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`.
@ -287,7 +312,9 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Double-check that your interface address and API Key are correct. + Double-check that your interface address and API Key are correct.
## Related Projects ## Related Projects
[FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM * [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM
* [VChart](https://github.com/VisActor/VChart): More than just a cross-platform charting library, but also an expressive data storyteller.
* [VMind](https://github.com/VisActor/VMind): Not just automatic, but also fantastic. Open-source solution for intelligent visualization.
## Note ## Note
This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes.

View File

@ -88,6 +88,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/)
+ [x] [DeepL](https://www.deepl.com/) + [x] [DeepL](https://www.deepl.com/)
+ [x] [together.ai](https://www.together.ai/) + [x] [together.ai](https://www.together.ai/)
+ [x] [novita.ai](https://www.novita.ai/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。 3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
@ -370,32 +371,33 @@ graph LR
9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440` + 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
11. 例子:`CHANNEL_TEST_FREQUENCY=1440` +例子:`CHANNEL_TEST_FREQUENCY=1440`
12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5` + 例子:`POLLING_INTERVAL=5`
13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true``false`,未设置则默认为 `false` 12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`BATCH_UPDATE_ENABLED=true` + 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5` 13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`
+ 例子:`BATCH_UPDATE_INTERVAL=5` + 例子:`BATCH_UPDATE_INTERVAL=5`
15. 请求频率限制: 14. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180` + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60` + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`
16. 编码器缓存设置: 15. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
18. `RELAY_PROXY`:设置后使用该代理来请求 API。 17. `RELAY_PROXY`:设置后使用该代理来请求 API。
19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 18. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。
20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 19. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。
21. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000` 20. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`
22. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE` 21. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`
23. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1` 22. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1`
24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 23. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true``false` 24. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true``false`
26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10` 25. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`
27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8` 26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`
28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000` 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`
@ -448,6 +450,8 @@ https://openai.justsong.cn
## 相关项目 ## 相关项目
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
* [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用 * [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用
* [VChart](https://github.com/VisActor/VChart): 不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。
* [VMind](https://github.com/VisActor/VMind): 不仅自动,还很智能。开源智能可视化解决方案。
## 注意 ## 注意

View File

@ -143,8 +143,12 @@ var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN")
var GeminiVersion = env.String("GEMINI_VERSION", "v1") var GeminiVersion = env.String("GEMINI_VERSION", "v1")
var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
var RelayProxy = env.String("RELAY_PROXY", "") var RelayProxy = env.String("RELAY_PROXY", "")
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)

View File

@ -21,4 +21,5 @@ const (
TokenName = "token_name" TokenName = "token_name"
BaseURL = "base_url" BaseURL = "base_url"
AvailableModels = "available_models" AvailableModels = "available_models"
KeyRequestBody = "key_request_body"
) )

View File

@ -6,12 +6,11 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/ctxkey"
) )
const KeyRequestBody = "key_request_body"
func GetRequestBody(c *gin.Context) ([]byte, error) { func GetRequestBody(c *gin.Context) ([]byte, error) {
requestBody, _ := c.Get(KeyRequestBody) requestBody, _ := c.Get(ctxkey.KeyRequestBody)
if requestBody != nil { if requestBody != nil {
return requestBody.([]byte), nil return requestBody.([]byte), nil
} }
@ -20,7 +19,7 @@ func GetRequestBody(c *gin.Context) ([]byte, error) {
return nil, err return nil, err
} }
_ = c.Request.Body.Close() _ = c.Request.Body.Close()
c.Set(KeyRequestBody, requestBody) c.Set(ctxkey.KeyRequestBody, requestBody)
return requestBody.([]byte), nil return requestBody.([]byte), nil
} }

View File

@ -27,7 +27,12 @@ var setupLogOnce sync.Once
func SetupLogger() { func SetupLogger() {
setupLogOnce.Do(func() { setupLogOnce.Do(func() {
if LogDir != "" { if LogDir != "" {
logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) var logPath string
if config.OnlyOneLogFile {
logPath = filepath.Join(LogDir, "oneapi.log")
} else {
logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
}
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
log.Fatal("failed to open log file") log.Fatal("failed to open log file")

View File

@ -6,11 +6,16 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"net"
"net/smtp" "net/smtp"
"strings" "strings"
"time" "time"
) )
func shouldAuth() bool {
return config.SMTPAccount != "" || config.SMTPToken != ""
}
func SendEmail(subject string, receiver string, content string) error { func SendEmail(subject string, receiver string, content string) error {
if receiver == "" { if receiver == "" {
return fmt.Errorf("receiver is empty") return fmt.Errorf("receiver is empty")
@ -41,16 +46,24 @@ func SendEmail(subject string, receiver string, content string) error {
"Date: %s\r\n"+ "Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
to := strings.Split(receiver, ";") to := strings.Split(receiver, ";")
if config.SMTPPort == 465 || !shouldAuth() {
// need advanced client
var conn net.Conn
var err error
if config.SMTPPort == 465 { if config.SMTPPort == 465 {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: config.SMTPServer, ServerName: config.SMTPServer,
} }
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
} else {
conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort))
}
if err != nil { if err != nil {
return err return err
} }
@ -59,9 +72,11 @@ func SendEmail(subject string, receiver string, content string) error {
return err return err
} }
defer client.Close() defer client.Close()
if shouldAuth() {
if err = client.Auth(auth); err != nil { if err = client.Auth(auth); err != nil {
return err return err
} }
}
if err = client.Mail(config.SMTPFrom); err != nil { if err = client.Mail(config.SMTPFrom); err != nil {
return err return err
} }

29
common/render/render.go Normal file
View File

@ -0,0 +1,29 @@
package render
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"strings"
)
func StringData(c *gin.Context, str string) {
str = strings.TrimPrefix(str, "data: ")
str = strings.TrimSuffix(str, "\r")
c.Render(-1, common.CustomEvent{Data: "data: " + str})
c.Writer.Flush()
}
func ObjectData(c *gin.Context, object interface{}) error {
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
StringData(c, string(jsonData))
return nil
}
func Done(c *gin.Context) {
StringData(c, "[DONE]")
}

View File

@ -14,6 +14,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
@ -27,15 +28,15 @@ import (
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"github.com/gin-gonic/gin"
) )
func buildTestRequest() *relaymodel.GeneralOpenAIRequest { func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
if model == "" {
model = "gpt-3.5-turbo"
}
testRequest := &relaymodel.GeneralOpenAIRequest{ testRequest := &relaymodel.GeneralOpenAIRequest{
MaxTokens: 2, MaxTokens: 2,
Stream: false, Model: model,
Model: "gpt-3.5-turbo",
} }
testMessage := relaymodel.Message{ testMessage := relaymodel.Message{
Role: "user", Role: "user",
@ -45,7 +46,7 @@ func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
return testRequest return testRequest
} }
func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) { func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{ c.Request = &http.Request{
@ -68,12 +69,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
adaptor.Init(meta) adaptor.Init(meta)
var modelName string modelName := request.Model
modelList := adaptor.GetModelList()
modelMap := channel.GetModelMapping() modelMap := channel.GetModelMapping()
if len(modelList) != 0 {
modelName = modelList[0]
}
if modelName == "" || !strings.Contains(channel.Models, modelName) { if modelName == "" || !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",") modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 { if len(modelNames) > 0 {
@ -83,9 +80,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
modelName = modelMap[modelName] modelName = modelMap[modelName]
} }
} }
request := buildTestRequest() meta.OriginModelName, meta.ActualModelName = request.Model, modelName
request.Model = modelName request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil { if err != nil {
return err, nil return err, nil
@ -139,10 +135,15 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
model := c.Query("model")
testRequest := buildTestRequest(model)
tik := time.Now() tik := time.Now()
err, _ = testChannel(channel) err, _ = testChannel(channel, testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if err != nil {
milliseconds = 0
}
go channel.UpdateResponseTime(milliseconds) go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0 consumedTime := float64(milliseconds) / 1000.0
if err != nil { if err != nil {
@ -150,6 +151,7 @@ func TestChannel(c *gin.Context) {
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
"time": consumedTime, "time": consumedTime,
"model": model,
}) })
return return
} }
@ -157,6 +159,7 @@ func TestChannel(c *gin.Context) {
"success": true, "success": true,
"message": "", "message": "",
"time": consumedTime, "time": consumedTime,
"model": model,
}) })
return return
} }
@ -187,11 +190,12 @@ func testChannels(notify bool, scope string) error {
for _, channel := range channels { for _, channel := range channels {
isChannelEnabled := channel.Status == model.ChannelStatusEnabled isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
err, openaiErr := testChannel(channel) testRequest := buildTestRequest("")
err, openaiErr := testChannel(channel, testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold { if isChannelEnabled && milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
if config.AutomaticDisableChannelEnabled { if config.AutomaticDisableChannelEnabled {
monitor.DisableChannel(channel.Id, channel.Name, err.Error()) monitor.DisableChannel(channel.Id, channel.Name, err.Error())
} else { } else {

View File

@ -45,7 +45,7 @@ func Relay(c *gin.Context) {
ctx := c.Request.Context() ctx := c.Request.Context()
relayMode := relaymode.GetByPath(c.Request.URL.Path) relayMode := relaymode.GetByPath(c.Request.URL.Path)
channelId := c.GetInt(ctxkey.ChannelId) channelId := c.GetInt(ctxkey.ChannelId)
userId := c.GetInt("id") userId := c.GetInt(ctxkey.Id)
bizErr := relayHelper(c, relayMode) bizErr := relayHelper(c, relayMode)
if bizErr == nil { if bizErr == nil {
monitor.Emit(channelId, true) monitor.Emit(channelId, true)

6
go.mod
View File

@ -24,7 +24,7 @@ require (
github.com/smartystreets/goconvey v1.8.1 github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.23.0 golang.org/x/crypto v0.23.0
golang.org/x/image v0.16.0 golang.org/x/image v0.18.0
gorm.io/driver/mysql v1.5.6 gorm.io/driver/mysql v1.5.6
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.5 gorm.io/driver/sqlite v1.5.5
@ -68,7 +68,7 @@ require (
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect
@ -80,7 +80,7 @@ require (
golang.org/x/net v0.25.0 // indirect golang.org/x/net v0.25.0 // indirect
golang.org/x/sync v0.7.0 // indirect golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.20.0 // indirect golang.org/x/sys v0.20.0 // indirect
golang.org/x/text v0.15.0 // indirect golang.org/x/text v0.16.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect google.golang.org/protobuf v1.34.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

12
go.sum
View File

@ -110,8 +110,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -154,8 +154,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw= golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs= golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
@ -164,8 +164,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=

22
main.go
View File

@ -27,27 +27,19 @@ func main() {
common.Init() common.Init()
logger.SetupLogger() logger.SetupLogger()
logger.SysLogf("One API %s started", common.Version) logger.SysLogf("One API %s started", common.Version)
if os.Getenv("GIN_MODE") != "debug" {
if os.Getenv("GIN_MODE") != gin.DebugMode {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
if config.DebugEnabled { if config.DebugEnabled {
logger.SysLog("running in debug mode") logger.SysLog("running in debug mode")
} }
var err error
// Initialize SQL Database // Initialize SQL Database
model.DB, err = model.InitDB("SQL_DSN") model.InitDB()
if err != nil { model.InitLogDB()
logger.FatalLog("failed to initialize database: " + err.Error())
} var err error
if os.Getenv("LOG_SQL_DSN") != "" {
logger.SysLog("using secondary database for table logs")
model.LOG_DB, err = model.InitDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
}
} else {
model.LOG_DB = model.DB
}
err = model.CreateRootAccountIfNeed() err = model.CreateRootAccountIfNeed()
if err != nil { if err != nil {
logger.FatalLog("database init error: " + err.Error()) logger.FatalLog("database init error: " + err.Error())

View File

@ -1,6 +1,7 @@
package model package model
import ( import (
"database/sql"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
@ -29,13 +30,17 @@ func CreateRootAccountIfNeed() error {
if err != nil { if err != nil {
return err return err
} }
accessToken := random.GetUUID()
if config.InitialRootAccessToken != "" {
accessToken = config.InitialRootAccessToken
}
rootUser := User{ rootUser := User{
Username: "root", Username: "root",
Password: hashedPassword, Password: hashedPassword,
Role: RoleRootUser, Role: RoleRootUser,
Status: UserStatusEnabled, Status: UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: random.GetUUID(), AccessToken: accessToken,
Quota: 500000000000000, Quota: 500000000000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)
@ -60,10 +65,22 @@ func CreateRootAccountIfNeed() error {
} }
func chooseDB(envName string) (*gorm.DB, error) { func chooseDB(envName string) (*gorm.DB, error) {
if os.Getenv(envName) != "" {
dsn := os.Getenv(envName) dsn := os.Getenv(envName)
if strings.HasPrefix(dsn, "postgres://") {
switch {
case strings.HasPrefix(dsn, "postgres://"):
// Use PostgreSQL // Use PostgreSQL
return openPostgreSQL(dsn)
case dsn != "":
// Use MySQL
return openMySQL(dsn)
default:
// Use SQLite
return openSQLite()
}
}
func openPostgreSQL(dsn string) (*gorm.DB, error) {
logger.SysLog("using PostgreSQL as database") logger.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{ return gorm.Open(postgres.New(postgres.Config{
@ -72,78 +89,132 @@ func chooseDB(envName string) (*gorm.DB, error) {
}), &gorm.Config{ }), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
} }
// Use MySQL
func openMySQL(dsn string) (*gorm.DB, error) {
logger.SysLog("using MySQL as database") logger.SysLog("using MySQL as database")
common.UsingMySQL = true common.UsingMySQL = true
return gorm.Open(mysql.Open(dsn), &gorm.Config{ return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
} }
// Use SQLite
func openSQLite() (*gorm.DB, error) {
logger.SysLog("SQL_DSN not set, using SQLite as database") logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true common.UsingSQLite = true
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ return gorm.Open(sqlite.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
} }
func InitDB(envName string) (db *gorm.DB, err error) { func InitDB() {
db, err = chooseDB(envName) var err error
if err == nil { DB, err = chooseDB("SQL_DSN")
if config.DebugSQLEnabled {
db = db.Debug()
}
sqlDB, err := db.DB()
if err != nil { if err != nil {
return nil, err logger.FatalLog("failed to initialize database: " + err.Error())
return
} }
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) sqlDB := setDBConns(DB)
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
if !config.IsMasterNode { if !config.IsMasterNode {
return db, err return
} }
if common.UsingMySQL { if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
} }
logger.SysLog("database migration started") logger.SysLog("database migration started")
err = db.AutoMigrate(&Channel{}) if err = migrateDB(); err != nil {
if err != nil { logger.FatalLog("failed to migrate database: " + err.Error())
return nil, err return
}
err = db.AutoMigrate(&Token{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&User{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Option{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Redemption{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return nil, err
} }
logger.SysLog("database migrated") logger.SysLog("database migrated")
return db, err }
} else {
logger.FatalLog(err) func migrateDB() error {
var err error
if err = DB.AutoMigrate(&Channel{}); err != nil {
return err
} }
return db, err if err = DB.AutoMigrate(&Token{}); err != nil {
return err
}
if err = DB.AutoMigrate(&User{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Option{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Redemption{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Ability{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Log{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Channel{}); err != nil {
return err
}
return nil
}
func InitLogDB() {
if os.Getenv("LOG_SQL_DSN") == "" {
LOG_DB = DB
return
}
logger.SysLog("using secondary database for table logs")
var err error
LOG_DB, err = chooseDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
return
}
setDBConns(LOG_DB)
if !config.IsMasterNode {
return
}
logger.SysLog("secondary database migration started")
err = migrateLOGDB()
if err != nil {
logger.FatalLog("failed to migrate secondary database: " + err.Error())
return
}
logger.SysLog("secondary database migrated")
}
func migrateLOGDB() error {
var err error
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
return err
}
return nil
}
func setDBConns(db *gorm.DB) *sql.DB {
if config.DebugSQLEnabled {
db = db.Debug()
}
sqlDB, err := db.DB()
if err != nil {
logger.FatalLog("failed to connect database: " + err.Error())
return nil
}
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
return sqlDB
} }
func closeDB(db *gorm.DB) error { func closeDB(db *gorm.DB) error {

View File

@ -4,6 +4,12 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@ -12,10 +18,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
"strings"
) )
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
@ -89,6 +91,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage var usage model.Usage
var documents []LibraryDocument
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
@ -102,60 +105,48 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
stopChan := make(chan bool) common.SetEventStreamHeaders(c)
go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format if len(data) < 5 || data[:5] != "data:" {
continue
}
if data[:5] != "data:" {
continue continue
} }
data = data[5:] data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
var documents []LibraryDocument
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var AIProxyLibraryResponse LibraryStreamResponse var AIProxyLibraryResponse LibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true continue
} }
if len(AIProxyLibraryResponse.Documents) != 0 { if len(AIProxyLibraryResponse.Documents) != 0 {
documents = AIProxyLibraryResponse.Documents documents = AIProxyLibraryResponse.Documents
} }
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError(err.Error())
return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) }
return true
case <-stopChan: if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
response := documentsAIProxyLibrary(documents) response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response) err := render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError(err.Error())
return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) render.Done(c)
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false err = resp.Body.Close()
}
})
err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, &usage return nil, &usage
} }

View File

@ -3,15 +3,17 @@ package ali
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
) )
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
@ -181,32 +183,21 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
stopChan := make(chan bool) common.SetEventStreamHeaders(c)
go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format if len(data) < 5 || data[:5] != "data:" {
continue
}
if data[:5] != "data:" {
continue continue
} }
data = data[5:] data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
//lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse ChatResponse var aliResponse ChatResponse
err := json.Unmarshal([]byte(data), &aliResponse) err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true continue
} }
if aliResponse.Usage.OutputTokens != 0 { if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens usage.PromptTokens = aliResponse.Usage.InputTokens
@ -215,22 +206,20 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
response := streamResponseAli2OpenAI(&aliResponse) response := streamResponseAli2OpenAI(&aliResponse)
if response == nil { if response == nil {
return true continue
} }
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) err = render.ObjectData(c, response)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError(err.Error())
return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
})
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil

View File

@ -5,4 +5,5 @@ var ModelList = []string{
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
"claude-3-sonnet-20240229", "claude-3-sonnet-20240229",
"claude-3-opus-20240229", "claude-3-opus-20240229",
"claude-3-5-sonnet-20240620",
} }

View File

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -28,12 +29,30 @@ func stopReasonClaude2OpenAI(reason *string) string {
return "stop" return "stop"
case "max_tokens": case "max_tokens":
return "length" return "length"
case "tool_use":
return "tool_calls"
default: default:
return *reason return *reason
} }
} }
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeTools := make([]Tool, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTools = append(claudeTools, Tool{
Name: tool.Function.Name,
Description: tool.Function.Description,
InputSchema: InputSchema{
Type: params["type"].(string),
Properties: params["properties"],
Required: params["required"],
},
})
}
}
claudeRequest := Request{ claudeRequest := Request{
Model: textRequest.Model, Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens, MaxTokens: textRequest.MaxTokens,
@ -41,6 +60,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
TopP: textRequest.TopP, TopP: textRequest.TopP,
TopK: textRequest.TopK, TopK: textRequest.TopK,
Stream: textRequest.Stream, Stream: textRequest.Stream,
Tools: claudeTools,
}
if len(claudeTools) > 0 {
claudeToolChoice := struct {
Type string `json:"type"`
Name string `json:"name,omitempty"`
}{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output
if choice, ok := textRequest.ToolChoice.(map[string]any); ok {
if function, ok := choice["function"].(map[string]any); ok {
claudeToolChoice.Type = "tool"
claudeToolChoice.Name = function["name"].(string)
}
} else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok {
if toolChoiceType == "any" {
claudeToolChoice.Type = toolChoiceType
}
}
claudeRequest.ToolChoice = claudeToolChoice
} }
if claudeRequest.MaxTokens == 0 { if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096 claudeRequest.MaxTokens = 4096
@ -63,7 +100,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
if message.IsStringContent() { if message.IsStringContent() {
content.Type = "text" content.Type = "text"
content.Text = message.StringContent() content.Text = message.StringContent()
if message.Role == "tool" {
claudeMessage.Role = "user"
content.Type = "tool_result"
content.Content = content.Text
content.Text = ""
content.ToolUseId = message.ToolCallId
}
claudeMessage.Content = append(claudeMessage.Content, content) claudeMessage.Content = append(claudeMessage.Content, content)
for i := range message.ToolCalls {
inputParam := make(map[string]any)
_ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam)
claudeMessage.Content = append(claudeMessage.Content, Content{
Type: "tool_use",
Id: message.ToolCalls[i].Id,
Name: message.ToolCalls[i].Function.Name,
Input: inputParam,
})
}
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
continue continue
} }
@ -96,16 +150,35 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
var response *Response var response *Response
var responseText string var responseText string
var stopReason string var stopReason string
tools := make([]model.Tool, 0)
switch claudeResponse.Type { switch claudeResponse.Type {
case "message_start": case "message_start":
return nil, claudeResponse.Message return nil, claudeResponse.Message
case "content_block_start": case "content_block_start":
if claudeResponse.ContentBlock != nil { if claudeResponse.ContentBlock != nil {
responseText = claudeResponse.ContentBlock.Text responseText = claudeResponse.ContentBlock.Text
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, model.Tool{
Id: claudeResponse.ContentBlock.Id,
Type: "function",
Function: model.Function{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
})
}
} }
case "content_block_delta": case "content_block_delta":
if claudeResponse.Delta != nil { if claudeResponse.Delta != nil {
responseText = claudeResponse.Delta.Text responseText = claudeResponse.Delta.Text
if claudeResponse.Delta.Type == "input_json_delta" {
tools = append(tools, model.Tool{
Function: model.Function{
Arguments: claudeResponse.Delta.PartialJson,
},
})
}
} }
case "message_delta": case "message_delta":
if claudeResponse.Usage != nil { if claudeResponse.Usage != nil {
@ -119,6 +192,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
} }
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText choice.Delta.Content = responseText
if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools
}
choice.Delta.Role = "assistant" choice.Delta.Role = "assistant"
finishReason := stopReasonClaude2OpenAI(&stopReason) finishReason := stopReasonClaude2OpenAI(&stopReason)
if finishReason != "null" { if finishReason != "null" {
@ -135,12 +212,27 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
if len(claudeResponse.Content) > 0 { if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text responseText = claudeResponse.Content[0].Text
} }
tools := make([]model.Tool, 0)
for _, v := range claudeResponse.Content {
if v.Type == "tool_use" {
args, _ := json.Marshal(v.Input)
tools = append(tools, model.Tool{
Id: v.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: model.Function{
Name: v.Name,
Arguments: string(args),
},
})
}
}
choice := openai.TextResponseChoice{ choice := openai.TextResponseChoice{
Index: 0, Index: 0,
Message: model.Message{ Message: model.Message{
Role: "assistant", Role: "assistant",
Content: responseText, Content: responseText,
Name: nil, Name: nil,
ToolCalls: tools,
}, },
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
} }
@ -169,64 +261,77 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 {
continue
}
if !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
var usage model.Usage var usage model.Usage
var modelName string var modelName string
var id string var id string
c.Stream(func(w io.Writer) bool { var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
select {
case data := <-dataChan: for scanner.Scan() {
// some implementations may add \r at the end of data data := scanner.Text()
if len(data) < 6 || !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
data = strings.TrimSpace(data) data = strings.TrimSpace(data)
var claudeResponse StreamResponse var claudeResponse StreamResponse
err := json.Unmarshal([]byte(data), &claudeResponse) err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true continue
} }
response, meta := StreamResponseClaude2OpenAI(&claudeResponse) response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
if meta != nil { if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens usage.CompletionTokens += meta.Usage.OutputTokens
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
modelName = meta.Model modelName = meta.Model
id = fmt.Sprintf("chatcmpl-%s", meta.Id) id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true continue
} else { // finish_reason case
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
lastArgs.Arguments = "{}"
response.Choices[len(response.Choices)-1].Delta.Content = nil
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
}
}
}
} }
if response == nil { if response == nil {
return true continue
} }
response.Id = id response.Id = id
response.Model = modelName response.Model = modelName
response.Created = createdTime response.Created = createdTime
jsonStr, err := json.Marshal(response)
for _, choice := range response.Choices {
if len(choice.Delta.ToolCalls) > 0 {
lastToolCallChoice = choice
}
}
err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError(err.Error())
return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
})
_ = resp.Body.Close() if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage return nil, &usage
} }

View File

@ -16,6 +16,12 @@ type Content struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
Source *ImageSource `json:"source,omitempty"` Source *ImageSource `json:"source,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Content string `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"`
} }
type Message struct { type Message struct {
@ -23,6 +29,18 @@ type Message struct {
Content []Content `json:"content"` Content []Content `json:"content"`
} }
type Tool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema InputSchema `json:"input_schema"`
}
type InputSchema struct {
Type string `json:"type"`
Properties any `json:"properties,omitempty"`
Required any `json:"required,omitempty"`
}
type Request struct { type Request struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
@ -33,6 +51,8 @@ type Request struct {
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
//Metadata `json:"metadata,omitempty"` //Metadata `json:"metadata,omitempty"`
} }
@ -61,6 +81,7 @@ type Response struct {
type Delta struct { type Delta struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text"` Text string `json:"text"`
PartialJson string `json:"partial_json,omitempty"`
StopReason *string `json:"stop_reason"` StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"` StopSequence *string `json:"stop_sequence"`
} }

View File

@ -1,17 +1,16 @@
package aws package aws
import ( import (
"github.com/aws/aws-sdk-go-v2/aws" "errors"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/songquanpeng/one-api/common/ctxkey"
"io" "io"
"net/http" "net/http"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic" "github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
) )
@ -19,18 +18,52 @@ import (
var _ adaptor.Adaptor = new(Adaptor) var _ adaptor.Adaptor = new(Adaptor)
type Adaptor struct { type Adaptor struct {
meta *meta.Meta awsAdapter utils.AwsAdapter
awsClient *bedrockruntime.Client
Meta *meta.Meta
AwsClient *bedrockruntime.Client
} }
func (a *Adaptor) Init(meta *meta.Meta) { func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta a.Meta = meta
a.awsClient = bedrockruntime.New(bedrockruntime.Options{ a.AwsClient = bedrockruntime.New(bedrockruntime.Options{
Region: meta.Config.Region, Region: meta.Config.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
}) })
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
adaptor := GetAdaptor(request.Model)
if adaptor == nil {
return nil, errors.New("adaptor not found")
}
a.awsAdapter = adaptor
return adaptor.ConvertRequest(c, relayMode, request)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if a.awsAdapter == nil {
return nil, utils.WrapErr(errors.New("awsAdapter is nil"))
}
return a.awsAdapter.DoResponse(c, a.AwsClient, meta)
}
func (a *Adaptor) GetModelList() (models []string) {
for model := range adaptors {
models = append(models, model)
}
return
}
func (a *Adaptor) GetChannelName() string {
return "aws"
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil return "", nil
} }
@ -39,17 +72,6 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
claudeReq := anthropic.ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
@ -60,23 +82,3 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil return nil, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, a.awsClient)
} else {
err, usage = Handler(c, a.awsClient, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() (models []string) {
for n := range awsModelIDMap {
models = append(models, n)
}
return
}
func (a *Adaptor) GetChannelName() string {
return "aws"
}

View File

@ -0,0 +1,37 @@
package aws
import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var _ utils.AwsAdapter = new(Adaptor)
type Adaptor struct {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
claudeReq := anthropic.ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, awsCli)
} else {
err, usage = Handler(c, awsCli, meta.ActualModelName)
}
return
}

View File

@ -5,7 +5,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/ctxkey"
"io" "io"
"net/http" "net/http"
@ -16,33 +15,28 @@ import (
"github.com/jinzhu/copier" "github.com/jinzhu/copier"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic" "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
) )
func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: fmt.Sprintf("%s", err.Error()),
},
}
}
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var awsModelIDMap = map[string]string{ var AwsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1", "claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2", "claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1", "claude-2.1": "anthropic.claude-v2:1",
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
} }
func awsModelID(requestModel string) (string, error) { func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := awsModelIDMap[requestModel]; ok { if awsModelID, ok := AwsModelIDMap[requestModel]; ok {
return awsModelID, nil return awsModelID, nil
} }
@ -52,7 +46,7 @@ func awsModelID(requestModel string) (string, error) {
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsReq := &bedrockruntime.InvokeModelInput{ awsReq := &bedrockruntime.InvokeModelInput{
@ -63,30 +57,30 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok { if !ok {
return wrapErr(errors.New("request not found")), nil return utils.WrapErr(errors.New("request not found")), nil
} }
claudeReq := claudeReq_.(*anthropic.Request) claudeReq := claudeReq_.(*anthropic.Request)
awsClaudeReq := &Request{ awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31", AnthropicVersion: "bedrock-2023-05-31",
} }
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil return utils.WrapErr(errors.Wrap(err, "copy request")), nil
} }
awsReq.Body, err = json.Marshal(awsClaudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
} }
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModel")), nil return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil
} }
claudeResponse := new(anthropic.Response) claudeResponse := new(anthropic.Response)
err = json.Unmarshal(awsResp.Body, claudeResponse) err = json.Unmarshal(awsResp.Body, claudeResponse)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "unmarshal response")), nil return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
} }
openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse)
@ -106,7 +100,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
@ -117,7 +111,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok { if !ok {
return wrapErr(errors.New("request not found")), nil return utils.WrapErr(errors.New("request not found")), nil
} }
claudeReq := claudeReq_.(*anthropic.Request) claudeReq := claudeReq_.(*anthropic.Request)
@ -125,16 +119,16 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
AnthropicVersion: "bedrock-2023-05-31", AnthropicVersion: "bedrock-2023-05-31",
} }
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil return utils.WrapErr(errors.Wrap(err, "copy request")), nil
} }
awsReq.Body, err = json.Marshal(awsClaudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
} }
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
} }
stream := awsResp.GetStream() stream := awsResp.GetStream()
defer stream.Close() defer stream.Close()
@ -142,6 +136,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage var usage relaymodel.Usage
var id string var id string
var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events() event, ok := <-stream.Events()
if !ok { if !ok {
@ -162,8 +158,19 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
if meta != nil { if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens usage.CompletionTokens += meta.Usage.OutputTokens
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
id = fmt.Sprintf("chatcmpl-%s", meta.Id) id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true return true
} else { // finish_reason case
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
lastArgs.Arguments = "{}"
response.Choices[len(response.Choices)-1].Delta.Content = nil
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
}
}
}
} }
if response == nil { if response == nil {
return true return true
@ -171,6 +178,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
response.Id = id response.Id = id
response.Model = c.GetString(ctxkey.OriginalModel) response.Model = c.GetString(ctxkey.OriginalModel)
response.Created = createdTime response.Created = createdTime
for _, choice := range response.Choices {
if len(choice.Delta.ToolCalls) > 0 {
lastToolCallChoice = choice
}
}
jsonStr, err := json.Marshal(response) jsonStr, err := json.Marshal(response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError("error marshalling stream response: " + err.Error())

View File

@ -9,9 +9,12 @@ type Request struct {
// AnthropicVersion should be "bedrock-2023-05-31" // AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"` AnthropicVersion string `json:"anthropic_version"`
Messages []anthropic.Message `json:"messages"` Messages []anthropic.Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
} }

View File

@ -0,0 +1,37 @@
package aws
import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var _ utils.AwsAdapter = new(Adaptor)
type Adaptor struct {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
llamaReq := ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, llamaReq)
return llamaReq, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, awsCli)
} else {
err, usage = Handler(c, awsCli, meta.ActualModelName)
}
return
}

View File

@ -0,0 +1,231 @@
// Package aws provides the AWS adaptor for the relay service.
package aws
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"text/template"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/random"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
// Only support llama-3-8b and llama-3-70b instruction models
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var AwsModelIDMap = map[string]string{
"llama3-8b-8192": "meta.llama3-8b-instruct-v1:0",
"llama3-70b-8192": "meta.llama3-70b-instruct-v1:0",
}
func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := AwsModelIDMap[requestModel]; ok {
return awsModelID, nil
}
return "", errors.Errorf("model %s not found", requestModel)
}
// promptTemplate with range
const promptTemplate = `<|begin_of_text|>{{range .Messages}}<|start_header_id|>{{.Role}}<|end_header_id|>{{.StringContent}}<|eot_id|>{{end}}<|start_header_id|>assistant<|end_header_id|>
`
var promptTpl = template.Must(template.New("llama3-chat").Parse(promptTemplate))
func RenderPrompt(messages []relaymodel.Message) string {
var buf bytes.Buffer
err := promptTpl.Execute(&buf, struct{ Messages []relaymodel.Message }{messages})
if err != nil {
logger.SysError("error rendering prompt messages: " + err.Error())
}
return buf.String()
}
func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request {
llamaRequest := Request{
MaxGenLen: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
}
if llamaRequest.MaxGenLen == 0 {
llamaRequest.MaxGenLen = 2048
}
prompt := RenderPrompt(textRequest.Messages)
llamaRequest.Prompt = prompt
return &llamaRequest
}
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
llamaReq, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return utils.WrapErr(errors.New("request not found")), nil
}
awsReq.Body, err = json.Marshal(llamaReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil
}
var llamaResponse Response
err = json.Unmarshal(awsResp.Body, &llamaResponse)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
}
openaiResp := ResponseLlama2OpenAI(&llamaResponse)
openaiResp.Model = modelName
usage := relaymodel.Usage{
PromptTokens: llamaResponse.PromptTokenCount,
CompletionTokens: llamaResponse.GenerationTokenCount,
TotalTokens: llamaResponse.PromptTokenCount + llamaResponse.GenerationTokenCount,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
}
func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse {
var responseText string
if len(llamaResponse.Generation) > 0 {
responseText = llamaResponse.Generation
}
choice := openai.TextResponseChoice{
Index: 0,
Message: relaymodel.Message{
Role: "assistant",
Content: responseText,
Name: nil,
},
FinishReason: llamaResponse.StopReason,
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
llamaReq, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return utils.WrapErr(errors.New("request not found")), nil
}
awsReq.Body, err = json.Marshal(llamaReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
}
stream := awsResp.GetStream()
defer stream.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
var llamaResp StreamResponse
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&llamaResp)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return false
}
if llamaResp.PromptTokenCount > 0 {
usage.PromptTokens = llamaResp.PromptTokenCount
}
if llamaResp.StopReason == "stop" {
usage.CompletionTokens = llamaResp.GenerationTokenCount
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
response := StreamResponseLlama2OpenAI(&llamaResp)
response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID())
response.Model = c.GetString(ctxkey.OriginalModel)
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case *types.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return false
default:
fmt.Println("union is nil or unknown type")
return false
}
})
return nil, &usage
}
func StreamResponseLlama2OpenAI(llamaResponse *StreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = llamaResponse.Generation
choice.Delta.Role = "assistant"
finishReason := llamaResponse.StopReason
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &openaiResponse
}

View File

@ -0,0 +1,45 @@
package aws_test
import (
"testing"
aws "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/stretchr/testify/assert"
)
func TestRenderPrompt(t *testing.T) {
messages := []relaymodel.Message{
{
Role: "user",
Content: "What's your name?",
},
}
prompt := aws.RenderPrompt(messages)
expected := `<|begin_of_text|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
`
assert.Equal(t, expected, prompt)
messages = []relaymodel.Message{
{
Role: "system",
Content: "Your name is Kat. You are a detective.",
},
{
Role: "user",
Content: "What's your name?",
},
{
Role: "assistant",
Content: "Kat",
},
{
Role: "user",
Content: "What's your job?",
},
}
prompt = aws.RenderPrompt(messages)
expected = `<|begin_of_text|><|start_header_id|>system<|end_header_id|>Your name is Kat. You are a detective.<|eot_id|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>Kat<|eot_id|><|start_header_id|>user<|end_header_id|>What's your job?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
`
assert.Equal(t, expected, prompt)
}

View File

@ -0,0 +1,29 @@
package aws
// Request is the request to AWS Llama3
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
type Request struct {
Prompt string `json:"prompt"`
MaxGenLen int `json:"max_gen_len,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
}
// Response is the response from AWS Llama3
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
type Response struct {
Generation string `json:"generation"`
PromptTokenCount int `json:"prompt_token_count"`
GenerationTokenCount int `json:"generation_token_count"`
StopReason string `json:"stop_reason"`
}
// {'generation': 'Hi', 'prompt_token_count': 15, 'generation_token_count': 1, 'stop_reason': None}
type StreamResponse struct {
Generation string `json:"generation"`
PromptTokenCount int `json:"prompt_token_count"`
GenerationTokenCount int `json:"generation_token_count"`
StopReason string `json:"stop_reason"`
}

View File

@ -0,0 +1,39 @@
package aws
import (
claude "github.com/songquanpeng/one-api/relay/adaptor/aws/claude"
llama3 "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
)
type AwsModelType int
const (
AwsClaude AwsModelType = iota + 1
AwsLlama3
)
var (
adaptors = map[string]AwsModelType{}
)
func init() {
for model := range claude.AwsModelIDMap {
adaptors[model] = AwsClaude
}
for model := range llama3.AwsModelIDMap {
adaptors[model] = AwsLlama3
}
}
func GetAdaptor(model string) utils.AwsAdapter {
adaptorType := adaptors[model]
switch adaptorType {
case AwsClaude:
return &claude.Adaptor{}
case AwsLlama3:
return &llama3.Adaptor{}
default:
return nil
}
}

View File

@ -0,0 +1,51 @@
package utils
import (
"errors"
"io"
"net/http"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
type AwsAdapter interface {
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
}
type Adaptor struct {
Meta *meta.Meta
AwsClient *bedrockruntime.Client
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.Meta = meta
a.AwsClient = bedrockruntime.New(bedrockruntime.Options{
Region: meta.Config.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
})
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
return nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}

View File

@ -0,0 +1,16 @@
package utils
import (
"net/http"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func WrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: err.Error(),
},
}
}

View File

@ -5,6 +5,13 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/client"
@ -12,11 +19,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"sync"
"time"
) )
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
@ -137,40 +139,22 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage var usage model.Usage
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil common.SetEventStreamHeaders(c)
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format if len(data) < 6 {
continue continue
} }
data = data[6:] data = data[6:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var baiduResponse ChatStreamResponse var baiduResponse ChatStreamResponse
err := json.Unmarshal([]byte(data), &baiduResponse) err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true continue
} }
if baiduResponse.Usage.TotalTokens != 0 { if baiduResponse.Usage.TotalTokens != 0 {
usage.TotalTokens = baiduResponse.Usage.TotalTokens usage.TotalTokens = baiduResponse.Usage.TotalTokens
@ -178,18 +162,18 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
} }
response := streamResponseBaidu2OpenAI(&baiduResponse) response := streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError(err.Error())
return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
})
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil

View File

@ -5,11 +5,13 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
) )
type Adaptor struct { type Adaptor struct {
@ -27,8 +29,33 @@ func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta a.meta = meta
} }
// WorkerAI cannot be used across accounts with AIGateWay
// https://developers.cloudflare.com/ai-gateway/providers/workersai/#openai-compatible-endpoints
// https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/workers-ai
func (a *Adaptor) isAIGateWay(baseURL string) bool {
return strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai")
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil isAIGateWay := a.isAIGateWay(meta.BaseURL)
var urlPrefix string
if isAIGateWay {
urlPrefix = meta.BaseURL
} else {
urlPrefix = fmt.Sprintf("%s/client/v4/accounts/%s/ai", meta.BaseURL, meta.Config.UserID)
}
switch meta.Mode {
case relaymode.ChatCompletions:
return fmt.Sprintf("%s/v1/chat/completions", urlPrefix), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/v1/embeddings", urlPrefix), nil
default:
if isAIGateWay {
return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModelName), nil
}
return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModelName), nil
}
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
@ -41,7 +68,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
return ConvertRequest(*request), nil switch relayMode {
case relaymode.Completions:
return ConvertCompletionsRequest(*request), nil
case relaymode.ChatCompletions, relaymode.Embeddings:
return request, nil
default:
return nil, errors.New("not implemented")
}
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {

View File

@ -2,12 +2,14 @@ package cloudflare
import ( import (
"bufio" "bufio"
"bytes"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/render"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@ -16,114 +18,66 @@ import (
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
) )
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request {
var promptBuilder strings.Builder p, _ := textRequest.Prompt.(string)
for _, message := range textRequest.Messages {
promptBuilder.WriteString(message.StringContent())
promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息
}
return &Request{ return &Request{
Prompt: p,
MaxTokens: textRequest.MaxTokens, MaxTokens: textRequest.MaxTokens,
Prompt: promptBuilder.String(),
Stream: textRequest.Stream, Stream: textRequest.Stream,
Temperature: textRequest.Temperature, Temperature: textRequest.Temperature,
} }
} }
func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: cloudflareResponse.Result.Response,
},
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = cloudflareResponse.Response
choice.Delta.Role = "assistant"
openaiResponse := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
Created: helper.GetTimestamp(),
}
return &openaiResponse
}
func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil common.SetEventStreamHeaders(c)
} id := helper.GetResponseID(c)
if i := bytes.IndexByte(data, '\n'); i >= 0 { responseModel := c.GetString(ctxkey.OriginalModel)
return i + 1, data[0:i], nil var responseText string
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if len(data) < len("data: ") { if len(data) < len("data: ") {
continue continue
} }
data = strings.TrimPrefix(data, "data: ") data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
responseModel := c.GetString("original_model")
var responseText string
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r") data = strings.TrimSuffix(data, "\r")
var cloudflareResponse StreamResponse
err := json.Unmarshal([]byte(data), &cloudflareResponse) if data == "[DONE]" {
break
}
var response openai.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true continue
} }
response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) for _, v := range response.Choices {
if response == nil { v.Delta.Role = "assistant"
return true responseText += v.Delta.StringContent()
} }
responseText += cloudflareResponse.Response
response.Id = id response.Id = id
response.Model = responseModel response.Model = modelName
jsonStr, err := json.Marshal(response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError(err.Error())
return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
})
_ = resp.Body.Close() if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens) usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens)
return nil, usage return nil, usage
} }
@ -137,22 +91,25 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
var cloudflareResponse Response var response openai.TextResponse
err = json.Unmarshal(responseBody, &cloudflareResponse) err = json.Unmarshal(responseBody, &response)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse) response.Model = modelName
fullTextResponse.Model = modelName var responseText string
usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens) for _, v := range response.Choices {
fullTextResponse.Usage = *usage responseText += v.Message.Content.(string)
fullTextResponse.Id = helper.GetResponseID(c) }
jsonResponse, err := json.Marshal(fullTextResponse) usage := openai.ResponseText2Usage(responseText, modelName, promptTokens)
response.Usage = *usage
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse) _, _ = c.Writer.Write(jsonResponse)
return nil, usage return nil, usage
} }

View File

@ -1,6 +1,9 @@
package cloudflare package cloudflare
import "github.com/songquanpeng/one-api/relay/model"
type Request struct { type Request struct {
Messages []model.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"` Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
@ -8,18 +11,3 @@ type Request struct {
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
} }
type Result struct {
Response string `json:"response"`
}
type Response struct {
Result Result `json:"result"`
Success bool `json:"success"`
Errors []string `json:"errors"`
Messages []string `json:"messages"`
}
type StreamResponse struct {
Response string `json:"response"`
}

View File

@ -2,9 +2,9 @@ package cohere
import ( import (
"bufio" "bufio"
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -134,66 +134,53 @@ func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse {
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
var usage model.Usage var usage model.Usage
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := scanner.Text()
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r") data = strings.TrimSuffix(data, "\r")
var cohereResponse StreamResponse var cohereResponse StreamResponse
err := json.Unmarshal([]byte(data), &cohereResponse) err := json.Unmarshal([]byte(data), &cohereResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true continue
} }
response, meta := StreamResponseCohere2OpenAI(&cohereResponse) response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
if meta != nil { if meta != nil {
usage.PromptTokens += meta.Meta.Tokens.InputTokens usage.PromptTokens += meta.Meta.Tokens.InputTokens
usage.CompletionTokens += meta.Meta.Tokens.OutputTokens usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
return true continue
} }
if response == nil { if response == nil {
return true continue
} }
response.Id = fmt.Sprintf("chatcmpl-%d", createdTime) response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
response.Model = c.GetString("original_model") response.Model = c.GetString("original_model")
response.Created = createdTime response.Created = createdTime
jsonStr, err := json.Marshal(response)
err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError(err.Error())
return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
})
_ = resp.Body.Close() if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage return nil, &usage
} }

View File

@ -4,6 +4,11 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
@ -12,9 +17,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype" "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
) )
// https://www.coze.com/open // https://www.coze.com/open
@ -109,69 +111,54 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
var responseText string var responseText string
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil common.SetEventStreamHeaders(c)
} var modelName string
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if len(data) < 5 { if len(data) < 5 || !strings.HasPrefix(data, "data:") {
continue
}
if !strings.HasPrefix(data, "data:") {
continue continue
} }
data = strings.TrimPrefix(data, "data:") data = strings.TrimPrefix(data, "data:")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
var modelName string
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r") data = strings.TrimSuffix(data, "\r")
var cozeResponse StreamResponse var cozeResponse StreamResponse
err := json.Unmarshal([]byte(data), &cozeResponse) err := json.Unmarshal([]byte(data), &cozeResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true continue
} }
response, _ := StreamResponseCoze2OpenAI(&cozeResponse) response, _ := StreamResponseCoze2OpenAI(&cozeResponse)
if response == nil { if response == nil {
return true continue
} }
for _, choice := range response.Choices { for _, choice := range response.Choices {
responseText += conv.AsString(choice.Delta.Content) responseText += conv.AsString(choice.Delta.Content)
} }
response.Model = modelName response.Model = modelName
response.Created = createdTime response.Created = createdTime
jsonStr, err := json.Marshal(response)
err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError(err.Error())
return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
})
_ = resp.Body.Close() if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &responseText return nil, &responseText
} }

View File

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -275,21 +276,10 @@ func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.Embeddi
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := "" responseText := ""
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil common.SetEventStreamHeaders(c)
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
data = strings.TrimSpace(data) data = strings.TrimSpace(data)
@ -298,41 +288,38 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
data = strings.TrimPrefix(data, "data: ") data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\"") data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var geminiResponse ChatResponse var geminiResponse ChatResponse
err := json.Unmarshal([]byte(data), &geminiResponse) err := json.Unmarshal([]byte(data), &geminiResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true continue
} }
response := streamResponseGeminiChat2OpenAI(&geminiResponse) response := streamResponseGeminiChat2OpenAI(&geminiResponse)
if response == nil { if response == nil {
return true continue
} }
responseText += response.Choices[0].Delta.StringContent() responseText += response.Choices[0].Delta.StringContent()
jsonResponse, err := json.Marshal(response)
err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError(err.Error())
return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
})
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
return nil, responseText return nil, responseText
} }

View File

@ -0,0 +1,19 @@
package novita
// https://novita.ai/llm-api
var ModelList = []string{
"meta-llama/llama-3-8b-instruct",
"meta-llama/llama-3-70b-instruct",
"nousresearch/hermes-2-pro-llama-3-8b",
"nousresearch/nous-hermes-llama2-13b",
"mistralai/mistral-7b-instruct",
"cognitivecomputations/dolphin-mixtral-8x22b",
"sao10k/l3-70b-euryale-v2.1",
"sophosympatheia/midnight-rose-70b",
"gryphe/mythomax-l2-13b",
"Nous-Hermes-2-Mixtral-8x7B-DPO",
"lzlv_70b",
"teknium/openhermes-2.5-mistral-7b",
"microsoft/wizardlm-2-8x22b",
}

View File

@ -0,0 +1,15 @@
package novita
import (
"fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
)
func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
return fmt.Sprintf("%s/chat/completions", meta.BaseURL), nil
}
return "", fmt.Errorf("unsupported relay mode %d for novita", meta.Mode)
}

View File

@ -5,12 +5,14 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/common/random"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/random"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/image"
@ -105,54 +107,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return 0, nil, nil return 0, nil, nil
} }
if i := strings.Index(string(data), "}\n"); i >= 0 { if i := strings.Index(string(data), "}\n"); i >= 0 {
return i + 2, data[0:i], nil return i + 2, data[0 : i+1], nil
} }
if atEOF { if atEOF {
return len(data), data, nil return len(data), data, nil
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
stopChan := make(chan bool) common.SetEventStreamHeaders(c)
go func() {
for scanner.Scan() { for scanner.Scan() {
data := strings.TrimPrefix(scanner.Text(), "}") data := strings.TrimPrefix(scanner.Text(), "}")
dataChan <- data + "}" data = data + "}"
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var ollamaResponse ChatResponse var ollamaResponse ChatResponse
err := json.Unmarshal([]byte(data), &ollamaResponse) err := json.Unmarshal([]byte(data), &ollamaResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true continue
} }
if ollamaResponse.EvalCount != 0 { if ollamaResponse.EvalCount != 0 {
usage.PromptTokens = ollamaResponse.PromptEvalCount usage.PromptTokens = ollamaResponse.PromptEvalCount
usage.CompletionTokens = ollamaResponse.EvalCount usage.CompletionTokens = ollamaResponse.EvalCount
usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
} }
response := streamResponseOllama2OpenAI(&ollamaResponse) response := streamResponseOllama2OpenAI(&ollamaResponse)
jsonResponse, err := json.Marshal(response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError(err.Error())
return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
})
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, &usage return nil, &usage
} }

View File

@ -3,17 +3,19 @@ package openai
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/doubao" "github.com/songquanpeng/one-api/relay/adaptor/doubao"
"github.com/songquanpeng/one-api/relay/adaptor/minimax" "github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
) )
type Adaptor struct { type Adaptor struct {
@ -48,6 +50,8 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return minimax.GetRequestURL(meta) return minimax.GetRequestURL(meta)
case channeltype.Doubao: case channeltype.Doubao:
return doubao.GetRequestURL(meta) return doubao.GetRequestURL(meta)
case channeltype.Novita:
return novita.GetRequestURL(meta)
default: default:
return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/minimax" "github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/mistral" "github.com/songquanpeng/one-api/relay/adaptor/mistral"
"github.com/songquanpeng/one-api/relay/adaptor/moonshot" "github.com/songquanpeng/one-api/relay/adaptor/moonshot"
"github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/adaptor/stepfun" "github.com/songquanpeng/one-api/relay/adaptor/stepfun"
"github.com/songquanpeng/one-api/relay/adaptor/togetherai" "github.com/songquanpeng/one-api/relay/adaptor/togetherai"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
@ -28,6 +29,7 @@ var CompatibleChannels = []int{
channeltype.StepFun, channeltype.StepFun,
channeltype.DeepSeek, channeltype.DeepSeek,
channeltype.TogetherAI, channeltype.TogetherAI,
channeltype.Novita,
} }
func GetCompatibleChannelMeta(channelType int) (string, []string) { func GetCompatibleChannelMeta(channelType int) (string, []string) {
@ -56,6 +58,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
return "together.ai", togetherai.ModelList return "together.ai", togetherai.ModelList
case channeltype.Doubao: case channeltype.Doubao:
return "doubao", doubao.ModelList return "doubao", doubao.ModelList
case channeltype.Novita:
return "novita", novita.ModelList
default: default:
return "openai", ModelList return "openai", ModelList
} }

View File

@ -4,15 +4,18 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"io"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common/render"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
) )
const ( const (
@ -24,22 +27,12 @@ const (
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
responseText := "" responseText := ""
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
var usage *model.Usage var usage *model.Usage
go func() {
common.SetEventStreamHeaders(c)
doneRendered := false
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if len(data) < dataPrefixLength { // ignore blank line or wrong format if len(data) < dataPrefixLength { // ignore blank line or wrong format
@ -49,7 +42,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
continue continue
} }
if strings.HasPrefix(data[dataPrefixLength:], done) { if strings.HasPrefix(data[dataPrefixLength:], done) {
dataChan <- data render.StringData(c, data)
doneRendered = true
continue continue
} }
switch relayMode { switch relayMode {
@ -58,14 +52,14 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
dataChan <- data // if error happened, pass the data to client render.StringData(c, data) // if error happened, pass the data to client
continue // just ignore the error continue // just ignore the error
} }
if len(streamResponse.Choices) == 0 { if len(streamResponse.Choices) == 0 {
// but for empty choice, we should not pass it to client, this is for azure // but for empty choice, we should not pass it to client, this is for azure
continue // just ignore empty choice continue // just ignore empty choice
} }
dataChan <- data render.StringData(c, data)
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseText += conv.AsString(choice.Delta.Content) responseText += conv.AsString(choice.Delta.Content)
} }
@ -73,7 +67,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
usage = streamResponse.Usage usage = streamResponse.Usage
} }
case relaymode.Completions: case relaymode.Completions:
dataChan <- data render.StringData(c, data)
var streamResponse CompletionsStreamResponse var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil { if err != nil {
@ -85,27 +79,20 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
} }
} }
} }
stopChan <- true
}() if err := scanner.Err(); err != nil {
common.SetEventStreamHeaders(c) logger.SysError("error reading stream: " + err.Error())
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
} }
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r") if !doneRendered {
c.Render(-1, common.CustomEvent{Data: data}) render.Done(c)
return true
case <-stopChan:
return false
} }
})
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
} }
return nil, responseText, usage return nil, responseText, usage
} }
@ -149,7 +136,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
if textResponse.Usage.TotalTokens == 0 { if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) {
completionTokens := 0 completionTokens := 0
for _, choice := range textResponse.Choices { for _, choice := range textResponse.Choices {
completionTokens += CountTokenText(choice.Message.StringContent(), modelName) completionTokens += CountTokenText(choice.Message.StringContent(), modelName)

View File

@ -3,6 +3,10 @@ package palm
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@ -11,8 +15,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
) )
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
@ -77,58 +79,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
responseText := "" responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID())
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool) common.SetEventStreamHeaders(c)
go func() {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
logger.SysError("error reading stream response: " + err.Error()) logger.SysError("error reading stream response: " + err.Error())
stopChan <- true err := resp.Body.Close()
return if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), ""
}
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
logger.SysError("error closing stream response: " + err.Error()) return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
stopChan <- true
return
} }
var palmResponse ChatResponse var palmResponse ChatResponse
err = json.Unmarshal(responseBody, &palmResponse) err = json.Unmarshal(responseBody, &palmResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), ""
return
} }
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
fullTextResponse.Id = responseId fullTextResponse.Id = responseId
fullTextResponse.Created = createdTime fullTextResponse.Created = createdTime
if len(palmResponse.Candidates) > 0 { if len(palmResponse.Candidates) > 0 {
responseText = palmResponse.Candidates[0].Content responseText = palmResponse.Candidates[0].Content
} }
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError("error marshalling stream response: " + err.Error())
stopChan <- true return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), ""
return
} }
dataChan <- string(jsonResponse)
stopChan <- true err = render.ObjectData(c, string(jsonResponse))
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
c.Render(-1, common.CustomEvent{Data: "data: " + data})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" logger.SysError(err.Error())
} }
render.Done(c)
return nil, responseText return nil, responseText
} }

View File

@ -8,6 +8,13 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
@ -17,11 +24,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
"strings"
"time"
) )
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
@ -87,64 +89,46 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
var responseText string var responseText string
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil common.SetEventStreamHeaders(c)
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format if len(data) < 5 || !strings.HasPrefix(data, "data:") {
continue continue
} }
if data[:5] != "data:" { data = strings.TrimPrefix(data, "data:")
continue
} var tencentResponse ChatResponse
data = data[5:] err := json.Unmarshal([]byte(data), &tencentResponse)
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var TencentResponse ChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true continue
} }
response := streamResponseTencent2OpenAI(&TencentResponse)
response := streamResponseTencent2OpenAI(&tencentResponse)
if len(response.Choices) != 0 { if len(response.Choices) != 0 {
responseText += conv.AsString(response.Choices[0].Delta.Content) responseText += conv.AsString(response.Choices[0].Delta.Content)
} }
jsonResponse, err := json.Marshal(response)
err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError(err.Error())
return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
})
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
return nil, responseText return nil, responseText
} }

View File

@ -6,4 +6,5 @@ var ModelList = []string{
"SparkDesk-v2.1", "SparkDesk-v2.1",
"SparkDesk-v3.1", "SparkDesk-v3.1",
"SparkDesk-v3.5", "SparkDesk-v3.5",
"SparkDesk-v4.0",
} }

View File

@ -44,7 +44,7 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages xunfeiRequest.Payload.Message.Text = messages
if strings.HasPrefix(domain, "generalv3") { if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" {
functions := make([]model.Function, len(request.Tools)) functions := make([]model.Function, len(request.Tools))
for i, tool := range request.Tools { for i, tool := range request.Tools {
functions[i] = tool.Function functions[i] = tool.Function
@ -290,6 +290,8 @@ func apiVersion2domain(apiVersion string) string {
return "generalv3" return "generalv3"
case "v3.5": case "v3.5":
return "generalv3.5" return "generalv3.5"
case "v4.0":
return "4.0Ultra"
} }
return "general" + apiVersion return "general" + apiVersion
} }

View File

@ -3,6 +3,13 @@ package zhipu
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
@ -11,11 +18,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"sync"
"time"
) )
// https://open.bigmodel.cn/doc/api#chatglm_std // https://open.bigmodel.cn/doc/api#chatglm_std
@ -155,10 +157,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
metaChan := make(chan string) common.SetEventStreamHeaders(c)
stopChan := make(chan bool)
go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
lines := strings.Split(data, "\n") lines := strings.Split(data, "\n")
@ -166,55 +167,45 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
if len(line) < 5 { if len(line) < 5 {
continue continue
} }
if line[:5] == "data:" { if strings.HasPrefix(line, "data:") {
dataChan <- line[5:] dataSegment := line[5:]
if i != len(lines)-1 { if i != len(lines)-1 {
dataChan <- "\n" dataSegment += "\n"
} }
} else if line[:5] == "meta:" { response := streamResponseZhipu2OpenAI(dataSegment)
metaChan <- line[5:] err := render.ObjectData(c, response)
}
}
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
response := streamResponseZhipu2OpenAI(data)
jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError("error marshalling stream response: " + err.Error())
return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) } else if strings.HasPrefix(line, "meta:") {
return true metaSegment := line[5:]
case data := <-metaChan:
var zhipuResponse StreamMetaResponse var zhipuResponse StreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse) err := json.Unmarshal([]byte(metaSegment), &zhipuResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true continue
} }
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError("error marshalling stream response: " + err.Error())
return true
} }
usage = zhipuUsage usage = zhipuUsage
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) }
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, usage return nil, usage
} }

View File

@ -2,6 +2,7 @@ package ratio
import ( import (
"encoding/json" "encoding/json"
"fmt"
"strings" "strings"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
@ -75,6 +76,7 @@ var ModelRatio = map[string]float64{
"claude-2.1": 8.0 / 1000 * USD, "claude-2.1": 8.0 / 1000 * USD,
"claude-3-haiku-20240307": 0.25 / 1000 * USD, "claude-3-haiku-20240307": 0.25 / 1000 * USD,
"claude-3-sonnet-20240229": 3.0 / 1000 * USD, "claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-5-sonnet-20240620": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD, "claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-4.0-8K": 0.120 * RMB,
@ -124,6 +126,7 @@ var ModelRatio = map[string]float64{
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
@ -167,6 +170,9 @@ var ModelRatio = map[string]float64{
"step-1v-32k": 0.024 * RMB, "step-1v-32k": 0.024 * RMB,
"step-1-32k": 0.024 * RMB, "step-1-32k": 0.024 * RMB,
"step-1-200k": 0.15 * RMB, "step-1-200k": 0.15 * RMB,
// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/
"llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens
"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens
// https://cohere.com/pricing // https://cohere.com/pricing
"command": 0.5, "command": 0.5,
"command-nightly": 0.5, "command-nightly": 0.5,
@ -183,7 +189,11 @@ var ModelRatio = map[string]float64{
"deepl-ja": 25.0 / 1000 * USD, "deepl-ja": 25.0 / 1000 * USD,
} }
var CompletionRatio = map[string]float64{} var CompletionRatio = map[string]float64{
// aws llama3
"llama3-8b-8192(33)": 0.0006 / 0.0003,
"llama3-70b-8192(33)": 0.0035 / 0.00265,
}
var DefaultModelRatio map[string]float64 var DefaultModelRatio map[string]float64
var DefaultCompletionRatio map[string]float64 var DefaultCompletionRatio map[string]float64
@ -232,22 +242,28 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
return json.Unmarshal([]byte(jsonStr), &ModelRatio) return json.Unmarshal([]byte(jsonStr), &ModelRatio)
} }
func GetModelRatio(name string) float64 { func GetModelRatio(name string, channelType int) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet") name = strings.TrimSuffix(name, "-internet")
} }
if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet") name = strings.TrimSuffix(name, "-internet")
} }
ratio, ok := ModelRatio[name] model := fmt.Sprintf("%s(%d)", name, channelType)
if !ok { if ratio, ok := ModelRatio[model]; ok {
ratio, ok = DefaultModelRatio[name] return ratio
}
if ratio, ok := DefaultModelRatio[model]; ok {
return ratio
}
if ratio, ok := ModelRatio[name]; ok {
return ratio
}
if ratio, ok := DefaultModelRatio[name]; ok {
return ratio
} }
if !ok {
logger.SysError("model ratio not found: " + name) logger.SysError("model ratio not found: " + name)
return 30 return 30
}
return ratio
} }
func CompletionRatio2JSONString() string { func CompletionRatio2JSONString() string {
@ -263,7 +279,17 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
return json.Unmarshal([]byte(jsonStr), &CompletionRatio) return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
} }
func GetCompletionRatio(name string) float64 { func GetCompletionRatio(name string, channelType int) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
model := fmt.Sprintf("%s(%d)", name, channelType)
if ratio, ok := CompletionRatio[model]; ok {
return ratio
}
if ratio, ok := DefaultCompletionRatio[model]; ok {
return ratio
}
if ratio, ok := CompletionRatio[name]; ok { if ratio, ok := CompletionRatio[name]; ok {
return ratio return ratio
} }

View File

@ -42,5 +42,6 @@ const (
DeepL DeepL
TogetherAI TogetherAI
Doubao Doubao
Novita
Dummy Dummy
) )

View File

@ -42,6 +42,7 @@ var ChannelBaseURLs = []string{
"https://api-free.deepl.com", // 38 "https://api-free.deepl.com", // 38
"https://api.together.xyz", // 39 "https://api.together.xyz", // 39
"https://ark.cn-beijing.volces.com", // 40 "https://ark.cn-beijing.volces.com", // 40
"https://api.novita.ai/v3/openai", // 41
} }
func init() { func init() {

View File

@ -7,6 +7,10 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/client"
@ -21,9 +25,6 @@ import (
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
) )
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
@ -53,7 +54,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
} }
} }
modelRatio := billingratio.GetModelRatio(audioModel) modelRatio := billingratio.GetModelRatio(audioModel, channelType)
groupRatio := billingratio.GetGroupRatio(group) groupRatio := billingratio.GetGroupRatio(group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
var quota int64 var quota int64

View File

@ -4,6 +4,10 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"math"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
@ -16,9 +20,6 @@ import (
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"math"
"net/http"
"strings"
) )
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) { func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
@ -40,78 +41,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
return textRequest, nil return textRequest, nil
} }
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
imageRequest := &relaymodel.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
return imageRequest, nil
}
func isValidImageSize(model string, size string) bool {
if model == "cogview-3" {
return true
}
_, ok := billingratio.ImageSizeRatios[model][size]
return ok
}
func getImageSizeRatio(model string, size string) float64 {
ratio, ok := billingratio.ImageSizeRatios[model][size]
if !ok {
return 1
}
return ratio
}
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
// model validation
hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
if !hasValidSize {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
// check prompt length
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if !isWithinRange(imageRequest.Model, imageRequest.N) {
// channel not azure
if meta.ChannelType != channeltype.Azure {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
}
return nil
}
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
return imageCostRatio, nil
}
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
switch relayMode { switch relayMode {
case relaymode.ChatCompletions: case relaymode.ChatCompletions:
@ -167,7 +96,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
return return
} }
var quota int64 var quota int64
completionRatio := billingratio.GetCompletionRatio(textRequest.Model) completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType)
promptTokens := usage.PromptTokens promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens completionTokens := usage.CompletionTokens
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
@ -22,13 +23,84 @@ import (
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
) )
func isWithinRange(element string, value int) bool { func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
if _, ok := billingratio.ImageGenerationAmounts[element]; !ok { imageRequest := &relaymodel.ImageRequest{}
return false err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
} }
min := billingratio.ImageGenerationAmounts[element][0] if imageRequest.N == 0 {
max := billingratio.ImageGenerationAmounts[element][1] imageRequest.N = 1
return value >= min && value <= max }
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
return imageRequest, nil
}
func isValidImageSize(model string, size string) bool {
if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil {
return true
}
_, ok := billingratio.ImageSizeRatios[model][size]
return ok
}
func isValidImagePromptLength(model string, promptLength int) bool {
maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model]
return !ok || promptLength <= maxPromptLength
}
func isWithinRange(element string, value int) bool {
amounts, ok := billingratio.ImageGenerationAmounts[element]
return !ok || (value >= amounts[0] && value <= amounts[1])
}
func getImageSizeRatio(model string, size string) float64 {
if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok {
return ratio
}
return 1
}
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
// check prompt length
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
// model validation
if !isValidImageSize(imageRequest.Model, imageRequest.Size) {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if !isWithinRange(imageRequest.Model, imageRequest.N) {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
return nil
}
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
return imageCostRatio, nil
} }
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
@ -97,7 +169,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = bytes.NewBuffer(jsonStr) requestBody = bytes.NewBuffer(jsonStr)
} }
modelRatio := billingratio.GetModelRatio(imageModel) modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group) groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)

View File

@ -4,6 +4,9 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay"
@ -14,8 +17,6 @@ import (
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
) )
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
@ -35,7 +36,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping) textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model meta.ActualModelName = textRequest.Model
// get model ratio & group ratio // get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model) modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group) groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
// pre-consume quota // pre-consume quota

View File

@ -5,6 +5,7 @@ type Message struct {
Content any `json:"content,omitempty"` Content any `json:"content,omitempty"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
ToolCalls []Tool `json:"tool_calls,omitempty"` ToolCalls []Tool `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
} }
func (m Message) IsStringContent() bool { func (m Message) IsStringContent() bool {

View File

@ -2,13 +2,13 @@ package model
type Tool struct { type Tool struct {
Id string `json:"id,omitempty"` Id string `json:"id,omitempty"`
Type string `json:"type"` Type string `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty
Function Function `json:"function"` Function Function `json:"function"`
} }
type Function struct { type Function struct {
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Name string `json:"name"` Name string `json:"name,omitempty"` // when splicing claude tools stream messages, it is empty
Parameters any `json:"parameters,omitempty"` // request Parameters any `json:"parameters,omitempty"` // request
Arguments any `json:"arguments,omitempty"` // response Arguments any `json:"arguments,omitempty"` // response
} }

View File

@ -47,7 +47,7 @@ const PersonalSetting = () => {
const [countdown, setCountdown] = useState(30); const [countdown, setCountdown] = useState(30);
const [affLink, setAffLink] = useState(''); const [affLink, setAffLink] = useState('');
const [systemToken, setSystemToken] = useState(''); const [systemToken, setSystemToken] = useState('');
// const [models, setModels] = useState([]); const [models, setModels] = useState([]);
const [openTransfer, setOpenTransfer] = useState(false); const [openTransfer, setOpenTransfer] = useState(false);
const [transferAmount, setTransferAmount] = useState(0); const [transferAmount, setTransferAmount] = useState(0);
@ -72,7 +72,7 @@ const PersonalSetting = () => {
console.log(userState); console.log(userState);
} }
); );
// loadModels().then(); loadModels().then();
getAffLink().then(); getAffLink().then();
setTransferAmount(getQuotaPerUnit()); setTransferAmount(getQuotaPerUnit());
}, []); }, []);
@ -127,16 +127,16 @@ const PersonalSetting = () => {
} }
}; };
// const loadModels = async () => { const loadModels = async () => {
// let res = await API.get(`/api/user/models`); let res = await API.get(`/api/user/available_models`);
// const { success, message, data } = res.data; const { success, message, data } = res.data;
// if (success) { if (success) {
// setModels(data); setModels(data);
// console.log(data); console.log(data);
// } else { } else {
// showError(message); showError(message);
// } }
// }; };
const handleAffLinkClick = async (e) => { const handleAffLinkClick = async (e) => {
e.target.select(); e.target.select();
@ -344,7 +344,7 @@ const PersonalSetting = () => {
} }
> >
<Typography.Title heading={6}>调用信息</Typography.Title> <Typography.Title heading={6}>调用信息</Typography.Title>
{/* <Typography.Title heading={6}></Typography.Title> <p>可用模型可点击复制</p>
<div style={{ marginTop: 10 }}> <div style={{ marginTop: 10 }}>
<Space wrap> <Space wrap>
{models.map((model) => ( {models.map((model) => (
@ -355,7 +355,7 @@ const PersonalSetting = () => {
</Tag> </Tag>
))} ))}
</Space> </Space>
</div> */} </div>
</Card> </Card>
{/* <Card {/* <Card
footer={ footer={

View File

@ -63,7 +63,7 @@ const EditChannel = (props) => {
let localModels = []; let localModels = [];
switch (value) { switch (value) {
case 14: case 14:
localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"]; localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"];
break; break;
case 11: case 11:
localModels = ['PaLM-2']; localModels = ['PaLM-2'];
@ -78,7 +78,7 @@ const EditChannel = (props) => {
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break; break;
case 18: case 18:
localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5']; localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'];
break; break;
case 19: case 19:
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];

View File

@ -13,7 +13,7 @@ export const CHANNEL_OPTIONS = {
}, },
33: { 33: {
key: 33, key: 33,
text: 'AWS Claude', text: 'AWS',
value: 33, value: 33,
color: 'primary' color: 'primary'
}, },
@ -161,6 +161,12 @@ export const CHANNEL_OPTIONS = {
value: 39, value: 39,
color: 'primary' color: 'primary'
}, },
41: {
key: 41,
text: 'Novita',
value: 41,
color: 'purple'
},
8: { 8: {
key: 8, key: 8,
text: '自定义渠道', text: '自定义渠道',

View File

@ -91,7 +91,7 @@ const typeConfig = {
other: '版本号' other: '版本号'
}, },
input: { input: {
models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5'] models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']
}, },
prompt: { prompt: {
key: '按照如下格式输入APPID|APISecret|APIKey', key: '按照如下格式输入APPID|APISecret|APIKey',

View File

@ -1,5 +1,5 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; import { Button, Dropdown, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom'; import { Link } from 'react-router-dom';
import { import {
API, API,
@ -70,11 +70,31 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/?p=${startIdx}`); const res = await API.get(`/api/channel/?p=${startIdx}`);
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
let localChannels = data.map((channel) => {
if (channel.models === '') {
channel.models = [];
channel.test_model = "";
} else {
channel.models = channel.models.split(',');
if (channel.models.length > 0) {
channel.test_model = channel.models[0];
}
channel.model_options = channel.models.map((model) => {
return {
key: model,
text: model,
value: model,
}
})
console.log('channel', channel)
}
return channel;
});
if (startIdx === 0) { if (startIdx === 0) {
setChannels(data); setChannels(localChannels);
} else { } else {
let newChannels = [...channels]; let newChannels = [...channels];
newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...localChannels);
setChannels(newChannels); setChannels(newChannels);
} }
} else { } else {
@ -225,19 +245,31 @@ const ChannelsTable = () => {
setSearching(false); setSearching(false);
}; };
const testChannel = async (id, name, idx) => { const switchTestModel = async (idx, model) => {
const res = await API.get(`/api/channel/test/${id}/`); let newChannels = [...channels];
const { success, message, time } = res.data; let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].test_model = model;
setChannels(newChannels);
};
const testChannel = async (id, name, idx, m) => {
const res = await API.get(`/api/channel/test/${id}?model=${m}`);
const { success, message, time, model } = res.data;
if (success) { if (success) {
let newChannels = [...channels]; let newChannels = [...channels];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].response_time = time * 1000; newChannels[realIdx].response_time = time * 1000;
newChannels[realIdx].test_time = Date.now() / 1000; newChannels[realIdx].test_time = Date.now() / 1000;
setChannels(newChannels); setChannels(newChannels);
showInfo(`渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); showInfo(`渠道 ${name} 测试成功,模型 ${model}耗时 ${time.toFixed(2)} 秒。`);
} else { } else {
showError(message); showError(message);
} }
let newChannels = [...channels];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].response_time = time * 1000;
newChannels[realIdx].test_time = Date.now() / 1000;
setChannels(newChannels);
}; };
const testChannels = async (scope) => { const testChannels = async (scope) => {
@ -405,6 +437,7 @@ const ChannelsTable = () => {
> >
优先级 优先级
</Table.HeaderCell> </Table.HeaderCell>
<Table.HeaderCell>测试模型</Table.HeaderCell>
<Table.HeaderCell>操作</Table.HeaderCell> <Table.HeaderCell>操作</Table.HeaderCell>
</Table.Row> </Table.Row>
</Table.Header> </Table.Header>
@ -459,13 +492,24 @@ const ChannelsTable = () => {
basic basic
/> />
</Table.Cell> </Table.Cell>
<Table.Cell>
<Dropdown
placeholder='请选择测试模型'
selection
options={channel.model_options}
defaultValue={channel.test_model}
onChange={(event, data) => {
switchTestModel(idx, data.value);
}}
/>
</Table.Cell>
<Table.Cell> <Table.Cell>
<div> <div>
<Button <Button
size={'small'} size={'small'}
positive positive
onClick={() => { onClick={() => {
testChannel(channel.id, channel.name, idx); testChannel(channel.id, channel.name, idx, channel.test_model);
}} }}
> >
测试 测试

View File

@ -1,11 +1,12 @@
export const CHANNEL_OPTIONS = [ export const CHANNEL_OPTIONS = [
{key: 1, text: 'OpenAI', value: 1, color: 'green'}, {key: 1, text: 'OpenAI', value: 1, color: 'green'},
{key: 14, text: 'Anthropic Claude', value: 14, color: 'black'}, {key: 14, text: 'Anthropic Claude', value: 14, color: 'black'},
{key: 33, text: 'AWS Claude', value: 33, color: 'black'}, {key: 33, text: 'AWS', value: 33, color: 'black'},
{key: 3, text: 'Azure OpenAI', value: 3, color: 'olive'}, {key: 3, text: 'Azure OpenAI', value: 3, color: 'olive'},
{key: 11, text: 'Google PaLM2', value: 11, color: 'orange'}, {key: 11, text: 'Google PaLM2', value: 11, color: 'orange'},
{key: 24, text: 'Google Gemini', value: 24, color: 'orange'}, {key: 24, text: 'Google Gemini', value: 24, color: 'orange'},
{key: 28, text: 'Mistral AI', value: 28, color: 'orange'}, {key: 28, text: 'Mistral AI', value: 28, color: 'orange'},
{key: 41, text: 'Novita', value: 41, color: 'purple'},
{key: 40, text: '字节跳动豆包', value: 40, color: 'blue'}, {key: 40, text: '字节跳动豆包', value: 40, color: 'blue'},
{key: 15, text: '百度文心千帆', value: 15, color: 'blue'}, {key: 15, text: '百度文心千帆', value: 15, color: 'blue'},
{key: 17, text: '阿里通义千问', value: 17, color: 'orange'}, {key: 17, text: '阿里通义千问', value: 17, color: 'orange'},