From 5a58426859e6c128392079e80544c340316db307 Mon Sep 17 00:00:00 2001 From: Ghostz <137054651+ye4293@users.noreply.github.com> Date: Sun, 30 Jun 2024 16:09:16 +0800 Subject: [PATCH 01/26] fix minimax empty log (#1560) --- relay/adaptor/openai/main.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 72c675e1..07cb967f 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -4,15 +4,16 @@ import ( "bufio" "bytes" "encoding/json" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "io" - "net/http" - "strings" ) const ( @@ -149,7 +150,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - if textResponse.Usage.TotalTokens == 0 { + if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range textResponse.Choices { completionTokens += CountTokenText(choice.Message.StringContent(), modelName) From 8cc1ee63605d36cc20d096e2be786fc533870833 Mon Sep 17 00:00:00 2001 From: Leo Q Date: Sun, 30 Jun 2024 16:12:16 +0800 Subject: [PATCH 02/26] ci: use codecov to upload coverage report (#1583) --- .github/workflows/ci.yml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 89ba75cd..698acdf1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,17 +45,15 @@ jobs: code_coverage: name: "Code coverage report" - if: github.event_name == 'pull_request' # Do not run when workflow is triggered by push to main branch runs-on: ubuntu-latest needs: unit_tests # Depends on the artifact uploaded by the "unit_tests" job steps: - - uses: fgrosse/go-coverage-report@v1.0.2 # Consider using a Git revision for maximum security - with: - coverage-artifact-name: "code-coverage" # can be omitted if you used this default value - coverage-file-name: "coverage.txt" # can be omitted if you used this default value + - uses: codecov/codecov-action@v4 + with: + use_oidc: true commit_lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - uses: wagoid/commitlint-github-action@v6 \ No newline at end of file + - uses: wagoid/commitlint-github-action@v6 From 34cb147a744e717404ebccd566cdf1b753ef78a1 Mon Sep 17 00:00:00 2001 From: igophper <34326532+igophper@users.noreply.github.com> Date: Sun, 30 Jun 2024 16:13:43 +0800 Subject: [PATCH 03/26] refactor: replace hardcoded string with ctxkey constant (#1579) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 江杭辉 --- common/ctxkey/key.go | 1 + common/gin.go | 7 +++---- controller/relay.go | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 6c640870..90556b3a 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -19,4 +19,5 @@ const ( TokenName = "token_name" BaseURL = "base_url" AvailableModels = "available_models" + KeyRequestBody = "key_request_body" ) diff --git a/common/gin.go b/common/gin.go index b6ef96a6..549d3279 100644 --- a/common/gin.go +++ b/common/gin.go @@ -4,14 +4,13 @@ import ( "bytes" "encoding/json" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/ctxkey" "io" "strings" ) -const KeyRequestBody = "key_request_body" - func GetRequestBody(c *gin.Context) ([]byte, error) { - requestBody, _ := c.Get(KeyRequestBody) + requestBody, _ := c.Get(ctxkey.KeyRequestBody) if requestBody != nil { return requestBody.([]byte), nil } @@ -20,7 +19,7 @@ func GetRequestBody(c *gin.Context) ([]byte, error) { return nil, err } _ = c.Request.Body.Close() - c.Set(KeyRequestBody, requestBody) + c.Set(ctxkey.KeyRequestBody, requestBody) return requestBody.([]byte), nil } diff --git a/controller/relay.go b/controller/relay.go index 5d8ac690..932e023b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -48,7 +48,7 @@ func Relay(c *gin.Context) { logger.Debugf(ctx, "request body: %s", string(requestBody)) } channelId := c.GetInt(ctxkey.ChannelId) - userId := c.GetInt("id") + userId := c.GetInt(ctxkey.Id) bizErr := relayHelper(c, relayMode) if bizErr == nil { monitor.Emit(channelId, true) From b70a07e814c5907e044f45dac32cb02ab1e51efc Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 30 Jun 2024 16:19:49 +0800 Subject: [PATCH 04/26] fix: fix ci --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 698acdf1..30ac5f82 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -50,7 +50,7 @@ jobs: steps: - uses: codecov/codecov-action@v4 with: - use_oidc: true + token: ${{ secrets.CODECOV_TOKEN }} commit_lint: runs-on: ubuntu-latest From f25aaf7752a6f1719f445bb3d2d62863774e626b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 30 Jun 2024 16:21:48 +0800 Subject: [PATCH 05/26] chore(deps): bump golang.org/x/image from 0.16.0 to 0.18.0 (#1568) Bumps [golang.org/x/image](https://github.com/golang/image) from 0.16.0 to 0.18.0. - [Commits](https://github.com/golang/image/compare/v0.16.0...v0.18.0) --- updated-dependencies: - dependency-name: golang.org/x/image dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 7a396314..2d0df03f 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.23.0 - golang.org/x/image v0.16.0 + golang.org/x/image v0.18.0 gorm.io/driver/mysql v1.5.6 gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.5 @@ -80,7 +80,7 @@ require ( golang.org/x/net v0.25.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.20.0 // indirect - golang.org/x/text v0.15.0 // indirect + golang.org/x/text v0.16.0 // indirect google.golang.org/protobuf v1.34.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4c1aac95..ab04845c 100644 --- a/go.sum +++ b/go.sum @@ -154,8 +154,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw= -golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs= +golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= +golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= @@ -164,8 +164,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= From ae1cd29f943b31d4c12dffecd166b621b1ac2400 Mon Sep 17 00:00:00 2001 From: shaoyun Date: Sun, 30 Jun 2024 16:25:25 +0800 Subject: [PATCH 06/26] feat: added support for Claude Sonnet 3.5 (#1567) --- relay/adaptor/anthropic/constants.go | 1 + relay/adaptor/aws/main.go | 13 +++++++------ relay/billing/ratio/model.go | 13 +++++++------ web/air/src/pages/Channel/EditChannel.js | 2 +- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/relay/adaptor/anthropic/constants.go b/relay/adaptor/anthropic/constants.go index cadcedc8..143d1efc 100644 --- a/relay/adaptor/anthropic/constants.go +++ b/relay/adaptor/anthropic/constants.go @@ -5,4 +5,5 @@ var ModelList = []string{ "claude-3-haiku-20240307", "claude-3-sonnet-20240229", "claude-3-opus-20240229", + "claude-3-5-sonnet-20240620", } diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go index 0776f985..5d29597c 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/main.go @@ -33,12 +33,13 @@ func wrapErr(err error) *relaymodel.ErrorWithStatusCode { // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html var awsModelIDMap = map[string]string{ - "claude-instant-1.2": "anthropic.claude-instant-v1", - "claude-2.0": "anthropic.claude-v2", - "claude-2.1": "anthropic.claude-v2:1", - "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", - "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", - "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", + "claude-instant-1.2": "anthropic.claude-instant-v1", + "claude-2.0": "anthropic.claude-v2", + "claude-2.1": "anthropic.claude-v2:1", + "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", + "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", } func awsModelID(requestModel string) (string, error) { diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 3b289499..b1a8a5b4 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -70,12 +70,13 @@ var ModelRatio = map[string]float64{ "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image // https://www.anthropic.com/api#pricing - "claude-instant-1.2": 0.8 / 1000 * USD, - "claude-2.0": 8.0 / 1000 * USD, - "claude-2.1": 8.0 / 1000 * USD, - "claude-3-haiku-20240307": 0.25 / 1000 * USD, - "claude-3-sonnet-20240229": 3.0 / 1000 * USD, - "claude-3-opus-20240229": 15.0 / 1000 * USD, + "claude-instant-1.2": 0.8 / 1000 * USD, + "claude-2.0": 8.0 / 1000 * USD, + "claude-2.1": 8.0 / 1000 * USD, + "claude-3-haiku-20240307": 0.25 / 1000 * USD, + "claude-3-sonnet-20240229": 3.0 / 1000 * USD, + "claude-3-5-sonnet-20240620": 3.0 / 1000 * USD, + "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 "ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-3.5-8K": 0.012 * RMB, diff --git a/web/air/src/pages/Channel/EditChannel.js b/web/air/src/pages/Channel/EditChannel.js index efb2cee8..d63fa8fa 100644 --- a/web/air/src/pages/Channel/EditChannel.js +++ b/web/air/src/pages/Channel/EditChannel.js @@ -63,7 +63,7 @@ const EditChannel = (props) => { let localModels = []; switch (value) { case 14: - localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"]; + localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"]; break; case 11: localModels = ['PaLM-2']; From b21b3b5b460502a40217b3973cf6ee5f44c916f9 Mon Sep 17 00:00:00 2001 From: zijiren <84728412+zijiren233@users.noreply.github.com> Date: Sun, 30 Jun 2024 18:36:33 +0800 Subject: [PATCH 07/26] refactor: abusing goroutines and channel (#1561) * refactor: abusing goroutines * fix: trim data prefix * refactor: move functions to render package * refactor: add back trim & flush --------- Co-authored-by: JustSong --- common/render/render.go | 29 +++++++ relay/adaptor/aiproxy/main.go | 97 +++++++++++------------ relay/adaptor/ali/main.go | 91 ++++++++++----------- relay/adaptor/anthropic/main.go | 102 ++++++++++++------------ relay/adaptor/baidu/main.go | 96 ++++++++++------------ relay/adaptor/cloudflare/main.go | 122 +++++++++++++--------------- relay/adaptor/cohere/main.go | 101 +++++++++++------------- relay/adaptor/coze/main.go | 113 ++++++++++++-------------- relay/adaptor/gemini/main.go | 91 +++++++++------------ relay/adaptor/ollama/main.go | 75 +++++++++--------- relay/adaptor/openai/main.go | 131 +++++++++++++------------------ relay/adaptor/palm/palm.go | 93 +++++++++++----------- relay/adaptor/tencent/main.go | 98 ++++++++++------------- relay/adaptor/zhipu/main.go | 103 +++++++++++------------- 14 files changed, 614 insertions(+), 728 deletions(-) create mode 100644 common/render/render.go diff --git a/common/render/render.go b/common/render/render.go new file mode 100644 index 00000000..646b3777 --- /dev/null +++ b/common/render/render.go @@ -0,0 +1,29 @@ +package render + +import ( + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "strings" +) + +func StringData(c *gin.Context, str string) { + str = strings.TrimPrefix(str, "data: ") + str = strings.TrimSuffix(str, "\r") + c.Render(-1, common.CustomEvent{Data: "data: " + str}) + c.Writer.Flush() +} + +func ObjectData(c *gin.Context, object interface{}) error { + jsonData, err := json.Marshal(object) + if err != nil { + return fmt.Errorf("error marshalling object: %w", err) + } + StringData(c, string(jsonData)) + return nil +} + +func Done(c *gin.Context) { + StringData(c, "[DONE]") +} diff --git a/relay/adaptor/aiproxy/main.go b/relay/adaptor/aiproxy/main.go index 01a568f6..d64b6809 100644 --- a/relay/adaptor/aiproxy/main.go +++ b/relay/adaptor/aiproxy/main.go @@ -4,6 +4,12 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strconv" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -12,10 +18,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strconv" - "strings" ) // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 @@ -89,6 +91,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var usage model.Usage + var documents []LibraryDocument scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -102,60 +105,48 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) - var documents []LibraryDocument - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var AIProxyLibraryResponse LibraryStreamResponse - err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if len(AIProxyLibraryResponse.Documents) != 0 { - documents = AIProxyLibraryResponse.Documents - } - response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - response := documentsAIProxyLibrary(documents) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || data[:5] != "data:" { + continue } - }) - err := resp.Body.Close() + data = data[5:] + + var AIProxyLibraryResponse LibraryStreamResponse + err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + if len(AIProxyLibraryResponse.Documents) != 0 { + documents = AIProxyLibraryResponse.Documents + } + response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + response := documentsAIProxyLibrary(documents) + err := render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + render.Done(c) + + err = resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + return nil, &usage } diff --git a/relay/adaptor/ali/main.go b/relay/adaptor/ali/main.go index 0462c26b..f9039dbe 100644 --- a/relay/adaptor/ali/main.go +++ b/relay/adaptor/ali/main.go @@ -3,15 +3,17 @@ package ali import ( "bufio" "encoding/json" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" ) // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r @@ -181,56 +183,43 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) - //lastResponseText := "" - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var aliResponse ChatResponse - err := json.Unmarshal([]byte(data), &aliResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if aliResponse.Usage.OutputTokens != 0 { - usage.PromptTokens = aliResponse.Usage.InputTokens - usage.CompletionTokens = aliResponse.Usage.OutputTokens - usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens - } - response := streamResponseAli2OpenAI(&aliResponse) - if response == nil { - return true - } - //response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) - //lastResponseText = aliResponse.Output.Text - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || data[:5] != "data:" { + continue } - }) + data = data[5:] + + var aliResponse ChatResponse + err := json.Unmarshal([]byte(data), &aliResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + if aliResponse.Usage.OutputTokens != 0 { + usage.PromptTokens = aliResponse.Usage.InputTokens + usage.CompletionTokens = aliResponse.Usage.OutputTokens + usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens + } + response := streamResponseAli2OpenAI(&aliResponse) + if response == nil { + continue + } + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index a8de185c..c817a9d1 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -169,64 +170,59 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { - continue - } - if !strings.HasPrefix(data, "data:") { - continue - } - data = strings.TrimPrefix(data, "data:") - dataChan <- data - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) + var usage model.Usage var modelName string var id string - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSpace(data) - var claudeResponse StreamResponse - err := json.Unmarshal([]byte(data), &claudeResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response, meta := StreamResponseClaude2OpenAI(&claudeResponse) - if meta != nil { - usage.PromptTokens += meta.Usage.InputTokens - usage.CompletionTokens += meta.Usage.OutputTokens - modelName = meta.Model - id = fmt.Sprintf("chatcmpl-%s", meta.Id) - return true - } - if response == nil { - return true - } - response.Id = id - response.Model = modelName - response.Created = createdTime - jsonStr, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 || !strings.HasPrefix(data, "data:") { + continue } - }) - _ = resp.Body.Close() + data = strings.TrimPrefix(data, "data:") + data = strings.TrimSpace(data) + + var claudeResponse StreamResponse + err := json.Unmarshal([]byte(data), &claudeResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response, meta := StreamResponseClaude2OpenAI(&claudeResponse) + if meta != nil { + usage.PromptTokens += meta.Usage.InputTokens + usage.CompletionTokens += meta.Usage.OutputTokens + modelName = meta.Model + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + continue + } + if response == nil { + continue + } + + response.Id = id + response.Model = modelName + response.Created = createdTime + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } return nil, &usage } diff --git a/relay/adaptor/baidu/main.go b/relay/adaptor/baidu/main.go index b816e0f4..ebe70c32 100644 --- a/relay/adaptor/baidu/main.go +++ b/relay/adaptor/baidu/main.go @@ -5,6 +5,13 @@ import ( "encoding/json" "errors" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "sync" + "time" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/client" @@ -12,11 +19,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" - "sync" - "time" ) // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 @@ -137,59 +139,41 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var usage model.Usage scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // ignore blank line or wrong format - continue - } - data = data[6:] - dataChan <- data - } - stopChan <- true - }() + scanner.Split(bufio.ScanLines) + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var baiduResponse ChatStreamResponse - err := json.Unmarshal([]byte(data), &baiduResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if baiduResponse.Usage.TotalTokens != 0 { - usage.TotalTokens = baiduResponse.Usage.TotalTokens - usage.PromptTokens = baiduResponse.Usage.PromptTokens - usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens - } - response := streamResponseBaidu2OpenAI(&baiduResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { + continue } - }) + data = data[6:] + + var baiduResponse ChatStreamResponse + err := json.Unmarshal([]byte(data), &baiduResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + if baiduResponse.Usage.TotalTokens != 0 { + usage.TotalTokens = baiduResponse.Usage.TotalTokens + usage.PromptTokens = baiduResponse.Usage.PromptTokens + usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens + } + response := streamResponseBaidu2OpenAI(&baiduResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/adaptor/cloudflare/main.go b/relay/adaptor/cloudflare/main.go index f6d496f7..c76520a2 100644 --- a/relay/adaptor/cloudflare/main.go +++ b/relay/adaptor/cloudflare/main.go @@ -2,8 +2,8 @@ package cloudflare import ( "bufio" - "bytes" "encoding/json" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -17,21 +17,20 @@ import ( ) func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { - var promptBuilder strings.Builder - for _, message := range textRequest.Messages { - promptBuilder.WriteString(message.StringContent()) - promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息 - } + var promptBuilder strings.Builder + for _, message := range textRequest.Messages { + promptBuilder.WriteString(message.StringContent()) + promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息 + } - return &Request{ - MaxTokens: textRequest.MaxTokens, - Prompt: promptBuilder.String(), - Stream: textRequest.Stream, - Temperature: textRequest.Temperature, - } + return &Request{ + MaxTokens: textRequest.MaxTokens, + Prompt: promptBuilder.String(), + Stream: textRequest.Stream, + Temperature: textRequest.Temperature, + } } - func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse { choice := openai.TextResponseChoice{ Index: 0, @@ -63,67 +62,54 @@ func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, '\n'); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) + scanner.Split(bufio.ScanLines) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < len("data: ") { - continue - } - data = strings.TrimPrefix(data, "data: ") - dataChan <- data - } - stopChan <- true - }() common.SetEventStreamHeaders(c) id := helper.GetResponseID(c) responseModel := c.GetString("original_model") var responseText string - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var cloudflareResponse StreamResponse - err := json.Unmarshal([]byte(data), &cloudflareResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) - if response == nil { - return true - } - responseText += cloudflareResponse.Response - response.Id = id - response.Model = responseModel - jsonStr, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < len("data: ") { + continue } - }) - _ = resp.Body.Close() + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\r") + + var cloudflareResponse StreamResponse + err := json.Unmarshal([]byte(data), &cloudflareResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) + if response == nil { + continue + } + + responseText += cloudflareResponse.Response + response.Id = id + response.Model = responseModel + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens) return nil, usage } diff --git a/relay/adaptor/cohere/main.go b/relay/adaptor/cohere/main.go index 4bc3fa8d..45db437b 100644 --- a/relay/adaptor/cohere/main.go +++ b/relay/adaptor/cohere/main.go @@ -2,9 +2,9 @@ package cohere import ( "bufio" - "bytes" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -134,66 +134,53 @@ func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { createdTime := helper.GetTimestamp() scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, '\n'); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) + scanner.Split(bufio.ScanLines) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - dataChan <- data - } - stopChan <- true - }() common.SetEventStreamHeaders(c) var usage model.Usage - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var cohereResponse StreamResponse - err := json.Unmarshal([]byte(data), &cohereResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response, meta := StreamResponseCohere2OpenAI(&cohereResponse) - if meta != nil { - usage.PromptTokens += meta.Meta.Tokens.InputTokens - usage.CompletionTokens += meta.Meta.Tokens.OutputTokens - return true - } - if response == nil { - return true - } - response.Id = fmt.Sprintf("chatcmpl-%d", createdTime) - response.Model = c.GetString("original_model") - response.Created = createdTime - jsonStr, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSuffix(data, "\r") + + var cohereResponse StreamResponse + err := json.Unmarshal([]byte(data), &cohereResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue } - }) - _ = resp.Body.Close() + + response, meta := StreamResponseCohere2OpenAI(&cohereResponse) + if meta != nil { + usage.PromptTokens += meta.Meta.Tokens.InputTokens + usage.CompletionTokens += meta.Meta.Tokens.OutputTokens + continue + } + if response == nil { + continue + } + + response.Id = fmt.Sprintf("chatcmpl-%d", createdTime) + response.Model = c.GetString("original_model") + response.Created = createdTime + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage } diff --git a/relay/adaptor/coze/main.go b/relay/adaptor/coze/main.go index 721c5d13..d0402a76 100644 --- a/relay/adaptor/coze/main.go +++ b/relay/adaptor/coze/main.go @@ -4,6 +4,11 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" @@ -12,9 +17,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" ) // https://www.coze.com/open @@ -109,69 +111,54 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC var responseText string createdTime := helper.GetTimestamp() scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { - continue - } - if !strings.HasPrefix(data, "data:") { - continue - } - data = strings.TrimPrefix(data, "data:") - dataChan <- data - } - stopChan <- true - }() + scanner.Split(bufio.ScanLines) + common.SetEventStreamHeaders(c) var modelName string - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var cozeResponse StreamResponse - err := json.Unmarshal([]byte(data), &cozeResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response, _ := StreamResponseCoze2OpenAI(&cozeResponse) - if response == nil { - return true - } - for _, choice := range response.Choices { - responseText += conv.AsString(choice.Delta.Content) - } - response.Model = modelName - response.Created = createdTime - jsonStr, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || !strings.HasPrefix(data, "data:") { + continue } - }) - _ = resp.Body.Close() + data = strings.TrimPrefix(data, "data:") + data = strings.TrimSuffix(data, "\r") + + var cozeResponse StreamResponse + err := json.Unmarshal([]byte(data), &cozeResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response, _ := StreamResponseCoze2OpenAI(&cozeResponse) + if response == nil { + continue + } + + for _, choice := range response.Choices { + responseText += conv.AsString(choice.Delta.Content) + } + response.Model = modelName + response.Created = createdTime + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &responseText } diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 74a7d5d5..51fd6aa8 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -275,64 +276,50 @@ func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.Embeddi func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { responseText := "" scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - data = strings.TrimSpace(data) - if !strings.HasPrefix(data, "data: ") { - continue - } - data = strings.TrimPrefix(data, "data: ") - data = strings.TrimSuffix(data, "\"") - dataChan <- data - } - stopChan <- true - }() + scanner.Split(bufio.ScanLines) + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var geminiResponse ChatResponse - err := json.Unmarshal([]byte(data), &geminiResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response := streamResponseGeminiChat2OpenAI(&geminiResponse) - if response == nil { - return true - } - responseText += response.Choices[0].Delta.StringContent() - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "data: ") { + continue } - }) + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\"") + + var geminiResponse ChatResponse + err := json.Unmarshal([]byte(data), &geminiResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response := streamResponseGeminiChat2OpenAI(&geminiResponse) + if response == nil { + continue + } + + responseText += response.Choices[0].Delta.StringContent() + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } + return nil, responseText } diff --git a/relay/adaptor/ollama/main.go b/relay/adaptor/ollama/main.go index c5fe08e6..936a7e14 100644 --- a/relay/adaptor/ollama/main.go +++ b/relay/adaptor/ollama/main.go @@ -5,12 +5,14 @@ import ( "context" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/helper" - "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/random" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/image" @@ -105,54 +107,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC return 0, nil, nil } if i := strings.Index(string(data), "}\n"); i >= 0 { - return i + 2, data[0:i], nil + return i + 2, data[0 : i+1], nil } if atEOF { return len(data), data, nil } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := strings.TrimPrefix(scanner.Text(), "}") - dataChan <- data + "}" - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var ollamaResponse ChatResponse - err := json.Unmarshal([]byte(data), &ollamaResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if ollamaResponse.EvalCount != 0 { - usage.PromptTokens = ollamaResponse.PromptEvalCount - usage.CompletionTokens = ollamaResponse.EvalCount - usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount - } - response := streamResponseOllama2OpenAI(&ollamaResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := strings.TrimPrefix(scanner.Text(), "}") + data = data + "}" + + var ollamaResponse ChatResponse + err := json.Unmarshal([]byte(data), &ollamaResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue } - }) + + if ollamaResponse.EvalCount != 0 { + usage.PromptTokens = ollamaResponse.PromptEvalCount + usage.CompletionTokens = ollamaResponse.EvalCount + usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount + } + + response := streamResponseOllama2OpenAI(&ollamaResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + return nil, &usage } diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 07cb967f..1d534644 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/json" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -25,88 +26,68 @@ const ( func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { responseText := "" scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) + scanner.Split(bufio.ScanLines) var usage *model.Usage - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < dataPrefixLength { // ignore blank line or wrong format - continue - } - if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done { - continue - } - if strings.HasPrefix(data[dataPrefixLength:], done) { - dataChan <- data - continue - } - switch relayMode { - case relaymode.ChatCompletions: - var streamResponse ChatCompletionsStreamResponse - err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - dataChan <- data // if error happened, pass the data to client - continue // just ignore the error - } - if len(streamResponse.Choices) == 0 { - // but for empty choice, we should not pass it to client, this is for azure - continue // just ignore empty choice - } - dataChan <- data - for _, choice := range streamResponse.Choices { - responseText += conv.AsString(choice.Delta.Content) - } - if streamResponse.Usage != nil { - usage = streamResponse.Usage - } - case relaymode.Completions: - dataChan <- data - var streamResponse CompletionsStreamResponse - err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - continue - } - for _, choice := range streamResponse.Choices { - responseText += choice.Text - } - } - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if strings.HasPrefix(data, "data: [DONE]") { - data = data[:12] - } - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - c.Render(-1, common.CustomEvent{Data: data}) - return true - case <-stopChan: - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < dataPrefixLength { // ignore blank line or wrong format + continue } - }) + if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done { + continue + } + if strings.HasPrefix(data[dataPrefixLength:], done) { + render.StringData(c, data) + continue + } + switch relayMode { + case relaymode.ChatCompletions: + var streamResponse ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + render.StringData(c, data) // if error happened, pass the data to client + continue // just ignore the error + } + if len(streamResponse.Choices) == 0 { + // but for empty choice, we should not pass it to client, this is for azure + continue // just ignore empty choice + } + render.StringData(c, data) + for _, choice := range streamResponse.Choices { + responseText += conv.AsString(choice.Delta.Content) + } + if streamResponse.Usage != nil { + usage = streamResponse.Usage + } + case relaymode.Completions: + render.StringData(c, data) + var streamResponse CompletionsStreamResponse + err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + for _, choice := range streamResponse.Choices { + responseText += choice.Text + } + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil } + return nil, responseText, usage } diff --git a/relay/adaptor/palm/palm.go b/relay/adaptor/palm/palm.go index 1e60e7cd..d31784ec 100644 --- a/relay/adaptor/palm/palm.go +++ b/relay/adaptor/palm/palm.go @@ -3,6 +3,10 @@ package palm import ( "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -11,8 +15,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body @@ -77,58 +79,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC responseText := "" responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID()) createdTime := helper.GetTimestamp() - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - logger.SysError("error reading stream response: " + err.Error()) - stopChan <- true - return - } - err = resp.Body.Close() - if err != nil { - logger.SysError("error closing stream response: " + err.Error()) - stopChan <- true - return - } - var palmResponse ChatResponse - err = json.Unmarshal(responseBody, &palmResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - stopChan <- true - return - } - fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) - fullTextResponse.Id = responseId - fullTextResponse.Created = createdTime - if len(palmResponse.Candidates) > 0 { - responseText = palmResponse.Candidates[0].Content - } - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - stopChan <- true - return - } - dataChan <- string(jsonResponse) - stopChan <- true - }() + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - c.Render(-1, common.CustomEvent{Data: "data: " + data}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + logger.SysError("error reading stream response: " + err.Error()) + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } - }) - err := resp.Body.Close() + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), "" + } + + err = resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } + + var palmResponse ChatResponse + err = json.Unmarshal(responseBody, &palmResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), "" + } + + fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) + fullTextResponse.Id = responseId + fullTextResponse.Created = createdTime + if len(palmResponse.Candidates) > 0 { + responseText = palmResponse.Candidates[0].Content + } + + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), "" + } + + err = render.ObjectData(c, string(jsonResponse)) + if err != nil { + logger.SysError(err.Error()) + } + + render.Done(c) + return nil, responseText } diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go index 0a57dcf7..365e33ae 100644 --- a/relay/adaptor/tencent/main.go +++ b/relay/adaptor/tencent/main.go @@ -8,6 +8,13 @@ import ( "encoding/json" "errors" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strconv" + "strings" + "time" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" @@ -17,11 +24,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strconv" - "strings" - "time" ) func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { @@ -87,64 +89,46 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { var responseText string scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() + scanner.Split(bufio.ScanLines) + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var TencentResponse ChatResponse - err := json.Unmarshal([]byte(data), &TencentResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response := streamResponseTencent2OpenAI(&TencentResponse) - if len(response.Choices) != 0 { - responseText += conv.AsString(response.Choices[0].Delta.Content) - } - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || !strings.HasPrefix(data, "data:") { + continue } - }) + data = strings.TrimPrefix(data, "data:") + + var tencentResponse ChatResponse + err := json.Unmarshal([]byte(data), &tencentResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response := streamResponseTencent2OpenAI(&tencentResponse) + if len(response.Choices) != 0 { + responseText += conv.AsString(response.Choices[0].Delta.Content) + } + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } + return nil, responseText } diff --git a/relay/adaptor/zhipu/main.go b/relay/adaptor/zhipu/main.go index 74a1a05e..ab3a5678 100644 --- a/relay/adaptor/zhipu/main.go +++ b/relay/adaptor/zhipu/main.go @@ -3,6 +3,13 @@ package zhipu import ( "bufio" "encoding/json" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "sync" + "time" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" "github.com/songquanpeng/one-api/common" @@ -11,11 +18,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" - "sync" - "time" ) // https://open.bigmodel.cn/doc/api#chatglm_std @@ -155,66 +157,55 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) - dataChan := make(chan string) - metaChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - lines := strings.Split(data, "\n") - for i, line := range lines { - if len(line) < 5 { + + common.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + lines := strings.Split(data, "\n") + for i, line := range lines { + if len(line) < 5 { + continue + } + if strings.HasPrefix(line, "data:") { + dataSegment := line[5:] + if i != len(lines)-1 { + dataSegment += "\n" + } + response := streamResponseZhipu2OpenAI(dataSegment) + err := render.ObjectData(c, response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + } + } else if strings.HasPrefix(line, "meta:") { + metaSegment := line[5:] + var zhipuResponse StreamMetaResponse + err := json.Unmarshal([]byte(metaSegment), &zhipuResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) continue } - if line[:5] == "data:" { - dataChan <- line[5:] - if i != len(lines)-1 { - dataChan <- "\n" - } - } else if line[:5] == "meta:" { - metaChan <- line[5:] + response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) } + usage = zhipuUsage } } - stopChan <- true - }() - common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - response := streamResponseZhipu2OpenAI(data) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case data := <-metaChan: - var zhipuResponse StreamMetaResponse - err := json.Unmarshal([]byte(data), &zhipuResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - usage = zhipuUsage - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + return nil, usage } From d0369b114f6b9a34a926979b309ac1fd052db698 Mon Sep 17 00:00:00 2001 From: lihangfu <280001404@qq.com> Date: Sun, 30 Jun 2024 19:37:07 +0800 Subject: [PATCH 08/26] feat: support spark4.0 ultra (#1569) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: 支持v3最新协议的腾讯混元(#1452) * feat: 支持Spark4.0 Ultra --------- Co-authored-by: lihangfu --- relay/adaptor/xunfei/constants.go | 1 + relay/adaptor/xunfei/main.go | 2 ++ relay/billing/ratio/model.go | 1 + web/air/src/pages/Channel/EditChannel.js | 2 +- web/berry/src/views/Channel/type/Config.js | 2 +- 5 files changed, 6 insertions(+), 2 deletions(-) diff --git a/relay/adaptor/xunfei/constants.go b/relay/adaptor/xunfei/constants.go index 31dcec71..12a56210 100644 --- a/relay/adaptor/xunfei/constants.go +++ b/relay/adaptor/xunfei/constants.go @@ -6,4 +6,5 @@ var ModelList = []string{ "SparkDesk-v2.1", "SparkDesk-v3.1", "SparkDesk-v3.5", + "SparkDesk-v4.0", } diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go index 39b76e27..7cf413a4 100644 --- a/relay/adaptor/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -290,6 +290,8 @@ func apiVersion2domain(apiVersion string) string { return "generalv3" case "v3.5": return "generalv3.5" + case "v4.0": + return "4.0Ultra" } return "general" + apiVersion } diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index b1a8a5b4..56d31e13 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -125,6 +125,7 @@ var ModelRatio = map[string]float64{ "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens diff --git a/web/air/src/pages/Channel/EditChannel.js b/web/air/src/pages/Channel/EditChannel.js index d63fa8fa..73fd2da2 100644 --- a/web/air/src/pages/Channel/EditChannel.js +++ b/web/air/src/pages/Channel/EditChannel.js @@ -78,7 +78,7 @@ const EditChannel = (props) => { localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; break; case 18: - localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5']; + localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']; break; case 19: localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index 88e1ea92..51b7c6c4 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -91,7 +91,7 @@ const typeConfig = { other: '版本号' }, input: { - models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5'] + models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'] }, prompt: { key: '按照如下格式输入:APPID|APISecret|APIKey', From c135d74f136813ff26731c2e78bcf2fc3dc3daed Mon Sep 17 00:00:00 2001 From: Shi Jilin <40982122+shijilin0116@users.noreply.github.com> Date: Sun, 30 Jun 2024 19:38:02 +0800 Subject: [PATCH 09/26] feat: support Spark4.0 Ultra (#1575) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: fix SparkDesk Function Call (修复 Spark Pro/Max函数调用只会返回普通对话回答而不是Function Call回答的问题 * feat: support Spark4.0 Ultra --- relay/adaptor/xunfei/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go index 7cf413a4..ef6120e5 100644 --- a/relay/adaptor/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -44,7 +44,7 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Payload.Message.Text = messages - if strings.HasPrefix(domain, "generalv3") { + if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" { functions := make([]model.Function, len(request.Tools)) for i, tool := range request.Tools { functions[i] = tool.Function From fecaece71b700b43ba11c161a3f8a971af204971 Mon Sep 17 00:00:00 2001 From: igophper <34326532+igophper@users.noreply.github.com> Date: Sun, 30 Jun 2024 19:52:33 +0800 Subject: [PATCH 10/26] fix: fix size not support during image generation (#1564) Fixes #1224, #1068 --- relay/controller/helper.go | 72 -------------------------------- relay/controller/image.go | 84 +++++++++++++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 78 deletions(-) diff --git a/relay/controller/helper.go b/relay/controller/helper.go index dccff486..c47cb558 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -40,78 +40,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener return textRequest, nil } -func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { - imageRequest := &relaymodel.ImageRequest{} - err := common.UnmarshalBodyReusable(c, imageRequest) - if err != nil { - return nil, err - } - if imageRequest.N == 0 { - imageRequest.N = 1 - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-2" - } - return imageRequest, nil -} - -func isValidImageSize(model string, size string) bool { - if model == "cogview-3" { - return true - } - _, ok := billingratio.ImageSizeRatios[model][size] - return ok -} - -func getImageSizeRatio(model string, size string) float64 { - ratio, ok := billingratio.ImageSizeRatios[model][size] - if !ok { - return 1 - } - return ratio -} - -func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { - // model validation - hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size) - if !hasValidSize { - return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) - } - // check prompt length - if imageRequest.Prompt == "" { - return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) - } - if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] { - return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) - } - // Number of generated images validation - if !isWithinRange(imageRequest.Model, imageRequest.N) { - // channel not azure - if meta.ChannelType != channeltype.Azure { - return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) - } - } - return nil -} - -func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { - if imageRequest == nil { - return 0, errors.New("imageRequest is nil") - } - imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) - if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { - if imageRequest.Size == "1024x1024" { - imageCostRatio *= 2 - } else { - imageCostRatio *= 1.5 - } - } - return imageCostRatio, nil -} - func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case relaymode.ChatCompletions: diff --git a/relay/controller/image.go b/relay/controller/image.go index 691c7c0e..e6245226 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" @@ -20,13 +21,84 @@ import ( "net/http" ) -func isWithinRange(element string, value int) bool { - if _, ok := billingratio.ImageGenerationAmounts[element]; !ok { - return false +func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { + imageRequest := &relaymodel.ImageRequest{} + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err } - min := billingratio.ImageGenerationAmounts[element][0] - max := billingratio.ImageGenerationAmounts[element][1] - return value >= min && value <= max + if imageRequest.N == 0 { + imageRequest.N = 1 + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" + } + return imageRequest, nil +} + +func isValidImageSize(model string, size string) bool { + if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil { + return true + } + _, ok := billingratio.ImageSizeRatios[model][size] + return ok +} + +func isValidImagePromptLength(model string, promptLength int) bool { + maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model] + return !ok || promptLength <= maxPromptLength +} + +func isWithinRange(element string, value int) bool { + amounts, ok := billingratio.ImageGenerationAmounts[element] + return !ok || (value >= amounts[0] && value <= amounts[1]) +} + +func getImageSizeRatio(model string, size string) float64 { + if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok { + return ratio + } + return 1 +} + +func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { + // check prompt length + if imageRequest.Prompt == "" { + return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) + } + + // model validation + if !isValidImageSize(imageRequest.Model, imageRequest.Size) { + return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + } + + if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) { + return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) + } + + // Number of generated images validation + if !isWithinRange(imageRequest.Model, imageRequest.N) { + return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + } + return nil +} + +func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { + if imageRequest == nil { + return 0, errors.New("imageRequest is nil") + } + imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) + if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { + if imageRequest.Size == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + return imageCostRatio, nil } func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { From d936817de9866b3e8b6bf1a0f741a2a4eb6c3bd4 Mon Sep 17 00:00:00 2001 From: Darkside Date: Sun, 30 Jun 2024 19:57:30 +0800 Subject: [PATCH 11/26] docs: add related projects (#1562) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 成达 --- README.en.md | 10 ++++++---- README.md | 12 +++++++----- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/README.en.md b/README.en.md index bce47353..db96a858 100644 --- a/README.en.md +++ b/README.en.md @@ -101,7 +101,7 @@ Nginx reference configuration: ``` server{ server_name openai.justsong.cn; # Modify your domain name accordingly - + location / { client_max_body_size 64m; proxy_http_version 1.1; @@ -132,12 +132,12 @@ The initial account username is `root` and password is `123456`. 1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source: ```shell git clone https://github.com/songquanpeng/one-api.git - + # Build the frontend cd one-api/web/default npm install npm run build - + # Build the backend cd ../.. go mod download @@ -287,7 +287,9 @@ If the channel ID is not provided, load balancing will be used to distribute the + Double-check that your interface address and API Key are correct. ## Related Projects -[FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM +* [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM +* [VChart](https://github.com/VisActor/VChart): More than just a cross-platform charting library, but also an expressive data storyteller. +* [VMind](https://github.com/VisActor/VMind): Not just automatic, but also fantastic. Open-source solution for intelligent visualization. ## Note This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. diff --git a/README.md b/README.md index 8f59a14a..b5168264 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 > [!NOTE] > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 -> +> > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 > [!WARNING] @@ -144,7 +144,7 @@ Nginx 的参考配置: ``` server{ server_name openai.justsong.cn; # 请根据实际情况修改你的域名 - + location / { client_max_body_size 64m; proxy_http_version 1.1; @@ -189,12 +189,12 @@ docker-compose ps 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: ```shell git clone https://github.com/songquanpeng/one-api.git - + # 构建前端 cd one-api/web/default npm install npm run build - + # 构建后端 cd ../.. go mod download @@ -321,7 +321,7 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashbo 例如对于 OpenAI 的官方库: ```bash OPENAI_API_KEY="sk-xxxxxx" -OPENAI_API_BASE="https://:/v1" +OPENAI_API_BASE="https://:/v1" ``` ```mermaid @@ -448,6 +448,8 @@ https://openai.justsong.cn ## 相关项目 * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 * [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用 +* [VChart](https://github.com/VisActor/VChart): 不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。 +* [VMind](https://github.com/VisActor/VMind): 不仅自动,还很智能。开源智能可视化解决方案。 ## 注意 From 1ce1e529ee547989cdbcc1ab04163fbce608d887 Mon Sep 17 00:00:00 2001 From: Leo Q Date: Tue, 2 Jul 2024 00:05:47 +0800 Subject: [PATCH 12/26] ci: skip archive, upload directly (#1586) --- .github/workflows/ci.yml | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 30ac5f82..36798711 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,21 +36,9 @@ jobs: # in the next step as well as the next job. - name: Test run: go test -cover -coverprofile=coverage.txt ./... - - - name: Archive code coverage results - uses: actions/upload-artifact@v4 + - uses: codecov/codecov-action@v4 with: - name: code-coverage - path: coverage.txt # Make sure to use the same file name you chose for the "-coverprofile" in the "Test" step - - code_coverage: - name: "Code coverage report" - runs-on: ubuntu-latest - needs: unit_tests # Depends on the artifact uploaded by the "unit_tests" job - steps: - - uses: codecov/codecov-action@v4 - with: - token: ${{ secrets.CODECOV_TOKEN }} + token: ${{ secrets.CODECOV_TOKEN }} commit_lint: runs-on: ubuntu-latest From 0fc07ea55897a9d74380da2767b9bfa25e71cbd3 Mon Sep 17 00:00:00 2001 From: Mikey Date: Tue, 2 Jul 2024 00:12:01 +0800 Subject: [PATCH 13/26] feat: add support for Claude 3 tool use (function calling) (#1587) * feat: add tool support for AWS & Claude * fix: add {} for openai compatibility in streaming tool_use --- relay/adaptor/anthropic/main.go | 121 +++++++++++++++++++++++++++++-- relay/adaptor/anthropic/model.go | 21 ++++++ relay/adaptor/aws/main.go | 24 +++++- relay/adaptor/aws/model.go | 3 + relay/model/message.go | 9 ++- relay/model/tool.go | 4 +- 6 files changed, 168 insertions(+), 14 deletions(-) diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index c817a9d1..d3e306c8 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -29,12 +29,30 @@ func stopReasonClaude2OpenAI(reason *string) string { return "stop" case "max_tokens": return "length" + case "tool_use": + return "tool_calls" default: return *reason } } func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { + claudeTools := make([]Tool, 0, len(textRequest.Tools)) + + for _, tool := range textRequest.Tools { + if params, ok := tool.Function.Parameters.(map[string]any); ok { + claudeTools = append(claudeTools, Tool{ + Name: tool.Function.Name, + Description: tool.Function.Description, + InputSchema: InputSchema{ + Type: params["type"].(string), + Properties: params["properties"], + Required: params["required"], + }, + }) + } + } + claudeRequest := Request{ Model: textRequest.Model, MaxTokens: textRequest.MaxTokens, @@ -42,6 +60,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { TopP: textRequest.TopP, TopK: textRequest.TopK, Stream: textRequest.Stream, + Tools: claudeTools, + } + if len(claudeTools) > 0 { + claudeToolChoice := struct { + Type string `json:"type"` + Name string `json:"name,omitempty"` + }{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output + if choice, ok := textRequest.ToolChoice.(map[string]any); ok { + if function, ok := choice["function"].(map[string]any); ok { + claudeToolChoice.Type = "tool" + claudeToolChoice.Name = function["name"].(string) + } + } else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok { + if toolChoiceType == "any" { + claudeToolChoice.Type = toolChoiceType + } + } + claudeRequest.ToolChoice = claudeToolChoice } if claudeRequest.MaxTokens == 0 { claudeRequest.MaxTokens = 4096 @@ -64,7 +100,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { if message.IsStringContent() { content.Type = "text" content.Text = message.StringContent() + if message.Role == "tool" { + claudeMessage.Role = "user" + content.Type = "tool_result" + content.Content = content.Text + content.Text = "" + content.ToolUseId = message.ToolCallId + } claudeMessage.Content = append(claudeMessage.Content, content) + for i := range message.ToolCalls { + inputParam := make(map[string]any) + _ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam) + claudeMessage.Content = append(claudeMessage.Content, Content{ + Type: "tool_use", + Id: message.ToolCalls[i].Id, + Name: message.ToolCalls[i].Function.Name, + Input: inputParam, + }) + } claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) continue } @@ -97,16 +150,35 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo var response *Response var responseText string var stopReason string + tools := make([]model.Tool, 0) + switch claudeResponse.Type { case "message_start": return nil, claudeResponse.Message case "content_block_start": if claudeResponse.ContentBlock != nil { responseText = claudeResponse.ContentBlock.Text + if claudeResponse.ContentBlock.Type == "tool_use" { + tools = append(tools, model.Tool{ + Id: claudeResponse.ContentBlock.Id, + Type: "function", + Function: model.Function{ + Name: claudeResponse.ContentBlock.Name, + Arguments: "", + }, + }) + } } case "content_block_delta": if claudeResponse.Delta != nil { responseText = claudeResponse.Delta.Text + if claudeResponse.Delta.Type == "input_json_delta" { + tools = append(tools, model.Tool{ + Function: model.Function{ + Arguments: claudeResponse.Delta.PartialJson, + }, + }) + } } case "message_delta": if claudeResponse.Usage != nil { @@ -120,6 +192,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo } var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = responseText + if len(tools) > 0 { + choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... + choice.Delta.ToolCalls = tools + } choice.Delta.Role = "assistant" finishReason := stopReasonClaude2OpenAI(&stopReason) if finishReason != "null" { @@ -136,12 +212,27 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { if len(claudeResponse.Content) > 0 { responseText = claudeResponse.Content[0].Text } + tools := make([]model.Tool, 0) + for _, v := range claudeResponse.Content { + if v.Type == "tool_use" { + args, _ := json.Marshal(v.Input) + tools = append(tools, model.Tool{ + Id: v.Id, + Type: "function", // compatible with other OpenAI derivative applications + Function: model.Function{ + Name: v.Name, + Arguments: string(args), + }, + }) + } + } choice := openai.TextResponseChoice{ Index: 0, Message: model.Message{ - Role: "assistant", - Content: responseText, - Name: nil, + Role: "assistant", + Content: responseText, + Name: nil, + ToolCalls: tools, }, FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), } @@ -176,6 +267,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC var usage model.Usage var modelName string var id string + var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice for scanner.Scan() { data := scanner.Text() @@ -196,9 +288,20 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC if meta != nil { usage.PromptTokens += meta.Usage.InputTokens usage.CompletionTokens += meta.Usage.OutputTokens - modelName = meta.Model - id = fmt.Sprintf("chatcmpl-%s", meta.Id) - continue + if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. + modelName = meta.Model + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + continue + } else { // finish_reason case + if len(lastToolCallChoice.Delta.ToolCalls) > 0 { + lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function + if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. + lastArgs.Arguments = "{}" + response.Choices[len(response.Choices)-1].Delta.Content = nil + response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls + } + } + } } if response == nil { continue @@ -207,6 +310,12 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC response.Id = id response.Model = modelName response.Created = createdTime + + for _, choice := range response.Choices { + if len(choice.Delta.ToolCalls) > 0 { + lastToolCallChoice = choice + } + } err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) diff --git a/relay/adaptor/anthropic/model.go b/relay/adaptor/anthropic/model.go index 32b187cd..47f76629 100644 --- a/relay/adaptor/anthropic/model.go +++ b/relay/adaptor/anthropic/model.go @@ -16,6 +16,12 @@ type Content struct { Type string `json:"type"` Text string `json:"text,omitempty"` Source *ImageSource `json:"source,omitempty"` + // tool_calls + Id string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` + Content string `json:"content,omitempty"` + ToolUseId string `json:"tool_use_id,omitempty"` } type Message struct { @@ -23,6 +29,18 @@ type Message struct { Content []Content `json:"content"` } +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema InputSchema `json:"input_schema"` +} + +type InputSchema struct { + Type string `json:"type"` + Properties any `json:"properties,omitempty"` + Required any `json:"required,omitempty"` +} + type Request struct { Model string `json:"model"` Messages []Message `json:"messages"` @@ -33,6 +51,8 @@ type Request struct { Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` //Metadata `json:"metadata,omitempty"` } @@ -61,6 +81,7 @@ type Response struct { type Delta struct { Type string `json:"type"` Text string `json:"text"` + PartialJson string `json:"partial_json,omitempty"` StopReason *string `json:"stop_reason"` StopSequence *string `json:"stop_sequence"` } diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go index 5d29597c..72f40ddc 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/main.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "io" "net/http" @@ -143,6 +144,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E c.Writer.Header().Set("Content-Type", "text/event-stream") var usage relaymodel.Usage var id string + var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice + c.Stream(func(w io.Writer) bool { event, ok := <-stream.Events() if !ok { @@ -163,8 +166,19 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E if meta != nil { usage.PromptTokens += meta.Usage.InputTokens usage.CompletionTokens += meta.Usage.OutputTokens - id = fmt.Sprintf("chatcmpl-%s", meta.Id) - return true + if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + return true + } else { // finish_reason case + if len(lastToolCallChoice.Delta.ToolCalls) > 0 { + lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function + if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. + lastArgs.Arguments = "{}" + response.Choices[len(response.Choices)-1].Delta.Content = nil + response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls + } + } + } } if response == nil { return true @@ -172,6 +186,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E response.Id = id response.Model = c.GetString(ctxkey.OriginalModel) response.Created = createdTime + + for _, choice := range response.Choices { + if len(choice.Delta.ToolCalls) > 0 { + lastToolCallChoice = choice + } + } jsonStr, err := json.Marshal(response) if err != nil { logger.SysError("error marshalling stream response: " + err.Error()) diff --git a/relay/adaptor/aws/model.go b/relay/adaptor/aws/model.go index bcbfb584..6d00b688 100644 --- a/relay/adaptor/aws/model.go +++ b/relay/adaptor/aws/model.go @@ -9,9 +9,12 @@ type Request struct { // AnthropicVersion should be "bedrock-2023-05-31" AnthropicVersion string `json:"anthropic_version"` Messages []anthropic.Message `json:"messages"` + System string `json:"system,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` + Tools []anthropic.Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` } diff --git a/relay/model/message.go b/relay/model/message.go index 32a1055b..b908f989 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -1,10 +1,11 @@ package model type Message struct { - Role string `json:"role,omitempty"` - Content any `json:"content,omitempty"` - Name *string `json:"name,omitempty"` - ToolCalls []Tool `json:"tool_calls,omitempty"` + Role string `json:"role,omitempty"` + Content any `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []Tool `json:"tool_calls,omitempty"` + ToolCallId string `json:"tool_call_id,omitempty"` } func (m Message) IsStringContent() bool { diff --git a/relay/model/tool.go b/relay/model/tool.go index 253dca35..75dbb8f7 100644 --- a/relay/model/tool.go +++ b/relay/model/tool.go @@ -2,13 +2,13 @@ package model type Tool struct { Id string `json:"id,omitempty"` - Type string `json:"type"` + Type string `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty Function Function `json:"function"` } type Function struct { Description string `json:"description,omitempty"` - Name string `json:"name"` + Name string `json:"name,omitempty"` // when splicing claude tools stream messages, it is empty Parameters any `json:"parameters,omitempty"` // request Arguments any `json:"arguments,omitempty"` // response } From 274fcf3d76299e1e56a670a4c203e130d3561a0e Mon Sep 17 00:00:00 2001 From: igophper <34326532+igophper@users.noreply.github.com> Date: Wed, 3 Jul 2024 20:50:40 +0800 Subject: [PATCH 14/26] refactor: init db (#1590) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 江杭辉 --- main.go | 22 ++--- model/main.go | 219 ++++++++++++++++++++++++++++++++------------------ 2 files changed, 150 insertions(+), 91 deletions(-) diff --git a/main.go b/main.go index 4afbe5dd..67a3cd95 100644 --- a/main.go +++ b/main.go @@ -27,27 +27,19 @@ func main() { common.Init() logger.SetupLogger() logger.SysLogf("One API %s started", common.Version) - if os.Getenv("GIN_MODE") != "debug" { + + if os.Getenv("GIN_MODE") != gin.DebugMode { gin.SetMode(gin.ReleaseMode) } if config.DebugEnabled { logger.SysLog("running in debug mode") } - var err error + // Initialize SQL Database - model.DB, err = model.InitDB("SQL_DSN") - if err != nil { - logger.FatalLog("failed to initialize database: " + err.Error()) - } - if os.Getenv("LOG_SQL_DSN") != "" { - logger.SysLog("using secondary database for table logs") - model.LOG_DB, err = model.InitDB("LOG_SQL_DSN") - if err != nil { - logger.FatalLog("failed to initialize secondary database: " + err.Error()) - } - } else { - model.LOG_DB = model.DB - } + model.InitDB() + model.InitLogDB() + + var err error err = model.CreateRootAccountIfNeed() if err != nil { logger.FatalLog("database init error: " + err.Error()) diff --git a/model/main.go b/model/main.go index 4b5323c4..11752404 100644 --- a/model/main.go +++ b/model/main.go @@ -1,6 +1,7 @@ package model import ( + "database/sql" "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" @@ -60,90 +61,156 @@ func CreateRootAccountIfNeed() error { } func chooseDB(envName string) (*gorm.DB, error) { - if os.Getenv(envName) != "" { - dsn := os.Getenv(envName) - if strings.HasPrefix(dsn, "postgres://") { - // Use PostgreSQL - logger.SysLog("using PostgreSQL as database") - common.UsingPostgreSQL = true - return gorm.Open(postgres.New(postgres.Config{ - DSN: dsn, - PreferSimpleProtocol: true, // disables implicit prepared statement usage - }), &gorm.Config{ - PrepareStmt: true, // precompile SQL - }) - } + dsn := os.Getenv(envName) + + switch { + case strings.HasPrefix(dsn, "postgres://"): + // Use PostgreSQL + return openPostgreSQL(dsn) + case dsn != "": // Use MySQL - logger.SysLog("using MySQL as database") - common.UsingMySQL = true - return gorm.Open(mysql.Open(dsn), &gorm.Config{ - PrepareStmt: true, // precompile SQL - }) + return openMySQL(dsn) + default: + // Use SQLite + return openSQLite() } - // Use SQLite - logger.SysLog("SQL_DSN not set, using SQLite as database") - common.UsingSQLite = true - config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) - return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ +} + +func openPostgreSQL(dsn string) (*gorm.DB, error) { + logger.SysLog("using PostgreSQL as database") + common.UsingPostgreSQL = true + return gorm.Open(postgres.New(postgres.Config{ + DSN: dsn, + PreferSimpleProtocol: true, // disables implicit prepared statement usage + }), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } -func InitDB(envName string) (db *gorm.DB, err error) { - db, err = chooseDB(envName) - if err == nil { - if config.DebugSQLEnabled { - db = db.Debug() - } - sqlDB, err := db.DB() - if err != nil { - return nil, err - } - sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) +func openMySQL(dsn string) (*gorm.DB, error) { + logger.SysLog("using MySQL as database") + common.UsingMySQL = true + return gorm.Open(mysql.Open(dsn), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) +} - if !config.IsMasterNode { - return db, err - } - if common.UsingMySQL { - _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded - } - logger.SysLog("database migration started") - err = db.AutoMigrate(&Channel{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Token{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&User{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Option{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Redemption{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Ability{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Log{}) - if err != nil { - return nil, err - } - logger.SysLog("database migrated") - return db, err - } else { - logger.FatalLog(err) +func openSQLite() (*gorm.DB, error) { + logger.SysLog("SQL_DSN not set, using SQLite as database") + common.UsingSQLite = true + dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout) + return gorm.Open(sqlite.Open(dsn), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) +} + +func InitDB() { + var err error + DB, err = chooseDB("SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize database: " + err.Error()) + return } - return db, err + + sqlDB := setDBConns(DB) + + if !config.IsMasterNode { + return + } + + if common.UsingMySQL { + _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded + } + + logger.SysLog("database migration started") + if err = migrateDB(); err != nil { + logger.FatalLog("failed to migrate database: " + err.Error()) + return + } + logger.SysLog("database migrated") +} + +func migrateDB() error { + var err error + if err = DB.AutoMigrate(&Channel{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Token{}); err != nil { + return err + } + if err = DB.AutoMigrate(&User{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Option{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Redemption{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Ability{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Log{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Channel{}); err != nil { + return err + } + return nil +} + +func InitLogDB() { + if os.Getenv("LOG_SQL_DSN") == "" { + LOG_DB = DB + return + } + + logger.SysLog("using secondary database for table logs") + var err error + LOG_DB, err = chooseDB("LOG_SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize secondary database: " + err.Error()) + return + } + + setDBConns(LOG_DB) + + if !config.IsMasterNode { + return + } + + logger.SysLog("secondary database migration started") + err = migrateLOGDB() + if err != nil { + logger.FatalLog("failed to migrate secondary database: " + err.Error()) + return + } + logger.SysLog("secondary database migrated") +} + +func migrateLOGDB() error { + var err error + if err = LOG_DB.AutoMigrate(&Log{}); err != nil { + return err + } + return nil +} + +func setDBConns(db *gorm.DB) *sql.DB { + if config.DebugSQLEnabled { + db = db.Debug() + } + + sqlDB, err := db.DB() + if err != nil { + logger.FatalLog("failed to connect database: " + err.Error()) + return nil + } + + sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) + return sqlDB } func closeDB(db *gorm.DB) error { From c4fe57c16512372b84f0765c78d3e0b2d1eef912 Mon Sep 17 00:00:00 2001 From: LinZeliang Date: Wed, 3 Jul 2024 20:53:29 +0800 Subject: [PATCH 15/26] feat: support one or more log file (#1400) Co-authored-by: Laisky.Cai --- common/config/config.go | 3 +++ common/logger/logger.go | 7 ++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/common/config/config.go b/common/config/config.go index 4f1c25b6..3f321c87 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -145,6 +145,9 @@ var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") var GeminiVersion = env.String("GEMINI_VERSION", "v1") + +var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) + var RelayProxy = env.String("RELAY_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) diff --git a/common/logger/logger.go b/common/logger/logger.go index f725c619..d1022932 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -27,7 +27,12 @@ var setupLogOnce sync.Once func SetupLogger() { setupLogOnce.Do(func() { if LogDir != "" { - logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + var logPath string + if config.OnlyOneLogFile { + logPath = filepath.Join(LogDir, "oneapi.log") + } else { + logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + } fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") From ec6ad248104045d7b67effc72867d9f4a31e55fe Mon Sep 17 00:00:00 2001 From: Leo Q Date: Wed, 3 Jul 2024 22:23:49 +0800 Subject: [PATCH 16/26] feat: support smtp without auth (#1101) --- common/message/email.go | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/common/message/email.go b/common/message/email.go index b06782db..187ac8c3 100644 --- a/common/message/email.go +++ b/common/message/email.go @@ -6,11 +6,16 @@ import ( "encoding/base64" "fmt" "github.com/songquanpeng/one-api/common/config" + "net" "net/smtp" "strings" "time" ) +func shouldAuth() bool { + return config.SMTPAccount != "" || config.SMTPToken != "" +} + func SendEmail(subject string, receiver string, content string) error { if receiver == "" { return fmt.Errorf("receiver is empty") @@ -41,16 +46,24 @@ func SendEmail(subject string, receiver string, content string) error { "Date: %s\r\n"+ "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) + auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) to := strings.Split(receiver, ";") - if config.SMTPPort == 465 { - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - ServerName: config.SMTPServer, + if config.SMTPPort == 465 || !shouldAuth() { + // need advanced client + var conn net.Conn + var err error + if config.SMTPPort == 465 { + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + ServerName: config.SMTPServer, + } + conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) + } else { + conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)) } - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) if err != nil { return err } @@ -59,8 +72,10 @@ func SendEmail(subject string, receiver string, content string) error { return err } defer client.Close() - if err = client.Auth(auth); err != nil { - return err + if shouldAuth() { + if err = client.Auth(auth); err != nil { + return err + } } if err = client.Mail(config.SMTPFrom); err != nil { return err From 273be557975b758c4e6ee36165daeab772895b58 Mon Sep 17 00:00:00 2001 From: Leo Q Date: Thu, 4 Jul 2024 08:35:41 +0800 Subject: [PATCH 17/26] feat(ui): show available models for air theme (#1595) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(ui): air 主题显示可用模型 * chore: 改为全角括号 --- web/air/src/components/PersonalSetting.js | 28 +++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/web/air/src/components/PersonalSetting.js b/web/air/src/components/PersonalSetting.js index 45a5b776..ef4acf14 100644 --- a/web/air/src/components/PersonalSetting.js +++ b/web/air/src/components/PersonalSetting.js @@ -47,7 +47,7 @@ const PersonalSetting = () => { const [countdown, setCountdown] = useState(30); const [affLink, setAffLink] = useState(''); const [systemToken, setSystemToken] = useState(''); - // const [models, setModels] = useState([]); + const [models, setModels] = useState([]); const [openTransfer, setOpenTransfer] = useState(false); const [transferAmount, setTransferAmount] = useState(0); @@ -72,7 +72,7 @@ const PersonalSetting = () => { console.log(userState); } ); - // loadModels().then(); + loadModels().then(); getAffLink().then(); setTransferAmount(getQuotaPerUnit()); }, []); @@ -127,16 +127,16 @@ const PersonalSetting = () => { } }; - // const loadModels = async () => { - // let res = await API.get(`/api/user/models`); - // const { success, message, data } = res.data; - // if (success) { - // setModels(data); - // console.log(data); - // } else { - // showError(message); - // } - // }; + const loadModels = async () => { + let res = await API.get(`/api/user/available_models`); + const { success, message, data } = res.data; + if (success) { + setModels(data); + console.log(data); + } else { + showError(message); + } + }; const handleAffLinkClick = async (e) => { e.target.select(); @@ -344,7 +344,7 @@ const PersonalSetting = () => { } > 调用信息 - {/* 可用模型 +

可用模型(可点击复制)

{models.map((model) => ( @@ -355,7 +355,7 @@ const PersonalSetting = () => { ))} -
*/} + {/* Date: Fri, 5 Jul 2024 18:05:16 +0800 Subject: [PATCH 18/26] feat: support test specific model (#1600) --- controller/channel-test.go | 36 ++++++----- web/default/src/components/ChannelsTable.js | 70 +++++++++++++++++---- 2 files changed, 77 insertions(+), 29 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index b8c41819..f8327284 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" @@ -27,15 +28,15 @@ import ( "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - - "github.com/gin-gonic/gin" ) -func buildTestRequest() *relaymodel.GeneralOpenAIRequest { +func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest { + if model == "" { + model = "gpt-3.5-turbo" + } testRequest := &relaymodel.GeneralOpenAIRequest{ MaxTokens: 2, - Stream: false, - Model: "gpt-3.5-turbo", + Model: model, } testMessage := relaymodel.Message{ Role: "user", @@ -45,7 +46,7 @@ func buildTestRequest() *relaymodel.GeneralOpenAIRequest { return testRequest } -func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) { +func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = &http.Request{ @@ -68,12 +69,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } adaptor.Init(meta) - var modelName string - modelList := adaptor.GetModelList() + modelName := request.Model modelMap := channel.GetModelMapping() - if len(modelList) != 0 { - modelName = modelList[0] - } if modelName == "" || !strings.Contains(channel.Models, modelName) { modelNames := strings.Split(channel.Models, ",") if len(modelNames) > 0 { @@ -83,9 +80,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error modelName = modelMap[modelName] } } - request := buildTestRequest() + meta.OriginModelName, meta.ActualModelName = request.Model, modelName request.Model = modelName - meta.OriginModelName, meta.ActualModelName = modelName, modelName convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) if err != nil { return err, nil @@ -139,10 +135,15 @@ func TestChannel(c *gin.Context) { }) return } + model := c.Query("model") + testRequest := buildTestRequest(model) tik := time.Now() - err, _ = testChannel(channel) + err, _ = testChannel(channel, testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() + if err != nil { + milliseconds = 0 + } go channel.UpdateResponseTime(milliseconds) consumedTime := float64(milliseconds) / 1000.0 if err != nil { @@ -150,6 +151,7 @@ func TestChannel(c *gin.Context) { "success": false, "message": err.Error(), "time": consumedTime, + "model": model, }) return } @@ -157,6 +159,7 @@ func TestChannel(c *gin.Context) { "success": true, "message": "", "time": consumedTime, + "model": model, }) return } @@ -187,11 +190,12 @@ func testChannels(notify bool, scope string) error { for _, channel := range channels { isChannelEnabled := channel.Status == model.ChannelStatusEnabled tik := time.Now() - err, openaiErr := testChannel(channel) + testRequest := buildTestRequest("") + err, openaiErr := testChannel(channel, testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() if isChannelEnabled && milliseconds > disableThreshold { - err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) if config.AutomaticDisableChannelEnabled { monitor.DisableChannel(channel.Id, channel.Name, err.Error()) } else { diff --git a/web/default/src/components/ChannelsTable.js b/web/default/src/components/ChannelsTable.js index 1258ca5a..6025b7d9 100644 --- a/web/default/src/components/ChannelsTable.js +++ b/web/default/src/components/ChannelsTable.js @@ -1,5 +1,5 @@ import React, { useEffect, useState } from 'react'; -import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; +import { Button, Dropdown, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; import { Link } from 'react-router-dom'; import { API, @@ -70,13 +70,33 @@ const ChannelsTable = () => { const res = await API.get(`/api/channel/?p=${startIdx}`); const { success, message, data } = res.data; if (success) { - if (startIdx === 0) { - setChannels(data); - } else { - let newChannels = [...channels]; - newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); - setChannels(newChannels); - } + let localChannels = data.map((channel) => { + if (channel.models === '') { + channel.models = []; + channel.test_model = ""; + } else { + channel.models = channel.models.split(','); + if (channel.models.length > 0) { + channel.test_model = channel.models[0]; + } + channel.model_options = channel.models.map((model) => { + return { + key: model, + text: model, + value: model, + } + }) + console.log('channel', channel) + } + return channel; + }); + if (startIdx === 0) { + setChannels(localChannels); + } else { + let newChannels = [...channels]; + newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...localChannels); + setChannels(newChannels); + } } else { showError(message); } @@ -225,19 +245,31 @@ const ChannelsTable = () => { setSearching(false); }; - const testChannel = async (id, name, idx) => { - const res = await API.get(`/api/channel/test/${id}/`); - const { success, message, time } = res.data; + const switchTestModel = async (idx, model) => { + let newChannels = [...channels]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + newChannels[realIdx].test_model = model; + setChannels(newChannels); + }; + + const testChannel = async (id, name, idx, m) => { + const res = await API.get(`/api/channel/test/${id}?model=${m}`); + const { success, message, time, model } = res.data; if (success) { let newChannels = [...channels]; let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; newChannels[realIdx].response_time = time * 1000; newChannels[realIdx].test_time = Date.now() / 1000; setChannels(newChannels); - showInfo(`渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); + showInfo(`渠道 ${name} 测试成功,模型 ${model},耗时 ${time.toFixed(2)} 秒。`); } else { showError(message); } + let newChannels = [...channels]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + newChannels[realIdx].response_time = time * 1000; + newChannels[realIdx].test_time = Date.now() / 1000; + setChannels(newChannels); }; const testChannels = async (scope) => { @@ -405,6 +437,7 @@ const ChannelsTable = () => { > 优先级 + 测试模型 操作 @@ -459,13 +492,24 @@ const ChannelsTable = () => { basic /> + + { + switchTestModel(idx, data.value); + }} + /> +