diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 89ba75cd..36798711 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,26 +36,12 @@ jobs: # in the next step as well as the next job. - name: Test run: go test -cover -coverprofile=coverage.txt ./... - - - name: Archive code coverage results - uses: actions/upload-artifact@v4 + - uses: codecov/codecov-action@v4 with: - name: code-coverage - path: coverage.txt # Make sure to use the same file name you chose for the "-coverprofile" in the "Test" step - - code_coverage: - name: "Code coverage report" - if: github.event_name == 'pull_request' # Do not run when workflow is triggered by push to main branch - runs-on: ubuntu-latest - needs: unit_tests # Depends on the artifact uploaded by the "unit_tests" job - steps: - - uses: fgrosse/go-coverage-report@v1.0.2 # Consider using a Git revision for maximum security - with: - coverage-artifact-name: "code-coverage" # can be omitted if you used this default value - coverage-file-name: "coverage.txt" # can be omitted if you used this default value + token: ${{ secrets.CODECOV_TOKEN }} commit_lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - uses: wagoid/commitlint-github-action@v6 \ No newline at end of file + - uses: wagoid/commitlint-github-action@v6 diff --git a/Dockerfile b/Dockerfile index 6743b139..29b4ca71 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,9 @@ WORKDIR /web/air RUN npm install RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build -FROM golang AS builder2 +FROM golang:alpine AS builder2 + +RUN apk add --no-cache g++ ENV GO111MODULE=on \ CGO_ENABLED=1 \ @@ -27,7 +29,7 @@ ADD go.mod go.sum ./ RUN go mod download COPY . . COPY --from=builder /web/build ./web/build -RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api +RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api FROM alpine diff --git a/README.en.md b/README.en.md index bce47353..c9fdbbc8 100644 --- a/README.en.md +++ b/README.en.md @@ -101,7 +101,7 @@ Nginx reference configuration: ``` server{ server_name openai.justsong.cn; # Modify your domain name accordingly - + location / { client_max_body_size 64m; proxy_http_version 1.1; @@ -132,12 +132,12 @@ The initial account username is `root` and password is `123456`. 1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source: ```shell git clone https://github.com/songquanpeng/one-api.git - + # Build the frontend cd one-api/web/default npm install npm run build - + # Build the backend cd ../.. go mod download @@ -245,16 +245,41 @@ If the channel ID is not provided, load balancing will be used to distribute the + Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs` 5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` -6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. +6. 'MEMORY_CACHE_ENABLED': Enabling memory caching can cause a certain delay in updating user quotas, with optional values of 'true' and 'false'. If not set, it defaults to 'false'. +7. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. + Example: `SYNC_FREQUENCY=60` -7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. +8. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. + Example: `NODE_TYPE=slave` -8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. +9. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. + Example: `CHANNEL_UPDATE_FREQUENCY=1440` -9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. +10. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. + Example: `CHANNEL_TEST_FREQUENCY=1440` -10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. +11. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. + Example: `POLLING_INTERVAL=5` +12. `BATCH_UPDATE_ENABLED`: Enabling batch database update aggregation can cause a certain delay in updating user quotas. The optional values are 'true' and 'false', but if not set, it defaults to 'false'. + +Example: ` BATCH_UPDATE_ENABLED=true` + +If you encounter an issue with too many database connections, you can try enabling this option. +13. `BATCH_UPDATE_INTERVAL=5`: The time interval for batch updating aggregates, measured in seconds, defaults to '5'. + +Example: ` BATCH_UPDATE_INTERVAL=5` +14. Request frequency limit: + + `GLOBAL_API_RATE_LIMIT`: Global API rate limit (excluding relay requests), the maximum number of requests within three minutes per IP, default to 180. + + `GLOBAL_WEL_RATE_LIMIT`: Global web speed limit, the maximum number of requests within three minutes per IP, default to 60. +15. Encoder cache settings: + +`TIKTOKEN_CACHE_DIR`: By default, when the program starts, it will download the encoding of some common word elements online, such as' gpt-3.5 turbo '. In some unstable network environments or offline situations, it may cause startup problems. This directory can be configured to cache data and can be migrated to an offline environment. + +`DATA_GYM_CACHE_DIR`: Currently, this configuration has the same function as' TIKTOKEN-CACHE-DIR ', but its priority is not as high as it. +16. `RELAY_TIMEOUT`: Relay timeout setting, measured in seconds, with no default timeout time set. +17. `RELAY_PROXY`: After setting up, use this proxy to request APIs. +18. `USER_CONTENT_REQUEST_TIMEOUT`: The timeout period for users to upload and download content, measured in seconds. +19. `USER_CONTENT_REQUEST_PROXY`: After setting up, use this agent to request content uploaded by users, such as images. +20. `SQLITE_BUSY_TIMEOUT`: SQLite lock wait timeout setting, measured in milliseconds, default to '3000'. +21. `GEMINI_SAFETY_SETTING`: Gemini's security settings are set to 'BLOCK-NONE' by default. +22. `GEMINI_VERSION`: The Gemini version used by the One API, which defaults to 'v1'. +23. `THE`: The system's theme setting, default to 'default', specific optional values refer to [here] (./web/README. md). +24. `ENABLE_METRIC`: Whether to disable channels based on request success rate, default not enabled, optional values are 'true' and 'false'. +25. `METRIC_QUEUE_SIZE`: Request success rate statistics queue size, default to '10'. +26. `METRIC_SUCCESS_RATE_THRESHOLD`: Request success rate threshold, default to '0.8'. +27. `INITIAL_ROOT_TOKEN`: If this value is set, a root user token with the value of the environment variable will be automatically created when the system starts for the first time. +28. `INITIAL_ROOT_ACCESS_TOKEN`: If this value is set, a system management token will be automatically created for the root user with a value of the environment variable when the system starts for the first time. ### Command Line Parameters 1. `--port `: Specifies the port number on which the server listens. Defaults to `3000`. @@ -287,7 +312,9 @@ If the channel ID is not provided, load balancing will be used to distribute the + Double-check that your interface address and API Key are correct. ## Related Projects -[FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM +* [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM +* [VChart](https://github.com/VisActor/VChart): More than just a cross-platform charting library, but also an expressive data storyteller. +* [VMind](https://github.com/VisActor/VMind): Not just automatic, but also fantastic. Open-source solution for intelligent visualization. ## Note This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. diff --git a/README.md b/README.md index 8f59a14a..987fde7d 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 > [!NOTE] > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 -> +> > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 > [!WARNING] @@ -88,6 +88,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) + [x] [DeepL](https://www.deepl.com/) + [x] [together.ai](https://www.together.ai/) + + [x] [novita.ai](https://www.novita.ai/) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 @@ -144,7 +145,7 @@ Nginx 的参考配置: ``` server{ server_name openai.justsong.cn; # 请根据实际情况修改你的域名 - + location / { client_max_body_size 64m; proxy_http_version 1.1; @@ -189,12 +190,12 @@ docker-compose ps 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: ```shell git clone https://github.com/songquanpeng/one-api.git - + # 构建前端 cd one-api/web/default npm install npm run build - + # 构建后端 cd ../.. go mod download @@ -321,7 +322,7 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashbo 例如对于 OpenAI 的官方库: ```bash OPENAI_API_KEY="sk-xxxxxx" -OPENAI_API_BASE="https://:/v1" +OPENAI_API_BASE="https://:/v1" ``` ```mermaid @@ -369,33 +370,34 @@ graph LR + 例子:`NODE_TYPE=slave` 9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` -10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 -11. 例子:`CHANNEL_TEST_FREQUENCY=1440` -12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 +10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 + +例子:`CHANNEL_TEST_FREQUENCY=1440` +11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + 例子:`POLLING_INTERVAL=5` -13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 +12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + 例子:`BATCH_UPDATE_ENABLED=true` + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 -14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 +13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + 例子:`BATCH_UPDATE_INTERVAL=5` -15. 请求频率限制: +14. 请求频率限制: + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 -16. 编码器缓存设置: +15. 编码器缓存设置: + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 -17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 -18. `RELAY_PROXY`:设置后使用该代理来请求 API。 -19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 -20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 -21. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 -22. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 -23. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 -24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 -25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 -26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 -27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 -28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 +16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 +17. `RELAY_PROXY`:设置后使用该代理来请求 API。 +18. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 +19. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 +20. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 +21. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 +22. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 +23. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 +24. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 +25. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 +26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 +27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 +28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 @@ -448,6 +450,8 @@ https://openai.justsong.cn ## 相关项目 * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 * [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用 +* [VChart](https://github.com/VisActor/VChart): 不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。 +* [VMind](https://github.com/VisActor/VMind): 不仅自动,还很智能。开源智能可视化解决方案。 ## 注意 diff --git a/common/config/config.go b/common/config/config.go index 4f1c25b6..11da0b96 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -143,8 +143,12 @@ var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") +var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN") + var GeminiVersion = env.String("GEMINI_VERSION", "v1") +var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) + var RelayProxy = env.String("RELAY_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index d497e2bd..e655680f 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -21,4 +21,5 @@ const ( TokenName = "token_name" BaseURL = "base_url" AvailableModels = "available_models" + KeyRequestBody = "key_request_body" ) diff --git a/common/gin.go b/common/gin.go index 92e42474..4b68eb06 100644 --- a/common/gin.go +++ b/common/gin.go @@ -6,12 +6,11 @@ import ( "github.com/gin-gonic/gin" "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/ctxkey" ) -const KeyRequestBody = "key_request_body" - func GetRequestBody(c *gin.Context) ([]byte, error) { - requestBody, _ := c.Get(KeyRequestBody) + requestBody, _ := c.Get(ctxkey.KeyRequestBody) if requestBody != nil { return requestBody.([]byte), nil } @@ -20,7 +19,7 @@ func GetRequestBody(c *gin.Context) ([]byte, error) { return nil, err } _ = c.Request.Body.Close() - c.Set(KeyRequestBody, requestBody) + c.Set(ctxkey.KeyRequestBody, requestBody) return requestBody.([]byte), nil } diff --git a/common/logger/logger.go b/common/logger/logger.go index f725c619..d1022932 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -27,7 +27,12 @@ var setupLogOnce sync.Once func SetupLogger() { setupLogOnce.Do(func() { if LogDir != "" { - logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + var logPath string + if config.OnlyOneLogFile { + logPath = filepath.Join(LogDir, "oneapi.log") + } else { + logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + } fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") diff --git a/common/message/email.go b/common/message/email.go index b06782db..187ac8c3 100644 --- a/common/message/email.go +++ b/common/message/email.go @@ -6,11 +6,16 @@ import ( "encoding/base64" "fmt" "github.com/songquanpeng/one-api/common/config" + "net" "net/smtp" "strings" "time" ) +func shouldAuth() bool { + return config.SMTPAccount != "" || config.SMTPToken != "" +} + func SendEmail(subject string, receiver string, content string) error { if receiver == "" { return fmt.Errorf("receiver is empty") @@ -41,16 +46,24 @@ func SendEmail(subject string, receiver string, content string) error { "Date: %s\r\n"+ "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) + auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) to := strings.Split(receiver, ";") - if config.SMTPPort == 465 { - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - ServerName: config.SMTPServer, + if config.SMTPPort == 465 || !shouldAuth() { + // need advanced client + var conn net.Conn + var err error + if config.SMTPPort == 465 { + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + ServerName: config.SMTPServer, + } + conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) + } else { + conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)) } - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) if err != nil { return err } @@ -59,8 +72,10 @@ func SendEmail(subject string, receiver string, content string) error { return err } defer client.Close() - if err = client.Auth(auth); err != nil { - return err + if shouldAuth() { + if err = client.Auth(auth); err != nil { + return err + } } if err = client.Mail(config.SMTPFrom); err != nil { return err diff --git a/common/render/render.go b/common/render/render.go new file mode 100644 index 00000000..646b3777 --- /dev/null +++ b/common/render/render.go @@ -0,0 +1,29 @@ +package render + +import ( + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "strings" +) + +func StringData(c *gin.Context, str string) { + str = strings.TrimPrefix(str, "data: ") + str = strings.TrimSuffix(str, "\r") + c.Render(-1, common.CustomEvent{Data: "data: " + str}) + c.Writer.Flush() +} + +func ObjectData(c *gin.Context, object interface{}) error { + jsonData, err := json.Marshal(object) + if err != nil { + return fmt.Errorf("error marshalling object: %w", err) + } + StringData(c, string(jsonData)) + return nil +} + +func Done(c *gin.Context) { + StringData(c, "[DONE]") +} diff --git a/controller/channel-test.go b/controller/channel-test.go index b8c41819..f8327284 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" @@ -27,15 +28,15 @@ import ( "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - - "github.com/gin-gonic/gin" ) -func buildTestRequest() *relaymodel.GeneralOpenAIRequest { +func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest { + if model == "" { + model = "gpt-3.5-turbo" + } testRequest := &relaymodel.GeneralOpenAIRequest{ MaxTokens: 2, - Stream: false, - Model: "gpt-3.5-turbo", + Model: model, } testMessage := relaymodel.Message{ Role: "user", @@ -45,7 +46,7 @@ func buildTestRequest() *relaymodel.GeneralOpenAIRequest { return testRequest } -func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) { +func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = &http.Request{ @@ -68,12 +69,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } adaptor.Init(meta) - var modelName string - modelList := adaptor.GetModelList() + modelName := request.Model modelMap := channel.GetModelMapping() - if len(modelList) != 0 { - modelName = modelList[0] - } if modelName == "" || !strings.Contains(channel.Models, modelName) { modelNames := strings.Split(channel.Models, ",") if len(modelNames) > 0 { @@ -83,9 +80,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error modelName = modelMap[modelName] } } - request := buildTestRequest() + meta.OriginModelName, meta.ActualModelName = request.Model, modelName request.Model = modelName - meta.OriginModelName, meta.ActualModelName = modelName, modelName convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) if err != nil { return err, nil @@ -139,10 +135,15 @@ func TestChannel(c *gin.Context) { }) return } + model := c.Query("model") + testRequest := buildTestRequest(model) tik := time.Now() - err, _ = testChannel(channel) + err, _ = testChannel(channel, testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() + if err != nil { + milliseconds = 0 + } go channel.UpdateResponseTime(milliseconds) consumedTime := float64(milliseconds) / 1000.0 if err != nil { @@ -150,6 +151,7 @@ func TestChannel(c *gin.Context) { "success": false, "message": err.Error(), "time": consumedTime, + "model": model, }) return } @@ -157,6 +159,7 @@ func TestChannel(c *gin.Context) { "success": true, "message": "", "time": consumedTime, + "model": model, }) return } @@ -187,11 +190,12 @@ func testChannels(notify bool, scope string) error { for _, channel := range channels { isChannelEnabled := channel.Status == model.ChannelStatusEnabled tik := time.Now() - err, openaiErr := testChannel(channel) + testRequest := buildTestRequest("") + err, openaiErr := testChannel(channel, testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() if isChannelEnabled && milliseconds > disableThreshold { - err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) if config.AutomaticDisableChannelEnabled { monitor.DisableChannel(channel.Id, channel.Name, err.Error()) } else { diff --git a/controller/relay.go b/controller/relay.go index 016c7517..9a2a7f3a 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -45,7 +45,7 @@ func Relay(c *gin.Context) { ctx := c.Request.Context() relayMode := relaymode.GetByPath(c.Request.URL.Path) channelId := c.GetInt(ctxkey.ChannelId) - userId := c.GetInt("id") + userId := c.GetInt(ctxkey.Id) bizErr := relayHelper(c, relayMode) if bizErr == nil { monitor.Emit(channelId, true) @@ -55,8 +55,8 @@ func Relay(c *gin.Context) { channelName := c.GetString(ctxkey.ChannelName) group := c.GetString(ctxkey.Group) originalModel := c.GetString(ctxkey.OriginalModel) - - // BUG: bizErr is shared, should not run this function in goroutine to avoid race + + // BUG: bizErr is shared, should not run this function in goroutine to avoid race go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) requestId := c.GetString(helper.RequestIdKey) retryTimes := config.RetryTimes @@ -85,7 +85,7 @@ func Relay(c *gin.Context) { lastFailedChannelId = channelId channelName := c.GetString(ctxkey.ChannelName) - // BUG: bizErr is shared, should not run this function in goroutine to avoid race + // BUG: bizErr is shared, should not run this function in goroutine to avoid race go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) } diff --git a/go.mod b/go.mod index 7a396314..9d9ce35a 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.23.0 - golang.org/x/image v0.16.0 + golang.org/x/image v0.18.0 gorm.io/driver/mysql v1.5.6 gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.5 @@ -68,7 +68,7 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect @@ -80,7 +80,7 @@ require ( golang.org/x/net v0.25.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.20.0 // indirect - golang.org/x/text v0.15.0 // indirect + golang.org/x/text v0.16.0 // indirect google.golang.org/protobuf v1.34.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4c1aac95..11810f25 100644 --- a/go.sum +++ b/go.sum @@ -110,8 +110,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= -github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -154,8 +154,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw= -golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs= +golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= +golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= @@ -164,8 +164,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= diff --git a/main.go b/main.go index 4afbe5dd..67a3cd95 100644 --- a/main.go +++ b/main.go @@ -27,27 +27,19 @@ func main() { common.Init() logger.SetupLogger() logger.SysLogf("One API %s started", common.Version) - if os.Getenv("GIN_MODE") != "debug" { + + if os.Getenv("GIN_MODE") != gin.DebugMode { gin.SetMode(gin.ReleaseMode) } if config.DebugEnabled { logger.SysLog("running in debug mode") } - var err error + // Initialize SQL Database - model.DB, err = model.InitDB("SQL_DSN") - if err != nil { - logger.FatalLog("failed to initialize database: " + err.Error()) - } - if os.Getenv("LOG_SQL_DSN") != "" { - logger.SysLog("using secondary database for table logs") - model.LOG_DB, err = model.InitDB("LOG_SQL_DSN") - if err != nil { - logger.FatalLog("failed to initialize secondary database: " + err.Error()) - } - } else { - model.LOG_DB = model.DB - } + model.InitDB() + model.InitLogDB() + + var err error err = model.CreateRootAccountIfNeed() if err != nil { logger.FatalLog("database init error: " + err.Error()) diff --git a/model/main.go b/model/main.go index 4b5323c4..72e271a0 100644 --- a/model/main.go +++ b/model/main.go @@ -1,6 +1,7 @@ package model import ( + "database/sql" "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" @@ -29,13 +30,17 @@ func CreateRootAccountIfNeed() error { if err != nil { return err } + accessToken := random.GetUUID() + if config.InitialRootAccessToken != "" { + accessToken = config.InitialRootAccessToken + } rootUser := User{ Username: "root", Password: hashedPassword, Role: RoleRootUser, Status: UserStatusEnabled, DisplayName: "Root User", - AccessToken: random.GetUUID(), + AccessToken: accessToken, Quota: 500000000000000, } DB.Create(&rootUser) @@ -60,90 +65,156 @@ func CreateRootAccountIfNeed() error { } func chooseDB(envName string) (*gorm.DB, error) { - if os.Getenv(envName) != "" { - dsn := os.Getenv(envName) - if strings.HasPrefix(dsn, "postgres://") { - // Use PostgreSQL - logger.SysLog("using PostgreSQL as database") - common.UsingPostgreSQL = true - return gorm.Open(postgres.New(postgres.Config{ - DSN: dsn, - PreferSimpleProtocol: true, // disables implicit prepared statement usage - }), &gorm.Config{ - PrepareStmt: true, // precompile SQL - }) - } + dsn := os.Getenv(envName) + + switch { + case strings.HasPrefix(dsn, "postgres://"): + // Use PostgreSQL + return openPostgreSQL(dsn) + case dsn != "": // Use MySQL - logger.SysLog("using MySQL as database") - common.UsingMySQL = true - return gorm.Open(mysql.Open(dsn), &gorm.Config{ - PrepareStmt: true, // precompile SQL - }) + return openMySQL(dsn) + default: + // Use SQLite + return openSQLite() } - // Use SQLite - logger.SysLog("SQL_DSN not set, using SQLite as database") - common.UsingSQLite = true - config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) - return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ +} + +func openPostgreSQL(dsn string) (*gorm.DB, error) { + logger.SysLog("using PostgreSQL as database") + common.UsingPostgreSQL = true + return gorm.Open(postgres.New(postgres.Config{ + DSN: dsn, + PreferSimpleProtocol: true, // disables implicit prepared statement usage + }), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } -func InitDB(envName string) (db *gorm.DB, err error) { - db, err = chooseDB(envName) - if err == nil { - if config.DebugSQLEnabled { - db = db.Debug() - } - sqlDB, err := db.DB() - if err != nil { - return nil, err - } - sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) +func openMySQL(dsn string) (*gorm.DB, error) { + logger.SysLog("using MySQL as database") + common.UsingMySQL = true + return gorm.Open(mysql.Open(dsn), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) +} - if !config.IsMasterNode { - return db, err - } - if common.UsingMySQL { - _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded - } - logger.SysLog("database migration started") - err = db.AutoMigrate(&Channel{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Token{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&User{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Option{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Redemption{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Ability{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Log{}) - if err != nil { - return nil, err - } - logger.SysLog("database migrated") - return db, err - } else { - logger.FatalLog(err) +func openSQLite() (*gorm.DB, error) { + logger.SysLog("SQL_DSN not set, using SQLite as database") + common.UsingSQLite = true + dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout) + return gorm.Open(sqlite.Open(dsn), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) +} + +func InitDB() { + var err error + DB, err = chooseDB("SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize database: " + err.Error()) + return } - return db, err + + sqlDB := setDBConns(DB) + + if !config.IsMasterNode { + return + } + + if common.UsingMySQL { + _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded + } + + logger.SysLog("database migration started") + if err = migrateDB(); err != nil { + logger.FatalLog("failed to migrate database: " + err.Error()) + return + } + logger.SysLog("database migrated") +} + +func migrateDB() error { + var err error + if err = DB.AutoMigrate(&Channel{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Token{}); err != nil { + return err + } + if err = DB.AutoMigrate(&User{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Option{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Redemption{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Ability{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Log{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Channel{}); err != nil { + return err + } + return nil +} + +func InitLogDB() { + if os.Getenv("LOG_SQL_DSN") == "" { + LOG_DB = DB + return + } + + logger.SysLog("using secondary database for table logs") + var err error + LOG_DB, err = chooseDB("LOG_SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize secondary database: " + err.Error()) + return + } + + setDBConns(LOG_DB) + + if !config.IsMasterNode { + return + } + + logger.SysLog("secondary database migration started") + err = migrateLOGDB() + if err != nil { + logger.FatalLog("failed to migrate secondary database: " + err.Error()) + return + } + logger.SysLog("secondary database migrated") +} + +func migrateLOGDB() error { + var err error + if err = LOG_DB.AutoMigrate(&Log{}); err != nil { + return err + } + return nil +} + +func setDBConns(db *gorm.DB) *sql.DB { + if config.DebugSQLEnabled { + db = db.Debug() + } + + sqlDB, err := db.DB() + if err != nil { + logger.FatalLog("failed to connect database: " + err.Error()) + return nil + } + + sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) + return sqlDB } func closeDB(db *gorm.DB) error { diff --git a/relay/adaptor/aiproxy/main.go b/relay/adaptor/aiproxy/main.go index 01a568f6..d64b6809 100644 --- a/relay/adaptor/aiproxy/main.go +++ b/relay/adaptor/aiproxy/main.go @@ -4,6 +4,12 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strconv" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -12,10 +18,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strconv" - "strings" ) // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 @@ -89,6 +91,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var usage model.Usage + var documents []LibraryDocument scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -102,60 +105,48 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) - var documents []LibraryDocument - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var AIProxyLibraryResponse LibraryStreamResponse - err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if len(AIProxyLibraryResponse.Documents) != 0 { - documents = AIProxyLibraryResponse.Documents - } - response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - response := documentsAIProxyLibrary(documents) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || data[:5] != "data:" { + continue } - }) - err := resp.Body.Close() + data = data[5:] + + var AIProxyLibraryResponse LibraryStreamResponse + err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + if len(AIProxyLibraryResponse.Documents) != 0 { + documents = AIProxyLibraryResponse.Documents + } + response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + response := documentsAIProxyLibrary(documents) + err := render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + render.Done(c) + + err = resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + return nil, &usage } diff --git a/relay/adaptor/ali/main.go b/relay/adaptor/ali/main.go index 0462c26b..f9039dbe 100644 --- a/relay/adaptor/ali/main.go +++ b/relay/adaptor/ali/main.go @@ -3,15 +3,17 @@ package ali import ( "bufio" "encoding/json" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" ) // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r @@ -181,56 +183,43 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) - //lastResponseText := "" - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var aliResponse ChatResponse - err := json.Unmarshal([]byte(data), &aliResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if aliResponse.Usage.OutputTokens != 0 { - usage.PromptTokens = aliResponse.Usage.InputTokens - usage.CompletionTokens = aliResponse.Usage.OutputTokens - usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens - } - response := streamResponseAli2OpenAI(&aliResponse) - if response == nil { - return true - } - //response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) - //lastResponseText = aliResponse.Output.Text - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || data[:5] != "data:" { + continue } - }) + data = data[5:] + + var aliResponse ChatResponse + err := json.Unmarshal([]byte(data), &aliResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + if aliResponse.Usage.OutputTokens != 0 { + usage.PromptTokens = aliResponse.Usage.InputTokens + usage.CompletionTokens = aliResponse.Usage.OutputTokens + usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens + } + response := streamResponseAli2OpenAI(&aliResponse) + if response == nil { + continue + } + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/adaptor/anthropic/constants.go b/relay/adaptor/anthropic/constants.go index cadcedc8..143d1efc 100644 --- a/relay/adaptor/anthropic/constants.go +++ b/relay/adaptor/anthropic/constants.go @@ -5,4 +5,5 @@ var ModelList = []string{ "claude-3-haiku-20240307", "claude-3-sonnet-20240229", "claude-3-opus-20240229", + "claude-3-5-sonnet-20240620", } diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index a8de185c..d3e306c8 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -28,12 +29,30 @@ func stopReasonClaude2OpenAI(reason *string) string { return "stop" case "max_tokens": return "length" + case "tool_use": + return "tool_calls" default: return *reason } } func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { + claudeTools := make([]Tool, 0, len(textRequest.Tools)) + + for _, tool := range textRequest.Tools { + if params, ok := tool.Function.Parameters.(map[string]any); ok { + claudeTools = append(claudeTools, Tool{ + Name: tool.Function.Name, + Description: tool.Function.Description, + InputSchema: InputSchema{ + Type: params["type"].(string), + Properties: params["properties"], + Required: params["required"], + }, + }) + } + } + claudeRequest := Request{ Model: textRequest.Model, MaxTokens: textRequest.MaxTokens, @@ -41,6 +60,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { TopP: textRequest.TopP, TopK: textRequest.TopK, Stream: textRequest.Stream, + Tools: claudeTools, + } + if len(claudeTools) > 0 { + claudeToolChoice := struct { + Type string `json:"type"` + Name string `json:"name,omitempty"` + }{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output + if choice, ok := textRequest.ToolChoice.(map[string]any); ok { + if function, ok := choice["function"].(map[string]any); ok { + claudeToolChoice.Type = "tool" + claudeToolChoice.Name = function["name"].(string) + } + } else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok { + if toolChoiceType == "any" { + claudeToolChoice.Type = toolChoiceType + } + } + claudeRequest.ToolChoice = claudeToolChoice } if claudeRequest.MaxTokens == 0 { claudeRequest.MaxTokens = 4096 @@ -63,7 +100,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { if message.IsStringContent() { content.Type = "text" content.Text = message.StringContent() + if message.Role == "tool" { + claudeMessage.Role = "user" + content.Type = "tool_result" + content.Content = content.Text + content.Text = "" + content.ToolUseId = message.ToolCallId + } claudeMessage.Content = append(claudeMessage.Content, content) + for i := range message.ToolCalls { + inputParam := make(map[string]any) + _ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam) + claudeMessage.Content = append(claudeMessage.Content, Content{ + Type: "tool_use", + Id: message.ToolCalls[i].Id, + Name: message.ToolCalls[i].Function.Name, + Input: inputParam, + }) + } claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) continue } @@ -96,16 +150,35 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo var response *Response var responseText string var stopReason string + tools := make([]model.Tool, 0) + switch claudeResponse.Type { case "message_start": return nil, claudeResponse.Message case "content_block_start": if claudeResponse.ContentBlock != nil { responseText = claudeResponse.ContentBlock.Text + if claudeResponse.ContentBlock.Type == "tool_use" { + tools = append(tools, model.Tool{ + Id: claudeResponse.ContentBlock.Id, + Type: "function", + Function: model.Function{ + Name: claudeResponse.ContentBlock.Name, + Arguments: "", + }, + }) + } } case "content_block_delta": if claudeResponse.Delta != nil { responseText = claudeResponse.Delta.Text + if claudeResponse.Delta.Type == "input_json_delta" { + tools = append(tools, model.Tool{ + Function: model.Function{ + Arguments: claudeResponse.Delta.PartialJson, + }, + }) + } } case "message_delta": if claudeResponse.Usage != nil { @@ -119,6 +192,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo } var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = responseText + if len(tools) > 0 { + choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... + choice.Delta.ToolCalls = tools + } choice.Delta.Role = "assistant" finishReason := stopReasonClaude2OpenAI(&stopReason) if finishReason != "null" { @@ -135,12 +212,27 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { if len(claudeResponse.Content) > 0 { responseText = claudeResponse.Content[0].Text } + tools := make([]model.Tool, 0) + for _, v := range claudeResponse.Content { + if v.Type == "tool_use" { + args, _ := json.Marshal(v.Input) + tools = append(tools, model.Tool{ + Id: v.Id, + Type: "function", // compatible with other OpenAI derivative applications + Function: model.Function{ + Name: v.Name, + Arguments: string(args), + }, + }) + } + } choice := openai.TextResponseChoice{ Index: 0, Message: model.Message{ - Role: "assistant", - Content: responseText, - Name: nil, + Role: "assistant", + Content: responseText, + Name: nil, + ToolCalls: tools, }, FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), } @@ -169,64 +261,77 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { - continue - } - if !strings.HasPrefix(data, "data:") { - continue - } - data = strings.TrimPrefix(data, "data:") - dataChan <- data - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) + var usage model.Usage var modelName string var id string - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSpace(data) - var claudeResponse StreamResponse - err := json.Unmarshal([]byte(data), &claudeResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response, meta := StreamResponseClaude2OpenAI(&claudeResponse) - if meta != nil { - usage.PromptTokens += meta.Usage.InputTokens - usage.CompletionTokens += meta.Usage.OutputTokens + var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 || !strings.HasPrefix(data, "data:") { + continue + } + data = strings.TrimPrefix(data, "data:") + data = strings.TrimSpace(data) + + var claudeResponse StreamResponse + err := json.Unmarshal([]byte(data), &claudeResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response, meta := StreamResponseClaude2OpenAI(&claudeResponse) + if meta != nil { + usage.PromptTokens += meta.Usage.InputTokens + usage.CompletionTokens += meta.Usage.OutputTokens + if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. modelName = meta.Model id = fmt.Sprintf("chatcmpl-%s", meta.Id) - return true + continue + } else { // finish_reason case + if len(lastToolCallChoice.Delta.ToolCalls) > 0 { + lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function + if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. + lastArgs.Arguments = "{}" + response.Choices[len(response.Choices)-1].Delta.Content = nil + response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls + } + } } - if response == nil { - return true - } - response.Id = id - response.Model = modelName - response.Created = createdTime - jsonStr, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false } - }) - _ = resp.Body.Close() + if response == nil { + continue + } + + response.Id = id + response.Model = modelName + response.Created = createdTime + + for _, choice := range response.Choices { + if len(choice.Delta.ToolCalls) > 0 { + lastToolCallChoice = choice + } + } + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } return nil, &usage } diff --git a/relay/adaptor/anthropic/model.go b/relay/adaptor/anthropic/model.go index 32b187cd..47f76629 100644 --- a/relay/adaptor/anthropic/model.go +++ b/relay/adaptor/anthropic/model.go @@ -16,6 +16,12 @@ type Content struct { Type string `json:"type"` Text string `json:"text,omitempty"` Source *ImageSource `json:"source,omitempty"` + // tool_calls + Id string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` + Content string `json:"content,omitempty"` + ToolUseId string `json:"tool_use_id,omitempty"` } type Message struct { @@ -23,6 +29,18 @@ type Message struct { Content []Content `json:"content"` } +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema InputSchema `json:"input_schema"` +} + +type InputSchema struct { + Type string `json:"type"` + Properties any `json:"properties,omitempty"` + Required any `json:"required,omitempty"` +} + type Request struct { Model string `json:"model"` Messages []Message `json:"messages"` @@ -33,6 +51,8 @@ type Request struct { Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` //Metadata `json:"metadata,omitempty"` } @@ -61,6 +81,7 @@ type Response struct { type Delta struct { Type string `json:"type"` Text string `json:"text"` + PartialJson string `json:"partial_json,omitempty"` StopReason *string `json:"stop_reason"` StopSequence *string `json:"stop_sequence"` } diff --git a/relay/adaptor/aws/adapter.go b/relay/adaptor/aws/adaptor.go similarity index 71% rename from relay/adaptor/aws/adapter.go rename to relay/adaptor/aws/adaptor.go index 7245d3d9..62221346 100644 --- a/relay/adaptor/aws/adapter.go +++ b/relay/adaptor/aws/adaptor.go @@ -1,17 +1,16 @@ package aws import ( - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" - "github.com/songquanpeng/one-api/common/ctxkey" + "errors" "io" "net/http" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/gin-gonic/gin" - "github.com/pkg/errors" "github.com/songquanpeng/one-api/relay/adaptor" - "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" ) @@ -19,18 +18,52 @@ import ( var _ adaptor.Adaptor = new(Adaptor) type Adaptor struct { - meta *meta.Meta - awsClient *bedrockruntime.Client + awsAdapter utils.AwsAdapter + + Meta *meta.Meta + AwsClient *bedrockruntime.Client } func (a *Adaptor) Init(meta *meta.Meta) { - a.meta = meta - a.awsClient = bedrockruntime.New(bedrockruntime.Options{ + a.Meta = meta + a.AwsClient = bedrockruntime.New(bedrockruntime.Options{ Region: meta.Config.Region, Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), }) } +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + adaptor := GetAdaptor(request.Model) + if adaptor == nil { + return nil, errors.New("adaptor not found") + } + + a.awsAdapter = adaptor + return adaptor.ConvertRequest(c, relayMode, request) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if a.awsAdapter == nil { + return nil, utils.WrapErr(errors.New("awsAdapter is nil")) + } + return a.awsAdapter.DoResponse(c, a.AwsClient, meta) +} + +func (a *Adaptor) GetModelList() (models []string) { + for model := range adaptors { + models = append(models, model) + } + return +} + +func (a *Adaptor) GetChannelName() string { + return "aws" +} + func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return "", nil } @@ -39,17 +72,6 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") - } - - claudeReq := anthropic.ConvertRequest(*request) - c.Set(ctxkey.RequestModel, request.Model) - c.Set(ctxkey.ConvertedRequest, claudeReq) - return claudeReq, nil -} - func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") @@ -60,23 +82,3 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return nil, nil } - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - if meta.IsStream { - err, usage = StreamHandler(c, a.awsClient) - } else { - err, usage = Handler(c, a.awsClient, meta.ActualModelName) - } - return -} - -func (a *Adaptor) GetModelList() (models []string) { - for n := range awsModelIDMap { - models = append(models, n) - } - return -} - -func (a *Adaptor) GetChannelName() string { - return "aws" -} diff --git a/relay/adaptor/aws/claude/adapter.go b/relay/adaptor/aws/claude/adapter.go new file mode 100644 index 00000000..eb3c9fb8 --- /dev/null +++ b/relay/adaptor/aws/claude/adapter.go @@ -0,0 +1,37 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var _ utils.AwsAdapter = new(Adaptor) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + claudeReq := anthropic.ConvertRequest(*request) + c.Set(ctxkey.RequestModel, request.Model) + c.Set(ctxkey.ConvertedRequest, claudeReq) + return claudeReq, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, awsCli) + } else { + err, usage = Handler(c, awsCli, meta.ActualModelName) + } + return +} diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/claude/main.go similarity index 66% rename from relay/adaptor/aws/main.go rename to relay/adaptor/aws/claude/main.go index 0776f985..7142e46f 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/claude/main.go @@ -5,7 +5,6 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/ctxkey" "io" "net/http" @@ -16,33 +15,28 @@ import ( "github.com/jinzhu/copier" "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/adaptor/openai" relaymodel "github.com/songquanpeng/one-api/relay/model" ) -func wrapErr(err error) *relaymodel.ErrorWithStatusCode { - return &relaymodel.ErrorWithStatusCode{ - StatusCode: http.StatusInternalServerError, - Error: relaymodel.Error{ - Message: fmt.Sprintf("%s", err.Error()), - }, - } -} - // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html -var awsModelIDMap = map[string]string{ - "claude-instant-1.2": "anthropic.claude-instant-v1", - "claude-2.0": "anthropic.claude-v2", - "claude-2.1": "anthropic.claude-v2:1", - "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", - "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", - "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", +var AwsModelIDMap = map[string]string{ + "claude-instant-1.2": "anthropic.claude-instant-v1", + "claude-2.0": "anthropic.claude-v2", + "claude-2.1": "anthropic.claude-v2:1", + "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", + "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", } func awsModelID(requestModel string) (string, error) { - if awsModelID, ok := awsModelIDMap[requestModel]; ok { + if awsModelID, ok := AwsModelIDMap[requestModel]; ok { return awsModelID, nil } @@ -52,7 +46,7 @@ func awsModelID(requestModel string) (string, error) { func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { - return wrapErr(errors.Wrap(err, "awsModelID")), nil + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } awsReq := &bedrockruntime.InvokeModelInput{ @@ -63,30 +57,30 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (* claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) if !ok { - return wrapErr(errors.New("request not found")), nil + return utils.WrapErr(errors.New("request not found")), nil } claudeReq := claudeReq_.(*anthropic.Request) awsClaudeReq := &Request{ AnthropicVersion: "bedrock-2023-05-31", } if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { - return wrapErr(errors.Wrap(err, "copy request")), nil + return utils.WrapErr(errors.Wrap(err, "copy request")), nil } awsReq.Body, err = json.Marshal(awsClaudeReq) if err != nil { - return wrapErr(errors.Wrap(err, "marshal request")), nil + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil } awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) if err != nil { - return wrapErr(errors.Wrap(err, "InvokeModel")), nil + return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil } claudeResponse := new(anthropic.Response) err = json.Unmarshal(awsResp.Body, claudeResponse) if err != nil { - return wrapErr(errors.Wrap(err, "unmarshal response")), nil + return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil } openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) @@ -106,7 +100,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E createdTime := helper.GetTimestamp() awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { - return wrapErr(errors.Wrap(err, "awsModelID")), nil + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ @@ -117,7 +111,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) if !ok { - return wrapErr(errors.New("request not found")), nil + return utils.WrapErr(errors.New("request not found")), nil } claudeReq := claudeReq_.(*anthropic.Request) @@ -125,16 +119,16 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E AnthropicVersion: "bedrock-2023-05-31", } if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { - return wrapErr(errors.Wrap(err, "copy request")), nil + return utils.WrapErr(errors.Wrap(err, "copy request")), nil } awsReq.Body, err = json.Marshal(awsClaudeReq) if err != nil { - return wrapErr(errors.Wrap(err, "marshal request")), nil + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil } awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) if err != nil { - return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil + return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil } stream := awsResp.GetStream() defer stream.Close() @@ -142,6 +136,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E c.Writer.Header().Set("Content-Type", "text/event-stream") var usage relaymodel.Usage var id string + var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice + c.Stream(func(w io.Writer) bool { event, ok := <-stream.Events() if !ok { @@ -162,8 +158,19 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E if meta != nil { usage.PromptTokens += meta.Usage.InputTokens usage.CompletionTokens += meta.Usage.OutputTokens - id = fmt.Sprintf("chatcmpl-%s", meta.Id) - return true + if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + return true + } else { // finish_reason case + if len(lastToolCallChoice.Delta.ToolCalls) > 0 { + lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function + if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. + lastArgs.Arguments = "{}" + response.Choices[len(response.Choices)-1].Delta.Content = nil + response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls + } + } + } } if response == nil { return true @@ -171,6 +178,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E response.Id = id response.Model = c.GetString(ctxkey.OriginalModel) response.Created = createdTime + + for _, choice := range response.Choices { + if len(choice.Delta.ToolCalls) > 0 { + lastToolCallChoice = choice + } + } jsonStr, err := json.Marshal(response) if err != nil { logger.SysError("error marshalling stream response: " + err.Error()) diff --git a/relay/adaptor/aws/model.go b/relay/adaptor/aws/claude/model.go similarity index 79% rename from relay/adaptor/aws/model.go rename to relay/adaptor/aws/claude/model.go index bcbfb584..6d00b688 100644 --- a/relay/adaptor/aws/model.go +++ b/relay/adaptor/aws/claude/model.go @@ -9,9 +9,12 @@ type Request struct { // AnthropicVersion should be "bedrock-2023-05-31" AnthropicVersion string `json:"anthropic_version"` Messages []anthropic.Message `json:"messages"` + System string `json:"system,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` + Tools []anthropic.Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` } diff --git a/relay/adaptor/aws/llama3/adapter.go b/relay/adaptor/aws/llama3/adapter.go new file mode 100644 index 00000000..83edbc9d --- /dev/null +++ b/relay/adaptor/aws/llama3/adapter.go @@ -0,0 +1,37 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/songquanpeng/one-api/common/ctxkey" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var _ utils.AwsAdapter = new(Adaptor) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + llamaReq := ConvertRequest(*request) + c.Set(ctxkey.RequestModel, request.Model) + c.Set(ctxkey.ConvertedRequest, llamaReq) + return llamaReq, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, awsCli) + } else { + err, usage = Handler(c, awsCli, meta.ActualModelName) + } + return +} diff --git a/relay/adaptor/aws/llama3/main.go b/relay/adaptor/aws/llama3/main.go new file mode 100644 index 00000000..e5fcd89f --- /dev/null +++ b/relay/adaptor/aws/llama3/main.go @@ -0,0 +1,231 @@ +// Package aws provides the AWS adaptor for the relay service. +package aws + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "text/template" + + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/random" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +// Only support llama-3-8b and llama-3-70b instruction models +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html +var AwsModelIDMap = map[string]string{ + "llama3-8b-8192": "meta.llama3-8b-instruct-v1:0", + "llama3-70b-8192": "meta.llama3-70b-instruct-v1:0", +} + +func awsModelID(requestModel string) (string, error) { + if awsModelID, ok := AwsModelIDMap[requestModel]; ok { + return awsModelID, nil + } + + return "", errors.Errorf("model %s not found", requestModel) +} + +// promptTemplate with range +const promptTemplate = `<|begin_of_text|>{{range .Messages}}<|start_header_id|>{{.Role}}<|end_header_id|>{{.StringContent}}<|eot_id|>{{end}}<|start_header_id|>assistant<|end_header_id|> +` + +var promptTpl = template.Must(template.New("llama3-chat").Parse(promptTemplate)) + +func RenderPrompt(messages []relaymodel.Message) string { + var buf bytes.Buffer + err := promptTpl.Execute(&buf, struct{ Messages []relaymodel.Message }{messages}) + if err != nil { + logger.SysError("error rendering prompt messages: " + err.Error()) + } + return buf.String() +} + +func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request { + llamaRequest := Request{ + MaxGenLen: textRequest.MaxTokens, + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + } + if llamaRequest.MaxGenLen == 0 { + llamaRequest.MaxGenLen = 2048 + } + prompt := RenderPrompt(textRequest.Messages) + llamaRequest.Prompt = prompt + return &llamaRequest +} + +func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + llamaReq, ok := c.Get(ctxkey.ConvertedRequest) + if !ok { + return utils.WrapErr(errors.New("request not found")), nil + } + + awsReq.Body, err = json.Marshal(llamaReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil + } + + var llamaResponse Response + err = json.Unmarshal(awsResp.Body, &llamaResponse) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil + } + + openaiResp := ResponseLlama2OpenAI(&llamaResponse) + openaiResp.Model = modelName + usage := relaymodel.Usage{ + PromptTokens: llamaResponse.PromptTokenCount, + CompletionTokens: llamaResponse.GenerationTokenCount, + TotalTokens: llamaResponse.PromptTokenCount + llamaResponse.GenerationTokenCount, + } + openaiResp.Usage = usage + + c.JSON(http.StatusOK, openaiResp) + return nil, &usage +} + +func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse { + var responseText string + if len(llamaResponse.Generation) > 0 { + responseText = llamaResponse.Generation + } + choice := openai.TextResponseChoice{ + Index: 0, + Message: relaymodel.Message{ + Role: "assistant", + Content: responseText, + Name: nil, + }, + FinishReason: llamaResponse.StopReason, + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + createdTime := helper.GetTimestamp() + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + llamaReq, ok := c.Get(ctxkey.ConvertedRequest) + if !ok { + return utils.WrapErr(errors.New("request not found")), nil + } + + awsReq.Body, err = json.Marshal(llamaReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil + } + stream := awsResp.GetStream() + defer stream.Close() + + c.Writer.Header().Set("Content-Type", "text/event-stream") + var usage relaymodel.Usage + c.Stream(func(w io.Writer) bool { + event, ok := <-stream.Events() + if !ok { + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + + switch v := event.(type) { + case *types.ResponseStreamMemberChunk: + var llamaResp StreamResponse + err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&llamaResp) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return false + } + + if llamaResp.PromptTokenCount > 0 { + usage.PromptTokens = llamaResp.PromptTokenCount + } + if llamaResp.StopReason == "stop" { + usage.CompletionTokens = llamaResp.GenerationTokenCount + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + response := StreamResponseLlama2OpenAI(&llamaResp) + response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID()) + response.Model = c.GetString(ctxkey.OriginalModel) + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case *types.UnknownUnionMember: + fmt.Println("unknown tag:", v.Tag) + return false + default: + fmt.Println("union is nil or unknown type") + return false + } + }) + + return nil, &usage +} + +func StreamResponseLlama2OpenAI(llamaResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = llamaResponse.Generation + choice.Delta.Role = "assistant" + finishReason := llamaResponse.StopReason + if finishReason != "null" { + choice.FinishReason = &finishReason + } + var openaiResponse openai.ChatCompletionsStreamResponse + openaiResponse.Object = "chat.completion.chunk" + openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + return &openaiResponse +} diff --git a/relay/adaptor/aws/llama3/main_test.go b/relay/adaptor/aws/llama3/main_test.go new file mode 100644 index 00000000..d539eee8 --- /dev/null +++ b/relay/adaptor/aws/llama3/main_test.go @@ -0,0 +1,45 @@ +package aws_test + +import ( + "testing" + + aws "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3" + relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/stretchr/testify/assert" +) + +func TestRenderPrompt(t *testing.T) { + messages := []relaymodel.Message{ + { + Role: "user", + Content: "What's your name?", + }, + } + prompt := aws.RenderPrompt(messages) + expected := `<|begin_of_text|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|> +` + assert.Equal(t, expected, prompt) + + messages = []relaymodel.Message{ + { + Role: "system", + Content: "Your name is Kat. You are a detective.", + }, + { + Role: "user", + Content: "What's your name?", + }, + { + Role: "assistant", + Content: "Kat", + }, + { + Role: "user", + Content: "What's your job?", + }, + } + prompt = aws.RenderPrompt(messages) + expected = `<|begin_of_text|><|start_header_id|>system<|end_header_id|>Your name is Kat. You are a detective.<|eot_id|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>Kat<|eot_id|><|start_header_id|>user<|end_header_id|>What's your job?<|eot_id|><|start_header_id|>assistant<|end_header_id|> +` + assert.Equal(t, expected, prompt) +} diff --git a/relay/adaptor/aws/llama3/model.go b/relay/adaptor/aws/llama3/model.go new file mode 100644 index 00000000..7b86c3b8 --- /dev/null +++ b/relay/adaptor/aws/llama3/model.go @@ -0,0 +1,29 @@ +package aws + +// Request is the request to AWS Llama3 +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html +type Request struct { + Prompt string `json:"prompt"` + MaxGenLen int `json:"max_gen_len,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` +} + +// Response is the response from AWS Llama3 +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html +type Response struct { + Generation string `json:"generation"` + PromptTokenCount int `json:"prompt_token_count"` + GenerationTokenCount int `json:"generation_token_count"` + StopReason string `json:"stop_reason"` +} + +// {'generation': 'Hi', 'prompt_token_count': 15, 'generation_token_count': 1, 'stop_reason': None} +type StreamResponse struct { + Generation string `json:"generation"` + PromptTokenCount int `json:"prompt_token_count"` + GenerationTokenCount int `json:"generation_token_count"` + StopReason string `json:"stop_reason"` +} diff --git a/relay/adaptor/aws/registry.go b/relay/adaptor/aws/registry.go new file mode 100644 index 00000000..5f655480 --- /dev/null +++ b/relay/adaptor/aws/registry.go @@ -0,0 +1,39 @@ +package aws + +import ( + claude "github.com/songquanpeng/one-api/relay/adaptor/aws/claude" + llama3 "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" +) + +type AwsModelType int + +const ( + AwsClaude AwsModelType = iota + 1 + AwsLlama3 +) + +var ( + adaptors = map[string]AwsModelType{} +) + +func init() { + for model := range claude.AwsModelIDMap { + adaptors[model] = AwsClaude + } + for model := range llama3.AwsModelIDMap { + adaptors[model] = AwsLlama3 + } +} + +func GetAdaptor(model string) utils.AwsAdapter { + adaptorType := adaptors[model] + switch adaptorType { + case AwsClaude: + return &claude.Adaptor{} + case AwsLlama3: + return &llama3.Adaptor{} + default: + return nil + } +} diff --git a/relay/adaptor/aws/utils/adaptor.go b/relay/adaptor/aws/utils/adaptor.go new file mode 100644 index 00000000..4cb880f2 --- /dev/null +++ b/relay/adaptor/aws/utils/adaptor.go @@ -0,0 +1,51 @@ +package utils + +import ( + "errors" + "io" + "net/http" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +type AwsAdapter interface { + ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) +} + +type Adaptor struct { + Meta *meta.Meta + AwsClient *bedrockruntime.Client +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.Meta = meta + a.AwsClient = bedrockruntime.New(bedrockruntime.Options{ + Region: meta.Config.Region, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), + }) +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return "", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + return nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return nil, nil +} diff --git a/relay/adaptor/aws/utils/utils.go b/relay/adaptor/aws/utils/utils.go new file mode 100644 index 00000000..669dc628 --- /dev/null +++ b/relay/adaptor/aws/utils/utils.go @@ -0,0 +1,16 @@ +package utils + +import ( + "net/http" + + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +func WrapErr(err error) *relaymodel.ErrorWithStatusCode { + return &relaymodel.ErrorWithStatusCode{ + StatusCode: http.StatusInternalServerError, + Error: relaymodel.Error{ + Message: err.Error(), + }, + } +} diff --git a/relay/adaptor/baidu/main.go b/relay/adaptor/baidu/main.go index b816e0f4..ebe70c32 100644 --- a/relay/adaptor/baidu/main.go +++ b/relay/adaptor/baidu/main.go @@ -5,6 +5,13 @@ import ( "encoding/json" "errors" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "sync" + "time" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/client" @@ -12,11 +19,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" - "sync" - "time" ) // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 @@ -137,59 +139,41 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var usage model.Usage scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // ignore blank line or wrong format - continue - } - data = data[6:] - dataChan <- data - } - stopChan <- true - }() + scanner.Split(bufio.ScanLines) + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var baiduResponse ChatStreamResponse - err := json.Unmarshal([]byte(data), &baiduResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if baiduResponse.Usage.TotalTokens != 0 { - usage.TotalTokens = baiduResponse.Usage.TotalTokens - usage.PromptTokens = baiduResponse.Usage.PromptTokens - usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens - } - response := streamResponseBaidu2OpenAI(&baiduResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { + continue } - }) + data = data[6:] + + var baiduResponse ChatStreamResponse + err := json.Unmarshal([]byte(data), &baiduResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + if baiduResponse.Usage.TotalTokens != 0 { + usage.TotalTokens = baiduResponse.Usage.TotalTokens + usage.PromptTokens = baiduResponse.Usage.PromptTokens + usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens + } + response := streamResponseBaidu2OpenAI(&baiduResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/adaptor/cloudflare/adaptor.go b/relay/adaptor/cloudflare/adaptor.go index 6ff6b0d3..97e3dbb2 100644 --- a/relay/adaptor/cloudflare/adaptor.go +++ b/relay/adaptor/cloudflare/adaptor.go @@ -5,11 +5,13 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" ) type Adaptor struct { @@ -27,8 +29,33 @@ func (a *Adaptor) Init(meta *meta.Meta) { a.meta = meta } +// WorkerAI cannot be used across accounts with AIGateWay +// https://developers.cloudflare.com/ai-gateway/providers/workersai/#openai-compatible-endpoints +// https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/workers-ai +func (a *Adaptor) isAIGateWay(baseURL string) bool { + return strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai") +} + func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil + isAIGateWay := a.isAIGateWay(meta.BaseURL) + var urlPrefix string + if isAIGateWay { + urlPrefix = meta.BaseURL + } else { + urlPrefix = fmt.Sprintf("%s/client/v4/accounts/%s/ai", meta.BaseURL, meta.Config.UserID) + } + + switch meta.Mode { + case relaymode.ChatCompletions: + return fmt.Sprintf("%s/v1/chat/completions", urlPrefix), nil + case relaymode.Embeddings: + return fmt.Sprintf("%s/v1/embeddings", urlPrefix), nil + default: + if isAIGateWay { + return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModelName), nil + } + return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModelName), nil + } } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { @@ -41,7 +68,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } - return ConvertRequest(*request), nil + switch relayMode { + case relaymode.Completions: + return ConvertCompletionsRequest(*request), nil + case relaymode.ChatCompletions, relaymode.Embeddings: + return request, nil + default: + return nil, errors.New("not implemented") + } } func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { diff --git a/relay/adaptor/cloudflare/main.go b/relay/adaptor/cloudflare/main.go index f6d496f7..980a2891 100644 --- a/relay/adaptor/cloudflare/main.go +++ b/relay/adaptor/cloudflare/main.go @@ -2,12 +2,14 @@ package cloudflare import ( "bufio" - "bytes" "encoding/json" "io" "net/http" "strings" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/render" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -16,114 +18,66 @@ import ( "github.com/songquanpeng/one-api/relay/model" ) -func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { - var promptBuilder strings.Builder - for _, message := range textRequest.Messages { - promptBuilder.WriteString(message.StringContent()) - promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息 - } - - return &Request{ - MaxTokens: textRequest.MaxTokens, - Prompt: promptBuilder.String(), - Stream: textRequest.Stream, - Temperature: textRequest.Temperature, - } -} - - -func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse { - choice := openai.TextResponseChoice{ - Index: 0, - Message: model.Message{ - Role: "assistant", - Content: cloudflareResponse.Result.Response, - }, - FinishReason: "stop", +func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request { + p, _ := textRequest.Prompt.(string) + return &Request{ + Prompt: p, + MaxTokens: textRequest.MaxTokens, + Stream: textRequest.Stream, + Temperature: textRequest.Temperature, } - fullTextResponse := openai.TextResponse{ - Object: "chat.completion", - Created: helper.GetTimestamp(), - Choices: []openai.TextResponseChoice{choice}, - } - return &fullTextResponse -} - -func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { - var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = cloudflareResponse.Response - choice.Delta.Role = "assistant" - openaiResponse := openai.ChatCompletionsStreamResponse{ - Object: "chat.completion.chunk", - Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, - Created: helper.GetTimestamp(), - } - return &openaiResponse } func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, '\n'); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) + scanner.Split(bufio.ScanLines) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < len("data: ") { - continue - } - data = strings.TrimPrefix(data, "data: ") - dataChan <- data - } - stopChan <- true - }() common.SetEventStreamHeaders(c) id := helper.GetResponseID(c) - responseModel := c.GetString("original_model") + responseModel := c.GetString(ctxkey.OriginalModel) var responseText string - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var cloudflareResponse StreamResponse - err := json.Unmarshal([]byte(data), &cloudflareResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) - if response == nil { - return true - } - responseText += cloudflareResponse.Response - response.Id = id - response.Model = responseModel - jsonStr, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < len("data: ") { + continue } - }) - _ = resp.Body.Close() + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\r") + + if data == "[DONE]" { + break + } + + var response openai.ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data), &response) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + for _, v := range response.Choices { + v.Delta.Role = "assistant" + responseText += v.Delta.StringContent() + } + response.Id = id + response.Model = modelName + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens) return nil, usage } @@ -137,22 +91,25 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - var cloudflareResponse Response - err = json.Unmarshal(responseBody, &cloudflareResponse) + var response openai.TextResponse + err = json.Unmarshal(responseBody, &response) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } - fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse) - fullTextResponse.Model = modelName - usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens) - fullTextResponse.Usage = *usage - fullTextResponse.Id = helper.GetResponseID(c) - jsonResponse, err := json.Marshal(fullTextResponse) + response.Model = modelName + var responseText string + for _, v := range response.Choices { + responseText += v.Message.Content.(string) + } + usage := openai.ResponseText2Usage(responseText, modelName, promptTokens) + response.Usage = *usage + response.Id = helper.GetResponseID(c) + jsonResponse, err := json.Marshal(response) if err != nil { return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) + _, _ = c.Writer.Write(jsonResponse) return nil, usage } diff --git a/relay/adaptor/cloudflare/model.go b/relay/adaptor/cloudflare/model.go index 0664ecd1..0d3bafe0 100644 --- a/relay/adaptor/cloudflare/model.go +++ b/relay/adaptor/cloudflare/model.go @@ -1,25 +1,13 @@ package cloudflare +import "github.com/songquanpeng/one-api/relay/model" + type Request struct { - Lora string `json:"lora,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Prompt string `json:"prompt,omitempty"` - Raw bool `json:"raw,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` -} - -type Result struct { - Response string `json:"response"` -} - -type Response struct { - Result Result `json:"result"` - Success bool `json:"success"` - Errors []string `json:"errors"` - Messages []string `json:"messages"` -} - -type StreamResponse struct { - Response string `json:"response"` + Messages []model.Message `json:"messages,omitempty"` + Lora string `json:"lora,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Prompt string `json:"prompt,omitempty"` + Raw bool `json:"raw,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` } diff --git a/relay/adaptor/cohere/main.go b/relay/adaptor/cohere/main.go index 4bc3fa8d..45db437b 100644 --- a/relay/adaptor/cohere/main.go +++ b/relay/adaptor/cohere/main.go @@ -2,9 +2,9 @@ package cohere import ( "bufio" - "bytes" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -134,66 +134,53 @@ func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { createdTime := helper.GetTimestamp() scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, '\n'); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) + scanner.Split(bufio.ScanLines) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - dataChan <- data - } - stopChan <- true - }() common.SetEventStreamHeaders(c) var usage model.Usage - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var cohereResponse StreamResponse - err := json.Unmarshal([]byte(data), &cohereResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response, meta := StreamResponseCohere2OpenAI(&cohereResponse) - if meta != nil { - usage.PromptTokens += meta.Meta.Tokens.InputTokens - usage.CompletionTokens += meta.Meta.Tokens.OutputTokens - return true - } - if response == nil { - return true - } - response.Id = fmt.Sprintf("chatcmpl-%d", createdTime) - response.Model = c.GetString("original_model") - response.Created = createdTime - jsonStr, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSuffix(data, "\r") + + var cohereResponse StreamResponse + err := json.Unmarshal([]byte(data), &cohereResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue } - }) - _ = resp.Body.Close() + + response, meta := StreamResponseCohere2OpenAI(&cohereResponse) + if meta != nil { + usage.PromptTokens += meta.Meta.Tokens.InputTokens + usage.CompletionTokens += meta.Meta.Tokens.OutputTokens + continue + } + if response == nil { + continue + } + + response.Id = fmt.Sprintf("chatcmpl-%d", createdTime) + response.Model = c.GetString("original_model") + response.Created = createdTime + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage } diff --git a/relay/adaptor/coze/main.go b/relay/adaptor/coze/main.go index 721c5d13..d0402a76 100644 --- a/relay/adaptor/coze/main.go +++ b/relay/adaptor/coze/main.go @@ -4,6 +4,11 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" @@ -12,9 +17,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" ) // https://www.coze.com/open @@ -109,69 +111,54 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC var responseText string createdTime := helper.GetTimestamp() scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { - continue - } - if !strings.HasPrefix(data, "data:") { - continue - } - data = strings.TrimPrefix(data, "data:") - dataChan <- data - } - stopChan <- true - }() + scanner.Split(bufio.ScanLines) + common.SetEventStreamHeaders(c) var modelName string - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var cozeResponse StreamResponse - err := json.Unmarshal([]byte(data), &cozeResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response, _ := StreamResponseCoze2OpenAI(&cozeResponse) - if response == nil { - return true - } - for _, choice := range response.Choices { - responseText += conv.AsString(choice.Delta.Content) - } - response.Model = modelName - response.Created = createdTime - jsonStr, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || !strings.HasPrefix(data, "data:") { + continue } - }) - _ = resp.Body.Close() + data = strings.TrimPrefix(data, "data:") + data = strings.TrimSuffix(data, "\r") + + var cozeResponse StreamResponse + err := json.Unmarshal([]byte(data), &cozeResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response, _ := StreamResponseCoze2OpenAI(&cozeResponse) + if response == nil { + continue + } + + for _, choice := range response.Choices { + responseText += conv.AsString(choice.Delta.Content) + } + response.Model = modelName + response.Created = createdTime + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &responseText } diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 74a7d5d5..51fd6aa8 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -275,64 +276,50 @@ func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.Embeddi func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { responseText := "" scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - data = strings.TrimSpace(data) - if !strings.HasPrefix(data, "data: ") { - continue - } - data = strings.TrimPrefix(data, "data: ") - data = strings.TrimSuffix(data, "\"") - dataChan <- data - } - stopChan <- true - }() + scanner.Split(bufio.ScanLines) + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var geminiResponse ChatResponse - err := json.Unmarshal([]byte(data), &geminiResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response := streamResponseGeminiChat2OpenAI(&geminiResponse) - if response == nil { - return true - } - responseText += response.Choices[0].Delta.StringContent() - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "data: ") { + continue } - }) + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\"") + + var geminiResponse ChatResponse + err := json.Unmarshal([]byte(data), &geminiResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response := streamResponseGeminiChat2OpenAI(&geminiResponse) + if response == nil { + continue + } + + responseText += response.Choices[0].Delta.StringContent() + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } + return nil, responseText } diff --git a/relay/adaptor/novita/constants.go b/relay/adaptor/novita/constants.go new file mode 100644 index 00000000..c6618308 --- /dev/null +++ b/relay/adaptor/novita/constants.go @@ -0,0 +1,19 @@ +package novita + +// https://novita.ai/llm-api + +var ModelList = []string{ + "meta-llama/llama-3-8b-instruct", + "meta-llama/llama-3-70b-instruct", + "nousresearch/hermes-2-pro-llama-3-8b", + "nousresearch/nous-hermes-llama2-13b", + "mistralai/mistral-7b-instruct", + "cognitivecomputations/dolphin-mixtral-8x22b", + "sao10k/l3-70b-euryale-v2.1", + "sophosympatheia/midnight-rose-70b", + "gryphe/mythomax-l2-13b", + "Nous-Hermes-2-Mixtral-8x7B-DPO", + "lzlv_70b", + "teknium/openhermes-2.5-mistral-7b", + "microsoft/wizardlm-2-8x22b", +} diff --git a/relay/adaptor/novita/main.go b/relay/adaptor/novita/main.go new file mode 100644 index 00000000..80efa412 --- /dev/null +++ b/relay/adaptor/novita/main.go @@ -0,0 +1,15 @@ +package novita + +import ( + "fmt" + + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +func GetRequestURL(meta *meta.Meta) (string, error) { + if meta.Mode == relaymode.ChatCompletions { + return fmt.Sprintf("%s/chat/completions", meta.BaseURL), nil + } + return "", fmt.Errorf("unsupported relay mode %d for novita", meta.Mode) +} diff --git a/relay/adaptor/ollama/main.go b/relay/adaptor/ollama/main.go index c5fe08e6..936a7e14 100644 --- a/relay/adaptor/ollama/main.go +++ b/relay/adaptor/ollama/main.go @@ -5,12 +5,14 @@ import ( "context" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/helper" - "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/random" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/image" @@ -105,54 +107,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC return 0, nil, nil } if i := strings.Index(string(data), "}\n"); i >= 0 { - return i + 2, data[0:i], nil + return i + 2, data[0 : i+1], nil } if atEOF { return len(data), data, nil } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := strings.TrimPrefix(scanner.Text(), "}") - dataChan <- data + "}" - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var ollamaResponse ChatResponse - err := json.Unmarshal([]byte(data), &ollamaResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if ollamaResponse.EvalCount != 0 { - usage.PromptTokens = ollamaResponse.PromptEvalCount - usage.CompletionTokens = ollamaResponse.EvalCount - usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount - } - response := streamResponseOllama2OpenAI(&ollamaResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := strings.TrimPrefix(scanner.Text(), "}") + data = data + "}" + + var ollamaResponse ChatResponse + err := json.Unmarshal([]byte(data), &ollamaResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue } - }) + + if ollamaResponse.EvalCount != 0 { + usage.PromptTokens = ollamaResponse.PromptEvalCount + usage.CompletionTokens = ollamaResponse.EvalCount + usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount + } + + response := streamResponseOllama2OpenAI(&ollamaResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + return nil, &usage } diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 9dba2aa1..120d21a6 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -3,17 +3,19 @@ package openai import ( "errors" "fmt" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/doubao" "github.com/songquanpeng/one-api/relay/adaptor/minimax" + "github.com/songquanpeng/one-api/relay/adaptor/novita" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "io" - "net/http" - "strings" ) type Adaptor struct { @@ -48,6 +50,8 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return minimax.GetRequestURL(meta) case channeltype.Doubao: return doubao.GetRequestURL(meta) + case channeltype.Novita: + return novita.GetRequestURL(meta) default: return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil } diff --git a/relay/adaptor/openai/compatible.go b/relay/adaptor/openai/compatible.go index 5d5b4008..3445249c 100644 --- a/relay/adaptor/openai/compatible.go +++ b/relay/adaptor/openai/compatible.go @@ -10,6 +10,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/minimax" "github.com/songquanpeng/one-api/relay/adaptor/mistral" "github.com/songquanpeng/one-api/relay/adaptor/moonshot" + "github.com/songquanpeng/one-api/relay/adaptor/novita" "github.com/songquanpeng/one-api/relay/adaptor/stepfun" "github.com/songquanpeng/one-api/relay/adaptor/togetherai" "github.com/songquanpeng/one-api/relay/channeltype" @@ -28,6 +29,7 @@ var CompatibleChannels = []int{ channeltype.StepFun, channeltype.DeepSeek, channeltype.TogetherAI, + channeltype.Novita, } func GetCompatibleChannelMeta(channelType int) (string, []string) { @@ -56,6 +58,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { return "together.ai", togetherai.ModelList case channeltype.Doubao: return "doubao", doubao.ModelList + case channeltype.Novita: + return "novita", novita.ModelList default: return "openai", ModelList } diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 72c675e1..9ee547b3 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -4,15 +4,18 @@ import ( "bufio" "bytes" "encoding/json" + "io" + "net/http" + "strings" + + "github.com/songquanpeng/one-api/common/render" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "io" - "net/http" - "strings" ) const ( @@ -24,88 +27,72 @@ const ( func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { responseText := "" scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) + scanner.Split(bufio.ScanLines) var usage *model.Usage - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < dataPrefixLength { // ignore blank line or wrong format - continue - } - if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done { - continue - } - if strings.HasPrefix(data[dataPrefixLength:], done) { - dataChan <- data - continue - } - switch relayMode { - case relaymode.ChatCompletions: - var streamResponse ChatCompletionsStreamResponse - err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - dataChan <- data // if error happened, pass the data to client - continue // just ignore the error - } - if len(streamResponse.Choices) == 0 { - // but for empty choice, we should not pass it to client, this is for azure - continue // just ignore empty choice - } - dataChan <- data - for _, choice := range streamResponse.Choices { - responseText += conv.AsString(choice.Delta.Content) - } - if streamResponse.Usage != nil { - usage = streamResponse.Usage - } - case relaymode.Completions: - dataChan <- data - var streamResponse CompletionsStreamResponse - err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - continue - } - for _, choice := range streamResponse.Choices { - responseText += choice.Text - } - } - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if strings.HasPrefix(data, "data: [DONE]") { - data = data[:12] - } - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - c.Render(-1, common.CustomEvent{Data: data}) - return true - case <-stopChan: - return false + + doneRendered := false + for scanner.Scan() { + data := scanner.Text() + if len(data) < dataPrefixLength { // ignore blank line or wrong format + continue } - }) + if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done { + continue + } + if strings.HasPrefix(data[dataPrefixLength:], done) { + render.StringData(c, data) + doneRendered = true + continue + } + switch relayMode { + case relaymode.ChatCompletions: + var streamResponse ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + render.StringData(c, data) // if error happened, pass the data to client + continue // just ignore the error + } + if len(streamResponse.Choices) == 0 { + // but for empty choice, we should not pass it to client, this is for azure + continue // just ignore empty choice + } + render.StringData(c, data) + for _, choice := range streamResponse.Choices { + responseText += conv.AsString(choice.Delta.Content) + } + if streamResponse.Usage != nil { + usage = streamResponse.Usage + } + case relaymode.Completions: + render.StringData(c, data) + var streamResponse CompletionsStreamResponse + err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + for _, choice := range streamResponse.Choices { + responseText += choice.Text + } + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + if !doneRendered { + render.Done(c) + } + err := resp.Body.Close() if err != nil { return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil } + return nil, responseText, usage } @@ -149,7 +136,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - if textResponse.Usage.TotalTokens == 0 { + if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range textResponse.Choices { completionTokens += CountTokenText(choice.Message.StringContent(), modelName) diff --git a/relay/adaptor/palm/palm.go b/relay/adaptor/palm/palm.go index 1e60e7cd..d31784ec 100644 --- a/relay/adaptor/palm/palm.go +++ b/relay/adaptor/palm/palm.go @@ -3,6 +3,10 @@ package palm import ( "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -11,8 +15,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body @@ -77,58 +79,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC responseText := "" responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID()) createdTime := helper.GetTimestamp() - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - logger.SysError("error reading stream response: " + err.Error()) - stopChan <- true - return - } - err = resp.Body.Close() - if err != nil { - logger.SysError("error closing stream response: " + err.Error()) - stopChan <- true - return - } - var palmResponse ChatResponse - err = json.Unmarshal(responseBody, &palmResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - stopChan <- true - return - } - fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) - fullTextResponse.Id = responseId - fullTextResponse.Created = createdTime - if len(palmResponse.Candidates) > 0 { - responseText = palmResponse.Candidates[0].Content - } - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - stopChan <- true - return - } - dataChan <- string(jsonResponse) - stopChan <- true - }() + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - c.Render(-1, common.CustomEvent{Data: "data: " + data}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + logger.SysError("error reading stream response: " + err.Error()) + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } - }) - err := resp.Body.Close() + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), "" + } + + err = resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } + + var palmResponse ChatResponse + err = json.Unmarshal(responseBody, &palmResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), "" + } + + fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) + fullTextResponse.Id = responseId + fullTextResponse.Created = createdTime + if len(palmResponse.Candidates) > 0 { + responseText = palmResponse.Candidates[0].Content + } + + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), "" + } + + err = render.ObjectData(c, string(jsonResponse)) + if err != nil { + logger.SysError(err.Error()) + } + + render.Done(c) + return nil, responseText } diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go index 0a57dcf7..365e33ae 100644 --- a/relay/adaptor/tencent/main.go +++ b/relay/adaptor/tencent/main.go @@ -8,6 +8,13 @@ import ( "encoding/json" "errors" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strconv" + "strings" + "time" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" @@ -17,11 +24,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strconv" - "strings" - "time" ) func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { @@ -87,64 +89,46 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { var responseText string scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() + scanner.Split(bufio.ScanLines) + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var TencentResponse ChatResponse - err := json.Unmarshal([]byte(data), &TencentResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response := streamResponseTencent2OpenAI(&TencentResponse) - if len(response.Choices) != 0 { - responseText += conv.AsString(response.Choices[0].Delta.Content) - } - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || !strings.HasPrefix(data, "data:") { + continue } - }) + data = strings.TrimPrefix(data, "data:") + + var tencentResponse ChatResponse + err := json.Unmarshal([]byte(data), &tencentResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response := streamResponseTencent2OpenAI(&tencentResponse) + if len(response.Choices) != 0 { + responseText += conv.AsString(response.Choices[0].Delta.Content) + } + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } + return nil, responseText } diff --git a/relay/adaptor/xunfei/constants.go b/relay/adaptor/xunfei/constants.go index 31dcec71..12a56210 100644 --- a/relay/adaptor/xunfei/constants.go +++ b/relay/adaptor/xunfei/constants.go @@ -6,4 +6,5 @@ var ModelList = []string{ "SparkDesk-v2.1", "SparkDesk-v3.1", "SparkDesk-v3.5", + "SparkDesk-v4.0", } diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go index 39b76e27..ef6120e5 100644 --- a/relay/adaptor/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -44,7 +44,7 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Payload.Message.Text = messages - if strings.HasPrefix(domain, "generalv3") { + if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" { functions := make([]model.Function, len(request.Tools)) for i, tool := range request.Tools { functions[i] = tool.Function @@ -290,6 +290,8 @@ func apiVersion2domain(apiVersion string) string { return "generalv3" case "v3.5": return "generalv3.5" + case "v4.0": + return "4.0Ultra" } return "general" + apiVersion } diff --git a/relay/adaptor/zhipu/main.go b/relay/adaptor/zhipu/main.go index 74a1a05e..ab3a5678 100644 --- a/relay/adaptor/zhipu/main.go +++ b/relay/adaptor/zhipu/main.go @@ -3,6 +3,13 @@ package zhipu import ( "bufio" "encoding/json" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "sync" + "time" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" "github.com/songquanpeng/one-api/common" @@ -11,11 +18,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" - "sync" - "time" ) // https://open.bigmodel.cn/doc/api#chatglm_std @@ -155,66 +157,55 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) - dataChan := make(chan string) - metaChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - lines := strings.Split(data, "\n") - for i, line := range lines { - if len(line) < 5 { + + common.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + lines := strings.Split(data, "\n") + for i, line := range lines { + if len(line) < 5 { + continue + } + if strings.HasPrefix(line, "data:") { + dataSegment := line[5:] + if i != len(lines)-1 { + dataSegment += "\n" + } + response := streamResponseZhipu2OpenAI(dataSegment) + err := render.ObjectData(c, response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + } + } else if strings.HasPrefix(line, "meta:") { + metaSegment := line[5:] + var zhipuResponse StreamMetaResponse + err := json.Unmarshal([]byte(metaSegment), &zhipuResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) continue } - if line[:5] == "data:" { - dataChan <- line[5:] - if i != len(lines)-1 { - dataChan <- "\n" - } - } else if line[:5] == "meta:" { - metaChan <- line[5:] + response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) } + usage = zhipuUsage } } - stopChan <- true - }() - common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - response := streamResponseZhipu2OpenAI(data) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case data := <-metaChan: - var zhipuResponse StreamMetaResponse - err := json.Unmarshal([]byte(data), &zhipuResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - usage = zhipuUsage - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + return nil, usage } diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 3b289499..8a7d5743 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -2,6 +2,7 @@ package ratio import ( "encoding/json" + "fmt" "strings" "github.com/songquanpeng/one-api/common/logger" @@ -70,12 +71,13 @@ var ModelRatio = map[string]float64{ "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image // https://www.anthropic.com/api#pricing - "claude-instant-1.2": 0.8 / 1000 * USD, - "claude-2.0": 8.0 / 1000 * USD, - "claude-2.1": 8.0 / 1000 * USD, - "claude-3-haiku-20240307": 0.25 / 1000 * USD, - "claude-3-sonnet-20240229": 3.0 / 1000 * USD, - "claude-3-opus-20240229": 15.0 / 1000 * USD, + "claude-instant-1.2": 0.8 / 1000 * USD, + "claude-2.0": 8.0 / 1000 * USD, + "claude-2.1": 8.0 / 1000 * USD, + "claude-3-haiku-20240307": 0.25 / 1000 * USD, + "claude-3-sonnet-20240229": 3.0 / 1000 * USD, + "claude-3-5-sonnet-20240620": 3.0 / 1000 * USD, + "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 "ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-3.5-8K": 0.012 * RMB, @@ -124,6 +126,7 @@ var ModelRatio = map[string]float64{ "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens @@ -167,6 +170,9 @@ var ModelRatio = map[string]float64{ "step-1v-32k": 0.024 * RMB, "step-1-32k": 0.024 * RMB, "step-1-200k": 0.15 * RMB, + // aws llama3 https://aws.amazon.com/cn/bedrock/pricing/ + "llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens + "llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens // https://cohere.com/pricing "command": 0.5, "command-nightly": 0.5, @@ -183,7 +189,11 @@ var ModelRatio = map[string]float64{ "deepl-ja": 25.0 / 1000 * USD, } -var CompletionRatio = map[string]float64{} +var CompletionRatio = map[string]float64{ + // aws llama3 + "llama3-8b-8192(33)": 0.0006 / 0.0003, + "llama3-70b-8192(33)": 0.0035 / 0.00265, +} var DefaultModelRatio map[string]float64 var DefaultCompletionRatio map[string]float64 @@ -232,22 +242,28 @@ func UpdateModelRatioByJSONString(jsonStr string) error { return json.Unmarshal([]byte(jsonStr), &ModelRatio) } -func GetModelRatio(name string) float64 { +func GetModelRatio(name string, channelType int) float64 { if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } - ratio, ok := ModelRatio[name] - if !ok { - ratio, ok = DefaultModelRatio[name] + model := fmt.Sprintf("%s(%d)", name, channelType) + if ratio, ok := ModelRatio[model]; ok { + return ratio } - if !ok { - logger.SysError("model ratio not found: " + name) - return 30 + if ratio, ok := DefaultModelRatio[model]; ok { + return ratio } - return ratio + if ratio, ok := ModelRatio[name]; ok { + return ratio + } + if ratio, ok := DefaultModelRatio[name]; ok { + return ratio + } + logger.SysError("model ratio not found: " + name) + return 30 } func CompletionRatio2JSONString() string { @@ -263,7 +279,17 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { return json.Unmarshal([]byte(jsonStr), &CompletionRatio) } -func GetCompletionRatio(name string) float64 { +func GetCompletionRatio(name string, channelType int) float64 { + if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { + name = strings.TrimSuffix(name, "-internet") + } + model := fmt.Sprintf("%s(%d)", name, channelType) + if ratio, ok := CompletionRatio[model]; ok { + return ratio + } + if ratio, ok := DefaultCompletionRatio[model]; ok { + return ratio + } if ratio, ok := CompletionRatio[name]; ok { return ratio } diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go index d8885ae9..d3891c16 100644 --- a/relay/channeltype/define.go +++ b/relay/channeltype/define.go @@ -42,5 +42,6 @@ const ( DeepL TogetherAI Doubao + Novita Dummy ) diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go index 513d183b..5177333b 100644 --- a/relay/channeltype/url.go +++ b/relay/channeltype/url.go @@ -42,6 +42,7 @@ var ChannelBaseURLs = []string{ "https://api-free.deepl.com", // 38 "https://api.together.xyz", // 39 "https://ark.cn-beijing.volces.com", // 40 + "https://api.novita.ai/v3/openai", // 41 } func init() { diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 8f9708d0..83040662 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -7,6 +7,10 @@ import ( "encoding/json" "errors" "fmt" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/client" @@ -21,9 +25,6 @@ import ( "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "io" - "net/http" - "strings" ) func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { @@ -53,7 +54,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } } - modelRatio := billingratio.GetModelRatio(audioModel) + modelRatio := billingratio.GetModelRatio(audioModel, channelType) groupRatio := billingratio.GetGroupRatio(group) ratio := modelRatio * groupRatio var quota int64 diff --git a/relay/controller/helper.go b/relay/controller/helper.go index dccff486..87d22f13 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -4,6 +4,10 @@ import ( "context" "errors" "fmt" + "math" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" @@ -16,9 +20,6 @@ import ( "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "math" - "net/http" - "strings" ) func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) { @@ -40,78 +41,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener return textRequest, nil } -func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { - imageRequest := &relaymodel.ImageRequest{} - err := common.UnmarshalBodyReusable(c, imageRequest) - if err != nil { - return nil, err - } - if imageRequest.N == 0 { - imageRequest.N = 1 - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-2" - } - return imageRequest, nil -} - -func isValidImageSize(model string, size string) bool { - if model == "cogview-3" { - return true - } - _, ok := billingratio.ImageSizeRatios[model][size] - return ok -} - -func getImageSizeRatio(model string, size string) float64 { - ratio, ok := billingratio.ImageSizeRatios[model][size] - if !ok { - return 1 - } - return ratio -} - -func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { - // model validation - hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size) - if !hasValidSize { - return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) - } - // check prompt length - if imageRequest.Prompt == "" { - return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) - } - if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] { - return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) - } - // Number of generated images validation - if !isWithinRange(imageRequest.Model, imageRequest.N) { - // channel not azure - if meta.ChannelType != channeltype.Azure { - return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) - } - } - return nil -} - -func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { - if imageRequest == nil { - return 0, errors.New("imageRequest is nil") - } - imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) - if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { - if imageRequest.Size == "1024x1024" { - imageCostRatio *= 2 - } else { - imageCostRatio *= 1.5 - } - } - return imageCostRatio, nil -} - func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case relaymode.ChatCompletions: @@ -167,7 +96,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M return } var quota int64 - completionRatio := billingratio.GetCompletionRatio(textRequest.Model) + completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) diff --git a/relay/controller/image.go b/relay/controller/image.go index bfcadb84..7f0536ab 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" @@ -22,13 +23,84 @@ import ( relaymodel "github.com/songquanpeng/one-api/relay/model" ) -func isWithinRange(element string, value int) bool { - if _, ok := billingratio.ImageGenerationAmounts[element]; !ok { - return false +func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { + imageRequest := &relaymodel.ImageRequest{} + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err } - min := billingratio.ImageGenerationAmounts[element][0] - max := billingratio.ImageGenerationAmounts[element][1] - return value >= min && value <= max + if imageRequest.N == 0 { + imageRequest.N = 1 + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" + } + return imageRequest, nil +} + +func isValidImageSize(model string, size string) bool { + if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil { + return true + } + _, ok := billingratio.ImageSizeRatios[model][size] + return ok +} + +func isValidImagePromptLength(model string, promptLength int) bool { + maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model] + return !ok || promptLength <= maxPromptLength +} + +func isWithinRange(element string, value int) bool { + amounts, ok := billingratio.ImageGenerationAmounts[element] + return !ok || (value >= amounts[0] && value <= amounts[1]) +} + +func getImageSizeRatio(model string, size string) float64 { + if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok { + return ratio + } + return 1 +} + +func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { + // check prompt length + if imageRequest.Prompt == "" { + return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) + } + + // model validation + if !isValidImageSize(imageRequest.Model, imageRequest.Size) { + return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + } + + if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) { + return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) + } + + // Number of generated images validation + if !isWithinRange(imageRequest.Model, imageRequest.N) { + return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + } + return nil +} + +func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { + if imageRequest == nil { + return 0, errors.New("imageRequest is nil") + } + imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) + if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { + if imageRequest.Size == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + return imageCostRatio, nil } func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { @@ -97,7 +169,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus requestBody = bytes.NewBuffer(jsonStr) } - modelRatio := billingratio.GetModelRatio(imageModel) + modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType) groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) diff --git a/relay/controller/text.go b/relay/controller/text.go index 6ed19b1d..0d3c56b0 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -4,6 +4,9 @@ import ( "bytes" "encoding/json" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay" @@ -14,8 +17,6 @@ import ( "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { @@ -35,7 +36,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model // get model ratio & group ratio - modelRatio := billingratio.GetModelRatio(textRequest.Model) + modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio // pre-consume quota diff --git a/relay/model/message.go b/relay/model/message.go index 32a1055b..b908f989 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -1,10 +1,11 @@ package model type Message struct { - Role string `json:"role,omitempty"` - Content any `json:"content,omitempty"` - Name *string `json:"name,omitempty"` - ToolCalls []Tool `json:"tool_calls,omitempty"` + Role string `json:"role,omitempty"` + Content any `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []Tool `json:"tool_calls,omitempty"` + ToolCallId string `json:"tool_call_id,omitempty"` } func (m Message) IsStringContent() bool { diff --git a/relay/model/tool.go b/relay/model/tool.go index 253dca35..75dbb8f7 100644 --- a/relay/model/tool.go +++ b/relay/model/tool.go @@ -2,13 +2,13 @@ package model type Tool struct { Id string `json:"id,omitempty"` - Type string `json:"type"` + Type string `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty Function Function `json:"function"` } type Function struct { Description string `json:"description,omitempty"` - Name string `json:"name"` + Name string `json:"name,omitempty"` // when splicing claude tools stream messages, it is empty Parameters any `json:"parameters,omitempty"` // request Arguments any `json:"arguments,omitempty"` // response } diff --git a/web/air/src/components/PersonalSetting.js b/web/air/src/components/PersonalSetting.js index 45a5b776..ef4acf14 100644 --- a/web/air/src/components/PersonalSetting.js +++ b/web/air/src/components/PersonalSetting.js @@ -47,7 +47,7 @@ const PersonalSetting = () => { const [countdown, setCountdown] = useState(30); const [affLink, setAffLink] = useState(''); const [systemToken, setSystemToken] = useState(''); - // const [models, setModels] = useState([]); + const [models, setModels] = useState([]); const [openTransfer, setOpenTransfer] = useState(false); const [transferAmount, setTransferAmount] = useState(0); @@ -72,7 +72,7 @@ const PersonalSetting = () => { console.log(userState); } ); - // loadModels().then(); + loadModels().then(); getAffLink().then(); setTransferAmount(getQuotaPerUnit()); }, []); @@ -127,16 +127,16 @@ const PersonalSetting = () => { } }; - // const loadModels = async () => { - // let res = await API.get(`/api/user/models`); - // const { success, message, data } = res.data; - // if (success) { - // setModels(data); - // console.log(data); - // } else { - // showError(message); - // } - // }; + const loadModels = async () => { + let res = await API.get(`/api/user/available_models`); + const { success, message, data } = res.data; + if (success) { + setModels(data); + console.log(data); + } else { + showError(message); + } + }; const handleAffLinkClick = async (e) => { e.target.select(); @@ -344,7 +344,7 @@ const PersonalSetting = () => { } > 调用信息 - {/* 可用模型 +

可用模型(可点击复制)

{models.map((model) => ( @@ -355,7 +355,7 @@ const PersonalSetting = () => { ))} -
*/} + {/* { let localModels = []; switch (value) { case 14: - localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"]; + localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"]; break; case 11: localModels = ['PaLM-2']; @@ -78,7 +78,7 @@ const EditChannel = (props) => { localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; break; case 18: - localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5']; + localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']; break; case 19: localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index 52d64083..881f66bd 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -13,7 +13,7 @@ export const CHANNEL_OPTIONS = { }, 33: { key: 33, - text: 'AWS Claude', + text: 'AWS', value: 33, color: 'primary' }, @@ -161,6 +161,12 @@ export const CHANNEL_OPTIONS = { value: 39, color: 'primary' }, + 41: { + key: 41, + text: 'Novita', + value: 41, + color: 'purple' + }, 8: { key: 8, text: '自定义渠道', diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index 88e1ea92..51b7c6c4 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -91,7 +91,7 @@ const typeConfig = { other: '版本号' }, input: { - models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5'] + models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'] }, prompt: { key: '按照如下格式输入:APPID|APISecret|APIKey', diff --git a/web/default/src/components/ChannelsTable.js b/web/default/src/components/ChannelsTable.js index 1258ca5a..6025b7d9 100644 --- a/web/default/src/components/ChannelsTable.js +++ b/web/default/src/components/ChannelsTable.js @@ -1,5 +1,5 @@ import React, { useEffect, useState } from 'react'; -import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; +import { Button, Dropdown, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; import { Link } from 'react-router-dom'; import { API, @@ -70,13 +70,33 @@ const ChannelsTable = () => { const res = await API.get(`/api/channel/?p=${startIdx}`); const { success, message, data } = res.data; if (success) { - if (startIdx === 0) { - setChannels(data); - } else { - let newChannels = [...channels]; - newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); - setChannels(newChannels); - } + let localChannels = data.map((channel) => { + if (channel.models === '') { + channel.models = []; + channel.test_model = ""; + } else { + channel.models = channel.models.split(','); + if (channel.models.length > 0) { + channel.test_model = channel.models[0]; + } + channel.model_options = channel.models.map((model) => { + return { + key: model, + text: model, + value: model, + } + }) + console.log('channel', channel) + } + return channel; + }); + if (startIdx === 0) { + setChannels(localChannels); + } else { + let newChannels = [...channels]; + newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...localChannels); + setChannels(newChannels); + } } else { showError(message); } @@ -225,19 +245,31 @@ const ChannelsTable = () => { setSearching(false); }; - const testChannel = async (id, name, idx) => { - const res = await API.get(`/api/channel/test/${id}/`); - const { success, message, time } = res.data; + const switchTestModel = async (idx, model) => { + let newChannels = [...channels]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + newChannels[realIdx].test_model = model; + setChannels(newChannels); + }; + + const testChannel = async (id, name, idx, m) => { + const res = await API.get(`/api/channel/test/${id}?model=${m}`); + const { success, message, time, model } = res.data; if (success) { let newChannels = [...channels]; let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; newChannels[realIdx].response_time = time * 1000; newChannels[realIdx].test_time = Date.now() / 1000; setChannels(newChannels); - showInfo(`渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); + showInfo(`渠道 ${name} 测试成功,模型 ${model},耗时 ${time.toFixed(2)} 秒。`); } else { showError(message); } + let newChannels = [...channels]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + newChannels[realIdx].response_time = time * 1000; + newChannels[realIdx].test_time = Date.now() / 1000; + setChannels(newChannels); }; const testChannels = async (scope) => { @@ -405,6 +437,7 @@ const ChannelsTable = () => { > 优先级 + 测试模型 操作 @@ -459,13 +492,24 @@ const ChannelsTable = () => { basic /> + + { + switchTestModel(idx, data.value); + }} + /> +