diff --git a/common/client.go b/common/client.go index 9aa5e3ed..8d81f9e4 100644 --- a/common/client.go +++ b/common/client.go @@ -1,10 +1,13 @@ package common import ( + "bytes" "encoding/json" "fmt" "io" "net/http" + "one-api/types" + "strconv" "time" "github.com/gin-gonic/gin" @@ -79,28 +82,74 @@ func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http return req, nil } -func (c *Client) SendRequest(req *http.Request, response any) error { - +func SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) { // 发送请求 resp, err := HttpClient.Do(req) if err != nil { - return err + return nil, types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) } - defer resp.Body.Close() + if !outputResp { + defer resp.Body.Close() + } // 处理响应 if IsFailureStatusCode(resp) { - return fmt.Errorf("status code: %d", resp.StatusCode) + return nil, HandleErrorResp(resp) } // 解析响应 - err = DecodeResponse(resp.Body, response) + if outputResp { + var buf bytes.Buffer + tee := io.TeeReader(resp.Body, &buf) + err = DecodeResponse(tee, response) + + // 将响应体重新写入 resp.Body + resp.Body = io.NopCloser(&buf) + } else { + err = DecodeResponse(resp.Body, nil) + } if err != nil { - return err + return nil, types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) } - return nil + if outputResp { + return resp, nil + } + + return nil, nil +} + +// 处理错误响应 +func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ + StatusCode: resp.StatusCode, + OpenAIError: types.OpenAIError{ + Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + err = resp.Body.Close() + if err != nil { + return + } + var errorResponse types.OpenAIErrorResponse + err = json.Unmarshal(responseBody, &errorResponse) + if err != nil { + return + } + if errorResponse.Error.Type != "" { + openAIErrorWithStatusCode.OpenAIError = errorResponse.Error + } else { + openAIErrorWithStatusCode.OpenAIError.Message = string(responseBody) + } + return } func (c *Client) SendRequestRaw(req *http.Request) (body io.ReadCloser, err error) { diff --git a/providers/ali/chat.go b/providers/ali/chat.go index 6a8a41ba..721bfe13 100644 --- a/providers/ali/chat.go +++ b/providers/ali/chat.go @@ -116,7 +116,7 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa } else { aliResponse := &AliChatResponse{} - errWithCode = p.SendRequest(req, aliResponse) + errWithCode = p.SendRequest(req, aliResponse, false) if errWithCode != nil { return } @@ -159,7 +159,7 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, } if common.IsFailureStatusCode(resp) { - return nil, p.HandleErrorResp(resp) + return nil, common.HandleErrorResp(resp) } defer resp.Body.Close() diff --git a/providers/ali/embeddings.go b/providers/ali/embeddings.go index b3ce200f..0ee85531 100644 --- a/providers/ali/embeddings.go +++ b/providers/ali/embeddings.go @@ -63,7 +63,7 @@ func (p *AliProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelM } aliEmbeddingResponse := &AliEmbeddingResponse{} - errWithCode = p.SendRequest(req, aliEmbeddingResponse) + errWithCode = p.SendRequest(req, aliEmbeddingResponse, false) if errWithCode != nil { return } diff --git a/providers/baidu/chat.go b/providers/baidu/chat.go index 0cf19751..8fc7dafe 100644 --- a/providers/baidu/chat.go +++ b/providers/baidu/chat.go @@ -95,7 +95,7 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel } else { baiduChatRequest := &BaiduChatResponse{} - errWithCode = p.SendRequest(req, baiduChatRequest) + errWithCode = p.SendRequest(req, baiduChatRequest, false) if errWithCode != nil { return } @@ -132,7 +132,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage } if common.IsFailureStatusCode(resp) { - return nil, p.HandleErrorResp(resp) + return nil, common.HandleErrorResp(resp) } defer resp.Body.Close() diff --git a/providers/baidu/embeddings.go b/providers/baidu/embeddings.go index 9a26e1a5..bdceaf31 100644 --- a/providers/baidu/embeddings.go +++ b/providers/baidu/embeddings.go @@ -59,7 +59,7 @@ func (p *BaiduProvider) EmbeddingsAction(request *types.EmbeddingRequest, isMode } baiduEmbeddingResponse := &BaiduEmbeddingResponse{} - errWithCode = p.SendRequest(req, baiduEmbeddingResponse) + errWithCode = p.SendRequest(req, baiduEmbeddingResponse, false) if errWithCode != nil { return } diff --git a/providers/base/common.go b/providers/base/common.go index ef182493..a7b19104 100644 --- a/providers/base/common.go +++ b/providers/base/common.go @@ -7,7 +7,6 @@ import ( "net/http" "one-api/common" "one-api/types" - "strconv" "strings" "github.com/gin-gonic/gin" @@ -54,42 +53,42 @@ func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) { } // 发送请求 -func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { - // 发送请求 - resp, err := common.HttpClient.Do(req) - if err != nil { - return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) + resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true) + if openAIErrorWithStatusCode != nil { + return } defer resp.Body.Close() - // 处理响应 - if common.IsFailureStatusCode(resp) { - return p.HandleErrorResp(resp) - } - - // 解析响应 - err = common.DecodeResponse(resp.Body, response) - if err != nil { - return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) - } - openAIResponse, openAIErrorWithStatusCode := response.ResponseHandler(resp) if openAIErrorWithStatusCode != nil { return } - jsonResponse, err := json.Marshal(openAIResponse) - if err != nil { - return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) - } - p.Context.Writer.Header().Set("Content-Type", "application/json") - p.Context.Writer.WriteHeader(resp.StatusCode) - _, err = p.Context.Writer.Write(jsonResponse) + if rawOutput { + for k, v := range resp.Header { + p.Context.Writer.Header().Set(k, v[0]) + } - if err != nil { - return types.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError) + p.Context.Writer.WriteHeader(resp.StatusCode) + _, err := io.Copy(p.Context.Writer, resp.Body) + if err != nil { + return types.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + } + } else { + jsonResponse, err := json.Marshal(openAIResponse) + if err != nil { + return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) + } + p.Context.Writer.Header().Set("Content-Type", "application/json") + p.Context.Writer.WriteHeader(resp.StatusCode) + _, err = p.Context.Writer.Write(jsonResponse) + + if err != nil { + return types.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError) + } } return nil @@ -107,7 +106,7 @@ func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusC // 处理响应 if common.IsFailureStatusCode(resp) { - return p.HandleErrorResp(resp) + return common.HandleErrorResp(resp) } for k, v := range resp.Header { @@ -124,38 +123,6 @@ func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusC return nil } -// 处理错误响应 -func (p *BaseProvider) HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { - openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ - StatusCode: resp.StatusCode, - OpenAIError: types.OpenAIError{ - Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, - } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return - } - err = resp.Body.Close() - if err != nil { - return - } - var errorResponse types.OpenAIErrorResponse - err = json.Unmarshal(responseBody, &errorResponse) - if err != nil { - return - } - if errorResponse.Error.Type != "" { - openAIErrorWithStatusCode.OpenAIError = errorResponse.Error - } else { - openAIErrorWithStatusCode.OpenAIError.Message = string(responseBody) - } - return -} - func (p *BaseProvider) SupportAPI(relayMode int) bool { switch relayMode { case common.RelayModeChatCompletions: diff --git a/providers/claude/chat.go b/providers/claude/chat.go index 660db19a..bef595fc 100644 --- a/providers/claude/chat.go +++ b/providers/claude/chat.go @@ -108,7 +108,7 @@ func (p *ClaudeProvider) ChatAction(request *types.ChatCompletionRequest, isMode PromptTokens: promptTokens, }, } - errWithCode = p.SendRequest(req, claudeResponse) + errWithCode = p.SendRequest(req, claudeResponse, false) if errWithCode != nil { return } @@ -141,7 +141,7 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro } if common.IsFailureStatusCode(resp) { - return p.HandleErrorResp(resp), "" + return common.HandleErrorResp(resp), "" } defer resp.Body.Close() diff --git a/providers/closeai/balance.go b/providers/closeai/balance.go index 85b22c02..ae649766 100644 --- a/providers/closeai/balance.go +++ b/providers/closeai/balance.go @@ -1,6 +1,7 @@ package closeai import ( + "errors" "fmt" "one-api/common" "one-api/model" @@ -19,9 +20,9 @@ func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) // 发送请求 var response OpenAICreditGrants - err = client.SendRequest(req, &response) - if err != nil { - return 0, err + _, errWithCode := common.SendRequest(req, &response, false) + if errWithCode != nil { + return 0, errors.New(errWithCode.OpenAIError.Message) } channel.UpdateBalance(response.TotalAvailable) diff --git a/providers/openai/base.go b/providers/openai/base.go index a2c6a008..f8be57c1 100644 --- a/providers/openai/base.go +++ b/providers/openai/base.go @@ -91,50 +91,6 @@ func (p *OpenAIProvider) getRequestBody(request any, isModelMapped bool) (reques return } -// 发送请求 -func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { - - // 发送请求 - resp, err := common.HttpClient.Do(req) - if err != nil { - return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) - } - - defer resp.Body.Close() - - // 处理响应 - if common.IsFailureStatusCode(resp) { - return p.HandleErrorResp(resp) - } - - // 创建一个 bytes.Buffer 来存储响应体 - var buf bytes.Buffer - tee := io.TeeReader(resp.Body, &buf) - - // 解析响应 - err = common.DecodeResponse(tee, response) - if err != nil { - return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) - } - - openAIErrorWithStatusCode = response.responseHandler(resp) - if openAIErrorWithStatusCode != nil { - return - } - - for k, v := range resp.Header { - p.Context.Writer.Header().Set(k, v[0]) - } - - p.Context.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(p.Context.Writer, &buf) - if err != nil { - return types.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - - return nil -} - // 发送流式请求 func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) { @@ -144,7 +100,7 @@ func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIPro } if common.IsFailureStatusCode(resp) { - return p.HandleErrorResp(resp), "" + return common.HandleErrorResp(resp), "" } defer resp.Body.Close() diff --git a/providers/openai/chat.go b/providers/openai/chat.go index a937d004..41905f92 100644 --- a/providers/openai/chat.go +++ b/providers/openai/chat.go @@ -6,7 +6,7 @@ import ( "one-api/types" ) -func (c *OpenAIProviderChatResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) { +func (c *OpenAIProviderChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { if c.Error.Type != "" { errWithCode = &types.OpenAIErrorWithStatusCode{ OpenAIError: c.Error, @@ -14,7 +14,7 @@ func (c *OpenAIProviderChatResponse) responseHandler(resp *http.Response) (errWi } return } - return nil + return nil, nil } func (c *OpenAIProviderChatStreamResponse) responseStreamHandler() (responseText string) { @@ -59,7 +59,7 @@ func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isMode } else { openAIProviderChatResponse := &OpenAIProviderChatResponse{} - errWithCode = p.sendRequest(req, openAIProviderChatResponse) + errWithCode = p.SendRequest(req, openAIProviderChatResponse, true) if errWithCode != nil { return } diff --git a/providers/openai/completion.go b/providers/openai/completion.go index e446c93b..cee0256b 100644 --- a/providers/openai/completion.go +++ b/providers/openai/completion.go @@ -6,7 +6,7 @@ import ( "one-api/types" ) -func (c *OpenAIProviderCompletionResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) { +func (c *OpenAIProviderCompletionResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { if c.Error.Type != "" { errWithCode = &types.OpenAIErrorWithStatusCode{ OpenAIError: c.Error, @@ -14,7 +14,7 @@ func (c *OpenAIProviderCompletionResponse) responseHandler(resp *http.Response) } return } - return nil + return nil, nil } func (c *OpenAIProviderCompletionResponse) responseStreamHandler() (responseText string) { @@ -59,7 +59,7 @@ func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isMode } } else { - errWithCode = p.sendRequest(req, openAIProviderCompletionResponse) + errWithCode = p.SendRequest(req, openAIProviderCompletionResponse, true) if errWithCode != nil { return } diff --git a/providers/openai/embeddings.go b/providers/openai/embeddings.go index 641caa49..6cc48d7f 100644 --- a/providers/openai/embeddings.go +++ b/providers/openai/embeddings.go @@ -6,7 +6,7 @@ import ( "one-api/types" ) -func (c *OpenAIProviderEmbeddingsResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) { +func (c *OpenAIProviderEmbeddingsResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { if c.Error.Type != "" { errWithCode = &types.OpenAIErrorWithStatusCode{ OpenAIError: c.Error, @@ -14,7 +14,7 @@ func (c *OpenAIProviderEmbeddingsResponse) responseHandler(resp *http.Response) } return } - return nil + return nil, nil } func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { @@ -34,7 +34,7 @@ func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isMod } openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{} - errWithCode = p.sendRequest(req, openAIProviderEmbeddingsResponse) + errWithCode = p.SendRequest(req, openAIProviderEmbeddingsResponse, true) if errWithCode != nil { return } diff --git a/providers/openai/moderation.go b/providers/openai/moderation.go index 67df21c1..2eceb12d 100644 --- a/providers/openai/moderation.go +++ b/providers/openai/moderation.go @@ -6,7 +6,7 @@ import ( "one-api/types" ) -func (c *OpenAIProviderModerationResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) { +func (c *OpenAIProviderModerationResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { if c.Error.Type != "" { errWithCode = &types.OpenAIErrorWithStatusCode{ OpenAIError: c.Error, @@ -14,7 +14,7 @@ func (c *OpenAIProviderModerationResponse) responseHandler(resp *http.Response) } return } - return nil + return nil, nil } func (p *OpenAIProvider) ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { @@ -34,7 +34,7 @@ func (p *OpenAIProvider) ModerationAction(request *types.ModerationRequest, isMo } openAIProviderModerationResponse := &OpenAIProviderModerationResponse{} - errWithCode = p.sendRequest(req, openAIProviderModerationResponse) + errWithCode = p.SendRequest(req, openAIProviderModerationResponse, true) if errWithCode != nil { return } diff --git a/providers/openaisb/balance.go b/providers/openaisb/balance.go index 72ea530e..8a789d44 100644 --- a/providers/openaisb/balance.go +++ b/providers/openaisb/balance.go @@ -21,9 +21,9 @@ func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) { // 发送请求 var response OpenAISBUsageResponse - err = client.SendRequest(req, &response) + _, errWithCode := common.SendRequest(req, &response, false) if err != nil { - return 0, err + return 0, errors.New(errWithCode.OpenAIError.Message) } if response.Data == nil { diff --git a/providers/palm/chat.go b/providers/palm/chat.go index 3159bdf0..158e8aad 100644 --- a/providers/palm/chat.go +++ b/providers/palm/chat.go @@ -103,7 +103,7 @@ func (p *PalmProvider) ChatAction(request *types.ChatCompletionRequest, isModelM PromptTokens: promptTokens, }, } - errWithCode = p.SendRequest(req, palmChatResponse) + errWithCode = p.SendRequest(req, palmChatResponse, false) if errWithCode != nil { return } @@ -135,7 +135,7 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW } if common.IsFailureStatusCode(resp) { - return p.HandleErrorResp(resp), "" + return common.HandleErrorResp(resp), "" } defer resp.Body.Close() diff --git a/providers/tencent/chat.go b/providers/tencent/chat.go index 52608630..46e05f8c 100644 --- a/providers/tencent/chat.go +++ b/providers/tencent/chat.go @@ -111,7 +111,7 @@ func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isMod } else { tencentResponse := &TencentChatResponse{} - errWithCode = p.SendRequest(req, tencentResponse) + errWithCode = p.SendRequest(req, tencentResponse, false) if errWithCode != nil { return } @@ -147,7 +147,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr } if common.IsFailureStatusCode(resp) { - return p.HandleErrorResp(resp), "" + return common.HandleErrorResp(resp), "" } defer resp.Body.Close() diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go index c4d24509..7254effe 100644 --- a/providers/zhipu/chat.go +++ b/providers/zhipu/chat.go @@ -101,7 +101,7 @@ func (p *ZhipuProvider) ChatAction(request *types.ChatCompletionRequest, isModel } else { zhipuResponse := &ZhipuResponse{} - errWithCode = p.SendRequest(req, zhipuResponse) + errWithCode = p.SendRequest(req, zhipuResponse, false) if errWithCode != nil { return } @@ -146,7 +146,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError } if common.IsFailureStatusCode(resp) { - return p.HandleErrorResp(resp), nil + return common.HandleErrorResp(resp), nil } defer resp.Body.Close()