300 lines
6.4 KiB
Go
300 lines
6.4 KiB
Go
package common
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"one-api/types"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"golang.org/x/net/proxy"
|
|
)
|
|
|
|
var clientPool = &sync.Pool{
|
|
New: func() interface{} {
|
|
return &http.Client{}
|
|
},
|
|
}
|
|
|
|
func GetHttpClient(proxyAddr string) *http.Client {
|
|
client := clientPool.Get().(*http.Client)
|
|
|
|
if RelayTimeout > 0 {
|
|
client.Timeout = time.Duration(RelayTimeout) * time.Second
|
|
}
|
|
|
|
if proxyAddr != "" {
|
|
proxyURL, err := url.Parse(proxyAddr)
|
|
if err != nil {
|
|
SysError("Error parsing proxy address: " + err.Error())
|
|
return client
|
|
}
|
|
|
|
switch proxyURL.Scheme {
|
|
case "http", "https":
|
|
client.Transport = &http.Transport{
|
|
Proxy: http.ProxyURL(proxyURL),
|
|
}
|
|
case "socks5":
|
|
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
|
|
if err != nil {
|
|
SysError("Error creating SOCKS5 dialer: " + err.Error())
|
|
return client
|
|
}
|
|
client.Transport = &http.Transport{
|
|
Dial: dialer.Dial,
|
|
}
|
|
default:
|
|
SysError("Unsupported proxy scheme: " + proxyURL.Scheme)
|
|
}
|
|
}
|
|
|
|
return client
|
|
|
|
}
|
|
|
|
func PutHttpClient(c *http.Client) {
|
|
clientPool.Put(c)
|
|
}
|
|
|
|
type Client struct {
|
|
requestBuilder RequestBuilder
|
|
CreateFormBuilder func(io.Writer) FormBuilder
|
|
}
|
|
|
|
func NewClient() *Client {
|
|
return &Client{
|
|
requestBuilder: NewRequestBuilder(),
|
|
CreateFormBuilder: func(body io.Writer) FormBuilder {
|
|
return NewFormBuilder(body)
|
|
},
|
|
}
|
|
}
|
|
|
|
type requestOptions struct {
|
|
body any
|
|
header http.Header
|
|
}
|
|
|
|
type requestOption func(*requestOptions)
|
|
|
|
type Stringer interface {
|
|
GetString() *string
|
|
}
|
|
|
|
func WithBody(body any) requestOption {
|
|
return func(args *requestOptions) {
|
|
args.body = body
|
|
}
|
|
}
|
|
|
|
func WithHeader(header map[string]string) requestOption {
|
|
return func(args *requestOptions) {
|
|
for k, v := range header {
|
|
args.header.Set(k, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
func WithContentType(contentType string) requestOption {
|
|
return func(args *requestOptions) {
|
|
args.header.Set("Content-Type", contentType)
|
|
}
|
|
}
|
|
|
|
type RequestError struct {
|
|
HTTPStatusCode int
|
|
Err error
|
|
}
|
|
|
|
func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
|
|
// Default Options
|
|
args := &requestOptions{
|
|
body: nil,
|
|
header: make(http.Header),
|
|
}
|
|
for _, setter := range setters {
|
|
setter(args)
|
|
}
|
|
req, err := c.requestBuilder.Build(method, url, args.body, args.header)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return req, nil
|
|
}
|
|
|
|
func SendRequest(req *http.Request, response any, outputResp bool, proxyAddr string) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
|
// 发送请求
|
|
client := GetHttpClient(proxyAddr)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
|
}
|
|
PutHttpClient(client)
|
|
|
|
if !outputResp {
|
|
defer resp.Body.Close()
|
|
}
|
|
|
|
// 处理响应
|
|
if IsFailureStatusCode(resp) {
|
|
return nil, HandleErrorResp(resp)
|
|
}
|
|
|
|
// 解析响应
|
|
if outputResp {
|
|
var buf bytes.Buffer
|
|
tee := io.TeeReader(resp.Body, &buf)
|
|
err = DecodeResponse(tee, response)
|
|
|
|
// 将响应体重新写入 resp.Body
|
|
resp.Body = io.NopCloser(&buf)
|
|
} else {
|
|
err = DecodeResponse(resp.Body, response)
|
|
}
|
|
if err != nil {
|
|
return nil, ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
|
}
|
|
|
|
if outputResp {
|
|
return resp, nil
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
type GeneralErrorResponse struct {
|
|
Error types.OpenAIError `json:"error"`
|
|
Message string `json:"message"`
|
|
Msg string `json:"msg"`
|
|
Err string `json:"err"`
|
|
ErrorMsg string `json:"error_msg"`
|
|
Header struct {
|
|
Message string `json:"message"`
|
|
} `json:"header"`
|
|
Response struct {
|
|
Error struct {
|
|
Message string `json:"message"`
|
|
} `json:"error"`
|
|
} `json:"response"`
|
|
}
|
|
|
|
func (e GeneralErrorResponse) ToMessage() string {
|
|
if e.Error.Message != "" {
|
|
return e.Error.Message
|
|
}
|
|
if e.Message != "" {
|
|
return e.Message
|
|
}
|
|
if e.Msg != "" {
|
|
return e.Msg
|
|
}
|
|
if e.Err != "" {
|
|
return e.Err
|
|
}
|
|
if e.ErrorMsg != "" {
|
|
return e.ErrorMsg
|
|
}
|
|
if e.Header.Message != "" {
|
|
return e.Header.Message
|
|
}
|
|
if e.Response.Error.Message != "" {
|
|
return e.Response.Error.Message
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// 处理错误响应
|
|
func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
|
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
|
StatusCode: resp.StatusCode,
|
|
OpenAIError: types.OpenAIError{
|
|
Message: "",
|
|
Type: "upstream_error",
|
|
Code: "bad_response_status_code",
|
|
Param: strconv.Itoa(resp.StatusCode),
|
|
},
|
|
}
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = resp.Body.Close()
|
|
if err != nil {
|
|
return
|
|
}
|
|
// var errorResponse types.OpenAIErrorResponse
|
|
var errorResponse GeneralErrorResponse
|
|
err = json.Unmarshal(responseBody, &errorResponse)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if errorResponse.Error.Message != "" {
|
|
// OpenAI format error, so we override the default one
|
|
openAIErrorWithStatusCode.OpenAIError = errorResponse.Error
|
|
} else {
|
|
openAIErrorWithStatusCode.OpenAIError.Message = errorResponse.ToMessage()
|
|
}
|
|
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
|
|
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (c *Client) SendRequestRaw(req *http.Request, proxyAddr string) (body io.ReadCloser, err error) {
|
|
client := GetHttpClient(proxyAddr)
|
|
resp, err := client.Do(req)
|
|
PutHttpClient(client)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
return resp.Body, nil
|
|
}
|
|
|
|
func IsFailureStatusCode(resp *http.Response) bool {
|
|
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
|
|
}
|
|
|
|
func DecodeResponse(body io.Reader, v any) error {
|
|
if v == nil {
|
|
return nil
|
|
}
|
|
|
|
if result, ok := v.(*string); ok {
|
|
return DecodeString(body, result)
|
|
}
|
|
|
|
if stringer, ok := v.(Stringer); ok {
|
|
return DecodeString(body, stringer.GetString())
|
|
}
|
|
|
|
return json.NewDecoder(body).Decode(v)
|
|
}
|
|
|
|
func DecodeString(body io.Reader, output *string) error {
|
|
b, err := io.ReadAll(body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*output = string(b)
|
|
return nil
|
|
}
|
|
|
|
func SetEventStreamHeaders(c *gin.Context) {
|
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
|
c.Writer.Header().Set("Connection", "keep-alive")
|
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
|
}
|