From 436a81f1f6a543bb8c1fc2d2b499c6c1ed72541f Mon Sep 17 00:00:00 2001 From: wood Date: Mon, 11 Dec 2023 01:01:34 +0800 Subject: [PATCH] 1 --- common/image/image.go | 47 ++++++++ common/image/image_test.go | 154 +++++++++++++++++++++++++++ controller/relay-audio.go | 85 +++++++++++++-- controller/relay-text.go | 3 + controller/relay-utils.go | 105 +++++++++++++++++- controller/relay.go | 23 +++- go.mod | 14 ++- go.sum | 10 +- middleware/distributor.go | 2 + middleware/recover.go | 26 +++++ router/relay-router.go | 2 +- web/src/pages/Channel/EditChannel.js | 14 +++ 12 files changed, 459 insertions(+), 26 deletions(-) create mode 100644 common/image/image.go create mode 100644 common/image/image_test.go create mode 100644 middleware/recover.go diff --git a/common/image/image.go b/common/image/image.go new file mode 100644 index 00000000..cbb656ad --- /dev/null +++ b/common/image/image.go @@ -0,0 +1,47 @@ +package image + +import ( + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "net/http" + "regexp" + "strings" + + _ "golang.org/x/image/webp" +) + +func GetImageSizeFromUrl(url string) (width int, height int, err error) { + resp, err := http.Get(url) + if err != nil { + return + } + defer resp.Body.Close() + img, _, err := image.DecodeConfig(resp.Body) + if err != nil { + return + } + return img.Width, img.Height, nil +} + +var ( + reg = regexp.MustCompile(`data:image/([^;]+);base64,`) +) + +func GetImageSizeFromBase64(encoded string) (width int, height int, err error) { + encoded = strings.TrimPrefix(encoded, "data:image/png;base64,") + base64 := strings.NewReader(reg.ReplaceAllString(encoded, "")) + img, _, err := image.DecodeConfig(base64) + if err != nil { + return + } + return img.Width, img.Height, nil +} + +func GetImageSize(image string) (width int, height int, err error) { + if strings.HasPrefix(image, "data:image/") { + return GetImageSizeFromBase64(image) + } + return GetImageSizeFromUrl(image) +} diff --git a/common/image/image_test.go b/common/image/image_test.go new file mode 100644 index 00000000..366eda6e --- /dev/null +++ b/common/image/image_test.go @@ -0,0 +1,154 @@ +package image_test + +import ( + "encoding/base64" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "io" + "net/http" + "strconv" + "strings" + "testing" + + img "one-api/common/image" + + "github.com/stretchr/testify/assert" + _ "golang.org/x/image/webp" +) + +type CountingReader struct { + reader io.Reader + BytesRead int +} + +func (r *CountingReader) Read(p []byte) (n int, err error) { + n, err = r.reader.Read(p) + r.BytesRead += n + return n, err +} + +var ( + cases = []struct { + url string + format string + width int + height int + }{ + {"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669}, + {"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592}, + {"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985}, + {"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533}, + {"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230}, + } +) + +func TestDecode(t *testing.T) { + // Bytes read: varies sometimes + // jpeg: 1063892 + // png: 294462 + // webp: 99529 + // gif: 956153 + // jpeg#01: 32805 + for _, c := range cases { + t.Run("Decode:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + reader := &CountingReader{reader: resp.Body} + img, format, err := image.Decode(reader) + assert.NoError(t, err) + size := img.Bounds().Size() + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, size.X) + assert.Equal(t, c.height, size.Y) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } + + // Bytes read: + // jpeg: 4096 + // png: 4096 + // webp: 4096 + // gif: 4096 + // jpeg#01: 4096 + for _, c := range cases { + t.Run("DecodeConfig:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + reader := &CountingReader{reader: resp.Body} + config, format, err := image.DecodeConfig(reader) + assert.NoError(t, err) + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, config.Width) + assert.Equal(t, c.height, config.Height) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } +} + +func TestBase64(t *testing.T) { + // Bytes read: + // jpeg: 1063892 + // png: 294462 + // webp: 99072 + // gif: 953856 + // jpeg#01: 32805 + for _, c := range cases { + t.Run("Decode:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + reader := &CountingReader{reader: body} + img, format, err := image.Decode(reader) + assert.NoError(t, err) + size := img.Bounds().Size() + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, size.X) + assert.Equal(t, c.height, size.Y) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } + + // Bytes read: + // jpeg: 1536 + // png: 768 + // webp: 768 + // gif: 1536 + // jpeg#01: 3840 + for _, c := range cases { + t.Run("DecodeConfig:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + reader := &CountingReader{reader: body} + config, format, err := image.DecodeConfig(reader) + assert.NoError(t, err) + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, config.Width) + assert.Equal(t, c.height, config.Height) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } +} + +func TestGetImageSize(t *testing.T) { + for i, c := range cases { + t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { + width, height, err := img.GetImageSize(c.url) + assert.NoError(t, err) + assert.Equal(t, c.width, width) + assert.Equal(t, c.height, height) + }) + } +} diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 9e78dadc..2247f4c7 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -1,6 +1,7 @@ package controller import ( + "bufio" "bytes" "context" "encoding/json" @@ -102,7 +103,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) } - requestBody := c.Request.Body + requestBody := &bytes.Buffer{} + _, err = io.Copy(requestBody, c.Request.Body) + if err != nil { + return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) + responseFormat := c.DefaultPostForm("response_format", "json") req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -144,12 +151,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - var whisperResponse WhisperResponse - err = json.Unmarshal(responseBody, &whisperResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + + var openAIErr TextResponse + if err = json.Unmarshal(responseBody, &openAIErr); err == nil { + if openAIErr.Error.Message != "" { + return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) + } } - quota = countTokenText(whisperResponse.Text, audioModel) + + var text string + switch responseFormat { + case "json": + text, err = getTextFromJSON(responseBody) + case "text": + text, err = getTextFromText(responseBody) + case "srt": + text, err = getTextFromSRT(responseBody) + case "verbose_json": + text, err = getTextFromVerboseJSON(responseBody) + case "vtt": + text, err = getTextFromVTT(responseBody) + default: + return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) + } + if err != nil { + return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) + } + quota = countTokenText(text, audioModel) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { @@ -187,3 +215,48 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } return nil } + +func getTextFromVTT(body []byte) (string, error) { + return getTextFromSRT(body) +} + +func getTextFromVerboseJSON(body []byte) (string, error) { + var whisperResponse WhisperVerboseJSONResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} + +func getTextFromSRT(body []byte) (string, error) { + scanner := bufio.NewScanner(strings.NewReader(string(body))) + var builder strings.Builder + var textLine bool + for scanner.Scan() { + line := scanner.Text() + if textLine { + builder.WriteString(line) + textLine = false + continue + } else if strings.Contains(line, "-->") { + textLine = true + continue + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return builder.String(), nil +} + +func getTextFromText(body []byte) (string, error) { + return strings.TrimSuffix(string(body), "\n"), nil +} + +func getTextFromJSON(body []byte) (string, error) { + var whisperResponse WhisperJSONResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 97f43604..24b5e9a7 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -360,6 +360,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.Stream { req.Header.Set("X-DashScope-SSE", "enable") } + if c.GetString("plugin") != "" { + req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) + } case APITypeTencent: req.Header.Set("Authorization", apiKey) case APITypePaLM: diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 3031b084..1d074b33 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -3,10 +3,13 @@ package controller import ( "context" "encoding/json" + "errors" "fmt" "io" + "math" "net/http" "one-api/common" + "one-api/common/image" "one-api/model" "strconv" "strings" @@ -87,7 +90,33 @@ func countTokenMessages(messages []Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.StringContent()) + switch v := message.Content.(type) { + case string: + tokenNum += getTokenNum(tokenEncoder, v) + case []any: + for _, it := range v { + m := it.(map[string]any) + switch m["type"] { + case "text": + tokenNum += getTokenNum(tokenEncoder, m["text"].(string)) + case "image_url": + imageUrl, ok := m["image_url"].(map[string]any) + if ok { + url := imageUrl["url"].(string) + detail := "" + if imageUrl["detail"] != nil { + detail = imageUrl["detail"].(string) + } + imageTokens, err := countImageTokens(url, detail) + if err != nil { + common.SysError("error counting image tokens: " + err.Error()) + } else { + tokenNum += imageTokens + } + } + } + } + } tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { tokenNum += tokensPerName @@ -98,13 +127,81 @@ func countTokenMessages(messages []Message, model string) int { return tokenNum } +const ( + lowDetailCost = 85 + highDetailCostPerTile = 170 + additionalCost = 85 +) + +// https://platform.openai.com/docs/guides/vision/calculating-costs +// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb +func countImageTokens(url string, detail string) (_ int, err error) { + var fetchSize = true + var width, height int + // Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding + // detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting. + // According to the official guide, "low" disable the high-res model, + // and only receive low-res 512px x 512px version of the image, indicating + // that image is treated as low-res when size is smaller than 512px x 512px, + // then we can assume that image size larger than 512px x 512px is treated + // as high-res. Then we have the following logic: + // if detail == "" || detail == "auto" { + // width, height, err = image.GetImageSize(url) + // if err != nil { + // return 0, err + // } + // fetchSize = false + // // not sure if this is correct + // if width > 512 || height > 512 { + // detail = "high" + // } else { + // detail = "low" + // } + // } + + // However, in my test, it seems to be always the same as "high". + // The following image, which is 125x50, is still treated as high-res, taken + // 255 tokens in the response of non-stream chat completion api. + // https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg + if detail == "" || detail == "auto" { + // assume by test, not sure if this is correct + detail = "high" + } + switch detail { + case "low": + return lowDetailCost, nil + case "high": + if fetchSize { + width, height, err = image.GetImageSize(url) + if err != nil { + return 0, err + } + } + if width > 2048 || height > 2048 { // max(width, height) > 2048 + ratio := float64(2048) / math.Max(float64(width), float64(height)) + width = int(float64(width) * ratio) + height = int(float64(height) * ratio) + } + if width > 768 && height > 768 { // min(width, height) > 768 + ratio := float64(768) / math.Min(float64(width), float64(height)) + width = int(float64(width) * ratio) + height = int(float64(height) * ratio) + } + numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512)) + result := numSquares*highDetailCostPerTile + additionalCost + return result, nil + default: + return 0, errors.New("invalid detail option") + } +} + func countTokenInput(input any, model string) int { - switch input.(type) { + switch v := input.(type) { case string: - return countTokenText(input.(string), model) + return countTokenText(v, model) case []string: text := "" - for _, s := range input.([]string) { + for _, s := range v { text += s } return countTokenText(text, model) diff --git a/controller/relay.go b/controller/relay.go index 86c17a9a..bb734eef 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -141,10 +141,31 @@ type ImageRequest struct { User string `json:"user,omitempty"` } -type WhisperResponse struct { +type WhisperJSONResponse struct { Text string `json:"text,omitempty"` } +type WhisperVerboseJSONResponse struct { + Task string `json:"task,omitempty"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` + Text string `json:"text,omitempty"` + Segments []Segment `json:"segments,omitempty"` +} + +type Segment struct { + Id int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} + type TextToSpeechRequest struct { Model string `json:"model" binding:"required"` Input string `json:"input" binding:"required"` diff --git a/go.mod b/go.mod index e4756efd..99e90345 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,9 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 github.com/pkoukk/tiktoken-go v0.1.5 + github.com/stretchr/testify v1.8.3 golang.org/x/crypto v0.14.0 + golang.org/x/image v0.14.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.5.2 gorm.io/driver/sqlite v1.4.3 @@ -27,6 +29,7 @@ require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.10.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect @@ -52,20 +55,15 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect - github.com/sergi/go-diff v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/exp/typeparams v0.0.0-20221212164502-fae10dda9338 // indirect golang.org/x/mod v0.14.0 // indirect golang.org/x/net v0.17.0 // indirect - golang.org/x/sync v0.4.0 // indirect - golang.org/x/sys v0.14.0 // indirect - golang.org/x/telemetry v0.0.0-20231114163143-69313e640400 // indirect - golang.org/x/text v0.13.0 // indirect - golang.org/x/tools v0.14.1-0.20231114185516-c9d3e7de13fd // indirect - golang.org/x/tools/gopls v0.14.2 // indirect - golang.org/x/vuln v1.0.1 // indirect + golang.org/x/sys v0.13.0 // indirect + golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect honnef.co/go/tools v0.4.5 // indirect diff --git a/go.sum b/go.sum index dc30a8d0..d19e80c4 100644 --- a/go.sum +++ b/go.sum @@ -159,10 +159,8 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/exp/typeparams v0.0.0-20221212164502-fae10dda9338 h1:2O2DON6y3XMJiQRAS1UWU+54aec2uopH3x7MAiqGW6Y= -golang.org/x/exp/typeparams v0.0.0-20221212164502-fae10dda9338/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= -golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= -golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= +golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= @@ -186,8 +184,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.14.1-0.20231114185516-c9d3e7de13fd h1:Oku7E+OCrXHyst1dG1z10etCTxewCHXNFLRlyMPbh3w= golang.org/x/tools v0.14.1-0.20231114185516-c9d3e7de13fd/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg= diff --git a/middleware/distributor.go b/middleware/distributor.go index c4ddc3a0..8be986c9 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -89,6 +89,8 @@ func Distribute() func(c *gin.Context) { c.Set("api_version", channel.Other) case common.ChannelTypeAIProxyLibrary: c.Set("library_id", channel.Other) + case common.ChannelTypeAli: + c.Set("plugin", channel.Other) } c.Next() } diff --git a/middleware/recover.go b/middleware/recover.go new file mode 100644 index 00000000..c3a3d748 --- /dev/null +++ b/middleware/recover.go @@ -0,0 +1,26 @@ +package middleware + +import ( + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" +) + +func RelayPanicRecover() gin.HandlerFunc { + return func(c *gin.Context) { + defer func() { + if err := recover(); err != nil { + common.SysError(fmt.Sprintf("panic detected: %v", err)) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), + "type": "one_api_panic", + }, + }) + c.Abort() + } + }() + c.Next() + } +} diff --git a/router/relay-router.go b/router/relay-router.go index 24edc9a9..56ab9b28 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter.GET("/:model", controller.RetrieveModel) } relayV1Router := router.Group("/v1") - relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) + relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute()) { relayV1Router.POST("/completions", controller.Relay) relayV1Router.POST("/chat/completions", controller.Relay) diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 0f4d99de..252fcff7 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -357,6 +357,20 @@ const EditChannel = () => { ) } + { + inputs.type === 17 && ( + + + + ) + }