♻️ refactor: 重构http请求函数

This commit is contained in:
Martial BE 2023-11-30 13:49:35 +08:00
parent 96dc7614e6
commit 7c6dee7390
No known key found for this signature in database
GPG Key ID: D06C32DF0EDB9084
17 changed files with 116 additions and 143 deletions

View File

@ -1,10 +1,13 @@
package common
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/types"
"strconv"
"time"
"github.com/gin-gonic/gin"
@ -79,28 +82,74 @@ func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http
return req, nil
}
func (c *Client) SendRequest(req *http.Request, response any) error {
func SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
// 发送请求
resp, err := HttpClient.Do(req)
if err != nil {
return err
return nil, types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
defer resp.Body.Close()
if !outputResp {
defer resp.Body.Close()
}
// 处理响应
if IsFailureStatusCode(resp) {
return fmt.Errorf("status code: %d", resp.StatusCode)
return nil, HandleErrorResp(resp)
}
// 解析响应
err = DecodeResponse(resp.Body, response)
if outputResp {
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
err = DecodeResponse(tee, response)
// 将响应体重新写入 resp.Body
resp.Body = io.NopCloser(&buf)
} else {
err = DecodeResponse(resp.Body, nil)
}
if err != nil {
return err
return nil, types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
return nil
if outputResp {
return resp, nil
}
return nil, nil
}
// 处理错误响应
func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: types.OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
var errorResponse types.OpenAIErrorResponse
err = json.Unmarshal(responseBody, &errorResponse)
if err != nil {
return
}
if errorResponse.Error.Type != "" {
openAIErrorWithStatusCode.OpenAIError = errorResponse.Error
} else {
openAIErrorWithStatusCode.OpenAIError.Message = string(responseBody)
}
return
}
func (c *Client) SendRequestRaw(req *http.Request) (body io.ReadCloser, err error) {

View File

@ -116,7 +116,7 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa
} else {
aliResponse := &AliChatResponse{}
errWithCode = p.SendRequest(req, aliResponse)
errWithCode = p.SendRequest(req, aliResponse, false)
if errWithCode != nil {
return
}
@ -159,7 +159,7 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage,
}
if common.IsFailureStatusCode(resp) {
return nil, p.HandleErrorResp(resp)
return nil, common.HandleErrorResp(resp)
}
defer resp.Body.Close()

View File

@ -63,7 +63,7 @@ func (p *AliProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelM
}
aliEmbeddingResponse := &AliEmbeddingResponse{}
errWithCode = p.SendRequest(req, aliEmbeddingResponse)
errWithCode = p.SendRequest(req, aliEmbeddingResponse, false)
if errWithCode != nil {
return
}

View File

@ -95,7 +95,7 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
} else {
baiduChatRequest := &BaiduChatResponse{}
errWithCode = p.SendRequest(req, baiduChatRequest)
errWithCode = p.SendRequest(req, baiduChatRequest, false)
if errWithCode != nil {
return
}
@ -132,7 +132,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage
}
if common.IsFailureStatusCode(resp) {
return nil, p.HandleErrorResp(resp)
return nil, common.HandleErrorResp(resp)
}
defer resp.Body.Close()

View File

@ -59,7 +59,7 @@ func (p *BaiduProvider) EmbeddingsAction(request *types.EmbeddingRequest, isMode
}
baiduEmbeddingResponse := &BaiduEmbeddingResponse{}
errWithCode = p.SendRequest(req, baiduEmbeddingResponse)
errWithCode = p.SendRequest(req, baiduEmbeddingResponse, false)
if errWithCode != nil {
return
}

View File

@ -7,7 +7,6 @@ import (
"net/http"
"one-api/common"
"one-api/types"
"strconv"
"strings"
"github.com/gin-gonic/gin"
@ -54,42 +53,42 @@ func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) {
}
// 发送请求
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true)
if openAIErrorWithStatusCode != nil {
return
}
defer resp.Body.Close()
// 处理响应
if common.IsFailureStatusCode(resp) {
return p.HandleErrorResp(resp)
}
// 解析响应
err = common.DecodeResponse(resp.Body, response)
if err != nil {
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
openAIResponse, openAIErrorWithStatusCode := response.ResponseHandler(resp)
if openAIErrorWithStatusCode != nil {
return
}
jsonResponse, err := json.Marshal(openAIResponse)
if err != nil {
return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
p.Context.Writer.Header().Set("Content-Type", "application/json")
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err = p.Context.Writer.Write(jsonResponse)
if rawOutput {
for k, v := range resp.Header {
p.Context.Writer.Header().Set(k, v[0])
}
if err != nil {
return types.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err := io.Copy(p.Context.Writer, resp.Body)
if err != nil {
return types.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
} else {
jsonResponse, err := json.Marshal(openAIResponse)
if err != nil {
return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
p.Context.Writer.Header().Set("Content-Type", "application/json")
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err = p.Context.Writer.Write(jsonResponse)
if err != nil {
return types.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
}
return nil
@ -107,7 +106,7 @@ func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusC
// 处理响应
if common.IsFailureStatusCode(resp) {
return p.HandleErrorResp(resp)
return common.HandleErrorResp(resp)
}
for k, v := range resp.Header {
@ -124,38 +123,6 @@ func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusC
return nil
}
// 处理错误响应
func (p *BaseProvider) HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: types.OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
var errorResponse types.OpenAIErrorResponse
err = json.Unmarshal(responseBody, &errorResponse)
if err != nil {
return
}
if errorResponse.Error.Type != "" {
openAIErrorWithStatusCode.OpenAIError = errorResponse.Error
} else {
openAIErrorWithStatusCode.OpenAIError.Message = string(responseBody)
}
return
}
func (p *BaseProvider) SupportAPI(relayMode int) bool {
switch relayMode {
case common.RelayModeChatCompletions:

View File

@ -108,7 +108,7 @@ func (p *ClaudeProvider) ChatAction(request *types.ChatCompletionRequest, isMode
PromptTokens: promptTokens,
},
}
errWithCode = p.SendRequest(req, claudeResponse)
errWithCode = p.SendRequest(req, claudeResponse, false)
if errWithCode != nil {
return
}
@ -141,7 +141,7 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
}
if common.IsFailureStatusCode(resp) {
return p.HandleErrorResp(resp), ""
return common.HandleErrorResp(resp), ""
}
defer resp.Body.Close()

View File

@ -1,6 +1,7 @@
package closeai
import (
"errors"
"fmt"
"one-api/common"
"one-api/model"
@ -19,9 +20,9 @@ func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error)
// 发送请求
var response OpenAICreditGrants
err = client.SendRequest(req, &response)
if err != nil {
return 0, err
_, errWithCode := common.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
channel.UpdateBalance(response.TotalAvailable)

View File

@ -91,50 +91,6 @@ func (p *OpenAIProvider) getRequestBody(request any, isModelMapped bool) (reques
return
}
// 发送请求
func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
defer resp.Body.Close()
// 处理响应
if common.IsFailureStatusCode(resp) {
return p.HandleErrorResp(resp)
}
// 创建一个 bytes.Buffer 来存储响应体
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
// 解析响应
err = common.DecodeResponse(tee, response)
if err != nil {
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
openAIErrorWithStatusCode = response.responseHandler(resp)
if openAIErrorWithStatusCode != nil {
return
}
for k, v := range resp.Header {
p.Context.Writer.Header().Set(k, v[0])
}
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(p.Context.Writer, &buf)
if err != nil {
return types.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
return nil
}
// 发送流式请求
func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
@ -144,7 +100,7 @@ func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIPro
}
if common.IsFailureStatusCode(resp) {
return p.HandleErrorResp(resp), ""
return common.HandleErrorResp(resp), ""
}
defer resp.Body.Close()

View File

@ -6,7 +6,7 @@ import (
"one-api/types"
)
func (c *OpenAIProviderChatResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
func (c *OpenAIProviderChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
@ -14,7 +14,7 @@ func (c *OpenAIProviderChatResponse) responseHandler(resp *http.Response) (errWi
}
return
}
return nil
return nil, nil
}
func (c *OpenAIProviderChatStreamResponse) responseStreamHandler() (responseText string) {
@ -59,7 +59,7 @@ func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isMode
} else {
openAIProviderChatResponse := &OpenAIProviderChatResponse{}
errWithCode = p.sendRequest(req, openAIProviderChatResponse)
errWithCode = p.SendRequest(req, openAIProviderChatResponse, true)
if errWithCode != nil {
return
}

View File

@ -6,7 +6,7 @@ import (
"one-api/types"
)
func (c *OpenAIProviderCompletionResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
func (c *OpenAIProviderCompletionResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
@ -14,7 +14,7 @@ func (c *OpenAIProviderCompletionResponse) responseHandler(resp *http.Response)
}
return
}
return nil
return nil, nil
}
func (c *OpenAIProviderCompletionResponse) responseStreamHandler() (responseText string) {
@ -59,7 +59,7 @@ func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isMode
}
} else {
errWithCode = p.sendRequest(req, openAIProviderCompletionResponse)
errWithCode = p.SendRequest(req, openAIProviderCompletionResponse, true)
if errWithCode != nil {
return
}

View File

@ -6,7 +6,7 @@ import (
"one-api/types"
)
func (c *OpenAIProviderEmbeddingsResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
func (c *OpenAIProviderEmbeddingsResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
@ -14,7 +14,7 @@ func (c *OpenAIProviderEmbeddingsResponse) responseHandler(resp *http.Response)
}
return
}
return nil
return nil, nil
}
func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
@ -34,7 +34,7 @@ func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isMod
}
openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{}
errWithCode = p.sendRequest(req, openAIProviderEmbeddingsResponse)
errWithCode = p.SendRequest(req, openAIProviderEmbeddingsResponse, true)
if errWithCode != nil {
return
}

View File

@ -6,7 +6,7 @@ import (
"one-api/types"
)
func (c *OpenAIProviderModerationResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
func (c *OpenAIProviderModerationResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
@ -14,7 +14,7 @@ func (c *OpenAIProviderModerationResponse) responseHandler(resp *http.Response)
}
return
}
return nil
return nil, nil
}
func (p *OpenAIProvider) ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
@ -34,7 +34,7 @@ func (p *OpenAIProvider) ModerationAction(request *types.ModerationRequest, isMo
}
openAIProviderModerationResponse := &OpenAIProviderModerationResponse{}
errWithCode = p.sendRequest(req, openAIProviderModerationResponse)
errWithCode = p.SendRequest(req, openAIProviderModerationResponse, true)
if errWithCode != nil {
return
}

View File

@ -21,9 +21,9 @@ func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
// 发送请求
var response OpenAISBUsageResponse
err = client.SendRequest(req, &response)
_, errWithCode := common.SendRequest(req, &response, false)
if err != nil {
return 0, err
return 0, errors.New(errWithCode.OpenAIError.Message)
}
if response.Data == nil {

View File

@ -103,7 +103,7 @@ func (p *PalmProvider) ChatAction(request *types.ChatCompletionRequest, isModelM
PromptTokens: promptTokens,
},
}
errWithCode = p.SendRequest(req, palmChatResponse)
errWithCode = p.SendRequest(req, palmChatResponse, false)
if errWithCode != nil {
return
}
@ -135,7 +135,7 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW
}
if common.IsFailureStatusCode(resp) {
return p.HandleErrorResp(resp), ""
return common.HandleErrorResp(resp), ""
}
defer resp.Body.Close()

View File

@ -111,7 +111,7 @@ func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isMod
} else {
tencentResponse := &TencentChatResponse{}
errWithCode = p.SendRequest(req, tencentResponse)
errWithCode = p.SendRequest(req, tencentResponse, false)
if errWithCode != nil {
return
}
@ -147,7 +147,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr
}
if common.IsFailureStatusCode(resp) {
return p.HandleErrorResp(resp), ""
return common.HandleErrorResp(resp), ""
}
defer resp.Body.Close()

View File

@ -101,7 +101,7 @@ func (p *ZhipuProvider) ChatAction(request *types.ChatCompletionRequest, isModel
} else {
zhipuResponse := &ZhipuResponse{}
errWithCode = p.SendRequest(req, zhipuResponse)
errWithCode = p.SendRequest(req, zhipuResponse, false)
if errWithCode != nil {
return
}
@ -146,7 +146,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
}
if common.IsFailureStatusCode(resp) {
return p.HandleErrorResp(resp), nil
return common.HandleErrorResp(resp), nil
}
defer resp.Body.Close()