🐛 fix: glm-4v support base64 image (#81)
* 💄 improve: http client changes proxy using context * 🐛 fix: glm-4v support base64 image
This commit is contained in:
parent
fab465e82a
commit
7c78ed9fad
@ -96,7 +96,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
|
|||||||
var BatchUpdateEnabled = false
|
var BatchUpdateEnabled = false
|
||||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||||
|
|
||||||
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 600) // unit is second
|
||||||
|
var ConnectTimeout = GetOrDefault("CONNECT_TIMEOUT", 5) // unit is second
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RequestIdKey = "X-Oneapi-Request-Id"
|
RequestIdKey = "X-Oneapi-Request-Id"
|
||||||
|
@ -1,70 +1,80 @@
|
|||||||
package requester
|
package requester
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
type HTTPClient struct{}
|
type ContextKey string
|
||||||
|
|
||||||
var clientPool = &sync.Pool{
|
const ProxyHTTPAddrKey ContextKey = "proxyHttpAddr"
|
||||||
New: func() interface{} {
|
const ProxySock5AddrKey ContextKey = "proxySock5Addr"
|
||||||
return &http.Client{}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HTTPClient) getClientFromPool(proxyAddr string) *http.Client {
|
func proxyFunc(req *http.Request) (*url.URL, error) {
|
||||||
client := clientPool.Get().(*http.Client)
|
proxyAddr := req.Context().Value(ProxyHTTPAddrKey)
|
||||||
|
if proxyAddr == nil {
|
||||||
if common.RelayTimeout > 0 {
|
return nil, nil
|
||||||
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if proxyAddr != "" {
|
proxyURL, err := url.Parse(proxyAddr.(string))
|
||||||
err := h.setProxy(client, proxyAddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(err.Error())
|
return nil, fmt.Errorf("error parsing proxy address: %w", err)
|
||||||
return client
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return client
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HTTPClient) returnClientToPool(client *http.Client) {
|
|
||||||
// 清除代理设置
|
|
||||||
client.Transport = nil
|
|
||||||
clientPool.Put(client)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HTTPClient) setProxy(client *http.Client, proxyAddr string) error {
|
|
||||||
proxyURL, err := url.Parse(proxyAddr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error parsing proxy address: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch proxyURL.Scheme {
|
switch proxyURL.Scheme {
|
||||||
case "http", "https":
|
case "http", "https":
|
||||||
client.Transport = &http.Transport{
|
return proxyURL, nil
|
||||||
Proxy: http.ProxyURL(proxyURL),
|
|
||||||
}
|
|
||||||
case "socks5":
|
|
||||||
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error creating socks5 dialer: %w", err)
|
|
||||||
}
|
|
||||||
client.Transport = &http.Transport{
|
|
||||||
Dial: dialer.Dial,
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
func socks5ProxyFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
// 设置TCP超时
|
||||||
|
dialer := &net.Dialer{
|
||||||
|
Timeout: time.Duration(common.ConnectTimeout) * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从上下文中获取代理地址
|
||||||
|
proxyAddr, ok := ctx.Value(ProxySock5AddrKey).(string)
|
||||||
|
if !ok {
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, err := url.Parse(proxyAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing proxy address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyDialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating socks5 dialer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return proxyDialer.Dial(network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
var HTTPClient *http.Client
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
trans := &http.Transport{
|
||||||
|
DialContext: socks5ProxyFunc,
|
||||||
|
Proxy: proxyFunc,
|
||||||
|
}
|
||||||
|
|
||||||
|
HTTPClient = &http.Client{
|
||||||
|
Transport: trans,
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.RelayTimeout != 0 {
|
||||||
|
HTTPClient.Timeout = time.Duration(common.RelayTimeout) * time.Second
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package requester
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -18,7 +19,6 @@ import (
|
|||||||
type HttpErrorHandler func(*http.Response) *types.OpenAIError
|
type HttpErrorHandler func(*http.Response) *types.OpenAIError
|
||||||
|
|
||||||
type HTTPRequester struct {
|
type HTTPRequester struct {
|
||||||
HTTPClient HTTPClient
|
|
||||||
requestBuilder RequestBuilder
|
requestBuilder RequestBuilder
|
||||||
CreateFormBuilder func(io.Writer) FormBuilder
|
CreateFormBuilder func(io.Writer) FormBuilder
|
||||||
ErrorHandler HttpErrorHandler
|
ErrorHandler HttpErrorHandler
|
||||||
@ -31,7 +31,6 @@ type HTTPRequester struct {
|
|||||||
// 如果 errorHandler 为 nil,那么会使用一个默认的错误处理函数。
|
// 如果 errorHandler 为 nil,那么会使用一个默认的错误处理函数。
|
||||||
func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequester {
|
func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequester {
|
||||||
return &HTTPRequester{
|
return &HTTPRequester{
|
||||||
HTTPClient: HTTPClient{},
|
|
||||||
requestBuilder: NewRequestBuilder(),
|
requestBuilder: NewRequestBuilder(),
|
||||||
CreateFormBuilder: func(body io.Writer) FormBuilder {
|
CreateFormBuilder: func(body io.Writer) FormBuilder {
|
||||||
return NewFormBuilder(body)
|
return NewFormBuilder(body)
|
||||||
@ -48,6 +47,21 @@ type requestOptions struct {
|
|||||||
|
|
||||||
type requestOption func(*requestOptions)
|
type requestOption func(*requestOptions)
|
||||||
|
|
||||||
|
func (r *HTTPRequester) getContext() context.Context {
|
||||||
|
if r.proxyAddr == "" {
|
||||||
|
return context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是以 socks5:// 开头的地址,那么使用 socks5 代理
|
||||||
|
if strings.HasPrefix(r.proxyAddr, "socks5://") {
|
||||||
|
return context.WithValue(context.Background(), ProxySock5AddrKey, r.proxyAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 否则使用 http 代理
|
||||||
|
return context.WithValue(context.Background(), ProxyHTTPAddrKey, r.proxyAddr)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// 创建请求
|
// 创建请求
|
||||||
func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
|
func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
|
||||||
args := &requestOptions{
|
args := &requestOptions{
|
||||||
@ -57,7 +71,7 @@ func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption)
|
|||||||
for _, setter := range setters {
|
for _, setter := range setters {
|
||||||
setter(args)
|
setter(args)
|
||||||
}
|
}
|
||||||
req, err := r.requestBuilder.Build(method, url, args.body, args.header)
|
req, err := r.requestBuilder.Build(r.getContext(), method, url, args.body, args.header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -67,9 +81,7 @@ func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption)
|
|||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||||
client := r.HTTPClient.getClientFromPool(r.proxyAddr)
|
resp, err := HTTPClient.Do(req)
|
||||||
resp, err := client.Do(req)
|
|
||||||
r.HTTPClient.returnClientToPool(client)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@ -105,9 +117,7 @@ func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp
|
|||||||
// 发送请求 RAW
|
// 发送请求 RAW
|
||||||
func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||||
// 发送请求
|
// 发送请求
|
||||||
client := r.HTTPClient.getClientFromPool(r.proxyAddr)
|
resp, err := HTTPClient.Do(req)
|
||||||
resp, err := client.Do(req)
|
|
||||||
r.HTTPClient.returnClientToPool(client)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
@ -2,13 +2,14 @@ package requester
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RequestBuilder interface {
|
type RequestBuilder interface {
|
||||||
Build(method, url string, body any, header http.Header) (*http.Request, error)
|
Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type HTTPRequestBuilder struct {
|
type HTTPRequestBuilder struct {
|
||||||
@ -22,6 +23,7 @@ func NewRequestBuilder() *HTTPRequestBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *HTTPRequestBuilder) Build(
|
func (b *HTTPRequestBuilder) Build(
|
||||||
|
ctx context.Context,
|
||||||
method string,
|
method string,
|
||||||
url string,
|
url string,
|
||||||
body any,
|
body any,
|
||||||
@ -40,7 +42,7 @@ func (b *HTTPRequestBuilder) Build(
|
|||||||
bodyReader = bytes.NewBuffer(reqBytes)
|
bodyReader = bytes.NewBuffer(reqBytes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
req, err = http.NewRequest(method, url, bodyReader)
|
req, err = http.NewRequestWithContext(ctx, method, url, bodyReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -120,6 +120,34 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest {
|
|||||||
ToolChoice: request.ToolChoice,
|
ToolChoice: request.ToolChoice,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果有图片的话,并且是base64编码的图片,需要把前缀去掉
|
||||||
|
if zhipuRequest.Model == "glm-4v" {
|
||||||
|
for i := range zhipuRequest.Messages {
|
||||||
|
contentList, ok := zhipuRequest.Messages[i].Content.([]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for j := range contentList {
|
||||||
|
contentMap, ok := contentList[j].(map[string]any)
|
||||||
|
if !ok || contentMap["type"] != "image_url" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
imageUrl, ok := contentMap["image_url"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
url, ok := imageUrl["url"].(string)
|
||||||
|
if !ok || !strings.HasPrefix(url, "data:image/") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
imageUrl["url"] = strings.Split(url, ",")[1]
|
||||||
|
contentMap["image_url"] = imageUrl
|
||||||
|
contentList[j] = contentMap
|
||||||
|
}
|
||||||
|
zhipuRequest.Messages[i].Content = contentList
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if request.Functions != nil {
|
if request.Functions != nil {
|
||||||
zhipuRequest.Tools = make([]ZhipuTool, 0, len(request.Functions))
|
zhipuRequest.Tools = make([]ZhipuTool, 0, len(request.Functions))
|
||||||
for _, function := range request.Functions {
|
for _, function := range request.Functions {
|
||||||
|
Loading…
Reference in New Issue
Block a user