From fb24d024a7d339ee2eac8747830cbfe44f8a4d32 Mon Sep 17 00:00:00 2001 From: Martial BE Date: Tue, 26 Dec 2023 18:42:39 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20add=20channel=20proxy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/client.go | 64 ++++++++++++++++---- go.mod | 2 +- go.sum | 2 + model/channel.go | 1 + providers/aigc2d/balance.go | 2 +- providers/aiproxy/balance.go | 2 +- providers/ali/chat.go | 4 +- providers/api2d/balance.go | 2 +- providers/api2gpt/balance.go | 2 +- providers/azure/image_generations.go | 3 +- providers/azure/type.go | 1 + providers/baidu/base.go | 4 +- providers/baidu/chat.go | 4 +- providers/base/common.go | 6 +- providers/claude/chat.go | 4 +- providers/closeai/balance.go | 2 +- providers/gemini/chat.go | 4 +- providers/openai/balance.go | 4 +- providers/openai/base.go | 4 +- providers/openaisb/balance.go | 2 +- providers/palm/chat.go | 4 +- providers/tencent/chat.go | 4 +- providers/zhipu/chat.go | 4 +- web/src/views/Channel/component/EditModal.js | 22 +++++++ web/src/views/Channel/component/NameLabel.js | 53 ++++++++++++++++ web/src/views/Channel/component/TableRow.js | 5 +- web/src/views/Channel/type/Config.js | 3 + 27 files changed, 181 insertions(+), 33 deletions(-) create mode 100644 web/src/views/Channel/component/NameLabel.js diff --git a/common/client.go b/common/client.go index 91903ac0..3e94e92b 100644 --- a/common/client.go +++ b/common/client.go @@ -6,23 +6,61 @@ import ( "fmt" "io" "net/http" + "net/url" "one-api/types" "strconv" + "sync" "time" "github.com/gin-gonic/gin" + "golang.org/x/net/proxy" ) -var HttpClient *http.Client +var clientPool = &sync.Pool{ + New: func() interface{} { + return &http.Client{} + }, +} -func init() { - if RelayTimeout == 0 { - HttpClient = &http.Client{} - } else { - HttpClient = &http.Client{ - Timeout: time.Duration(RelayTimeout) * time.Second, +func GetHttpClient(proxyAddr string) *http.Client { + client := clientPool.Get().(*http.Client) + + if RelayTimeout > 0 { + client.Timeout = time.Duration(RelayTimeout) * time.Second + } + + if proxyAddr != "" { + proxyURL, err := url.Parse(proxyAddr) + if err != nil { + SysError("Error parsing proxy address: " + err.Error()) + return client + } + + switch proxyURL.Scheme { + case "http", "https": + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + case "socks5": + dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct) + if err != nil { + SysError("Error creating SOCKS5 dialer: " + err.Error()) + return client + } + client.Transport = &http.Transport{ + Dial: dialer.Dial, + } + default: + SysError("Unsupported proxy scheme: " + proxyURL.Scheme) } } + + return client + +} + +func PutHttpClient(c *http.Client) { + clientPool.Put(c) } type Client struct { @@ -92,12 +130,14 @@ func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http return req, nil } -func SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) { +func SendRequest(req *http.Request, response any, outputResp bool, proxyAddr string) (*http.Response, *types.OpenAIErrorWithStatusCode) { // 发送请求 - resp, err := HttpClient.Do(req) + client := GetHttpClient(proxyAddr) + resp, err := client.Do(req) if err != nil { return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) } + PutHttpClient(client) if !outputResp { defer resp.Body.Close() @@ -210,8 +250,10 @@ func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.Open return } -func (c *Client) SendRequestRaw(req *http.Request) (body io.ReadCloser, err error) { - resp, err := HttpClient.Do(req) +func (c *Client) SendRequestRaw(req *http.Request, proxyAddr string) (body io.ReadCloser, err error) { + client := GetHttpClient(proxyAddr) + resp, err := client.Do(req) + PutHttpClient(client) if err != nil { return } diff --git a/go.mod b/go.mod index 81a59a52..fbb34d58 100644 --- a/go.mod +++ b/go.mod @@ -58,7 +58,7 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect - golang.org/x/net v0.17.0 // indirect + golang.org/x/net v0.19.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.30.0 // indirect diff --git a/go.sum b/go.sum index b0c52c8c..2517022f 100644 --- a/go.sum +++ b/go.sum @@ -159,6 +159,8 @@ golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/model/channel.go b/model/channel.go index 8bcb8a96..b8352862 100644 --- a/model/channel.go +++ b/model/channel.go @@ -25,6 +25,7 @@ type Channel struct { UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` + Proxy string `json:"proxy" gorm:"type:varchar(255);default:''"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { diff --git a/providers/aigc2d/balance.go b/providers/aigc2d/balance.go index cb1613bb..070950c3 100644 --- a/providers/aigc2d/balance.go +++ b/providers/aigc2d/balance.go @@ -19,7 +19,7 @@ func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) { // 发送请求 var response base.BalanceResponse - _, errWithCode := common.SendRequest(req, &response, false) + _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) if errWithCode != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } diff --git a/providers/aiproxy/balance.go b/providers/aiproxy/balance.go index 82a1653e..66f96d2b 100644 --- a/providers/aiproxy/balance.go +++ b/providers/aiproxy/balance.go @@ -20,7 +20,7 @@ func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) { // 发送请求 var response AIProxyUserOverviewResponse - _, errWithCode := common.SendRequest(req, &response, false) + _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) if errWithCode != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } diff --git a/providers/ali/chat.go b/providers/ali/chat.go index 723d8c33..a59ad636 100644 --- a/providers/ali/chat.go +++ b/providers/ali/chat.go @@ -157,10 +157,12 @@ func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage usage = &types.Usage{} // 发送请求 - resp, err := common.HttpClient.Do(req) + client := common.GetHttpClient(p.Channel.Proxy) + resp, err := client.Do(req) if err != nil { return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) } + common.PutHttpClient(client) if common.IsFailureStatusCode(resp) { return nil, common.HandleErrorResp(resp) diff --git a/providers/api2d/balance.go b/providers/api2d/balance.go index 520c04c3..67f9d8ae 100644 --- a/providers/api2d/balance.go +++ b/providers/api2d/balance.go @@ -19,7 +19,7 @@ func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) { // 发送请求 var response base.BalanceResponse - _, errWithCode := common.SendRequest(req, &response, false) + _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) if errWithCode != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } diff --git a/providers/api2gpt/balance.go b/providers/api2gpt/balance.go index a8872b40..1288e8a8 100644 --- a/providers/api2gpt/balance.go +++ b/providers/api2gpt/balance.go @@ -19,7 +19,7 @@ func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) { // 发送请求 var response base.BalanceResponse - _, errWithCode := common.SendRequest(req, &response, false) + _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) if errWithCode != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } diff --git a/providers/azure/image_generations.go b/providers/azure/image_generations.go index 15ca1b07..294f66fe 100644 --- a/providers/azure/image_generations.go +++ b/providers/azure/image_generations.go @@ -38,7 +38,7 @@ func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIRespons for i := 0; i < 3; i++ { // 休眠 2 秒 time.Sleep(2 * time.Second) - _, errWithCode = common.SendRequest(req, &getImageAzureResponse, false) + _, errWithCode = common.SendRequest(req, &getImageAzureResponse, false, c.Proxy) fmt.Println("getImageAzureResponse", getImageAzureResponse) if errWithCode != nil { return @@ -81,6 +81,7 @@ func (p *AzureProvider) ImageGenerationsAction(request *types.ImageRequest, isMo if request.Model == "dall-e-2" { imageAzureResponse := &ImageAzureResponse{ Header: headers, + Proxy: p.Channel.Proxy, } errWithCode = p.SendRequest(req, imageAzureResponse, false) } else { diff --git a/providers/azure/type.go b/providers/azure/type.go index 7452fee1..a6f677f0 100644 --- a/providers/azure/type.go +++ b/providers/azure/type.go @@ -10,6 +10,7 @@ type ImageAzureResponse struct { Status string `json:"status,omitempty"` Error ImageAzureError `json:"error,omitempty"` Header map[string]string `json:"header,omitempty"` + Proxy string `json:"proxy,omitempty"` } type ImageAzureError struct { diff --git a/providers/baidu/base.go b/providers/baidu/base.go index b365e0fb..7dea85b6 100644 --- a/providers/baidu/base.go +++ b/providers/baidu/base.go @@ -105,10 +105,12 @@ func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessTo return nil, err } - resp, err := common.HttpClient.Do(req) + httpClient := common.GetHttpClient(p.Channel.Proxy) + resp, err := httpClient.Do(req) if err != nil { return nil, err } + common.PutHttpClient(httpClient) defer resp.Body.Close() diff --git a/providers/baidu/chat.go b/providers/baidu/chat.go index 0c424b66..e675adc2 100644 --- a/providers/baidu/chat.go +++ b/providers/baidu/chat.go @@ -130,10 +130,12 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usag usage = &types.Usage{} // 发送请求 - resp, err := common.HttpClient.Do(req) + client := common.GetHttpClient(p.Channel.Proxy) + resp, err := client.Do(req) if err != nil { return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) } + common.PutHttpClient(client) if common.IsFailureStatusCode(resp) { return nil, common.HandleErrorResp(resp) diff --git a/providers/base/common.go b/providers/base/common.go index 02d63e57..7f2d1a41 100644 --- a/providers/base/common.go +++ b/providers/base/common.go @@ -65,7 +65,7 @@ func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) { func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { defer req.Body.Close() - resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true) + resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true, p.Channel.Proxy) if openAIErrorWithStatusCode != nil { return } @@ -108,10 +108,12 @@ func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusC defer req.Body.Close() // 发送请求 - resp, err := common.HttpClient.Do(req) + client := common.GetHttpClient(p.Channel.Proxy) + resp, err := client.Do(req) if err != nil { return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) } + common.PutHttpClient(client) defer resp.Body.Close() diff --git a/providers/claude/chat.go b/providers/claude/chat.go index 02f309b0..d2094926 100644 --- a/providers/claude/chat.go +++ b/providers/claude/chat.go @@ -142,10 +142,12 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro defer req.Body.Close() // 发送请求 - resp, err := common.HttpClient.Do(req) + client := common.GetHttpClient(p.Channel.Proxy) + resp, err := client.Do(req) if err != nil { return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" } + common.PutHttpClient(client) if common.IsFailureStatusCode(resp) { return common.HandleErrorResp(resp), "" diff --git a/providers/closeai/balance.go b/providers/closeai/balance.go index 80665df2..82a99432 100644 --- a/providers/closeai/balance.go +++ b/providers/closeai/balance.go @@ -18,7 +18,7 @@ func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) // 发送请求 var response OpenAICreditGrants - _, errWithCode := common.SendRequest(req, &response, false) + _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) if errWithCode != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } diff --git a/providers/gemini/chat.go b/providers/gemini/chat.go index 721d4475..c9c30244 100644 --- a/providers/gemini/chat.go +++ b/providers/gemini/chat.go @@ -217,10 +217,12 @@ func (p *GeminiProvider) sendStreamRequest(req *http.Request, model string) (*ty defer req.Body.Close() // 发送请求 - resp, err := common.HttpClient.Do(req) + client := common.GetHttpClient(p.Channel.Proxy) + resp, err := client.Do(req) if err != nil { return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" } + common.PutHttpClient(client) if common.IsFailureStatusCode(resp) { return common.HandleErrorResp(resp), "" diff --git a/providers/openai/balance.go b/providers/openai/balance.go index 8a616ea9..4c27ec6c 100644 --- a/providers/openai/balance.go +++ b/providers/openai/balance.go @@ -20,7 +20,7 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) { // 发送请求 var subscription OpenAISubscriptionResponse - _, errWithCode := common.SendRequest(req, &subscription, false) + _, errWithCode := common.SendRequest(req, &subscription, false, p.Channel.Proxy) if errWithCode != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } @@ -38,7 +38,7 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) { return 0, err } usage := OpenAIUsageResponse{} - _, errWithCode = common.SendRequest(req, &usage, false) + _, errWithCode = common.SendRequest(req, &usage, false, p.Channel.Proxy) balance := subscription.HardLimitUSD - usage.TotalUsage/100 channel.UpdateBalance(balance) diff --git a/providers/openai/base.go b/providers/openai/base.go index 42edf1b2..99db8079 100644 --- a/providers/openai/base.go +++ b/providers/openai/base.go @@ -111,10 +111,12 @@ func (p *OpenAIProvider) GetRequestBody(request any, isModelMapped bool) (reques func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) { defer req.Body.Close() - resp, err := common.HttpClient.Do(req) + client := common.GetHttpClient(p.Channel.Proxy) + resp, err := client.Do(req) if err != nil { return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" } + common.PutHttpClient(client) if common.IsFailureStatusCode(resp) { return common.HandleErrorResp(resp), "" diff --git a/providers/openaisb/balance.go b/providers/openaisb/balance.go index f03bef97..d67b03f4 100644 --- a/providers/openaisb/balance.go +++ b/providers/openaisb/balance.go @@ -21,7 +21,7 @@ func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) { // 发送请求 var response OpenAISBUsageResponse - _, errWithCode := common.SendRequest(req, &response, false) + _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) if err != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } diff --git a/providers/palm/chat.go b/providers/palm/chat.go index 67aa52e0..81dd1777 100644 --- a/providers/palm/chat.go +++ b/providers/palm/chat.go @@ -134,10 +134,12 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW defer req.Body.Close() // 发送请求 - resp, err := common.HttpClient.Do(req) + client := common.GetHttpClient(p.Channel.Proxy) + resp, err := client.Do(req) if err != nil { return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" } + common.PutHttpClient(client) if common.IsFailureStatusCode(resp) { return common.HandleErrorResp(resp), "" diff --git a/providers/tencent/chat.go b/providers/tencent/chat.go index 339a5a39..1965c549 100644 --- a/providers/tencent/chat.go +++ b/providers/tencent/chat.go @@ -147,10 +147,12 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC func (p *TencentProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) { defer req.Body.Close() // 发送请求 - resp, err := common.HttpClient.Do(req) + client := common.GetHttpClient(p.Channel.Proxy) + resp, err := client.Do(req) if err != nil { return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" } + common.PutHttpClient(client) if common.IsFailureStatusCode(resp) { return common.HandleErrorResp(resp), "" diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go index a22f9815..7e58c6e5 100644 --- a/providers/zhipu/chat.go +++ b/providers/zhipu/chat.go @@ -145,10 +145,12 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request, model string) (*typ defer req.Body.Close() // 发送请求 - resp, err := common.HttpClient.Do(req) + client := common.GetHttpClient(p.Channel.Proxy) + resp, err := client.Do(req) if err != nil { return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil } + common.PutHttpClient(client) if common.IsFailureStatusCode(resp) { return common.HandleErrorResp(resp), nil diff --git a/web/src/views/Channel/component/EditModal.js b/web/src/views/Channel/component/EditModal.js index c91c97cb..bef5b5bc 100644 --- a/web/src/views/Channel/component/EditModal.js +++ b/web/src/views/Channel/component/EditModal.js @@ -35,6 +35,7 @@ const validationSchema = Yup.object().shape({ type: Yup.number().required('渠道 不能为空'), key: Yup.string().when('is_edit', { is: false, then: Yup.string().required('密钥 不能为空') }), other: Yup.string(), + proxy: Yup.string(), models: Yup.array().min(1, '模型 不能为空'), groups: Yup.array().min(1, '用户组 不能为空'), base_url: Yup.string().when('type', { @@ -442,6 +443,27 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { {inputPrompt.model_mapping} )} + + {inputLabel.proxy} + + {touched.proxy && errors.proxy ? ( + + {errors.proxy} + + ) : ( + {inputPrompt.proxy} + )} +