From fe94656260ac0c8d70b8b9de0d048bea453088b5 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 31 Aug 2023 00:44:16 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=A4=9A=E8=B7=AF=E5=A4=8D?= =?UTF-8?q?=E7=94=A8bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/midjourney.go | 4 ++++ controller/relay-mj.go | 6 +++++- controller/relay-text.go | 19 ++++++++++++++++++- controller/relay.go | 4 ++-- middleware/auth.go | 17 +++++++++++++---- 5 files changed, 42 insertions(+), 8 deletions(-) diff --git a/controller/midjourney.go b/controller/midjourney.go index 567cf736..ff7f5115 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -37,22 +37,26 @@ func UpdateMidjourneyTask() { jsonStr, err := json.Marshal(requestBody) if err != nil { log.Printf("UpdateMidjourneyTask: %v", err) + continue } req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(jsonStr)) if err != nil { log.Printf("UpdateMidjourneyTask: %v", err) + continue } req.Header.Set("Content-Type", "application/json") req.Header.Set("mj-api-secret", "uhiftyuwadbkjshbiklahcuitguasguzhxliawodawdu") resp, err := httpClient.Do(req) if err != nil { log.Printf("UpdateMidjourneyTask: %v", err) + continue } defer resp.Body.Close() var response []Midjourney err = json.NewDecoder(resp.Body).Decode(&response) if err != nil { log.Printf("UpdateMidjourneyTask: %v", err) + continue } for _, responseItem := range response { var midjourneyTask *model.Midjourney diff --git a/controller/relay-mj.go b/controller/relay-mj.go index 8989f036..e1f6a2cd 100644 --- a/controller/relay-mj.go +++ b/controller/relay-mj.go @@ -248,6 +248,10 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) + //mjToken := "" + //if c.Request.Header.Get("Authorization") != "" { + // mjToken = strings.Split(c.Request.Header.Get("Authorization"), " ")[1] + //} req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) // print request header log.Printf("request header: %s", req.Header) @@ -353,7 +357,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { Progress: "0%", FailReason: "", } - if midjResponse.Code == 4 { + if midjResponse.Code == 4 || midjResponse.Code == 24 { midjourneyTask.FailReason = midjResponse.Description } err = midjourneyTask.Insert() diff --git a/controller/relay-text.go b/controller/relay-text.go index deb9efaf..841bb4c3 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "io" + "log" "net/http" "one-api/common" "one-api/model" @@ -278,6 +279,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if apiType != APITypeXunfei { // cause xunfei use websocket req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + // 设置GetBody函数,该函数返回一个新的io.ReadCloser,该io.ReadCloser返回与原始请求体相同的数据 + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(requestBody), nil + } if err != nil { return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) } @@ -308,7 +313,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) + //req.Header.Set("Connection", c.Request.Header.Get("Connection")) + req.Close = true resp, err = httpClient.Do(req) if err != nil { return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) @@ -324,8 +331,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") if resp.StatusCode != http.StatusOK { + //print resp body + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Println("read resp err body failed", err) + } + log.Println("resp body:", string(body)) + errStr := fmt.Sprintf("bad status code: %d", resp.StatusCode) + if resp.StatusCode == 503 { + errStr = string(body) + } return errorWrapper( - fmt.Errorf("bad status code: %d", resp.StatusCode), "bad_status_code", resp.StatusCode) + fmt.Errorf(errStr), "bad_status_code", resp.StatusCode) } } diff --git a/controller/relay.go b/controller/relay.go index ce9f986c..d5498f9c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -205,7 +205,7 @@ func Relay(c *gin.Context) { }) } channelId := c.GetInt("channel_id") - common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) + common.SysError(fmt.Sprintf("relay error (channel #%d): %v ", channelId, err)) // https://platform.openai.com/docs/guides/error-codes/api-errors if shouldDisableChannel(&err.OpenAIError) { channelId := c.GetInt("channel_id") @@ -259,7 +259,7 @@ func RelayMidjourney(c *gin.Context) { // channelId := c.GetInt("channel_id") // channelName := c.GetString("channel_name") // disableChannel(channelId, channelName, err.Result) - //} + //};'''''''''''''''''''''''''''''''' } } diff --git a/middleware/auth.go b/middleware/auth.go index 060e005c..8686eae2 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -85,10 +85,19 @@ func RootAuth() func(c *gin.Context) { func TokenAuth() func(c *gin.Context) { return func(c *gin.Context) { key := c.Request.Header.Get("Authorization") - key = strings.TrimPrefix(key, "Bearer ") - key = strings.TrimPrefix(key, "sk-") - parts := strings.Split(key, "-") - key = parts[0] + parts := make([]string, 0) + if key == "" { + key = c.Request.Header.Get("mj-api-secret") + key = strings.TrimPrefix(key, "Bearer ") + key = strings.TrimPrefix(key, "sk-") + parts := strings.Split(key, "-") + key = parts[0] + } else { + key = strings.TrimPrefix(key, "Bearer ") + key = strings.TrimPrefix(key, "sk-") + parts := strings.Split(key, "-") + key = parts[0] + } token, err := model.ValidateUserToken(key) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{