🐛 fix: glm-4v support base64 image (#81)

* 💄 improve: http client changes proxy using context

* 🐛 fix: glm-4v support base64 image
This commit is contained in:
Buer 2024-03-01 18:52:39 +08:00 committed by GitHub
parent fab465e82a
commit 7c78ed9fad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 109 additions and 58 deletions

View File

@ -96,7 +96,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 600) // unit is second
var ConnectTimeout = GetOrDefault("CONNECT_TIMEOUT", 5) // unit is second
const ( const (
RequestIdKey = "X-Oneapi-Request-Id" RequestIdKey = "X-Oneapi-Request-Id"

View File

@ -1,70 +1,80 @@
package requester package requester
import ( import (
"context"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"one-api/common" "one-api/common"
"sync"
"time" "time"
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
) )
type HTTPClient struct{} type ContextKey string
var clientPool = &sync.Pool{ const ProxyHTTPAddrKey ContextKey = "proxyHttpAddr"
New: func() interface{} { const ProxySock5AddrKey ContextKey = "proxySock5Addr"
return &http.Client{}
},
}
func (h *HTTPClient) getClientFromPool(proxyAddr string) *http.Client { func proxyFunc(req *http.Request) (*url.URL, error) {
client := clientPool.Get().(*http.Client) proxyAddr := req.Context().Value(ProxyHTTPAddrKey)
if proxyAddr == nil {
if common.RelayTimeout > 0 { return nil, nil
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
} }
if proxyAddr != "" { proxyURL, err := url.Parse(proxyAddr.(string))
err := h.setProxy(client, proxyAddr)
if err != nil {
common.SysError(err.Error())
return client
}
}
return client
}
func (h *HTTPClient) returnClientToPool(client *http.Client) {
// 清除代理设置
client.Transport = nil
clientPool.Put(client)
}
func (h *HTTPClient) setProxy(client *http.Client, proxyAddr string) error {
proxyURL, err := url.Parse(proxyAddr)
if err != nil { if err != nil {
return fmt.Errorf("error parsing proxy address: %w", err) return nil, fmt.Errorf("error parsing proxy address: %w", err)
} }
switch proxyURL.Scheme { switch proxyURL.Scheme {
case "http", "https": case "http", "https":
client.Transport = &http.Transport{ return proxyURL, nil
Proxy: http.ProxyURL(proxyURL),
}
case "socks5":
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
if err != nil {
return fmt.Errorf("error creating socks5 dialer: %w", err)
}
client.Transport = &http.Transport{
Dial: dialer.Dial,
}
default:
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
} }
return nil return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
}
func socks5ProxyFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 设置TCP超时
dialer := &net.Dialer{
Timeout: time.Duration(common.ConnectTimeout) * time.Second,
KeepAlive: 30 * time.Second,
}
// 从上下文中获取代理地址
proxyAddr, ok := ctx.Value(ProxySock5AddrKey).(string)
if !ok {
return dialer.DialContext(ctx, network, addr)
}
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return nil, fmt.Errorf("error parsing proxy address: %w", err)
}
proxyDialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("error creating socks5 dialer: %w", err)
}
return proxyDialer.Dial(network, addr)
}
var HTTPClient *http.Client
func init() {
trans := &http.Transport{
DialContext: socks5ProxyFunc,
Proxy: proxyFunc,
}
HTTPClient = &http.Client{
Transport: trans,
}
if common.RelayTimeout != 0 {
HTTPClient.Timeout = time.Duration(common.RelayTimeout) * time.Second
}
} }

View File

@ -3,6 +3,7 @@ package requester
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -18,7 +19,6 @@ import (
type HttpErrorHandler func(*http.Response) *types.OpenAIError type HttpErrorHandler func(*http.Response) *types.OpenAIError
type HTTPRequester struct { type HTTPRequester struct {
HTTPClient HTTPClient
requestBuilder RequestBuilder requestBuilder RequestBuilder
CreateFormBuilder func(io.Writer) FormBuilder CreateFormBuilder func(io.Writer) FormBuilder
ErrorHandler HttpErrorHandler ErrorHandler HttpErrorHandler
@ -31,7 +31,6 @@ type HTTPRequester struct {
// 如果 errorHandler 为 nil那么会使用一个默认的错误处理函数。 // 如果 errorHandler 为 nil那么会使用一个默认的错误处理函数。
func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequester { func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequester {
return &HTTPRequester{ return &HTTPRequester{
HTTPClient: HTTPClient{},
requestBuilder: NewRequestBuilder(), requestBuilder: NewRequestBuilder(),
CreateFormBuilder: func(body io.Writer) FormBuilder { CreateFormBuilder: func(body io.Writer) FormBuilder {
return NewFormBuilder(body) return NewFormBuilder(body)
@ -48,6 +47,21 @@ type requestOptions struct {
type requestOption func(*requestOptions) type requestOption func(*requestOptions)
func (r *HTTPRequester) getContext() context.Context {
if r.proxyAddr == "" {
return context.Background()
}
// 如果是以 socks5:// 开头的地址,那么使用 socks5 代理
if strings.HasPrefix(r.proxyAddr, "socks5://") {
return context.WithValue(context.Background(), ProxySock5AddrKey, r.proxyAddr)
}
// 否则使用 http 代理
return context.WithValue(context.Background(), ProxyHTTPAddrKey, r.proxyAddr)
}
// 创建请求 // 创建请求
func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) { func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
args := &requestOptions{ args := &requestOptions{
@ -57,7 +71,7 @@ func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption)
for _, setter := range setters { for _, setter := range setters {
setter(args) setter(args)
} }
req, err := r.requestBuilder.Build(method, url, args.body, args.header) req, err := r.requestBuilder.Build(r.getContext(), method, url, args.body, args.header)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -67,9 +81,7 @@ func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption)
// 发送请求 // 发送请求
func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) { func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
client := r.HTTPClient.getClientFromPool(r.proxyAddr) resp, err := HTTPClient.Do(req)
resp, err := client.Do(req)
r.HTTPClient.returnClientToPool(client)
if err != nil { if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
} }
@ -105,9 +117,7 @@ func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp
// 发送请求 RAW // 发送请求 RAW
func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *types.OpenAIErrorWithStatusCode) { func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *types.OpenAIErrorWithStatusCode) {
// 发送请求 // 发送请求
client := r.HTTPClient.getClientFromPool(r.proxyAddr) resp, err := HTTPClient.Do(req)
resp, err := client.Do(req)
r.HTTPClient.returnClientToPool(client)
if err != nil { if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
} }

View File

@ -2,13 +2,14 @@ package requester
import ( import (
"bytes" "bytes"
"context"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
) )
type RequestBuilder interface { type RequestBuilder interface {
Build(method, url string, body any, header http.Header) (*http.Request, error) Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error)
} }
type HTTPRequestBuilder struct { type HTTPRequestBuilder struct {
@ -22,6 +23,7 @@ func NewRequestBuilder() *HTTPRequestBuilder {
} }
func (b *HTTPRequestBuilder) Build( func (b *HTTPRequestBuilder) Build(
ctx context.Context,
method string, method string,
url string, url string,
body any, body any,
@ -40,7 +42,7 @@ func (b *HTTPRequestBuilder) Build(
bodyReader = bytes.NewBuffer(reqBytes) bodyReader = bytes.NewBuffer(reqBytes)
} }
} }
req, err = http.NewRequest(method, url, bodyReader) req, err = http.NewRequestWithContext(ctx, method, url, bodyReader)
if err != nil { if err != nil {
return return
} }

View File

@ -120,6 +120,34 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest {
ToolChoice: request.ToolChoice, ToolChoice: request.ToolChoice,
} }
// 如果有图片的话并且是base64编码的图片需要把前缀去掉
if zhipuRequest.Model == "glm-4v" {
for i := range zhipuRequest.Messages {
contentList, ok := zhipuRequest.Messages[i].Content.([]any)
if !ok {
continue
}
for j := range contentList {
contentMap, ok := contentList[j].(map[string]any)
if !ok || contentMap["type"] != "image_url" {
continue
}
imageUrl, ok := contentMap["image_url"].(map[string]any)
if !ok {
continue
}
url, ok := imageUrl["url"].(string)
if !ok || !strings.HasPrefix(url, "data:image/") {
continue
}
imageUrl["url"] = strings.Split(url, ",")[1]
contentMap["image_url"] = imageUrl
contentList[j] = contentMap
}
zhipuRequest.Messages[i].Content = contentList
}
}
if request.Functions != nil { if request.Functions != nil {
zhipuRequest.Tools = make([]ZhipuTool, 0, len(request.Functions)) zhipuRequest.Tools = make([]ZhipuTool, 0, len(request.Functions))
for _, function := range request.Functions { for _, function := range request.Functions {