diff --git a/common/client.go b/common/client.go
index 91903ac0..3e94e92b 100644
--- a/common/client.go
+++ b/common/client.go
@@ -6,23 +6,61 @@ import (
"fmt"
"io"
"net/http"
+ "net/url"
"one-api/types"
"strconv"
+ "sync"
"time"
"github.com/gin-gonic/gin"
+ "golang.org/x/net/proxy"
)
-var HttpClient *http.Client
+var clientPool = &sync.Pool{
+ New: func() interface{} {
+ return &http.Client{}
+ },
+}
-func init() {
- if RelayTimeout == 0 {
- HttpClient = &http.Client{}
- } else {
- HttpClient = &http.Client{
- Timeout: time.Duration(RelayTimeout) * time.Second,
+func GetHttpClient(proxyAddr string) *http.Client {
+ client := clientPool.Get().(*http.Client)
+
+ if RelayTimeout > 0 {
+ client.Timeout = time.Duration(RelayTimeout) * time.Second
+ }
+
+ if proxyAddr != "" {
+ proxyURL, err := url.Parse(proxyAddr)
+ if err != nil {
+ SysError("Error parsing proxy address: " + err.Error())
+ return client
+ }
+
+ switch proxyURL.Scheme {
+ case "http", "https":
+ client.Transport = &http.Transport{
+ Proxy: http.ProxyURL(proxyURL),
+ }
+ case "socks5":
+ dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
+ if err != nil {
+ SysError("Error creating SOCKS5 dialer: " + err.Error())
+ return client
+ }
+ client.Transport = &http.Transport{
+ Dial: dialer.Dial,
+ }
+ default:
+ SysError("Unsupported proxy scheme: " + proxyURL.Scheme)
}
}
+
+ return client
+
+}
+
+func PutHttpClient(c *http.Client) {
+ clientPool.Put(c)
}
type Client struct {
@@ -92,12 +130,14 @@ func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http
return req, nil
}
-func SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
+func SendRequest(req *http.Request, response any, outputResp bool, proxyAddr string) (*http.Response, *types.OpenAIErrorWithStatusCode) {
// 发送请求
- resp, err := HttpClient.Do(req)
+ client := GetHttpClient(proxyAddr)
+ resp, err := client.Do(req)
if err != nil {
return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
+ PutHttpClient(client)
if !outputResp {
defer resp.Body.Close()
@@ -210,8 +250,10 @@ func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.Open
return
}
-func (c *Client) SendRequestRaw(req *http.Request) (body io.ReadCloser, err error) {
- resp, err := HttpClient.Do(req)
+func (c *Client) SendRequestRaw(req *http.Request, proxyAddr string) (body io.ReadCloser, err error) {
+ client := GetHttpClient(proxyAddr)
+ resp, err := client.Do(req)
+ PutHttpClient(client)
if err != nil {
return
}
diff --git a/go.mod b/go.mod
index 81a59a52..fbb34d58 100644
--- a/go.mod
+++ b/go.mod
@@ -58,7 +58,7 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
- golang.org/x/net v0.17.0 // indirect
+ golang.org/x/net v0.19.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
diff --git a/go.sum b/go.sum
index b0c52c8c..2517022f 100644
--- a/go.sum
+++ b/go.sum
@@ -159,6 +159,8 @@ golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
+golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
+golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
diff --git a/model/channel.go b/model/channel.go
index 8bcb8a96..b8352862 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -25,6 +25,7 @@ type Channel struct {
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
+ Proxy string `json:"proxy" gorm:"type:varchar(255);default:''"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
diff --git a/providers/aigc2d/balance.go b/providers/aigc2d/balance.go
index cb1613bb..070950c3 100644
--- a/providers/aigc2d/balance.go
+++ b/providers/aigc2d/balance.go
@@ -19,7 +19,7 @@ func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) {
// 发送请求
var response base.BalanceResponse
- _, errWithCode := common.SendRequest(req, &response, false)
+ _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
diff --git a/providers/aiproxy/balance.go b/providers/aiproxy/balance.go
index 82a1653e..66f96d2b 100644
--- a/providers/aiproxy/balance.go
+++ b/providers/aiproxy/balance.go
@@ -20,7 +20,7 @@ func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) {
// 发送请求
var response AIProxyUserOverviewResponse
- _, errWithCode := common.SendRequest(req, &response, false)
+ _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
diff --git a/providers/ali/chat.go b/providers/ali/chat.go
index 723d8c33..a59ad636 100644
--- a/providers/ali/chat.go
+++ b/providers/ali/chat.go
@@ -157,10 +157,12 @@ func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage
usage = &types.Usage{}
// 发送请求
- resp, err := common.HttpClient.Do(req)
+ client := common.GetHttpClient(p.Channel.Proxy)
+ resp, err := client.Do(req)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
+ common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return nil, common.HandleErrorResp(resp)
diff --git a/providers/api2d/balance.go b/providers/api2d/balance.go
index 520c04c3..67f9d8ae 100644
--- a/providers/api2d/balance.go
+++ b/providers/api2d/balance.go
@@ -19,7 +19,7 @@ func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
// 发送请求
var response base.BalanceResponse
- _, errWithCode := common.SendRequest(req, &response, false)
+ _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
diff --git a/providers/api2gpt/balance.go b/providers/api2gpt/balance.go
index a8872b40..1288e8a8 100644
--- a/providers/api2gpt/balance.go
+++ b/providers/api2gpt/balance.go
@@ -19,7 +19,7 @@ func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) {
// 发送请求
var response base.BalanceResponse
- _, errWithCode := common.SendRequest(req, &response, false)
+ _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
diff --git a/providers/azure/image_generations.go b/providers/azure/image_generations.go
index 15ca1b07..294f66fe 100644
--- a/providers/azure/image_generations.go
+++ b/providers/azure/image_generations.go
@@ -38,7 +38,7 @@ func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIRespons
for i := 0; i < 3; i++ {
// 休眠 2 秒
time.Sleep(2 * time.Second)
- _, errWithCode = common.SendRequest(req, &getImageAzureResponse, false)
+ _, errWithCode = common.SendRequest(req, &getImageAzureResponse, false, c.Proxy)
fmt.Println("getImageAzureResponse", getImageAzureResponse)
if errWithCode != nil {
return
@@ -81,6 +81,7 @@ func (p *AzureProvider) ImageGenerationsAction(request *types.ImageRequest, isMo
if request.Model == "dall-e-2" {
imageAzureResponse := &ImageAzureResponse{
Header: headers,
+ Proxy: p.Channel.Proxy,
}
errWithCode = p.SendRequest(req, imageAzureResponse, false)
} else {
diff --git a/providers/azure/type.go b/providers/azure/type.go
index 7452fee1..a6f677f0 100644
--- a/providers/azure/type.go
+++ b/providers/azure/type.go
@@ -10,6 +10,7 @@ type ImageAzureResponse struct {
Status string `json:"status,omitempty"`
Error ImageAzureError `json:"error,omitempty"`
Header map[string]string `json:"header,omitempty"`
+ Proxy string `json:"proxy,omitempty"`
}
type ImageAzureError struct {
diff --git a/providers/baidu/base.go b/providers/baidu/base.go
index b365e0fb..7dea85b6 100644
--- a/providers/baidu/base.go
+++ b/providers/baidu/base.go
@@ -105,10 +105,12 @@ func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessTo
return nil, err
}
- resp, err := common.HttpClient.Do(req)
+ httpClient := common.GetHttpClient(p.Channel.Proxy)
+ resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
+ common.PutHttpClient(httpClient)
defer resp.Body.Close()
diff --git a/providers/baidu/chat.go b/providers/baidu/chat.go
index 0c424b66..e675adc2 100644
--- a/providers/baidu/chat.go
+++ b/providers/baidu/chat.go
@@ -130,10 +130,12 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usag
usage = &types.Usage{}
// 发送请求
- resp, err := common.HttpClient.Do(req)
+ client := common.GetHttpClient(p.Channel.Proxy)
+ resp, err := client.Do(req)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
+ common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return nil, common.HandleErrorResp(resp)
diff --git a/providers/base/common.go b/providers/base/common.go
index 02d63e57..7f2d1a41 100644
--- a/providers/base/common.go
+++ b/providers/base/common.go
@@ -65,7 +65,7 @@ func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) {
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close()
- resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true)
+ resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true, p.Channel.Proxy)
if openAIErrorWithStatusCode != nil {
return
}
@@ -108,10 +108,12 @@ func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusC
defer req.Body.Close()
// 发送请求
- resp, err := common.HttpClient.Do(req)
+ client := common.GetHttpClient(p.Channel.Proxy)
+ resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
+ common.PutHttpClient(client)
defer resp.Body.Close()
diff --git a/providers/claude/chat.go b/providers/claude/chat.go
index 02f309b0..d2094926 100644
--- a/providers/claude/chat.go
+++ b/providers/claude/chat.go
@@ -142,10 +142,12 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
defer req.Body.Close()
// 发送请求
- resp, err := common.HttpClient.Do(req)
+ client := common.GetHttpClient(p.Channel.Proxy)
+ resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
+ common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""
diff --git a/providers/closeai/balance.go b/providers/closeai/balance.go
index 80665df2..82a99432 100644
--- a/providers/closeai/balance.go
+++ b/providers/closeai/balance.go
@@ -18,7 +18,7 @@ func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error)
// 发送请求
var response OpenAICreditGrants
- _, errWithCode := common.SendRequest(req, &response, false)
+ _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
diff --git a/providers/gemini/chat.go b/providers/gemini/chat.go
index 721d4475..c9c30244 100644
--- a/providers/gemini/chat.go
+++ b/providers/gemini/chat.go
@@ -217,10 +217,12 @@ func (p *GeminiProvider) sendStreamRequest(req *http.Request, model string) (*ty
defer req.Body.Close()
// 发送请求
- resp, err := common.HttpClient.Do(req)
+ client := common.GetHttpClient(p.Channel.Proxy)
+ resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
+ common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""
diff --git a/providers/openai/balance.go b/providers/openai/balance.go
index 8a616ea9..4c27ec6c 100644
--- a/providers/openai/balance.go
+++ b/providers/openai/balance.go
@@ -20,7 +20,7 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
// 发送请求
var subscription OpenAISubscriptionResponse
- _, errWithCode := common.SendRequest(req, &subscription, false)
+ _, errWithCode := common.SendRequest(req, &subscription, false, p.Channel.Proxy)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
@@ -38,7 +38,7 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
return 0, err
}
usage := OpenAIUsageResponse{}
- _, errWithCode = common.SendRequest(req, &usage, false)
+ _, errWithCode = common.SendRequest(req, &usage, false, p.Channel.Proxy)
balance := subscription.HardLimitUSD - usage.TotalUsage/100
channel.UpdateBalance(balance)
diff --git a/providers/openai/base.go b/providers/openai/base.go
index 42edf1b2..99db8079 100644
--- a/providers/openai/base.go
+++ b/providers/openai/base.go
@@ -111,10 +111,12 @@ func (p *OpenAIProvider) GetRequestBody(request any, isModelMapped bool) (reques
func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
defer req.Body.Close()
- resp, err := common.HttpClient.Do(req)
+ client := common.GetHttpClient(p.Channel.Proxy)
+ resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
+ common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""
diff --git a/providers/openaisb/balance.go b/providers/openaisb/balance.go
index f03bef97..d67b03f4 100644
--- a/providers/openaisb/balance.go
+++ b/providers/openaisb/balance.go
@@ -21,7 +21,7 @@ func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
// 发送请求
var response OpenAISBUsageResponse
- _, errWithCode := common.SendRequest(req, &response, false)
+ _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
if err != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
diff --git a/providers/palm/chat.go b/providers/palm/chat.go
index 67aa52e0..81dd1777 100644
--- a/providers/palm/chat.go
+++ b/providers/palm/chat.go
@@ -134,10 +134,12 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW
defer req.Body.Close()
// 发送请求
- resp, err := common.HttpClient.Do(req)
+ client := common.GetHttpClient(p.Channel.Proxy)
+ resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
+ common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""
diff --git a/providers/tencent/chat.go b/providers/tencent/chat.go
index 339a5a39..1965c549 100644
--- a/providers/tencent/chat.go
+++ b/providers/tencent/chat.go
@@ -147,10 +147,12 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC
func (p *TencentProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) {
defer req.Body.Close()
// 发送请求
- resp, err := common.HttpClient.Do(req)
+ client := common.GetHttpClient(p.Channel.Proxy)
+ resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
+ common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""
diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go
index a22f9815..7e58c6e5 100644
--- a/providers/zhipu/chat.go
+++ b/providers/zhipu/chat.go
@@ -145,10 +145,12 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request, model string) (*typ
defer req.Body.Close()
// 发送请求
- resp, err := common.HttpClient.Do(req)
+ client := common.GetHttpClient(p.Channel.Proxy)
+ resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
}
+ common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), nil
diff --git a/web/src/views/Channel/component/EditModal.js b/web/src/views/Channel/component/EditModal.js
index c91c97cb..bef5b5bc 100644
--- a/web/src/views/Channel/component/EditModal.js
+++ b/web/src/views/Channel/component/EditModal.js
@@ -35,6 +35,7 @@ const validationSchema = Yup.object().shape({
type: Yup.number().required('渠道 不能为空'),
key: Yup.string().when('is_edit', { is: false, then: Yup.string().required('密钥 不能为空') }),
other: Yup.string(),
+ proxy: Yup.string(),
models: Yup.array().min(1, '模型 不能为空'),
groups: Yup.array().min(1, '用户组 不能为空'),
base_url: Yup.string().when('type', {
@@ -442,6 +443,27 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
{inputPrompt.model_mapping}
)}
+
+ {inputLabel.proxy}
+
+ {touched.proxy && errors.proxy ? (
+
+ {errors.proxy}
+
+ ) : (
+ {inputPrompt.proxy}
+ )}
+