diff --git a/common/constants.go b/common/constants.go index 93d35b77..f704fc98 100644 --- a/common/constants.go +++ b/common/constants.go @@ -96,7 +96,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second var BatchUpdateEnabled = false var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) -var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second +var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 600) // unit is second +var ConnectTimeout = GetOrDefault("CONNECT_TIMEOUT", 5) // unit is second const ( RequestIdKey = "X-Oneapi-Request-Id" diff --git a/common/requester/http_client.go b/common/requester/http_client.go index 469f2459..d049d3c2 100644 --- a/common/requester/http_client.go +++ b/common/requester/http_client.go @@ -1,70 +1,80 @@ package requester import ( + "context" "fmt" + "net" "net/http" "net/url" "one-api/common" - "sync" "time" "golang.org/x/net/proxy" ) -type HTTPClient struct{} +type ContextKey string -var clientPool = &sync.Pool{ - New: func() interface{} { - return &http.Client{} - }, -} +const ProxyHTTPAddrKey ContextKey = "proxyHttpAddr" +const ProxySock5AddrKey ContextKey = "proxySock5Addr" -func (h *HTTPClient) getClientFromPool(proxyAddr string) *http.Client { - client := clientPool.Get().(*http.Client) - - if common.RelayTimeout > 0 { - client.Timeout = time.Duration(common.RelayTimeout) * time.Second +func proxyFunc(req *http.Request) (*url.URL, error) { + proxyAddr := req.Context().Value(ProxyHTTPAddrKey) + if proxyAddr == nil { + return nil, nil } - if proxyAddr != "" { - err := h.setProxy(client, proxyAddr) - if err != nil { - common.SysError(err.Error()) - return client - } - } - - return client -} - -func (h *HTTPClient) returnClientToPool(client *http.Client) { - // 清除代理设置 - client.Transport = nil - clientPool.Put(client) -} - -func (h *HTTPClient) setProxy(client *http.Client, proxyAddr string) error { - proxyURL, err := url.Parse(proxyAddr) + proxyURL, err := url.Parse(proxyAddr.(string)) if err != nil { - return fmt.Errorf("error parsing proxy address: %w", err) + return nil, fmt.Errorf("error parsing proxy address: %w", err) } switch proxyURL.Scheme { case "http", "https": - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - } - case "socks5": - dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct) - if err != nil { - return fmt.Errorf("error creating socks5 dialer: %w", err) - } - client.Transport = &http.Transport{ - Dial: dialer.Dial, - } - default: - return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) + return proxyURL, nil } - return nil + return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) +} + +func socks5ProxyFunc(ctx context.Context, network, addr string) (net.Conn, error) { + // 设置TCP超时 + dialer := &net.Dialer{ + Timeout: time.Duration(common.ConnectTimeout) * time.Second, + KeepAlive: 30 * time.Second, + } + + // 从上下文中获取代理地址 + proxyAddr, ok := ctx.Value(ProxySock5AddrKey).(string) + if !ok { + return dialer.DialContext(ctx, network, addr) + } + + proxyURL, err := url.Parse(proxyAddr) + if err != nil { + return nil, fmt.Errorf("error parsing proxy address: %w", err) + } + + proxyDialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct) + if err != nil { + return nil, fmt.Errorf("error creating socks5 dialer: %w", err) + } + + return proxyDialer.Dial(network, addr) +} + +var HTTPClient *http.Client + +func init() { + trans := &http.Transport{ + DialContext: socks5ProxyFunc, + Proxy: proxyFunc, + } + + HTTPClient = &http.Client{ + Transport: trans, + } + + if common.RelayTimeout != 0 { + HTTPClient.Timeout = time.Duration(common.RelayTimeout) * time.Second + } } diff --git a/common/requester/http_requester.go b/common/requester/http_requester.go index 8f7e5dd0..17d870a9 100644 --- a/common/requester/http_requester.go +++ b/common/requester/http_requester.go @@ -3,6 +3,7 @@ package requester import ( "bufio" "bytes" + "context" "encoding/json" "fmt" "io" @@ -18,7 +19,6 @@ import ( type HttpErrorHandler func(*http.Response) *types.OpenAIError type HTTPRequester struct { - HTTPClient HTTPClient requestBuilder RequestBuilder CreateFormBuilder func(io.Writer) FormBuilder ErrorHandler HttpErrorHandler @@ -31,7 +31,6 @@ type HTTPRequester struct { // 如果 errorHandler 为 nil,那么会使用一个默认的错误处理函数。 func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequester { return &HTTPRequester{ - HTTPClient: HTTPClient{}, requestBuilder: NewRequestBuilder(), CreateFormBuilder: func(body io.Writer) FormBuilder { return NewFormBuilder(body) @@ -48,6 +47,21 @@ type requestOptions struct { type requestOption func(*requestOptions) +func (r *HTTPRequester) getContext() context.Context { + if r.proxyAddr == "" { + return context.Background() + } + + // 如果是以 socks5:// 开头的地址,那么使用 socks5 代理 + if strings.HasPrefix(r.proxyAddr, "socks5://") { + return context.WithValue(context.Background(), ProxySock5AddrKey, r.proxyAddr) + } + + // 否则使用 http 代理 + return context.WithValue(context.Background(), ProxyHTTPAddrKey, r.proxyAddr) + +} + // 创建请求 func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) { args := &requestOptions{ @@ -57,7 +71,7 @@ func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) for _, setter := range setters { setter(args) } - req, err := r.requestBuilder.Build(method, url, args.body, args.header) + req, err := r.requestBuilder.Build(r.getContext(), method, url, args.body, args.header) if err != nil { return nil, err } @@ -67,9 +81,7 @@ func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) // 发送请求 func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) { - client := r.HTTPClient.getClientFromPool(r.proxyAddr) - resp, err := client.Do(req) - r.HTTPClient.returnClientToPool(client) + resp, err := HTTPClient.Do(req) if err != nil { return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) } @@ -105,9 +117,7 @@ func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp // 发送请求 RAW func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *types.OpenAIErrorWithStatusCode) { // 发送请求 - client := r.HTTPClient.getClientFromPool(r.proxyAddr) - resp, err := client.Do(req) - r.HTTPClient.returnClientToPool(client) + resp, err := HTTPClient.Do(req) if err != nil { return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) } diff --git a/common/requester/request_builder.go b/common/requester/request_builder.go index 88cebe1a..a7f4ee83 100644 --- a/common/requester/request_builder.go +++ b/common/requester/request_builder.go @@ -2,13 +2,14 @@ package requester import ( "bytes" + "context" "io" "net/http" "one-api/common" ) type RequestBuilder interface { - Build(method, url string, body any, header http.Header) (*http.Request, error) + Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error) } type HTTPRequestBuilder struct { @@ -22,6 +23,7 @@ func NewRequestBuilder() *HTTPRequestBuilder { } func (b *HTTPRequestBuilder) Build( + ctx context.Context, method string, url string, body any, @@ -40,7 +42,7 @@ func (b *HTTPRequestBuilder) Build( bodyReader = bytes.NewBuffer(reqBytes) } } - req, err = http.NewRequest(method, url, bodyReader) + req, err = http.NewRequestWithContext(ctx, method, url, bodyReader) if err != nil { return } diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go index 7ae0fa21..687ffe7c 100644 --- a/providers/zhipu/chat.go +++ b/providers/zhipu/chat.go @@ -120,6 +120,34 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest { ToolChoice: request.ToolChoice, } + // 如果有图片的话,并且是base64编码的图片,需要把前缀去掉 + if zhipuRequest.Model == "glm-4v" { + for i := range zhipuRequest.Messages { + contentList, ok := zhipuRequest.Messages[i].Content.([]any) + if !ok { + continue + } + for j := range contentList { + contentMap, ok := contentList[j].(map[string]any) + if !ok || contentMap["type"] != "image_url" { + continue + } + imageUrl, ok := contentMap["image_url"].(map[string]any) + if !ok { + continue + } + url, ok := imageUrl["url"].(string) + if !ok || !strings.HasPrefix(url, "data:image/") { + continue + } + imageUrl["url"] = strings.Split(url, ",")[1] + contentMap["image_url"] = imageUrl + contentList[j] = contentMap + } + zhipuRequest.Messages[i].Content = contentList + } + } + if request.Functions != nil { zhipuRequest.Tools = make([]ZhipuTool, 0, len(request.Functions)) for _, function := range request.Functions {