diff --git a/.gitignore b/.gitignore index 67acb98d..1921ddf3 100644 --- a/.gitignore +++ b/.gitignore @@ -8,5 +8,5 @@ build logs data tmp/ -test/ +/test/ .env \ No newline at end of file diff --git a/common/client.go b/common/client.go deleted file mode 100644 index 3e94e92b..00000000 --- a/common/client.go +++ /dev/null @@ -1,299 +0,0 @@ -package common - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "one-api/types" - "strconv" - "sync" - "time" - - "github.com/gin-gonic/gin" - "golang.org/x/net/proxy" -) - -var clientPool = &sync.Pool{ - New: func() interface{} { - return &http.Client{} - }, -} - -func GetHttpClient(proxyAddr string) *http.Client { - client := clientPool.Get().(*http.Client) - - if RelayTimeout > 0 { - client.Timeout = time.Duration(RelayTimeout) * time.Second - } - - if proxyAddr != "" { - proxyURL, err := url.Parse(proxyAddr) - if err != nil { - SysError("Error parsing proxy address: " + err.Error()) - return client - } - - switch proxyURL.Scheme { - case "http", "https": - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - } - case "socks5": - dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct) - if err != nil { - SysError("Error creating SOCKS5 dialer: " + err.Error()) - return client - } - client.Transport = &http.Transport{ - Dial: dialer.Dial, - } - default: - SysError("Unsupported proxy scheme: " + proxyURL.Scheme) - } - } - - return client - -} - -func PutHttpClient(c *http.Client) { - clientPool.Put(c) -} - -type Client struct { - requestBuilder RequestBuilder - CreateFormBuilder func(io.Writer) FormBuilder -} - -func NewClient() *Client { - return &Client{ - requestBuilder: NewRequestBuilder(), - CreateFormBuilder: func(body io.Writer) FormBuilder { - return NewFormBuilder(body) - }, - } -} - -type requestOptions struct { - body any - header http.Header -} - -type requestOption func(*requestOptions) - -type Stringer interface { - GetString() *string -} - -func WithBody(body any) requestOption { - return func(args *requestOptions) { - args.body = body - } -} - -func WithHeader(header map[string]string) requestOption { - return func(args *requestOptions) { - for k, v := range header { - args.header.Set(k, v) - } - } -} - -func WithContentType(contentType string) requestOption { - return func(args *requestOptions) { - args.header.Set("Content-Type", contentType) - } -} - -type RequestError struct { - HTTPStatusCode int - Err error -} - -func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) { - // Default Options - args := &requestOptions{ - body: nil, - header: make(http.Header), - } - for _, setter := range setters { - setter(args) - } - req, err := c.requestBuilder.Build(method, url, args.body, args.header) - if err != nil { - return nil, err - } - - return req, nil -} - -func SendRequest(req *http.Request, response any, outputResp bool, proxyAddr string) (*http.Response, *types.OpenAIErrorWithStatusCode) { - // 发送请求 - client := GetHttpClient(proxyAddr) - resp, err := client.Do(req) - if err != nil { - return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) - } - PutHttpClient(client) - - if !outputResp { - defer resp.Body.Close() - } - - // 处理响应 - if IsFailureStatusCode(resp) { - return nil, HandleErrorResp(resp) - } - - // 解析响应 - if outputResp { - var buf bytes.Buffer - tee := io.TeeReader(resp.Body, &buf) - err = DecodeResponse(tee, response) - - // 将响应体重新写入 resp.Body - resp.Body = io.NopCloser(&buf) - } else { - err = DecodeResponse(resp.Body, response) - } - if err != nil { - return nil, ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) - } - - if outputResp { - return resp, nil - } - - return nil, nil -} - -type GeneralErrorResponse struct { - Error types.OpenAIError `json:"error"` - Message string `json:"message"` - Msg string `json:"msg"` - Err string `json:"err"` - ErrorMsg string `json:"error_msg"` - Header struct { - Message string `json:"message"` - } `json:"header"` - Response struct { - Error struct { - Message string `json:"message"` - } `json:"error"` - } `json:"response"` -} - -func (e GeneralErrorResponse) ToMessage() string { - if e.Error.Message != "" { - return e.Error.Message - } - if e.Message != "" { - return e.Message - } - if e.Msg != "" { - return e.Msg - } - if e.Err != "" { - return e.Err - } - if e.ErrorMsg != "" { - return e.ErrorMsg - } - if e.Header.Message != "" { - return e.Header.Message - } - if e.Response.Error.Message != "" { - return e.Response.Error.Message - } - return "" -} - -// 处理错误响应 -func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { - openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ - StatusCode: resp.StatusCode, - OpenAIError: types.OpenAIError{ - Message: "", - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, - } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return - } - err = resp.Body.Close() - if err != nil { - return - } - // var errorResponse types.OpenAIErrorResponse - var errorResponse GeneralErrorResponse - err = json.Unmarshal(responseBody, &errorResponse) - if err != nil { - return - } - - if errorResponse.Error.Message != "" { - // OpenAI format error, so we override the default one - openAIErrorWithStatusCode.OpenAIError = errorResponse.Error - } else { - openAIErrorWithStatusCode.OpenAIError.Message = errorResponse.ToMessage() - } - if openAIErrorWithStatusCode.OpenAIError.Message == "" { - openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) - } - - return -} - -func (c *Client) SendRequestRaw(req *http.Request, proxyAddr string) (body io.ReadCloser, err error) { - client := GetHttpClient(proxyAddr) - resp, err := client.Do(req) - PutHttpClient(client) - if err != nil { - return - } - - return resp.Body, nil -} - -func IsFailureStatusCode(resp *http.Response) bool { - return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest -} - -func DecodeResponse(body io.Reader, v any) error { - if v == nil { - return nil - } - - if result, ok := v.(*string); ok { - return DecodeString(body, result) - } - - if stringer, ok := v.(Stringer); ok { - return DecodeString(body, stringer.GetString()) - } - - return json.NewDecoder(body).Decode(v) -} - -func DecodeString(body io.Reader, output *string) error { - b, err := io.ReadAll(body) - if err != nil { - return err - } - *output = string(b) - return nil -} - -func SetEventStreamHeaders(c *gin.Context) { - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") -} diff --git a/common/gin.go b/common/gin.go index 7da11b7f..7e865f67 100644 --- a/common/gin.go +++ b/common/gin.go @@ -37,6 +37,14 @@ func ErrorWrapper(err error, code string, statusCode int) *types.OpenAIErrorWith return StringErrorWrapper(err.Error(), code, statusCode) } +func ErrorToOpenAIError(err error) *types.OpenAIError { + return &types.OpenAIError{ + Code: "system error", + Message: err.Error(), + Type: "one_api_error", + } +} + func StringErrorWrapper(err string, code string, statusCode int) *types.OpenAIErrorWithStatusCode { openAIError := types.OpenAIError{ Message: err, diff --git a/common/quota.go b/common/quota.go deleted file mode 100644 index b8c772d9..00000000 --- a/common/quota.go +++ /dev/null @@ -1,59 +0,0 @@ -package common - -// type Quota struct { -// ModelName string -// ModelRatio float64 -// GroupRatio float64 -// Ratio float64 -// UserQuota int -// } - -// func CreateQuota(modelName string, userQuota int, group string) *Quota { -// modelRatio := GetModelRatio(modelName) -// groupRatio := GetGroupRatio(group) - -// return &Quota{ -// ModelName: modelName, -// ModelRatio: modelRatio, -// GroupRatio: groupRatio, -// Ratio: modelRatio * groupRatio, -// UserQuota: userQuota, -// } -// } - -// func (q *Quota) getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { -// if ApproximateTokenEnabled { -// return int(float64(len(text)) * 0.38) -// } -// return len(tokenEncoder.Encode(text, nil, nil)) -// } - -// func (q *Quota) CountTokenMessages(messages []Message, model string) int { -// tokenEncoder := q.getTokenEncoder(model) -// // Reference: -// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb -// // https://github.com/pkoukk/tiktoken-go/issues/6 -// // -// // Every message follows <|start|>{role/name}\n{content}<|end|>\n -// var tokensPerMessage int -// var tokensPerName int -// if model == "gpt-3.5-turbo-0301" { -// tokensPerMessage = 4 -// tokensPerName = -1 // If there's a name, the role is omitted -// } else { -// tokensPerMessage = 3 -// tokensPerName = 1 -// } -// tokenNum := 0 -// for _, message := range messages { -// tokenNum += tokensPerMessage -// tokenNum += q.getTokenNum(tokenEncoder, message.StringContent()) -// tokenNum += q.getTokenNum(tokenEncoder, message.Role) -// if message.Name != nil { -// tokenNum += tokensPerName -// tokenNum += q.getTokenNum(tokenEncoder, *message.Name) -// } -// } -// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> -// return tokenNum -// } diff --git a/common/form_builder.go b/common/requester/form_builder.go similarity index 98% rename from common/form_builder.go rename to common/requester/form_builder.go index e8d13dd2..7cc333b1 100644 --- a/common/form_builder.go +++ b/common/requester/form_builder.go @@ -1,4 +1,4 @@ -package common +package requester import ( "fmt" diff --git a/common/requester/http_client.go b/common/requester/http_client.go new file mode 100644 index 00000000..af511629 --- /dev/null +++ b/common/requester/http_client.go @@ -0,0 +1,68 @@ +package requester + +import ( + "fmt" + "net/http" + "net/url" + "one-api/common" + "sync" + "time" + + "golang.org/x/net/proxy" +) + +type HTTPClient struct{} + +var clientPool = &sync.Pool{ + New: func() interface{} { + return &http.Client{} + }, +} + +func (h *HTTPClient) getClientFromPool(proxyAddr string) *http.Client { + client := clientPool.Get().(*http.Client) + + if common.RelayTimeout > 0 { + client.Timeout = time.Duration(common.RelayTimeout) * time.Second + } + + if proxyAddr != "" { + err := h.setProxy(client, proxyAddr) + if err != nil { + common.SysError(err.Error()) + return client + } + } + + return client +} + +func (h *HTTPClient) returnClientToPool(client *http.Client) { + clientPool.Put(client) +} + +func (h *HTTPClient) setProxy(client *http.Client, proxyAddr string) error { + proxyURL, err := url.Parse(proxyAddr) + if err != nil { + return fmt.Errorf("error parsing proxy address: %w", err) + } + + switch proxyURL.Scheme { + case "http", "https": + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + case "socks5": + dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct) + if err != nil { + return fmt.Errorf("error creating socks5 dialer: %w", err) + } + client.Transport = &http.Transport{ + Dial: dialer.Dial, + } + default: + return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) + } + + return nil +} diff --git a/common/requester/http_requester.go b/common/requester/http_requester.go new file mode 100644 index 00000000..83f95151 --- /dev/null +++ b/common/requester/http_requester.go @@ -0,0 +1,229 @@ +package requester + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/types" + "strconv" + + "github.com/gin-gonic/gin" +) + +type HttpErrorHandler func(*http.Response) *types.OpenAIError + +type HTTPRequester struct { + HTTPClient HTTPClient + requestBuilder RequestBuilder + CreateFormBuilder func(io.Writer) FormBuilder + ErrorHandler HttpErrorHandler + proxyAddr string +} + +// NewHTTPRequester 创建一个新的 HTTPRequester 实例。 +// proxyAddr: 是代理服务器的地址。 +// errorHandler: 是一个错误处理函数,它接收一个 *http.Response 参数并返回一个 *types.OpenAIErrorResponse。 +// 如果 errorHandler 为 nil,那么会使用一个默认的错误处理函数。 +func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequester { + return &HTTPRequester{ + HTTPClient: HTTPClient{}, + requestBuilder: NewRequestBuilder(), + CreateFormBuilder: func(body io.Writer) FormBuilder { + return NewFormBuilder(body) + }, + ErrorHandler: errorHandler, + proxyAddr: proxyAddr, + } +} + +type requestOptions struct { + body any + header http.Header +} + +type requestOption func(*requestOptions) + +// 创建请求 +func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) { + args := &requestOptions{ + body: nil, + header: make(http.Header), + } + for _, setter := range setters { + setter(args) + } + req, err := r.requestBuilder.Build(method, url, args.body, args.header) + if err != nil { + return nil, err + } + + return req, nil +} + +// 发送请求 +func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) { + client := r.HTTPClient.getClientFromPool(r.proxyAddr) + resp, err := client.Do(req) + r.HTTPClient.returnClientToPool(client) + if err != nil { + return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) + } + + if !outputResp { + defer resp.Body.Close() + } + + // 处理响应 + if r.IsFailureStatusCode(resp) { + return nil, HandleErrorResp(resp, r.ErrorHandler) + } + + // 解析响应 + if outputResp { + var buf bytes.Buffer + tee := io.TeeReader(resp.Body, &buf) + err = DecodeResponse(tee, response) + + // 将响应体重新写入 resp.Body + resp.Body = io.NopCloser(&buf) + } else { + err = json.NewDecoder(resp.Body).Decode(response) + } + + if err != nil { + return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) + } + + return resp, nil +} + +// 发送请求 RAW +func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *types.OpenAIErrorWithStatusCode) { + // 发送请求 + client := r.HTTPClient.getClientFromPool(r.proxyAddr) + resp, err := client.Do(req) + r.HTTPClient.returnClientToPool(client) + if err != nil { + return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) + } + + // 处理响应 + if r.IsFailureStatusCode(resp) { + return nil, HandleErrorResp(resp, r.ErrorHandler) + } + + return resp, nil +} + +// 获取流式响应 +func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response, handlerPrefix HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) { + // 如果返回的头是json格式 说明有错误 + if resp.Header.Get("Content-Type") == "application/json" { + return nil, HandleErrorResp(resp, requester.ErrorHandler) + } + + return &streamReader[T]{ + reader: bufio.NewReader(resp.Body), + response: resp, + handlerPrefix: handlerPrefix, + }, nil +} + +// 设置请求体 +func (r *HTTPRequester) WithBody(body any) requestOption { + return func(args *requestOptions) { + args.body = body + } +} + +// 设置请求头 +func (r *HTTPRequester) WithHeader(header map[string]string) requestOption { + return func(args *requestOptions) { + for k, v := range header { + args.header.Set(k, v) + } + } +} + +// 设置Content-Type +func (r *HTTPRequester) WithContentType(contentType string) requestOption { + return func(args *requestOptions) { + args.header.Set("Content-Type", contentType) + } +} + +// 判断是否为失败状态码 +func (r *HTTPRequester) IsFailureStatusCode(resp *http.Response) bool { + return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest +} + +// 处理错误响应 +func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler) *types.OpenAIErrorWithStatusCode { + + openAIErrorWithStatusCode := &types.OpenAIErrorWithStatusCode{ + StatusCode: resp.StatusCode, + OpenAIError: types.OpenAIError{ + Message: "", + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + + defer resp.Body.Close() + + if toOpenAIError != nil { + errorResponse := toOpenAIError(resp) + + if errorResponse != nil && errorResponse.Message != "" { + openAIErrorWithStatusCode.OpenAIError = *errorResponse + } + } + + if openAIErrorWithStatusCode.OpenAIError.Message == "" { + openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } + + return openAIErrorWithStatusCode +} + +func SetEventStreamHeaders(c *gin.Context) { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") +} + +type Stringer interface { + GetString() *string +} + +func DecodeResponse(body io.Reader, v any) error { + if v == nil { + return nil + } + + if result, ok := v.(*string); ok { + return DecodeString(body, result) + } + + if stringer, ok := v.(Stringer); ok { + return DecodeString(body, stringer.GetString()) + } + + return json.NewDecoder(body).Decode(v) +} + +func DecodeString(body io.Reader, output *string) error { + b, err := io.ReadAll(body) + if err != nil { + return err + } + *output = string(b) + return nil +} diff --git a/common/requester/http_stream_reader.go b/common/requester/http_stream_reader.go new file mode 100644 index 00000000..06e70090 --- /dev/null +++ b/common/requester/http_stream_reader.go @@ -0,0 +1,79 @@ +package requester + +import ( + "bufio" + "bytes" + "io" + "net/http" +) + +// 流处理函数,判断依据如下: +// 1.如果有错误信息,则直接返回错误信息 +// 2.如果isFinished=true,则返回io.EOF,并且如果response不为空,还将返回response +// 3.如果rawLine=nil 或者 response长度为0,则直接跳过 +// 4.如果以上条件都不满足,则返回response +type HandlerPrefix[T streamable] func(rawLine *[]byte, isFinished *bool, response *[]T) error + +type streamable interface { + // types.ChatCompletionStreamResponse | types.CompletionResponse + any +} + +type StreamReaderInterface[T streamable] interface { + Recv() (*[]T, error) + Close() +} + +type streamReader[T streamable] struct { + isFinished bool + + reader *bufio.Reader + response *http.Response + + handlerPrefix HandlerPrefix[T] +} + +func (stream *streamReader[T]) Recv() (response *[]T, err error) { + if stream.isFinished { + err = io.EOF + return + } + response, err = stream.processLines() + return +} + +//nolint:gocognit +func (stream *streamReader[T]) processLines() (*[]T, error) { + for { + rawLine, readErr := stream.reader.ReadBytes('\n') + if readErr != nil { + return nil, readErr + } + + noSpaceLine := bytes.TrimSpace(rawLine) + + var response []T + err := stream.handlerPrefix(&noSpaceLine, &stream.isFinished, &response) + + if err != nil { + return nil, err + } + + if stream.isFinished { + if len(response) > 0 { + return &response, io.EOF + } + return nil, io.EOF + } + + if noSpaceLine == nil || len(response) == 0 { + continue + } + + return &response, nil + } +} + +func (stream *streamReader[T]) Close() { + stream.response.Body.Close() +} diff --git a/common/request_builder.go b/common/requester/request_builder.go similarity index 88% rename from common/request_builder.go rename to common/requester/request_builder.go index 6a97b425..88cebe1a 100644 --- a/common/request_builder.go +++ b/common/requester/request_builder.go @@ -1,9 +1,10 @@ -package common +package requester import ( "bytes" "io" "net/http" + "one-api/common" ) type RequestBuilder interface { @@ -11,12 +12,12 @@ type RequestBuilder interface { } type HTTPRequestBuilder struct { - marshaller Marshaller + marshaller common.Marshaller } func NewRequestBuilder() *HTTPRequestBuilder { return &HTTPRequestBuilder{ - marshaller: &JSONMarshaller{}, + marshaller: &common.JSONMarshaller{}, } } diff --git a/common/requester/ws_client.go b/common/requester/ws_client.go new file mode 100644 index 00000000..1b4cb229 --- /dev/null +++ b/common/requester/ws_client.go @@ -0,0 +1,53 @@ +package requester + +import ( + "fmt" + "net" + "net/http" + "net/url" + "one-api/common" + "time" + + "github.com/gorilla/websocket" + "golang.org/x/net/proxy" +) + +func GetWSClient(proxyAddr string) *websocket.Dialer { + dialer := &websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + } + + if proxyAddr != "" { + err := setWSProxy(dialer, proxyAddr) + if err != nil { + common.SysError(err.Error()) + return dialer + } + } + + return dialer +} + +func setWSProxy(dialer *websocket.Dialer, proxyAddr string) error { + proxyURL, err := url.Parse(proxyAddr) + if err != nil { + return fmt.Errorf("error parsing proxy address: %w", err) + } + + switch proxyURL.Scheme { + case "http", "https": + dialer.Proxy = http.ProxyURL(proxyURL) + case "socks5": + socks5Proxy, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct) + if err != nil { + return fmt.Errorf("error creating socks5 dialer: %w", err) + } + dialer.NetDial = func(network, addr string) (net.Conn, error) { + return socks5Proxy.Dial(network, addr) + } + default: + return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) + } + + return nil +} diff --git a/common/requester/ws_reader.go b/common/requester/ws_reader.go new file mode 100644 index 00000000..24d91be4 --- /dev/null +++ b/common/requester/ws_reader.go @@ -0,0 +1,58 @@ +package requester + +import ( + "io" + + "github.com/gorilla/websocket" +) + +type wsReader[T streamable] struct { + isFinished bool + + reader *websocket.Conn + handlerPrefix HandlerPrefix[T] +} + +func (stream *wsReader[T]) Recv() (response *[]T, err error) { + if stream.isFinished { + err = io.EOF + return + } + + response, err = stream.processLines() + return +} + +func (stream *wsReader[T]) processLines() (*[]T, error) { + for { + _, msg, err := stream.reader.ReadMessage() + if err != nil { + return nil, err + } + + var response []T + err = stream.handlerPrefix(&msg, &stream.isFinished, &response) + + if err != nil { + return nil, err + } + + if stream.isFinished { + if len(response) > 0 { + return &response, io.EOF + } + return nil, io.EOF + } + + if msg == nil || len(response) == 0 { + continue + } + + return &response, nil + + } +} + +func (stream *wsReader[T]) Close() { + stream.reader.Close() +} diff --git a/common/requester/ws_requester.go b/common/requester/ws_requester.go new file mode 100644 index 00000000..e16c8121 --- /dev/null +++ b/common/requester/ws_requester.go @@ -0,0 +1,54 @@ +package requester + +import ( + "errors" + "net/http" + "one-api/common" + "one-api/types" + + "github.com/gorilla/websocket" +) + +type WSRequester struct { + WSClient *websocket.Dialer +} + +func NewWSRequester(proxyAddr string) *WSRequester { + return &WSRequester{ + WSClient: GetWSClient(proxyAddr), + } +} + +func (w *WSRequester) NewRequest(url string, header http.Header) (*websocket.Conn, error) { + conn, resp, err := w.WSClient.Dial(url, header) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusSwitchingProtocols { + return nil, errors.New("ws unexpected status code") + } + + return conn, nil +} + +func SendWSJsonRequest[T streamable](conn *websocket.Conn, data any, handlerPrefix HandlerPrefix[T]) (*wsReader[T], *types.OpenAIErrorWithStatusCode) { + err := conn.WriteJSON(data) + if err != nil { + return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError) + } + + return &wsReader[T]{ + reader: conn, + handlerPrefix: handlerPrefix, + }, nil +} + +// 设置请求头 +func (r *WSRequester) WithHeader(headers map[string]string) http.Header { + header := make(http.Header) + for k, v := range headers { + header.Set(k, v) + } + return header +} diff --git a/common/test/api.go b/common/test/api.go new file mode 100644 index 00000000..ba93e637 --- /dev/null +++ b/common/test/api.go @@ -0,0 +1,55 @@ +package test + +import ( + "io" + "net/http" + "net/http/httptest" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +func RequestJSONConfig() map[string]string { + return map[string]string{ + "Content-Type": "application/json", + "Accept": "application/json", + } +} + +func GetContext(method, path string, headers map[string]string, body io.Reader) (*gin.Context, *httptest.ResponseRecorder) { + var req *http.Request + req, _ = http.NewRequest(method, path, body) + for k, v := range headers { + req.Header.Set(k, v) + } + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = req + return c, w +} + +func GetGinRouter(method, path string, headers map[string]string, body *io.Reader) *httptest.ResponseRecorder { + var req *http.Request + r := gin.Default() + + w := httptest.NewRecorder() + req, _ = http.NewRequest(method, path, *body) + for k, v := range headers { + req.Header.Set(k, v) + } + + r.ServeHTTP(w, req) + + return w +} + +func GetChannel(channelType int, baseUrl, other, porxy, modelMapping string) model.Channel { + return model.Channel{ + Type: channelType, + BaseURL: &baseUrl, + Other: other, + Proxy: porxy, + ModelMapping: &modelMapping, + Key: GetTestToken(), + } +} diff --git a/common/test/chat_config.go b/common/test/chat_config.go new file mode 100644 index 00000000..5f2fa44d --- /dev/null +++ b/common/test/chat_config.go @@ -0,0 +1,132 @@ +package test + +import ( + "encoding/json" + "one-api/types" + "strings" +) + +func GetChatCompletionRequest(chatType, modelName, stream string) *types.ChatCompletionRequest { + chatJSON := GetChatRequest(chatType, modelName, stream) + chatCompletionRequest := &types.ChatCompletionRequest{} + json.NewDecoder(chatJSON).Decode(chatCompletionRequest) + return chatCompletionRequest +} + +func GetChatRequest(chatType, modelName, stream string) *strings.Reader { + var chatJSON string + switch chatType { + case "image": + chatJSON = `{ + "model": "` + modelName + `", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What’s in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + } + } + ] + } + ], + "max_tokens": 300, + "stream": ` + stream + ` + }` + case "default": + chatJSON = `{ + "model": "` + modelName + `", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello!" + } + ], + "stream": ` + stream + ` + }` + case "function": + chatJSON = `{ + "model": "` + modelName + `", + "stream": ` + stream + `, + "messages": [ + { + "role": "user", + "content": "What is the weather like in Boston?" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + } + ], + "tool_choice": "auto" + }` + + case "tools": + chatJSON = `{ + "model": "` + modelName + `", + "stream": ` + stream + `, + "messages": [ + { + "role": "user", + "content": "What is the weather like in Boston?" + } + ], + "functions": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit" + ] + } + }, + "required": [ + "location" + ] + } + } + ] + }` + } + + return strings.NewReader(chatJSON) +} diff --git a/common/test/check_chat.go b/common/test/check_chat.go new file mode 100644 index 00000000..20532f15 --- /dev/null +++ b/common/test/check_chat.go @@ -0,0 +1,65 @@ +package test + +import ( + "one-api/types" + "testing" + + "github.com/stretchr/testify/assert" +) + +func CheckChat(t *testing.T, response *types.ChatCompletionResponse, modelName string, usage *types.Usage) { + assert.NotEmpty(t, response.ID) + assert.NotEmpty(t, response.Object) + assert.NotEmpty(t, response.Created) + assert.Equal(t, response.Model, modelName) + assert.IsType(t, []types.ChatCompletionChoice{}, response.Choices) + // check choices 长度大于1 + assert.True(t, len(response.Choices) > 0) + for _, choice := range response.Choices { + assert.NotNil(t, choice.Index) + assert.IsType(t, types.ChatCompletionMessage{}, choice.Message) + assert.NotEmpty(t, choice.Message.Role) + assert.NotEmpty(t, choice.FinishReason) + + // check message + if choice.Message.Content != nil { + multiContents, ok := choice.Message.Content.([]types.ChatMessagePart) + if ok { + for _, content := range multiContents { + assert.NotEmpty(t, content.Type) + if content.Type == "text" { + assert.NotEmpty(t, content.Text) + } else if content.Type == "image_url" { + assert.IsType(t, types.ChatMessageImageURL{}, content.ImageURL) + } + } + } else { + content, ok := choice.Message.Content.(string) + assert.True(t, ok) + assert.NotEmpty(t, content) + } + } else if choice.Message.FunctionCall != nil { + assert.NotEmpty(t, choice.Message.FunctionCall.Name) + assert.Equal(t, choice.FinishReason, types.FinishReasonFunctionCall) + } else if choice.Message.ToolCalls != nil { + assert.IsType(t, []types.ChatCompletionToolCalls{}, choice.Message.ToolCalls) + assert.NotEmpty(t, choice.Message.ToolCalls[0].Id) + assert.NotEmpty(t, choice.Message.ToolCalls[0].Function) + assert.Equal(t, choice.Message.ToolCalls[0].Function, "function") + + assert.IsType(t, types.ChatCompletionToolCallsFunction{}, choice.Message.ToolCalls[0].Function) + assert.NotEmpty(t, choice.Message.ToolCalls[0].Function.Name) + + assert.Equal(t, choice.FinishReason, types.FinishReasonToolCalls) + } else { + assert.Fail(t, "message content is nil") + } + } + + // check usage + assert.IsType(t, &types.Usage{}, response.Usage) + assert.Equal(t, response.Usage.PromptTokens, usage.PromptTokens) + assert.Equal(t, response.Usage.CompletionTokens, usage.CompletionTokens) + assert.Equal(t, response.Usage.TotalTokens, usage.TotalTokens) + +} diff --git a/common/test/checks.go b/common/test/checks.go new file mode 100644 index 00000000..04524fc8 --- /dev/null +++ b/common/test/checks.go @@ -0,0 +1,48 @@ +package test + +import ( + "errors" + "testing" +) + +func NoError(t *testing.T, err error, message ...string) { + t.Helper() + if err != nil { + t.Error(err, message) + } +} + +func HasError(t *testing.T, err error, message ...string) { + t.Helper() + if err == nil { + t.Error(err, message) + } +} + +func ErrorIs(t *testing.T, err, target error, msg ...string) { + t.Helper() + if !errors.Is(err, target) { + t.Fatal(msg) + } +} + +func ErrorIsF(t *testing.T, err, target error, format string, msg ...string) { + t.Helper() + if !errors.Is(err, target) { + t.Fatalf(format, msg) + } +} + +func ErrorIsNot(t *testing.T, err, target error, msg ...string) { + t.Helper() + if errors.Is(err, target) { + t.Fatal(msg) + } +} + +func ErrorIsNotf(t *testing.T, err, target error, format string, msg ...string) { + t.Helper() + if errors.Is(err, target) { + t.Fatalf(format, msg) + } +} diff --git a/common/test/init/init.go b/common/test/init/init.go new file mode 100644 index 00000000..ade77570 --- /dev/null +++ b/common/test/init/init.go @@ -0,0 +1,7 @@ +package init + +import "testing" + +func init() { + testing.Init() +} diff --git a/common/test/server.go b/common/test/server.go new file mode 100644 index 00000000..db504cea --- /dev/null +++ b/common/test/server.go @@ -0,0 +1,63 @@ +package test + +import ( + "log" + "net/http" + "net/http/httptest" + "regexp" + "strings" +) + +const testAPI = "this-is-my-secure-token-do-not-steal!!" + +func GetTestToken() string { + return testAPI +} + +type ServerTest struct { + handlers map[string]handler +} +type handler func(w http.ResponseWriter, r *http.Request) + +func NewTestServer() *ServerTest { + return &ServerTest{handlers: make(map[string]handler)} +} + +func OpenAICheck(w http.ResponseWriter, r *http.Request) bool { + if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && r.Header.Get("api-key") != GetTestToken() { + w.WriteHeader(http.StatusUnauthorized) + return false + } + return true +} + +func (ts *ServerTest) RegisterHandler(path string, handler handler) { + // to make the registered paths friendlier to a regex match in the route handler + // in OpenAITestServer + path = strings.ReplaceAll(path, "*", ".*") + ts.handlers[path] = handler +} + +// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing. +func (ts *ServerTest) TestServer(headerCheck func(w http.ResponseWriter, r *http.Request) bool) *httptest.Server { + return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("received a %s request at path %q\n", r.Method, r.URL.Path) + + // check auth + if headerCheck != nil && !headerCheck(w, r) { + return + } + + // Handle /path/* routes. + // Note: the * is converted to a .* in register handler for proper regex handling + for route, handler := range ts.handlers { + // Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered + pattern, _ := regexp.Compile("^" + route + "$") + if pattern.MatchString(r.URL.Path) { + handler(w, r) + return + } + } + http.Error(w, "the resource path doesn't exist", http.StatusNotFound) + })) +} diff --git a/controller/channel-billing.go b/controller/channel-billing.go index ed4447c0..90241d64 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -67,7 +67,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { return 0, errors.New("provider not implemented") } - return balanceProvider.Balance(channel) + return balanceProvider.Balance() } diff --git a/controller/channel-test.go b/controller/channel-test.go index b291c04c..3f675184 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -1,6 +1,7 @@ package controller import ( + "encoding/json" "errors" "fmt" "net/http" @@ -38,30 +39,36 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e if provider == nil { return errors.New("channel not implemented"), nil } + + newModelName, err := provider.ModelMappingHandler(request.Model) + if err != nil { + return err, nil + } + + request.Model = newModelName + chatProvider, ok := provider.(providers_base.ChatInterface) if !ok { return errors.New("channel not implemented"), nil } - modelMap, err := parseModelMapping(channel.GetModelMapping()) - if err != nil { - return err, nil - } - if modelMap != nil && modelMap[request.Model] != "" { - request.Model = modelMap[request.Model] - } + chatProvider.SetUsage(&types.Usage{}) + + response, openAIErrorWithStatusCode := chatProvider.CreateChatCompletion(&request) - promptTokens := common.CountTokenMessages(request.Messages, request.Model) - Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens) if openAIErrorWithStatusCode != nil { return errors.New(openAIErrorWithStatusCode.Message), &openAIErrorWithStatusCode.OpenAIError } - if Usage.CompletionTokens == 0 { + usage := chatProvider.GetUsage() + + if usage.CompletionTokens == 0 { return fmt.Errorf("channel %s, message 补全 tokens 非预期返回 0", channel.Name), nil } - common.SysLog(fmt.Sprintf("测试模型 %s 返回内容为:%s", channel.Name, w.Body.String())) + // 转换为JSON字符串 + jsonBytes, _ := json.Marshal(response) + common.SysLog(fmt.Sprintf("测试模型 %s 返回内容为:%s", channel.Name, string(jsonBytes))) return nil, nil } @@ -74,9 +81,9 @@ func buildTestRequest() *types.ChatCompletionRequest { Content: "You just need to output 'hi' next.", }, }, - Model: "", - MaxTokens: 1, - Stream: false, + Model: "", + // MaxTokens: 1, + Stream: false, } return testRequest } diff --git a/controller/quota.go b/controller/quota.go new file mode 100644 index 00000000..e8a083a1 --- /dev/null +++ b/controller/quota.go @@ -0,0 +1,164 @@ +package controller + +import ( + "context" + "errors" + "fmt" + "math" + "net/http" + "one-api/common" + "one-api/model" + "one-api/types" + "time" + + "github.com/gin-gonic/gin" +) + +type QuotaInfo struct { + modelName string + promptTokens int + preConsumedTokens int + modelRatio float64 + groupRatio float64 + ratio float64 + preConsumedQuota int + userId int + channelId int + tokenId int + HandelStatus bool +} + +func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) { + quotaInfo := &QuotaInfo{ + modelName: modelName, + promptTokens: promptTokens, + userId: c.GetInt("id"), + channelId: c.GetInt("channel_id"), + tokenId: c.GetInt("token_id"), + HandelStatus: false, + } + quotaInfo.initQuotaInfo(c.GetString("group")) + + errWithCode := quotaInfo.preQuotaConsumption() + if errWithCode != nil { + return nil, errWithCode + } + + return quotaInfo, nil +} + +func (q *QuotaInfo) initQuotaInfo(groupName string) { + modelRatio := common.GetModelRatio(q.modelName) + groupRatio := common.GetGroupRatio(groupName) + preConsumedTokens := common.PreConsumedQuota + ratio := modelRatio * groupRatio + preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio) + + q.preConsumedTokens = preConsumedTokens + q.modelRatio = modelRatio + q.groupRatio = groupRatio + q.ratio = ratio + q.preConsumedQuota = preConsumedQuota + +} + +func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode { + userQuota, err := model.CacheGetUserQuota(q.userId) + if err != nil { + return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + + if userQuota < q.preConsumedQuota { + return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + + err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota) + if err != nil { + return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + + if userQuota > 100*q.preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + q.preConsumedQuota = 0 + // common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) + } + + if q.preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota) + if err != nil { + return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + q.HandelStatus = true + } + + return nil +} + +func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error { + quota := 0 + completionRatio := common.GetCompletionRatio(q.modelName) + promptTokens := usage.PromptTokens + completionTokens := usage.CompletionTokens + quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * q.ratio)) + if q.ratio != 0 && quota <= 0 { + quota = 1 + } + totalTokens := promptTokens + completionTokens + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + } + quotaDelta := quota - q.preConsumedQuota + err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta) + if err != nil { + return errors.New("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(q.userId) + if err != nil { + return errors.New("error consuming token remain quota: " + err.Error()) + } + if quota != 0 { + requestTime := 0 + requestStartTimeValue := ctx.Value("requestStartTime") + if requestStartTimeValue != nil { + requestStartTime, ok := requestStartTimeValue.(time.Time) + if ok { + requestTime = int(time.Since(requestStartTime).Milliseconds()) + } + } + + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", q.modelRatio, q.groupRatio) + model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime) + model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota) + model.UpdateChannelUsedQuota(q.channelId, quota) + } + + return nil +} + +func (q *QuotaInfo) undo(c *gin.Context, errWithCode *types.OpenAIErrorWithStatusCode) { + tokenId := c.GetInt("token_id") + if q.HandelStatus { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, -q.preConsumedQuota) + if err != nil { + common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) + } + }(c.Request.Context()) + } + errorHelper(c, errWithCode) +} + +func (q *QuotaInfo) consume(c *gin.Context, usage *types.Usage) { + tokenName := c.GetString("token_name") + // 如果没有报错,则消费配额 + go func(ctx context.Context) { + err := q.completedQuotaConsumption(usage, tokenName, ctx) + if err != nil { + common.LogError(ctx, err.Error()) + } + }(c.Request.Context()) +} diff --git a/controller/relay-chat.go b/controller/relay-chat.go index 8e74c789..9880ac5e 100644 --- a/controller/relay-chat.go +++ b/controller/relay-chat.go @@ -1,11 +1,11 @@ package controller import ( - "context" + "fmt" "math" "net/http" "one-api/common" - "one-api/model" + "one-api/common/requester" providersBase "one-api/providers/base" "one-api/types" @@ -20,33 +20,18 @@ func RelayChat(c *gin.Context) { return } - channel, pass := fetchChannel(c, chatRequest.Model) - if pass { - return - } - if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 { common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid") return } - // 解析模型映射 - var isModelMapped bool - modelMap, err := parseModelMapping(channel.GetModelMapping()) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - if modelMap != nil && modelMap[chatRequest.Model] != "" { - chatRequest.Model = modelMap[chatRequest.Model] - isModelMapped = true - } - // 获取供应商 - provider, pass := getProvider(c, channel, common.RelayModeChatCompletions) - if pass { + provider, modelName, fail := getProvider(c, chatRequest.Model) + if fail { return } + chatRequest.Model = modelName + chatProvider, ok := provider.(providersBase.ChatInterface) if !ok { common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") @@ -56,39 +41,42 @@ func RelayChat(c *gin.Context) { // 获取Input Tokens promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model) - var quotaInfo *QuotaInfo - var errWithCode *types.OpenAIErrorWithStatusCode - var usage *types.Usage - quotaInfo, errWithCode = generateQuotaInfo(c, chatRequest.Model, promptTokens) + usage := &types.Usage{ + PromptTokens: promptTokens, + } + provider.SetUsage(usage) + + quotaInfo, errWithCode := generateQuotaInfo(c, chatRequest.Model, promptTokens) if errWithCode != nil { errorHelper(c, errWithCode) return } - usage, errWithCode = chatProvider.ChatAction(&chatRequest, isModelMapped, promptTokens) + if chatRequest.Stream { + var response requester.StreamReaderInterface[types.ChatCompletionStreamResponse] + response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest) + if errWithCode != nil { + errorHelper(c, errWithCode) + return + } + errWithCode = responseStreamClient[types.ChatCompletionStreamResponse](c, response) + } else { + var response *types.ChatCompletionResponse + response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest) + if errWithCode != nil { + errorHelper(c, errWithCode) + return + } + errWithCode = responseJsonClient(c, response) + } + + fmt.Println(usage) // 如果报错,则退还配额 if errWithCode != nil { - tokenId := c.GetInt("token_id") - if quotaInfo.HandelStatus { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - errorHelper(c, errWithCode) + quotaInfo.undo(c, errWithCode) return - } else { - tokenName := c.GetString("token_name") - // 如果没有报错,则消费配额 - go func(ctx context.Context) { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }(c.Request.Context()) } + + quotaInfo.consume(c, usage) } diff --git a/controller/relay-completions.go b/controller/relay-completions.go index 1731cb86..4af1e685 100644 --- a/controller/relay-completions.go +++ b/controller/relay-completions.go @@ -1,11 +1,9 @@ package controller import ( - "context" "math" "net/http" "one-api/common" - "one-api/model" providersBase "one-api/providers/base" "one-api/types" @@ -20,33 +18,18 @@ func RelayCompletions(c *gin.Context) { return } - channel, pass := fetchChannel(c, completionRequest.Model) - if pass { - return - } - if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 { common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid") return } - // 解析模型映射 - var isModelMapped bool - modelMap, err := parseModelMapping(channel.GetModelMapping()) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - if modelMap != nil && modelMap[completionRequest.Model] != "" { - completionRequest.Model = modelMap[completionRequest.Model] - isModelMapped = true - } - // 获取供应商 - provider, pass := getProvider(c, channel, common.RelayModeCompletions) - if pass { + provider, modelName, fail := getProvider(c, completionRequest.Model) + if fail { return } + completionRequest.Model = modelName + completionProvider, ok := provider.(providersBase.CompletionInterface) if !ok { common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") @@ -56,39 +39,38 @@ func RelayCompletions(c *gin.Context) { // 获取Input Tokens promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model) - var quotaInfo *QuotaInfo - var errWithCode *types.OpenAIErrorWithStatusCode - var usage *types.Usage - quotaInfo, errWithCode = generateQuotaInfo(c, completionRequest.Model, promptTokens) + usage := &types.Usage{ + PromptTokens: promptTokens, + } + provider.SetUsage(usage) + + quotaInfo, errWithCode := generateQuotaInfo(c, completionRequest.Model, promptTokens) if errWithCode != nil { errorHelper(c, errWithCode) return } - usage, errWithCode = completionProvider.CompleteAction(&completionRequest, isModelMapped, promptTokens) + if completionRequest.Stream { + response, errWithCode := completionProvider.CreateCompletionStream(&completionRequest) + if errWithCode != nil { + errorHelper(c, errWithCode) + return + } + errWithCode = responseStreamClient[types.CompletionResponse](c, response) + } else { + response, errWithCode := completionProvider.CreateCompletion(&completionRequest) + if errWithCode != nil { + errorHelper(c, errWithCode) + return + } + errWithCode = responseJsonClient(c, response) + } // 如果报错,则退还配额 if errWithCode != nil { - tokenId := c.GetInt("token_id") - if quotaInfo.HandelStatus { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - errorHelper(c, errWithCode) + quotaInfo.undo(c, errWithCode) return - } else { - tokenName := c.GetString("token_name") - // 如果没有报错,则消费配额 - go func(ctx context.Context) { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }(c.Request.Context()) } + + quotaInfo.consume(c, usage) } diff --git a/controller/relay-embeddings.go b/controller/relay-embeddings.go index df3192bf..58dffc48 100644 --- a/controller/relay-embeddings.go +++ b/controller/relay-embeddings.go @@ -1,10 +1,8 @@ package controller import ( - "context" "net/http" "one-api/common" - "one-api/model" providersBase "one-api/providers/base" "one-api/types" "strings" @@ -24,28 +22,13 @@ func RelayEmbeddings(c *gin.Context) { return } - channel, pass := fetchChannel(c, embeddingsRequest.Model) - if pass { - return - } - - // 解析模型映射 - var isModelMapped bool - modelMap, err := parseModelMapping(channel.GetModelMapping()) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - if modelMap != nil && modelMap[embeddingsRequest.Model] != "" { - embeddingsRequest.Model = modelMap[embeddingsRequest.Model] - isModelMapped = true - } - // 获取供应商 - provider, pass := getProvider(c, channel, common.RelayModeEmbeddings) - if pass { + provider, modelName, fail := getProvider(c, embeddingsRequest.Model) + if fail { return } + embeddingsRequest.Model = modelName + embeddingsProvider, ok := provider.(providersBase.EmbeddingsInterface) if !ok { common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") @@ -55,39 +38,29 @@ func RelayEmbeddings(c *gin.Context) { // 获取Input Tokens promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model) - var quotaInfo *QuotaInfo - var errWithCode *types.OpenAIErrorWithStatusCode - var usage *types.Usage - quotaInfo, errWithCode = generateQuotaInfo(c, embeddingsRequest.Model, promptTokens) + usage := &types.Usage{ + PromptTokens: promptTokens, + } + provider.SetUsage(usage) + + quotaInfo, errWithCode := generateQuotaInfo(c, embeddingsRequest.Model, promptTokens) if errWithCode != nil { errorHelper(c, errWithCode) return } - usage, errWithCode = embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens) + response, errWithCode := embeddingsProvider.CreateEmbeddings(&embeddingsRequest) + if errWithCode != nil { + errorHelper(c, errWithCode) + return + } + errWithCode = responseJsonClient(c, response) // 如果报错,则退还配额 if errWithCode != nil { - tokenId := c.GetInt("token_id") - if quotaInfo.HandelStatus { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - errorHelper(c, errWithCode) + quotaInfo.undo(c, errWithCode) return - } else { - tokenName := c.GetString("token_name") - // 如果没有报错,则消费配额 - go func(ctx context.Context) { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }(c.Request.Context()) } + + quotaInfo.consume(c, usage) } diff --git a/controller/relay-image-edits.go b/controller/relay-image-edits.go index 006c9520..ef301e2a 100644 --- a/controller/relay-image-edits.go +++ b/controller/relay-image-edits.go @@ -1,10 +1,8 @@ package controller import ( - "context" "net/http" "one-api/common" - "one-api/model" providersBase "one-api/providers/base" "one-api/types" @@ -33,28 +31,13 @@ func RelayImageEdits(c *gin.Context) { imageEditRequest.Size = "1024x1024" } - channel, pass := fetchChannel(c, imageEditRequest.Model) - if pass { - return - } - - // 解析模型映射 - var isModelMapped bool - modelMap, err := parseModelMapping(channel.GetModelMapping()) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - if modelMap != nil && modelMap[imageEditRequest.Model] != "" { - imageEditRequest.Model = modelMap[imageEditRequest.Model] - isModelMapped = true - } - // 获取供应商 - provider, pass := getProvider(c, channel, common.RelayModeImagesEdits) - if pass { + provider, modelName, fail := getProvider(c, imageEditRequest.Model) + if fail { return } + imageEditRequest.Model = modelName + imageEditsProvider, ok := provider.(providersBase.ImageEditsInterface) if !ok { common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") @@ -68,39 +51,29 @@ func RelayImageEdits(c *gin.Context) { return } - var quotaInfo *QuotaInfo - var errWithCode *types.OpenAIErrorWithStatusCode - var usage *types.Usage - quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens) + usage := &types.Usage{ + PromptTokens: promptTokens, + } + provider.SetUsage(usage) + + quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens) if errWithCode != nil { errorHelper(c, errWithCode) return } - usage, errWithCode = imageEditsProvider.ImageEditsAction(&imageEditRequest, isModelMapped, promptTokens) + response, errWithCode := imageEditsProvider.CreateImageEdits(&imageEditRequest) + if errWithCode != nil { + errorHelper(c, errWithCode) + return + } + errWithCode = responseJsonClient(c, response) // 如果报错,则退还配额 if errWithCode != nil { - tokenId := c.GetInt("token_id") - if quotaInfo.HandelStatus { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - errorHelper(c, errWithCode) + quotaInfo.undo(c, errWithCode) return - } else { - tokenName := c.GetString("token_name") - // 如果没有报错,则消费配额 - go func(ctx context.Context) { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }(c.Request.Context()) } + + quotaInfo.consume(c, usage) } diff --git a/controller/relay-image-generations.go b/controller/relay-image-generations.go index 4c7f30c0..4332274d 100644 --- a/controller/relay-image-generations.go +++ b/controller/relay-image-generations.go @@ -1,10 +1,8 @@ package controller import ( - "context" "net/http" "one-api/common" - "one-api/model" providersBase "one-api/providers/base" "one-api/types" @@ -36,28 +34,13 @@ func RelayImageGenerations(c *gin.Context) { imageRequest.Quality = "standard" } - channel, pass := fetchChannel(c, imageRequest.Model) - if pass { - return - } - - // 解析模型映射 - var isModelMapped bool - modelMap, err := parseModelMapping(channel.GetModelMapping()) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - if modelMap != nil && modelMap[imageRequest.Model] != "" { - imageRequest.Model = modelMap[imageRequest.Model] - isModelMapped = true - } - // 获取供应商 - provider, pass := getProvider(c, channel, common.RelayModeImagesGenerations) - if pass { + provider, modelName, fail := getProvider(c, imageRequest.Model) + if fail { return } + imageRequest.Model = modelName + imageGenerationsProvider, ok := provider.(providersBase.ImageGenerationsInterface) if !ok { common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") @@ -71,39 +54,29 @@ func RelayImageGenerations(c *gin.Context) { return } - var quotaInfo *QuotaInfo - var errWithCode *types.OpenAIErrorWithStatusCode - var usage *types.Usage - quotaInfo, errWithCode = generateQuotaInfo(c, imageRequest.Model, promptTokens) + usage := &types.Usage{ + PromptTokens: promptTokens, + } + provider.SetUsage(usage) + + quotaInfo, errWithCode := generateQuotaInfo(c, imageRequest.Model, promptTokens) if errWithCode != nil { errorHelper(c, errWithCode) return } - usage, errWithCode = imageGenerationsProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens) + response, errWithCode := imageGenerationsProvider.CreateImageGenerations(&imageRequest) + if errWithCode != nil { + errorHelper(c, errWithCode) + return + } + errWithCode = responseJsonClient(c, response) // 如果报错,则退还配额 if errWithCode != nil { - tokenId := c.GetInt("token_id") - if quotaInfo.HandelStatus { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - errorHelper(c, errWithCode) + quotaInfo.undo(c, errWithCode) return - } else { - tokenName := c.GetString("token_name") - // 如果没有报错,则消费配额 - go func(ctx context.Context) { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }(c.Request.Context()) } + + quotaInfo.consume(c, usage) } diff --git a/controller/relay-image-variationsy.go b/controller/relay-image-variationsy.go index 019b431f..2c9069ee 100644 --- a/controller/relay-image-variationsy.go +++ b/controller/relay-image-variationsy.go @@ -1,10 +1,8 @@ package controller import ( - "context" "net/http" "one-api/common" - "one-api/model" providersBase "one-api/providers/base" "one-api/types" @@ -28,28 +26,13 @@ func RelayImageVariations(c *gin.Context) { imageEditRequest.Size = "1024x1024" } - channel, pass := fetchChannel(c, imageEditRequest.Model) - if pass { - return - } - - // 解析模型映射 - var isModelMapped bool - modelMap, err := parseModelMapping(channel.GetModelMapping()) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - if modelMap != nil && modelMap[imageEditRequest.Model] != "" { - imageEditRequest.Model = modelMap[imageEditRequest.Model] - isModelMapped = true - } - // 获取供应商 - provider, pass := getProvider(c, channel, common.RelayModeImagesVariations) - if pass { + provider, modelName, fail := getProvider(c, imageEditRequest.Model) + if fail { return } + imageEditRequest.Model = modelName + imageVariations, ok := provider.(providersBase.ImageVariationsInterface) if !ok { common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") @@ -63,39 +46,29 @@ func RelayImageVariations(c *gin.Context) { return } - var quotaInfo *QuotaInfo - var errWithCode *types.OpenAIErrorWithStatusCode - var usage *types.Usage - quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens) + usage := &types.Usage{ + PromptTokens: promptTokens, + } + provider.SetUsage(usage) + + quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens) if errWithCode != nil { errorHelper(c, errWithCode) return } - usage, errWithCode = imageVariations.ImageVariationsAction(&imageEditRequest, isModelMapped, promptTokens) + response, errWithCode := imageVariations.CreateImageVariations(&imageEditRequest) + if errWithCode != nil { + errorHelper(c, errWithCode) + return + } + errWithCode = responseJsonClient(c, response) // 如果报错,则退还配额 if errWithCode != nil { - tokenId := c.GetInt("token_id") - if quotaInfo.HandelStatus { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - errorHelper(c, errWithCode) + quotaInfo.undo(c, errWithCode) return - } else { - tokenName := c.GetString("token_name") - // 如果没有报错,则消费配额 - go func(ctx context.Context) { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }(c.Request.Context()) } + + quotaInfo.consume(c, usage) } diff --git a/controller/relay-moderations.go b/controller/relay-moderations.go index 5feccdb4..136b6fdd 100644 --- a/controller/relay-moderations.go +++ b/controller/relay-moderations.go @@ -1,10 +1,8 @@ package controller import ( - "context" "net/http" "one-api/common" - "one-api/model" providersBase "one-api/providers/base" "one-api/types" @@ -24,28 +22,13 @@ func RelayModerations(c *gin.Context) { moderationRequest.Model = "text-moderation-stable" } - channel, pass := fetchChannel(c, moderationRequest.Model) - if pass { - return - } - - // 解析模型映射 - var isModelMapped bool - modelMap, err := parseModelMapping(channel.GetModelMapping()) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - if modelMap != nil && modelMap[moderationRequest.Model] != "" { - moderationRequest.Model = modelMap[moderationRequest.Model] - isModelMapped = true - } - // 获取供应商 - provider, pass := getProvider(c, channel, common.RelayModeModerations) - if pass { + provider, modelName, fail := getProvider(c, moderationRequest.Model) + if fail { return } + moderationRequest.Model = modelName + moderationProvider, ok := provider.(providersBase.ModerationInterface) if !ok { common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") @@ -55,39 +38,29 @@ func RelayModerations(c *gin.Context) { // 获取Input Tokens promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model) - var quotaInfo *QuotaInfo - var errWithCode *types.OpenAIErrorWithStatusCode - var usage *types.Usage - quotaInfo, errWithCode = generateQuotaInfo(c, moderationRequest.Model, promptTokens) + usage := &types.Usage{ + PromptTokens: promptTokens, + } + provider.SetUsage(usage) + + quotaInfo, errWithCode := generateQuotaInfo(c, moderationRequest.Model, promptTokens) if errWithCode != nil { errorHelper(c, errWithCode) return } - usage, errWithCode = moderationProvider.ModerationAction(&moderationRequest, isModelMapped, promptTokens) + response, errWithCode := moderationProvider.CreateModeration(&moderationRequest) + if errWithCode != nil { + errorHelper(c, errWithCode) + return + } + errWithCode = responseJsonClient(c, response) // 如果报错,则退还配额 if errWithCode != nil { - tokenId := c.GetInt("token_id") - if quotaInfo.HandelStatus { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - errorHelper(c, errWithCode) + quotaInfo.undo(c, errWithCode) return - } else { - tokenName := c.GetString("token_name") - // 如果没有报错,则消费配额 - go func(ctx context.Context) { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }(c.Request.Context()) } + + quotaInfo.consume(c, usage) } diff --git a/controller/relay-speech.go b/controller/relay-speech.go index e5ace14c..ec7fd7ce 100644 --- a/controller/relay-speech.go +++ b/controller/relay-speech.go @@ -1,10 +1,8 @@ package controller import ( - "context" "net/http" "one-api/common" - "one-api/model" providersBase "one-api/providers/base" "one-api/types" @@ -20,28 +18,13 @@ func RelaySpeech(c *gin.Context) { return } - channel, pass := fetchChannel(c, speechRequest.Model) - if pass { - return - } - - // 解析模型映射 - var isModelMapped bool - modelMap, err := parseModelMapping(channel.GetModelMapping()) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - if modelMap != nil && modelMap[speechRequest.Model] != "" { - speechRequest.Model = modelMap[speechRequest.Model] - isModelMapped = true - } - // 获取供应商 - provider, pass := getProvider(c, channel, common.RelayModeAudioSpeech) - if pass { + provider, modelName, fail := getProvider(c, speechRequest.Model) + if fail { return } + speechRequest.Model = modelName + speechProvider, ok := provider.(providersBase.SpeechInterface) if !ok { common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") @@ -51,39 +34,29 @@ func RelaySpeech(c *gin.Context) { // 获取Input Tokens promptTokens := len(speechRequest.Input) - var quotaInfo *QuotaInfo - var errWithCode *types.OpenAIErrorWithStatusCode - var usage *types.Usage - quotaInfo, errWithCode = generateQuotaInfo(c, speechRequest.Model, promptTokens) + usage := &types.Usage{ + PromptTokens: promptTokens, + } + provider.SetUsage(usage) + + quotaInfo, errWithCode := generateQuotaInfo(c, speechRequest.Model, promptTokens) if errWithCode != nil { errorHelper(c, errWithCode) return } - usage, errWithCode = speechProvider.SpeechAction(&speechRequest, isModelMapped, promptTokens) + response, errWithCode := speechProvider.CreateSpeech(&speechRequest) + if errWithCode != nil { + errorHelper(c, errWithCode) + return + } + errWithCode = responseMultipart(c, response) // 如果报错,则退还配额 if errWithCode != nil { - tokenId := c.GetInt("token_id") - if quotaInfo.HandelStatus { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - errorHelper(c, errWithCode) + quotaInfo.undo(c, errWithCode) return - } else { - tokenName := c.GetString("token_name") - // 如果没有报错,则消费配额 - go func(ctx context.Context) { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }(c.Request.Context()) } + + quotaInfo.consume(c, usage) } diff --git a/controller/relay-transcriptions.go b/controller/relay-transcriptions.go index b08174a1..a6005963 100644 --- a/controller/relay-transcriptions.go +++ b/controller/relay-transcriptions.go @@ -1,10 +1,8 @@ package controller import ( - "context" "net/http" "one-api/common" - "one-api/model" providersBase "one-api/providers/base" "one-api/types" @@ -20,28 +18,13 @@ func RelayTranscriptions(c *gin.Context) { return } - channel, pass := fetchChannel(c, audioRequest.Model) - if pass { - return - } - - // 解析模型映射 - var isModelMapped bool - modelMap, err := parseModelMapping(channel.GetModelMapping()) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - if modelMap != nil && modelMap[audioRequest.Model] != "" { - audioRequest.Model = modelMap[audioRequest.Model] - isModelMapped = true - } - // 获取供应商 - provider, pass := getProvider(c, channel, common.RelayModeAudioTranscription) - if pass { + provider, modelName, fail := getProvider(c, audioRequest.Model) + if fail { return } + audioRequest.Model = modelName + transcriptionsProvider, ok := provider.(providersBase.TranscriptionsInterface) if !ok { common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") @@ -51,39 +34,29 @@ func RelayTranscriptions(c *gin.Context) { // 获取Input Tokens promptTokens := 0 - var quotaInfo *QuotaInfo - var errWithCode *types.OpenAIErrorWithStatusCode - var usage *types.Usage - quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens) + usage := &types.Usage{ + PromptTokens: promptTokens, + } + provider.SetUsage(usage) + + quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens) if errWithCode != nil { errorHelper(c, errWithCode) return } - usage, errWithCode = transcriptionsProvider.TranscriptionsAction(&audioRequest, isModelMapped, promptTokens) + response, errWithCode := transcriptionsProvider.CreateTranscriptions(&audioRequest) + if errWithCode != nil { + errorHelper(c, errWithCode) + return + } + errWithCode = responseCustom(c, response) // 如果报错,则退还配额 if errWithCode != nil { - tokenId := c.GetInt("token_id") - if quotaInfo.HandelStatus { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - errorHelper(c, errWithCode) + quotaInfo.undo(c, errWithCode) return - } else { - tokenName := c.GetString("token_name") - // 如果没有报错,则消费配额 - go func(ctx context.Context) { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }(c.Request.Context()) } + + quotaInfo.consume(c, usage) } diff --git a/controller/relay-translations.go b/controller/relay-translations.go index fcdada36..c13935eb 100644 --- a/controller/relay-translations.go +++ b/controller/relay-translations.go @@ -1,10 +1,8 @@ package controller import ( - "context" "net/http" "one-api/common" - "one-api/model" providersBase "one-api/providers/base" "one-api/types" @@ -20,28 +18,13 @@ func RelayTranslations(c *gin.Context) { return } - channel, pass := fetchChannel(c, audioRequest.Model) - if pass { - return - } - - // 解析模型映射 - var isModelMapped bool - modelMap, err := parseModelMapping(channel.GetModelMapping()) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - if modelMap != nil && modelMap[audioRequest.Model] != "" { - audioRequest.Model = modelMap[audioRequest.Model] - isModelMapped = true - } - // 获取供应商 - provider, pass := getProvider(c, channel, common.RelayModeAudioTranslation) - if pass { + provider, modelName, fail := getProvider(c, audioRequest.Model) + if fail { return } + audioRequest.Model = modelName + translationProvider, ok := provider.(providersBase.TranslationInterface) if !ok { common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") @@ -51,39 +34,29 @@ func RelayTranslations(c *gin.Context) { // 获取Input Tokens promptTokens := 0 - var quotaInfo *QuotaInfo - var errWithCode *types.OpenAIErrorWithStatusCode - var usage *types.Usage - quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens) + usage := &types.Usage{ + PromptTokens: promptTokens, + } + provider.SetUsage(usage) + + quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens) if errWithCode != nil { errorHelper(c, errWithCode) return } - usage, errWithCode = translationProvider.TranslationAction(&audioRequest, isModelMapped, promptTokens) + response, errWithCode := translationProvider.CreateTranslation(&audioRequest) + if errWithCode != nil { + errorHelper(c, errWithCode) + return + } + errWithCode = responseCustom(c, response) // 如果报错,则退还配额 if errWithCode != nil { - tokenId := c.GetInt("token_id") - if quotaInfo.HandelStatus { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - errorHelper(c, errWithCode) + quotaInfo.undo(c, errWithCode) return - } else { - tokenName := c.GetString("token_name") - // 如果没有报错,则消费配额 - go func(ctx context.Context) { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }(c.Request.Context()) } + + quotaInfo.consume(c, usage) } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 21776ce1..ec2e6ede 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -1,25 +1,47 @@ package controller import ( - "context" "encoding/json" "errors" "fmt" - "math" + "io" "net/http" "one-api/common" + "one-api/common/requester" "one-api/model" "one-api/providers" providersBase "one-api/providers/base" "one-api/types" "reflect" "strconv" - "time" "github.com/gin-gonic/gin" "github.com/go-playground/validator/v10" ) +func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail bool) { + channel, fail := fetchChannel(c, modeName) + if fail { + return + } + + provider = providers.GetProvider(channel, c) + if provider == nil { + common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found") + fail = true + return + } + + newModelName, err := provider.ModelMappingHandler(modeName) + if err != nil { + common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) + fail = true + return + } + + return +} + func GetValidFieldName(err error, obj interface{}) string { getObj := reflect.TypeOf(obj) if errs, ok := err.(validator.ValidationErrors); ok { @@ -32,17 +54,17 @@ func GetValidFieldName(err error, obj interface{}) string { return err.Error() } -func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, pass bool) { +func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail bool) { channelId, ok := c.Get("channelId") if ok { - channel, pass = fetchChannelById(c, channelId.(int)) - if pass { + channel, fail = fetchChannelById(c, channelId.(int)) + if fail { return } } - channel, pass = fetchChannelByModel(c, modelName) - if pass { + channel, fail = fetchChannelByModel(c, modelName) + if fail { return } @@ -86,21 +108,6 @@ func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool return channel, false } -func getProvider(c *gin.Context, channel *model.Channel, relayMode int) (providersBase.ProviderInterface, bool) { - provider := providers.GetProvider(channel, c) - if provider == nil { - common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found") - return nil, true - } - - if !provider.SupportAPI(relayMode) { - common.AbortWithMessage(c, http.StatusNotImplemented, "channel does not support this API") - return nil, true - } - - return provider, false -} - func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool { if !common.AutomaticDisableChannelEnabled { return false @@ -130,138 +137,81 @@ func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool { return true } -func parseModelMapping(modelMapping string) (map[string]string, error) { - if modelMapping == "" || modelMapping == "{}" { - return nil, nil - } - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) +func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWithStatusCode { + // 将data转换为 JSON + responseBody, err := json.Marshal(data) if err != nil { - return nil, err - } - return modelMap, nil -} - -type QuotaInfo struct { - modelName string - promptTokens int - preConsumedTokens int - modelRatio float64 - groupRatio float64 - ratio float64 - preConsumedQuota int - userId int - channelId int - tokenId int - HandelStatus bool -} - -func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) { - quotaInfo := &QuotaInfo{ - modelName: modelName, - promptTokens: promptTokens, - userId: c.GetInt("id"), - channelId: c.GetInt("channel_id"), - tokenId: c.GetInt("token_id"), - HandelStatus: false, - } - quotaInfo.initQuotaInfo(c.GetString("group")) - - errWithCode := quotaInfo.preQuotaConsumption() - if errWithCode != nil { - return nil, errWithCode + return common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) } - return quotaInfo, nil -} - -func (q *QuotaInfo) initQuotaInfo(groupName string) { - modelRatio := common.GetModelRatio(q.modelName) - groupRatio := common.GetGroupRatio(groupName) - preConsumedTokens := common.PreConsumedQuota - ratio := modelRatio * groupRatio - preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio) - - q.preConsumedTokens = preConsumedTokens - q.modelRatio = modelRatio - q.groupRatio = groupRatio - q.ratio = ratio - q.preConsumedQuota = preConsumedQuota - - return -} - -func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode { - userQuota, err := model.CacheGetUserQuota(q.userId) + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(http.StatusOK) + _, err = c.Writer.Write(responseBody) if err != nil { - return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) - } - - if userQuota < q.preConsumedQuota { - return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - - err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota) - if err != nil { - return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - - if userQuota > 100*q.preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - q.preConsumedQuota = 0 - // common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) - } - - if q.preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota) - if err != nil { - return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) - } - q.HandelStatus = true + return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError) } return nil } -func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error { - quota := 0 - completionRatio := common.GetCompletionRatio(q.modelName) - promptTokens := usage.PromptTokens - completionTokens := usage.CompletionTokens - quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * q.ratio)) - if q.ratio != 0 && quota <= 0 { - quota = 1 - } - totalTokens := promptTokens + completionTokens - if totalTokens == 0 { - // in this case, must be some error happened - // we cannot just return, because we may have to return the pre-consumed quota - quota = 0 - } - quotaDelta := quota - q.preConsumedQuota - err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta) - if err != nil { - return errors.New("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(q.userId) - if err != nil { - return errors.New("error consuming token remain quota: " + err.Error()) - } - if quota != 0 { - requestTime := 0 - requestStartTimeValue := ctx.Value("requestStartTime") - if requestStartTimeValue != nil { - requestStartTime, ok := requestStartTimeValue.(time.Time) - if ok { - requestTime = int(time.Since(requestStartTime).Milliseconds()) +func responseStreamClient[T any](c *gin.Context, stream requester.StreamReaderInterface[T]) *types.OpenAIErrorWithStatusCode { + requester.SetEventStreamHeaders(c) + defer stream.Close() + + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + if response != nil && len(*response) > 0 { + for _, streamResponse := range *response { + responseBody, _ := json.Marshal(streamResponse) + c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)}) + } } + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + break } - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", q.modelRatio, q.groupRatio) - model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime) - model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota) - model.UpdateChannelUsedQuota(q.channelId, quota) + if err != nil { + c.Render(-1, common.CustomEvent{Data: "data: " + err.Error()}) + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + break + } + + for _, streamResponse := range *response { + responseBody, _ := json.Marshal(streamResponse) + c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)}) + } + } + + return nil +} + +func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode { + defer resp.Body.Close() + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + + c.Writer.WriteHeader(resp.StatusCode) + + _, err := io.Copy(c.Writer, resp.Body) + if err != nil { + return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError) + } + + return nil +} + +func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types.OpenAIErrorWithStatusCode { + for k, v := range response.Headers { + c.Writer.Header().Set(k, v) + } + c.Writer.WriteHeader(http.StatusOK) + + _, err := c.Writer.Write(response.Body) + if err != nil { + return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError) } return nil diff --git a/providers/aigc2d/balance.go b/providers/aigc2d/balance.go index 070950c3..297bd703 100644 --- a/providers/aigc2d/balance.go +++ b/providers/aigc2d/balance.go @@ -2,29 +2,26 @@ package aigc2d import ( "errors" - "one-api/common" - "one-api/model" "one-api/providers/base" ) -func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) { +func (p *Aigc2dProvider) Balance() (float64, error) { fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") headers := p.GetRequestHeaders() - client := common.NewClient() - req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers)) if err != nil { return 0, err } // 发送请求 var response base.BalanceResponse - _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) + _, errWithCode := p.Requester.SendRequest(req, &response, false) if errWithCode != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } - channel.UpdateBalance(response.TotalAvailable) + p.Channel.UpdateBalance(response.TotalAvailable) return response.TotalAvailable, nil } diff --git a/providers/aigc2d/base.go b/providers/aigc2d/base.go index b4654b3b..1499e987 100644 --- a/providers/aigc2d/base.go +++ b/providers/aigc2d/base.go @@ -1,20 +1,19 @@ package aigc2d import ( + "one-api/model" "one-api/providers/base" "one-api/providers/openai" - - "github.com/gin-gonic/gin" ) type Aigc2dProviderFactory struct{} -func (f Aigc2dProviderFactory) Create(c *gin.Context) base.ProviderInterface { - return &Aigc2dProvider{ - OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aigc2d.com"), - } -} - type Aigc2dProvider struct { *openai.OpenAIProvider } + +func (f Aigc2dProviderFactory) Create(channel *model.Channel) base.ProviderInterface { + return &Aigc2dProvider{ + OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.aigc2d.com"), + } +} diff --git a/providers/aiproxy/balance.go b/providers/aiproxy/balance.go index 66f96d2b..830520cc 100644 --- a/providers/aiproxy/balance.go +++ b/providers/aiproxy/balance.go @@ -3,24 +3,21 @@ package aiproxy import ( "errors" "fmt" - "one-api/common" - "one-api/model" ) -func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) { +func (p *AIProxyProvider) Balance() (float64, error) { fullRequestURL := "https://aiproxy.io/api/report/getUserOverview" headers := make(map[string]string) - headers["Api-Key"] = channel.Key + headers["Api-Key"] = p.Channel.Key - client := common.NewClient() - req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers)) if err != nil { return 0, err } // 发送请求 var response AIProxyUserOverviewResponse - _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) + _, errWithCode := p.Requester.SendRequest(req, &response, false) if errWithCode != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } @@ -29,7 +26,7 @@ func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) { return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) } - channel.UpdateBalance(response.Data.TotalPoints) + p.Channel.UpdateBalance(response.Data.TotalPoints) return response.Data.TotalPoints, nil } diff --git a/providers/aiproxy/base.go b/providers/aiproxy/base.go index 8b0d0ff3..57fef229 100644 --- a/providers/aiproxy/base.go +++ b/providers/aiproxy/base.go @@ -1,20 +1,19 @@ package aiproxy import ( + "one-api/model" "one-api/providers/base" "one-api/providers/openai" - - "github.com/gin-gonic/gin" ) type AIProxyProviderFactory struct{} -func (f AIProxyProviderFactory) Create(c *gin.Context) base.ProviderInterface { - return &AIProxyProvider{ - OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aiproxy.io"), - } -} - type AIProxyProvider struct { *openai.OpenAIProvider } + +func (f AIProxyProviderFactory) Create(channel *model.Channel) base.ProviderInterface { + return &AIProxyProvider{ + OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.aiproxy.io"), + } +} diff --git a/providers/ali/ali_test.go b/providers/ali/ali_test.go new file mode 100644 index 00000000..c5cbc26f --- /dev/null +++ b/providers/ali/ali_test.go @@ -0,0 +1,24 @@ +package ali_test + +import ( + "net/http" + "one-api/common" + "one-api/common/test" + "one-api/model" +) + +func setupAliTestServer() (baseUrl string, server *test.ServerTest, teardown func()) { + server = test.NewTestServer() + ts := server.TestServer(func(w http.ResponseWriter, r *http.Request) bool { + return test.OpenAICheck(w, r) + }) + ts.Start() + teardown = ts.Close + + baseUrl = ts.URL + return +} + +func getAliChannel(baseUrl string) model.Channel { + return test.GetChannel(common.ChannelTypeAli, baseUrl, "", "", "") +} diff --git a/providers/ali/base.go b/providers/ali/base.go index 4db8c040..bbbdd11a 100644 --- a/providers/ali/base.go +++ b/providers/ali/base.go @@ -1,32 +1,66 @@ package ali import ( + "encoding/json" "fmt" + "net/http" "strings" + "one-api/common/requester" + "one-api/model" "one-api/providers/base" - - "github.com/gin-gonic/gin" + "one-api/types" ) // 定义供应商工厂 type AliProviderFactory struct{} +type AliProvider struct { + base.BaseProvider +} + // 创建 AliProvider // https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation -func (f AliProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f AliProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &AliProvider{ BaseProvider: base.BaseProvider{ - BaseURL: "https://dashscope.aliyuncs.com", - ChatCompletions: "/api/v1/services/aigc/text-generation/generation", - Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding", - Context: c, + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle), }, } } -type AliProvider struct { - base.BaseProvider +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://dashscope.aliyuncs.com", + ChatCompletions: "/api/v1/services/aigc/text-generation/generation", + Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding", + } +} + +// 请求错误处理 +func requestErrorHandle(resp *http.Response) *types.OpenAIError { + var aliError *AliError + err := json.NewDecoder(resp.Body).Decode(aliError) + if err != nil { + return nil + } + + return errorHandle(aliError) +} + +// 错误处理 +func errorHandle(aliError *AliError) *types.OpenAIError { + if aliError.Code == "" { + return nil + } + return &types.OpenAIError{ + Message: aliError.Message, + Type: aliError.Code, + Param: aliError.RequestId, + Code: aliError.Code, + } } func (p *AliProvider) GetFullRequestURL(requestURL string, modelName string) string { diff --git a/providers/ali/chat.go b/providers/ali/chat.go index 485f6172..b38b3c0f 100644 --- a/providers/ali/chat.go +++ b/providers/ali/chat.go @@ -1,51 +1,116 @@ package ali import ( - "bufio" "encoding/json" - "io" "net/http" "one-api/common" + "one-api/common/requester" "one-api/types" "strings" ) -// 阿里云响应处理 -func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if aliResponse.Code != "" { - errWithCode = &types.OpenAIErrorWithStatusCode{ - OpenAIError: types.OpenAIError{ - Message: aliResponse.Message, - Type: aliResponse.Code, - Param: aliResponse.RequestId, - Code: aliResponse.Code, - }, - StatusCode: resp.StatusCode, - } - - return - } - - OpenAIResponse = types.ChatCompletionResponse{ - ID: aliResponse.RequestId, - Object: "chat.completion", - Created: common.GetTimestamp(), - Model: aliResponse.Model, - Choices: aliResponse.Output.ToChatCompletionChoices(), - Usage: &types.Usage{ - PromptTokens: aliResponse.Usage.InputTokens, - CompletionTokens: aliResponse.Usage.OutputTokens, - TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens, - }, - } - - return +type aliStreamHandler struct { + Usage *types.Usage + Request *types.ChatCompletionRequest + lastStreamResponse string } const AliEnableSearchModelSuffix = "-internet" -// 获取聊天请求体 -func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest { +func (p *AliProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getAliChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + aliResponse := &AliChatResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, aliResponse, false) + if errWithCode != nil { + return nil, errWithCode + } + + return p.convertToChatOpenai(aliResponse, request) +} + +func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getAliChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode + } + + chatHandler := &aliStreamHandler{ + Usage: p.Usage, + Request: request, + } + + return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) +} + +func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + if errWithCode != nil { + return nil, errWithCode + } + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, request.Model) + + // 获取请求头 + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + headers["X-DashScope-SSE"] = "enable" + } + + aliRequest := convertFromChatOpenai(request) + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(aliRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + return req, nil +} + +// 转换为OpenAI聊天请求体 +func (p *AliProvider) convertToChatOpenai(response *AliChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + error := errorHandle(&response.AliError) + if error != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *error, + StatusCode: http.StatusBadRequest, + } + return + } + + openaiResponse = &types.ChatCompletionResponse{ + ID: response.RequestId, + Object: "chat.completion", + Created: common.GetTimestamp(), + Model: request.Model, + Choices: response.Output.ToChatCompletionChoices(), + Usage: &types.Usage{ + PromptTokens: response.Usage.InputTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, + }, + } + + *p.Usage = *openaiResponse.Usage + + return +} + +// 阿里云聊天请求体 +func convertFromChatOpenai(request *types.ChatCompletionRequest) *AliChatRequest { messages := make([]AliMessage, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] @@ -96,163 +161,68 @@ func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) * } } -// 聊天 -func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - - requestBody := p.getChatRequestBody(request) - - fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) - headers := p.GetRequestHeaders() - if request.Stream { - headers["Accept"] = "text/event-stream" - headers["X-DashScope-SSE"] = "enable" +// 转换为OpenAI聊天流式请求体 +func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { + // 如果rawLine 前缀不为data:,则直接返回 + if !strings.HasPrefix(string(*rawLine), "data:") { + *rawLine = nil + return nil } - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + // 去除前缀 + *rawLine = (*rawLine)[5:] + + var aliResponse AliChatResponse + err := json.Unmarshal(*rawLine, &aliResponse) if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return common.ErrorToOpenAIError(err) } - if request.Stream { - usage, errWithCode = p.sendStreamRequest(req, request.Model) - if errWithCode != nil { - return - } - - if usage == nil { - usage = &types.Usage{ - PromptTokens: 0, - CompletionTokens: 0, - TotalTokens: 0, - } - } - - } else { - aliResponse := &AliChatResponse{ - Model: request.Model, - } - errWithCode = p.SendRequest(req, aliResponse, false) - if errWithCode != nil { - return - } - - usage = &types.Usage{ - PromptTokens: aliResponse.Usage.InputTokens, - CompletionTokens: aliResponse.Usage.OutputTokens, - TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens, - } + error := errorHandle(&aliResponse.AliError) + if error != nil { + return error } - return + + return h.convertToOpenaiStream(&aliResponse, response) + } -// 阿里云响应转OpenAI响应 -func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse { - // chatChoice := aliResponse.Output.ToChatCompletionChoices() - // jsonBody, _ := json.MarshalIndent(chatChoice, "", " ") - // fmt.Println("requestBody:", string(jsonBody)) +func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, response *[]types.ChatCompletionStreamResponse) error { + content := aliResponse.Output.Choices[0].Message.StringContent() + var choice types.ChatCompletionStreamChoice choice.Index = aliResponse.Output.Choices[0].Index - choice.Delta.Content = aliResponse.Output.Choices[0].Message.StringContent() - // fmt.Println("choice.Delta.Content:", chatChoice[0].Message) - if aliResponse.Output.Choices[0].FinishReason != "null" { - finishReason := aliResponse.Output.Choices[0].FinishReason - choice.FinishReason = &finishReason + choice.Delta.Content = strings.TrimPrefix(content, h.lastStreamResponse) + if aliResponse.Output.Choices[0].FinishReason != "" { + if aliResponse.Output.Choices[0].FinishReason != "null" { + finishReason := aliResponse.Output.Choices[0].FinishReason + choice.FinishReason = &finishReason + } } - response := types.ChatCompletionStreamResponse{ + if aliResponse.Output.FinishReason != "" { + if aliResponse.Output.FinishReason != "null" { + finishReason := aliResponse.Output.FinishReason + choice.FinishReason = &finishReason + } + } + + h.lastStreamResponse = content + streamResponse := types.ChatCompletionStreamResponse{ ID: aliResponse.RequestId, Object: "chat.completion.chunk", Created: common.GetTimestamp(), - Model: aliResponse.Model, + Model: h.Request.Model, Choices: []types.ChatCompletionStreamChoice{choice}, } - return &response -} -// 发送流请求 -func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - defer req.Body.Close() - - usage = &types.Usage{} - // 发送请求 - client := common.GetHttpClient(p.Channel.Proxy) - resp, err := client.Do(req) - if err != nil { - return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) - } - common.PutHttpClient(client) - - if common.IsFailureStatusCode(resp) { - return nil, common.HandleErrorResp(resp) + if aliResponse.Usage.OutputTokens != 0 { + h.Usage.PromptTokens = aliResponse.Usage.InputTokens + h.Usage.CompletionTokens = aliResponse.Usage.OutputTokens + h.Usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens } - defer resp.Body.Close() + *response = append(*response, streamResponse) - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() - common.SetEventStreamHeaders(p.Context) - lastResponseText := "" - index := 0 - p.Context.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var aliResponse AliChatResponse - err := json.Unmarshal([]byte(data), &aliResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if aliResponse.Usage.OutputTokens != 0 { - usage.PromptTokens = aliResponse.Usage.InputTokens - usage.CompletionTokens = aliResponse.Usage.OutputTokens - usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens - } - aliResponse.Model = model - aliResponse.Output.Choices[0].Index = index - index++ - response := p.streamResponseAli2OpenAI(&aliResponse) - response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) - lastResponseText = aliResponse.Output.Choices[0].Message.StringContent() - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - - return + return nil } diff --git a/providers/ali/chat_test.go b/providers/ali/chat_test.go new file mode 100644 index 00000000..c2a4f28a --- /dev/null +++ b/providers/ali/chat_test.go @@ -0,0 +1,330 @@ +package ali_test + +import ( + "encoding/json" + "fmt" + "net/http" + "one-api/common/test" + _ "one-api/common/test/init" + "one-api/providers" + providers_base "one-api/providers/base" + "one-api/types" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func getChatProvider(url string, context *gin.Context) providers_base.ChatInterface { + channel := getAliChannel(url) + provider := providers.GetProvider(&channel, context) + chatProvider, _ := provider.(providers_base.ChatInterface) + + return chatProvider +} + +func TestChatCompletions(t *testing.T) { + url, server, teardown := setupAliTestServer() + context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil) + defer teardown() + server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionEndpoint) + + chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "false") + + chatProvider := getChatProvider(url, context) + usage := &types.Usage{} + chatProvider.SetUsage(usage) + response, errWithCode := chatProvider.CreateChatCompletion(chatRequest) + + assert.Nil(t, errWithCode) + assert.IsType(t, &types.Usage{}, usage) + assert.Equal(t, 33, usage.TotalTokens) + assert.Equal(t, 14, usage.PromptTokens) + assert.Equal(t, 19, usage.CompletionTokens) + + // 转换成JSON字符串 + responseBody, err := json.Marshal(response) + if err != nil { + fmt.Println(err) + assert.Fail(t, "json marshal error") + } + fmt.Println(string(responseBody)) + + test.CheckChat(t, response, "qwen-turbo", usage) +} + +func TestChatCompletionsError(t *testing.T) { + url, server, teardown := setupAliTestServer() + context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil) + defer teardown() + server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionErrorEndpoint) + + chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "false") + + chatProvider := getChatProvider(url, context) + _, err := chatProvider.CreateChatCompletion(chatRequest) + usage := chatProvider.GetUsage() + + assert.NotNil(t, err) + assert.Nil(t, usage) + assert.Equal(t, "InvalidParameter", err.Code) +} + +// func TestChatCompletionsStream(t *testing.T) { +// url, server, teardown := setupAliTestServer() +// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil) +// defer teardown() +// server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionStreamEndpoint) + +// channel := getAliChannel(url) +// provider := providers.GetProvider(&channel, context) +// chatProvider, _ := provider.(providers_base.ChatInterface) +// chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "true") + +// usage := &types.Usage{} +// chatProvider.SetUsage(usage) +// response, errWithCode := chatProvider.CreateChatCompletionStream(chatRequest) +// assert.Nil(t, errWithCode) + +// assert.IsType(t, &types.Usage{}, usage) +// assert.Equal(t, 16, usage.TotalTokens) +// assert.Equal(t, 8, usage.PromptTokens) +// assert.Equal(t, 8, usage.CompletionTokens) + +// streamResponseCheck(t, w.Body.String()) +// } + +// func TestChatCompletionsStreamError(t *testing.T) { +// url, server, teardown := setupAliTestServer() +// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil) +// defer teardown() +// server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionStreamErrorEndpoint) + +// channel := getAliChannel(url) +// provider := providers.GetProvider(&channel, context) +// chatProvider, _ := provider.(providers_base.ChatInterface) +// chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "true") + +// usage, err := chatProvider.ChatAction(chatRequest, 0) + +// // 打印 context 写入的内容 +// fmt.Println(w.Body.String()) + +// assert.NotNil(t, err) +// assert.Nil(t, usage) +// } + +// func TestChatImageCompletions(t *testing.T) { +// url, server, teardown := setupAliTestServer() +// context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil) +// defer teardown() +// server.RegisterHandler("/api/v1/services/aigc/multimodal-generation/generation", handleChatImageCompletionEndpoint) + +// channel := getAliChannel(url) +// provider := providers.GetProvider(&channel, context) +// chatProvider, _ := provider.(providers_base.ChatInterface) +// chatRequest := test.GetChatCompletionRequest("image", "qwen-vl-plus", "false") + +// usage, err := chatProvider.ChatAction(chatRequest, 0) + +// assert.Nil(t, err) +// assert.IsType(t, &types.Usage{}, usage) +// assert.Equal(t, 1306, usage.TotalTokens) +// assert.Equal(t, 1279, usage.PromptTokens) +// assert.Equal(t, 27, usage.CompletionTokens) +// } + +// func TestChatImageCompletionsStream(t *testing.T) { +// url, server, teardown := setupAliTestServer() +// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil) +// defer teardown() +// server.RegisterHandler("/api/v1/services/aigc/multimodal-generation/generation", handleChatImageCompletionStreamEndpoint) + +// channel := getAliChannel(url) +// provider := providers.GetProvider(&channel, context) +// chatProvider, _ := provider.(providers_base.ChatInterface) +// chatRequest := test.GetChatCompletionRequest("image", "qwen-vl-plus", "true") + +// usage, err := chatProvider.ChatAction(chatRequest, 0) + +// fmt.Println(w.Body.String()) + +// assert.Nil(t, err) +// assert.IsType(t, &types.Usage{}, usage) +// assert.Equal(t, 1342, usage.TotalTokens) +// assert.Equal(t, 1279, usage.PromptTokens) +// assert.Equal(t, 63, usage.CompletionTokens) +// streamResponseCheck(t, w.Body.String()) +// } + +func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + response := `{"output":{"choices":[{"finish_reason":"stop","message":{"role":"assistant","content":"您好!我可以帮您查询最近的公园,请问您现在所在的位置是哪里呢?"}}]},"usage":{"total_tokens":33,"output_tokens":19,"input_tokens":14},"request_id":"2479f818-9717-9b0b-9769-0d26e873a3f6"}` + + w.Header().Set("Content-Type", "application/json") + fmt.Fprintln(w, response) +} + +func handleChatCompletionErrorEndpoint(w http.ResponseWriter, r *http.Request) { + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + response := `{"code":"InvalidParameter","message":"Role must be user or assistant and Content length must be greater than 0","request_id":"4883ee8d-f095-94ff-a94a-5ce0a94bc81f"}` + + w.Header().Set("Content-Type", "application/json") + fmt.Fprintln(w, response) +} + +func handleChatCompletionStreamEndpoint(w http.ResponseWriter, r *http.Request) { + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + // 检测头部是否有X-DashScope-SSE: enable + if r.Header.Get("X-DashScope-SSE") != "enable" { + http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest) + } + + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("id:1\n")...) + dataBytes = append(dataBytes, []byte("event:result\n")...) + dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...) + //nolint:lll + data := `{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":10,"input_tokens":8,"output_tokens":2},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}` + dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("id:2\n")...) + dataBytes = append(dataBytes, []byte("event:result\n")...) + dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...) + //nolint:lll + data = `{"output":{"choices":[{"message":{"content":"有什么我可以帮助你的吗?","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":16,"input_tokens":8,"output_tokens":8},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}` + dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("id:3\n")...) + dataBytes = append(dataBytes, []byte("event:result\n")...) + dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...) + //nolint:lll + data = `{"output":{"choices":[{"message":{"content":"","role":"assistant"},"finish_reason":"stop"}]},"usage":{"total_tokens":16,"input_tokens":8,"output_tokens":8},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}` + dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...) + + _, err := w.Write(dataBytes) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func handleChatCompletionStreamErrorEndpoint(w http.ResponseWriter, r *http.Request) { + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + // 检测头部是否有X-DashScope-SSE: enable + if r.Header.Get("X-DashScope-SSE") != "enable" { + http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest) + } + + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("id:1\n")...) + dataBytes = append(dataBytes, []byte("event:error\n")...) + dataBytes = append(dataBytes, []byte(":HTTP_STATUS/400\n")...) + //nolint:lll + data := `{"code":"InvalidParameter","message":"Role must be user or assistant and Content length must be greater than 0","request_id":"6b932ba9-41bd-9ad3-b430-24bc1e125880"}` + dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...) + + _, err := w.Write(dataBytes) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func handleChatImageCompletionEndpoint(w http.ResponseWriter, r *http.Request) { + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + response := `{"output":{"finish_reason":"stop","choices":[{"message":{"role":"assistant","content":[{"text":"这张照片展示的是一个海滩的场景,但是并没有明确指出具体的位置。可以看到海浪和日落背景下的沙滩景色。"}]}}]},"usage":{"output_tokens":27,"input_tokens":1279,"image_tokens":1247},"request_id":"a360d53b-b993-927f-9a68-bef6b2b2042e"}` + + w.Header().Set("Content-Type", "application/json") + fmt.Fprintln(w, response) +} + +func handleChatImageCompletionStreamEndpoint(w http.ResponseWriter, r *http.Request) { + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + // 检测头部是否有X-DashScope-SSE: enable + if r.Header.Get("X-DashScope-SSE") != "enable" { + http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest) + } + + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("id:1\n")...) + dataBytes = append(dataBytes, []byte("event:result\n")...) + dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...) + //nolint:lll + data := `{"output":{"choices":[{"message":{"content":[{"text":"这张"}],"role":"assistant"}}],"finish_reason":"null"},"usage":{"input_tokens":1279,"output_tokens":1,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}` + dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("id:2\n")...) + dataBytes = append(dataBytes, []byte("event:result\n")...) + dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...) + //nolint:lll + data = `{"output":{"choices":[{"message":{"content":[{"text":"这张照片"}],"role":"assistant"}}],"finish_reason":"null"},"usage":{"input_tokens":1279,"output_tokens":2,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}` + dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("id:3\n")...) + dataBytes = append(dataBytes, []byte("event:result\n")...) + dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...) + //nolint:lll + data = `{"output":{"choices":[{"message":{"content":[{"text":"这张照片展示的是一个海滩的场景,具体来说是在日落时分。由于没有明显的地标或建筑物等特征可以辨认出具体的地点信息,所以无法确定这是哪个地方的海滩。但是根据图像中的元素和环境特点,我们可以推测这可能是一个位于沿海地区的沙滩海岸线。"}],"role":"assistant"}}],"finish_reason":"stop"},"usage":{"input_tokens":1279,"output_tokens":63,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}` + dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...) + + _, err := w.Write(dataBytes) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func streamResponseCheck(t *testing.T, response string) { + // 以换行符分割response + lines := strings.Split(response, "\n\n") + // 如果最后一行为空,则删除最后一行 + if lines[len(lines)-1] == "" { + lines = lines[:len(lines)-1] + } + + // 循环遍历每一行 + for _, line := range lines { + if line == "" { + continue + } + // assert判断 是否以data: 开头 + assert.True(t, strings.HasPrefix(line, "data: ")) + } + + // 检测最后一行是否以data: [DONE] 结尾 + assert.True(t, strings.HasSuffix(lines[len(lines)-1], "data: [DONE]")) + // 检测倒数第二行是否存在 `"finish_reason":"stop"` + assert.True(t, strings.Contains(lines[len(lines)-2], `"finish_reason":"stop"`)) +} diff --git a/providers/ali/embeddings.go b/providers/ali/embeddings.go index 1320eb74..fdbe6ca6 100644 --- a/providers/ali/embeddings.go +++ b/providers/ali/embeddings.go @@ -6,40 +6,37 @@ import ( "one-api/types" ) -// 嵌入请求处理 -func (aliResponse *AliEmbeddingResponse) ResponseHandler(resp *http.Response) (any, *types.OpenAIErrorWithStatusCode) { - if aliResponse.Code != "" { - return nil, &types.OpenAIErrorWithStatusCode{ - OpenAIError: types.OpenAIError{ - Message: aliResponse.Message, - Type: aliResponse.Code, - Param: aliResponse.RequestId, - Code: aliResponse.Code, - }, - StatusCode: resp.StatusCode, - } +func (p *AliProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) + if errWithCode != nil { + return nil, errWithCode + } + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, request.Model) + + // 获取请求头 + headers := p.GetRequestHeaders() + + aliRequest := convertFromEmbeddingOpenai(request) + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(aliRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + defer req.Body.Close() + + aliResponse := &AliEmbeddingResponse{} + + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, aliResponse, false) + if errWithCode != nil { + return nil, errWithCode } - openAIEmbeddingResponse := &types.EmbeddingResponse{ - Object: "list", - Data: make([]types.Embedding, 0, len(aliResponse.Output.Embeddings)), - Model: "text-embedding-v1", - Usage: &types.Usage{TotalTokens: aliResponse.Usage.TotalTokens}, - } - - for _, item := range aliResponse.Output.Embeddings { - openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{ - Object: `embedding`, - Index: item.TextIndex, - Embedding: item.Embedding, - }) - } - - return openAIEmbeddingResponse, nil + return p.convertToEmbeddingOpenai(aliResponse, request) } -// 获取嵌入请求体 -func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest { +func convertFromEmbeddingOpenai(request *types.EmbeddingRequest) *AliEmbeddingRequest { return &AliEmbeddingRequest{ Model: "text-embedding-v1", Input: struct { @@ -50,24 +47,36 @@ func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) } } -func (p *AliProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - - requestBody := p.getEmbeddingsRequestBody(request) - fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model) - headers := p.GetRequestHeaders() - - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - - aliEmbeddingResponse := &AliEmbeddingResponse{} - errWithCode = p.SendRequest(req, aliEmbeddingResponse, false) - if errWithCode != nil { +func (p *AliProvider) convertToEmbeddingOpenai(response *AliEmbeddingResponse, request *types.EmbeddingRequest) (openaiResponse *types.EmbeddingResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + error := errorHandle(&response.AliError) + if error != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *error, + StatusCode: http.StatusBadRequest, + } return } - usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens} - return usage, nil + openaiResponse = &types.EmbeddingResponse{ + Object: "list", + Data: make([]types.Embedding, 0, len(response.Output.Embeddings)), + Model: request.Model, + Usage: &types.Usage{ + PromptTokens: response.Usage.TotalTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.TotalTokens, + }, + } + + for _, item := range response.Output.Embeddings { + openaiResponse.Data = append(openaiResponse.Data, types.Embedding{ + Object: `embedding`, + Index: item.TextIndex, + Embedding: item.Embedding, + }) + } + + *p.Usage = *openaiResponse.Usage + + return } diff --git a/providers/ali/type.go b/providers/ali/type.go index 47d3ad51..4cd0b329 100644 --- a/providers/ali/type.go +++ b/providers/ali/type.go @@ -52,7 +52,8 @@ type AliChoice struct { } type AliOutput struct { - Choices []types.ChatCompletionChoice `json:"choices"` + Choices []types.ChatCompletionChoice `json:"choices"` + FinishReason string `json:"finish_reason,omitempty"` } func (o *AliOutput) ToChatCompletionChoices() []types.ChatCompletionChoice { @@ -70,7 +71,6 @@ func (o *AliOutput) ToChatCompletionChoices() []types.ChatCompletionChoice { type AliChatResponse struct { Output AliOutput `json:"output"` Usage AliUsage `json:"usage"` - Model string `json:"model,omitempty"` AliError } diff --git a/providers/api2d/balance.go b/providers/api2d/balance.go index 67f9d8ae..4f7e3cef 100644 --- a/providers/api2d/balance.go +++ b/providers/api2d/balance.go @@ -2,7 +2,6 @@ package api2d import ( "errors" - "one-api/common" "one-api/model" "one-api/providers/base" ) @@ -11,15 +10,14 @@ func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) { fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") headers := p.GetRequestHeaders() - client := common.NewClient() - req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers)) if err != nil { return 0, err } // 发送请求 var response base.BalanceResponse - _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) + _, errWithCode := p.Requester.SendRequest(req, &response, false) if errWithCode != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } diff --git a/providers/api2d/base.go b/providers/api2d/base.go index ca9ab256..5eefe366 100644 --- a/providers/api2d/base.go +++ b/providers/api2d/base.go @@ -1,18 +1,17 @@ package api2d import ( + "one-api/model" "one-api/providers/base" "one-api/providers/openai" - - "github.com/gin-gonic/gin" ) type Api2dProviderFactory struct{} // 创建 Api2dProvider -func (f Api2dProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f Api2dProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &Api2dProvider{ - OpenAIProvider: openai.CreateOpenAIProvider(c, "https://oa.api2d.net"), + OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://oa.api2d.net"), } } diff --git a/providers/api2gpt/balance.go b/providers/api2gpt/balance.go index 1288e8a8..370ee882 100644 --- a/providers/api2gpt/balance.go +++ b/providers/api2gpt/balance.go @@ -2,7 +2,6 @@ package api2gpt import ( "errors" - "one-api/common" "one-api/model" "one-api/providers/base" ) @@ -11,15 +10,14 @@ func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) { fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") headers := p.GetRequestHeaders() - client := common.NewClient() - req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers)) if err != nil { return 0, err } // 发送请求 var response base.BalanceResponse - _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) + _, errWithCode := p.Requester.SendRequest(req, &response, false) if errWithCode != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } diff --git a/providers/api2gpt/base.go b/providers/api2gpt/base.go index c502108a..9c3afa7f 100644 --- a/providers/api2gpt/base.go +++ b/providers/api2gpt/base.go @@ -1,17 +1,16 @@ package api2gpt import ( + "one-api/model" "one-api/providers/base" "one-api/providers/openai" - - "github.com/gin-gonic/gin" ) type Api2gptProviderFactory struct{} -func (f Api2gptProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f Api2gptProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &Api2gptProvider{ - OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.api2gpt.com"), + OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.api2gpt.com"), } } diff --git a/providers/azure/base.go b/providers/azure/base.go index 597dd2a8..f3c7a202 100644 --- a/providers/azure/base.go +++ b/providers/azure/base.go @@ -1,30 +1,23 @@ package azure import ( + "one-api/common/requester" + "one-api/model" "one-api/providers/base" "one-api/providers/openai" - - "github.com/gin-gonic/gin" ) type AzureProviderFactory struct{} // 创建 AzureProvider -func (f AzureProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f AzureProviderFactory) Create(channel *model.Channel) base.ProviderInterface { + config := getAzureConfig() return &AzureProvider{ OpenAIProvider: openai.OpenAIProvider{ BaseProvider: base.BaseProvider{ - BaseURL: "", - Completions: "/completions", - ChatCompletions: "/chat/completions", - Embeddings: "/embeddings", - AudioTranscriptions: "/audio/transcriptions", - AudioTranslations: "/audio/translations", - ImagesGenerations: "/images/generations", - // ImagesEdit: "/images/edit", - // ImagesVariations: "/images/variations", - Context: c, - // AudioSpeech: "/audio/speech", + Config: config, + Channel: channel, + Requester: requester.NewHTTPRequester(channel.Proxy, openai.RequestErrorHandle), }, IsAzure: true, BalanceAction: false, @@ -32,6 +25,18 @@ func (f AzureProviderFactory) Create(c *gin.Context) base.ProviderInterface { } } +func getAzureConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "", + Completions: "/completions", + ChatCompletions: "/chat/completions", + Embeddings: "/embeddings", + AudioTranscriptions: "/audio/transcriptions", + AudioTranslations: "/audio/translations", + ImagesGenerations: "/images/generations", + } +} + type AzureProvider struct { openai.OpenAIProvider } diff --git a/providers/azure/image_generations.go b/providers/azure/image_generations.go index 294f66fe..a2e143ce 100644 --- a/providers/azure/image_generations.go +++ b/providers/azure/image_generations.go @@ -10,13 +10,60 @@ import ( "time" ) -func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if c.Status == "canceled" || c.Status == "failed" { +func (p *AzureProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) { + if !openai.IsWithinRange(request.Model, request.N) { + return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest) + } + + req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + var response *types.ImageResponse + var resp *http.Response + if request.Model == "dall-e-2" { + imageAzureResponse := &ImageAzureResponse{} + resp, errWithCode = p.Requester.SendRequest(req, imageAzureResponse, false) + if errWithCode != nil { + return nil, errWithCode + } + response, errWithCode = p.ResponseAzureImageHandler(resp, imageAzureResponse) + if errWithCode != nil { + return nil, errWithCode + } + } else { + var openaiResponse openai.OpenAIProviderImageResponse + _, errWithCode = p.Requester.SendRequest(req, &openaiResponse, false) + if errWithCode != nil { + return nil, errWithCode + } + // 检测是否错误 + openaiErr := openai.ErrorHandle(&openaiResponse.OpenAIErrorResponse) + if openaiErr != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *openaiErr, + StatusCode: http.StatusBadRequest, + } + return nil, errWithCode + } + response = &openaiResponse.ImageResponse + } + + p.Usage.TotalTokens = p.Usage.PromptTokens + + return response, nil + +} + +func (p *AzureProvider) ResponseAzureImageHandler(resp *http.Response, azure *ImageAzureResponse) (OpenAIResponse *types.ImageResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + if azure.Status == "canceled" || azure.Status == "failed" { errWithCode = &types.OpenAIErrorWithStatusCode{ OpenAIError: types.OpenAIError{ - Message: c.Error.Message, + Message: azure.Error.Message, Type: "one_api_error", - Code: c.Error.Code, + Code: azure.Error.Code, }, StatusCode: resp.StatusCode, } @@ -28,8 +75,7 @@ func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIRespons return nil, common.ErrorWrapper(errors.New("image url is empty"), "get_images_url_failed", http.StatusInternalServerError) } - client := common.NewClient() - req, err := client.NewRequest("GET", operation_location, common.WithHeader(c.Header)) + req, err := p.Requester.NewRequest("GET", operation_location, p.Requester.WithHeader(p.GetRequestHeaders())) if err != nil { return nil, common.ErrorWrapper(err, "get_images_request_failed", http.StatusInternalServerError) } @@ -38,7 +84,7 @@ func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIRespons for i := 0; i < 3; i++ { // 休眠 2 秒 time.Sleep(2 * time.Second) - _, errWithCode = common.SendRequest(req, &getImageAzureResponse, false, c.Proxy) + _, errWithCode = p.Requester.SendRequest(req, &getImageAzureResponse, false) fmt.Println("getImageAzureResponse", getImageAzureResponse) if errWithCode != nil { return @@ -47,57 +93,17 @@ func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIRespons if getImageAzureResponse.Status == "canceled" || getImageAzureResponse.Status == "failed" { return nil, &types.OpenAIErrorWithStatusCode{ OpenAIError: types.OpenAIError{ - Message: c.Error.Message, + Message: getImageAzureResponse.Error.Message, Type: "get_images_request_failed", - Code: c.Error.Code, + Code: getImageAzureResponse.Error.Code, }, StatusCode: resp.StatusCode, } } if getImageAzureResponse.Status == "succeeded" { - return getImageAzureResponse.Result, nil + return &getImageAzureResponse.Result, nil } } return nil, common.ErrorWrapper(errors.New("get image Timeout"), "get_images_url_failed", http.StatusInternalServerError) } - -func (p *AzureProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - - requestBody, err := p.GetRequestBody(&request, isModelMapped) - if err != nil { - return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) - } - - fullRequestURL := p.GetFullRequestURL(p.ImagesGenerations, request.Model) - headers := p.GetRequestHeaders() - - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - - if request.Model == "dall-e-2" { - imageAzureResponse := &ImageAzureResponse{ - Header: headers, - Proxy: p.Channel.Proxy, - } - errWithCode = p.SendRequest(req, imageAzureResponse, false) - } else { - openAIProviderImageResponseResponse := &openai.OpenAIProviderImageResponseResponse{} - errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true) - } - - if errWithCode != nil { - return - } - - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: 0, - TotalTokens: promptTokens, - } - - return -} diff --git a/providers/azureSpeech/base.go b/providers/azureSpeech/base.go index 9f88a69a..af029484 100644 --- a/providers/azureSpeech/base.go +++ b/providers/azureSpeech/base.go @@ -1,21 +1,23 @@ package azureSpeech import ( + "one-api/common/requester" + "one-api/model" "one-api/providers/base" - - "github.com/gin-gonic/gin" ) // 定义供应商工厂 type AzureSpeechProviderFactory struct{} // 创建 AliProvider -func (f AzureSpeechProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f AzureSpeechProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &AzureSpeechProvider{ BaseProvider: base.BaseProvider{ - BaseURL: "", - AudioSpeech: "/cognitiveservices/v1", - Context: c, + Config: base.ProviderConfig{ + AudioSpeech: "/cognitiveservices/v1", + }, + Channel: channel, + Requester: requester.NewHTTPRequester(channel.Proxy, nil), }, } } diff --git a/providers/azureSpeech/speech.go b/providers/azureSpeech/speech.go index 8f215915..d6653ed4 100644 --- a/providers/azureSpeech/speech.go +++ b/providers/azureSpeech/speech.go @@ -55,9 +55,12 @@ func (p *AzureSpeechProvider) getRequestBody(request *types.SpeechAudioRequest) } -func (p *AzureSpeechProvider) SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - - fullRequestURL := p.GetFullRequestURL(p.AudioSpeech, request.Model) +func (p *AzureSpeechProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeAudioSpeech) + if errWithCode != nil { + return nil, errWithCode + } + fullRequestURL := p.GetFullRequestURL(url, request.Model) headers := p.GetRequestHeaders() responseFormatr := outputFormatMap[request.ResponseFormat] if responseFormatr == "" { @@ -67,22 +70,19 @@ func (p *AzureSpeechProvider) SpeechAction(request *types.SpeechAudioRequest, is requestBody := p.getRequestBody(request) - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(requestBody), p.Requester.WithHeader(headers)) if err != nil { return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } + defer req.Body.Close() - errWithCode = p.SendRequestRaw(req) + var resp *http.Response + resp, errWithCode = p.Requester.SendRequestRaw(req) if errWithCode != nil { - return + return nil, errWithCode } - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: 0, - TotalTokens: promptTokens, - } + p.Usage.TotalTokens = p.Usage.PromptTokens - return + return resp, nil } diff --git a/providers/baichuan/base.go b/providers/baichuan/base.go index a0e5fb8a..671c1932 100644 --- a/providers/baichuan/base.go +++ b/providers/baichuan/base.go @@ -1,10 +1,10 @@ package baichuan import ( + "one-api/common/requester" + "one-api/model" "one-api/providers/base" "one-api/providers/openai" - - "github.com/gin-gonic/gin" ) // 定义供应商工厂 @@ -12,19 +12,26 @@ type BaichuanProviderFactory struct{} // 创建 BaichuanProvider // https://platform.baichuan-ai.com/docs/api -func (f BaichuanProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f BaichuanProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &BaichuanProvider{ OpenAIProvider: openai.OpenAIProvider{ BaseProvider: base.BaseProvider{ - BaseURL: "https://api.baichuan-ai.com", - ChatCompletions: "/v1/chat/completions", - Embeddings: "/v1/embeddings", - Context: c, + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(channel.Proxy, openai.RequestErrorHandle), }, }, } } +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://api.baichuan-ai.com", + ChatCompletions: "/v1/chat/completions", + Embeddings: "/v1/embeddings", + } +} + type BaichuanProvider struct { openai.OpenAIProvider } diff --git a/providers/baichuan/chat.go b/providers/baichuan/chat.go index d439c4d9..39dc3894 100644 --- a/providers/baichuan/chat.go +++ b/providers/baichuan/chat.go @@ -3,31 +3,61 @@ package baichuan import ( "net/http" "one-api/common" + "one-api/common/requester" "one-api/providers/openai" "one-api/types" "strings" ) -func (baichuanResponse *BaichuanChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if baichuanResponse.Error.Message != "" { +func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + requestBody := p.getChatRequestBody(request) + req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, requestBody) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + response := &BaichuanChatResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, response, false) + if errWithCode != nil { + return nil, errWithCode + } + + // 检测是否错误 + openaiErr := openai.ErrorHandle(&response.OpenAIErrorResponse) + if openaiErr != nil { errWithCode = &types.OpenAIErrorWithStatusCode{ - OpenAIError: baichuanResponse.Error, - StatusCode: resp.StatusCode, + OpenAIError: *openaiErr, + StatusCode: http.StatusBadRequest, } - - return + return nil, errWithCode } - OpenAIResponse = types.ChatCompletionResponse{ - ID: baichuanResponse.ID, - Object: baichuanResponse.Object, - Created: baichuanResponse.Created, - Model: baichuanResponse.Model, - Choices: baichuanResponse.Choices, - Usage: baichuanResponse.Usage, + *p.Usage = *response.Usage + + return &response.ChatCompletionResponse, nil +} + +func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode } - return + chatHandler := openai.OpenAIStreamHandler{ + Usage: p.Usage, + ModelName: request.Model, + } + + return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream) } // 获取聊天请求体 @@ -55,46 +85,3 @@ func (p *BaichuanProvider) getChatRequestBody(request *types.ChatCompletionReque TopK: request.N, } } - -// 聊天 -func (p *BaichuanProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - - requestBody := p.getChatRequestBody(request) - - fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) - headers := p.GetRequestHeaders() - if request.Stream { - headers["Accept"] = "text/event-stream" - } - - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - - if request.Stream { - openAIProviderChatStreamResponse := &openai.OpenAIProviderChatStreamResponse{} - var textResponse string - errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderChatStreamResponse) - if errWithCode != nil { - return - } - - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: common.CountTokenText(textResponse, request.Model), - TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model), - } - - } else { - baichuanResponse := &BaichuanChatResponse{} - errWithCode = p.SendRequest(req, baichuanResponse, false) - if errWithCode != nil { - return - } - - usage = baichuanResponse.Usage - } - return -} diff --git a/providers/baidu/base.go b/providers/baidu/base.go index 7dea85b6..c17be3ed 100644 --- a/providers/baidu/base.go +++ b/providers/baidu/base.go @@ -4,35 +4,65 @@ import ( "encoding/json" "errors" "fmt" - "one-api/common" + "net/http" + "one-api/common/requester" + "one-api/model" "one-api/providers/base" + "one-api/types" "strings" "sync" "time" - - "github.com/gin-gonic/gin" ) // 定义供应商工厂 type BaiduProviderFactory struct{} -// 创建 BaiduProvider +var baiduTokenStore sync.Map -func (f BaiduProviderFactory) Create(c *gin.Context) base.ProviderInterface { +// 创建 BaiduProvider +type BaiduProvider struct { + base.BaseProvider +} + +func (f BaiduProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &BaiduProvider{ BaseProvider: base.BaseProvider{ - BaseURL: "https://aip.baidubce.com", - ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat", - Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings", - Context: c, + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle), }, } } -var baiduTokenStore sync.Map +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://aip.baidubce.com", + ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat", + Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings", + } +} -type BaiduProvider struct { - base.BaseProvider +// 请求错误处理 +func requestErrorHandle(resp *http.Response) *types.OpenAIError { + var baiduError *BaiduError + err := json.NewDecoder(resp.Body).Decode(baiduError) + if err != nil { + return nil + } + + return errorHandle(baiduError) +} + +// 错误处理 +func errorHandle(baiduError *BaiduError) *types.OpenAIError { + if baiduError.ErrorMsg == "" { + return nil + } + return &types.OpenAIError{ + Message: baiduError.ErrorMsg, + Type: "baidu_error", + Code: baiduError.ErrorCode, + } } // 获取完整请求 URL @@ -92,32 +122,21 @@ func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessTo return nil, errors.New("invalid baidu apikey") } - client := common.NewClient() - url := fmt.Sprintf(p.BaseURL+"/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1]) + url := fmt.Sprintf(p.Config.BaseURL+"/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1]) var headers = map[string]string{ "Content-Type": "application/json", "Accept": "application/json", } - req, err := client.NewRequest("POST", url, common.WithHeader(headers)) + req, err := p.Requester.NewRequest("POST", url, p.Requester.WithHeader(headers)) if err != nil { return nil, err } - - httpClient := common.GetHttpClient(p.Channel.Proxy) - resp, err := httpClient.Do(req) - if err != nil { - return nil, err - } - common.PutHttpClient(httpClient) - - defer resp.Body.Close() - var accessToken BaiduAccessToken - err = json.NewDecoder(resp.Body).Decode(&accessToken) - if err != nil { - return nil, err + _, errWithCode := p.Requester.SendRequest(req, &accessToken, false) + if errWithCode != nil { + return nil, errors.New(errWithCode.OpenAIError.Message) } if accessToken.Error != "" { return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) diff --git a/providers/baidu/chat.go b/providers/baidu/chat.go index f3ded14c..68a60ef8 100644 --- a/providers/baidu/chat.go +++ b/providers/baidu/chat.go @@ -1,79 +1,155 @@ package baidu import ( - "bufio" "encoding/json" - "io" "net/http" "one-api/common" - "one-api/providers/base" + "one-api/common/requester" "one-api/types" "strings" ) -func (baiduResponse *BaiduChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if baiduResponse.ErrorMsg != "" { - return nil, &types.OpenAIErrorWithStatusCode{ - OpenAIError: types.OpenAIError{ - Message: baiduResponse.ErrorMsg, - Type: "baidu_error", - Param: "", - Code: baiduResponse.ErrorCode, - }, - StatusCode: resp.StatusCode, +type baiduStreamHandler struct { + Usage *types.Usage + Request *types.ChatCompletionRequest +} + +func (p *BaiduProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getBaiduChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + baiduResponse := &BaiduChatResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, baiduResponse, false) + if errWithCode != nil { + return nil, errWithCode + } + + return p.convertToChatOpenai(baiduResponse, request) +} + +func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getBaiduChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode + } + + chatHandler := &baiduStreamHandler{ + Usage: p.Usage, + Request: request, + } + + return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) +} + +func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + if errWithCode != nil { + return nil, errWithCode + } + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, request.Model) + if fullRequestURL == "" { + return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError) + } + + // 获取请求头 + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + } + + baiduRequest := convertFromChatOpenai(request) + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(baiduRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + return req, nil +} + +func (p *BaiduProvider) convertToChatOpenai(response *BaiduChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + error := errorHandle(&response.BaiduError) + if error != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *error, + StatusCode: http.StatusBadRequest, } + return } choice := types.ChatCompletionChoice{ Index: 0, Message: types.ChatCompletionMessage{ Role: "assistant", - // Content: baiduResponse.Result, }, - FinishReason: base.StopFinishReason, + FinishReason: types.FinishReasonStop, } - if baiduResponse.FunctionCall != nil { - if baiduResponse.FunctionCate == "tool" { + if response.FunctionCall != nil { + if request.Tools != nil { choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{ { - Id: baiduResponse.Id, + Id: response.Id, Type: "function", - Function: *baiduResponse.FunctionCall, + Function: response.FunctionCall, }, } - choice.FinishReason = &base.StopFinishReasonToolFunction + choice.FinishReason = types.FinishReasonToolCalls } else { - choice.Message.FunctionCall = baiduResponse.FunctionCall - choice.FinishReason = &base.StopFinishReasonCallFunction + choice.Message.FunctionCall = response.FunctionCall + choice.FinishReason = types.FinishReasonFunctionCall } } else { - choice.Message.Content = baiduResponse.Result + choice.Message.Content = response.Result } - OpenAIResponse = types.ChatCompletionResponse{ - ID: baiduResponse.Id, + openaiResponse = &types.ChatCompletionResponse{ + ID: response.Id, Object: "chat.completion", - Created: baiduResponse.Created, + Model: request.Model, + Created: response.Created, Choices: []types.ChatCompletionChoice{choice}, - Usage: baiduResponse.Usage, + Usage: response.Usage, } + *p.Usage = *openaiResponse.Usage + return } -func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaiduChatRequest { +func convertFromChatOpenai(request *types.ChatCompletionRequest) *BaiduChatRequest { messages := make([]BaiduMessage, 0, len(request.Messages)) for _, message := range request.Messages { - if message.Role == "system" { + if message.Role == types.ChatMessageRoleSystem { messages = append(messages, BaiduMessage{ - Role: "user", + Role: types.ChatMessageRoleUser, Content: message.StringContent(), }) messages = append(messages, BaiduMessage{ - Role: "assistant", + Role: types.ChatMessageRoleAssistant, Content: "Okay", }) + } else if message.Role == types.ChatMessageRoleFunction { + messages = append(messages, BaiduMessage{ + Role: types.ChatMessageRoleAssistant, + Content: "Okay", + }) + messages = append(messages, BaiduMessage{ + Role: types.ChatMessageRoleUser, + Content: "这是函数调用返回的内容,请回答之前的问题:\n" + message.StringContent(), + }) } else { messages = append(messages, BaiduMessage{ Role: message.Role, @@ -101,154 +177,82 @@ func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) return baiduChatRequest } -func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody := p.getChatRequestBody(request) - fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) - if fullRequestURL == "" { - return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError) +// 转换为OpenAI聊天流式请求体 +func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { + // 如果rawLine 前缀不为data:,则直接返回 + if !strings.HasPrefix(string(*rawLine), "data: ") { + *rawLine = nil + return nil } - headers := p.GetRequestHeaders() - if request.Stream { - headers["Accept"] = "text/event-stream" - } + // 去除前缀 + *rawLine = (*rawLine)[6:] - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + var baiduResponse BaiduChatStreamResponse + err := json.Unmarshal(*rawLine, &baiduResponse) if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return common.ErrorToOpenAIError(err) } - if request.Stream { - usage, errWithCode = p.sendStreamRequest(req, request.Model, request.GetFunctionCate()) - if errWithCode != nil { - return - } - - } else { - baiduChatRequest := &BaiduChatResponse{ - Model: request.Model, - FunctionCate: request.GetFunctionCate(), - } - errWithCode = p.SendRequest(req, baiduChatRequest, false) - if errWithCode != nil { - return - } - - usage = baiduChatRequest.Usage + if baiduResponse.IsEnd { + *isFinished = true } - return + + return h.convertToOpenaiStream(&baiduResponse, response) } -func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse { - var choice types.ChatCompletionStreamChoice +func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, response *[]types.ChatCompletionStreamResponse) error { + choice := types.ChatCompletionStreamChoice{ + Index: 0, + Delta: types.ChatCompletionStreamChoiceDelta{ + Role: "assistant", + }, + } if baiduResponse.FunctionCall != nil { - if baiduResponse.FunctionCate == "tool" { + if h.Request.Tools != nil { choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{ { Id: baiduResponse.Id, Type: "function", - Function: *baiduResponse.FunctionCall, + Function: baiduResponse.FunctionCall, }, } - choice.FinishReason = &base.StopFinishReasonToolFunction + choice.FinishReason = types.FinishReasonToolCalls } else { choice.Delta.FunctionCall = baiduResponse.FunctionCall - choice.FinishReason = &base.StopFinishReasonCallFunction + choice.FinishReason = types.FinishReasonFunctionCall } } else { choice.Delta.Content = baiduResponse.Result if baiduResponse.IsEnd { - choice.FinishReason = &base.StopFinishReason + choice.FinishReason = types.FinishReasonStop } } - response := types.ChatCompletionStreamResponse{ + chatCompletion := types.ChatCompletionStreamResponse{ ID: baiduResponse.Id, Object: "chat.completion.chunk", Created: baiduResponse.Created, - Model: baiduResponse.Model, - Choices: []types.ChatCompletionStreamChoice{choice}, + Model: h.Request.Model, } - return &response -} - -func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - defer req.Body.Close() - - usage = &types.Usage{} - // 发送请求 - client := common.GetHttpClient(p.Channel.Proxy) - resp, err := client.Do(req) - if err != nil { - return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) - } - common.PutHttpClient(client) - - if common.IsFailureStatusCode(resp) { - return nil, common.HandleErrorResp(resp) - } - - defer resp.Body.Close() - - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // ignore blank line or wrong format - continue - } - data = data[6:] - dataChan <- data - } - stopChan <- true - }() - common.SetEventStreamHeaders(p.Context) - p.Context.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var baiduResponse BaiduChatStreamResponse - baiduResponse.FunctionCate = functionCate - err := json.Unmarshal([]byte(data), &baiduResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if baiduResponse.Usage.TotalTokens != 0 { - usage.TotalTokens = baiduResponse.Usage.TotalTokens - usage.PromptTokens = baiduResponse.Usage.PromptTokens - usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens - } - baiduResponse.Model = model - response := p.streamResponseBaidu2OpenAI(&baiduResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - - return usage, nil + + if baiduResponse.FunctionCall == nil { + chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice} + *response = append(*response, chatCompletion) + } else { + choices := choice.ConvertOpenaiStream() + for _, choice := range choices { + chatCompletionCopy := chatCompletion + chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice} + *response = append(*response, chatCompletionCopy) + } + } + + h.Usage.TotalTokens = baiduResponse.Usage.TotalTokens + h.Usage.PromptTokens = baiduResponse.Usage.PromptTokens + h.Usage.CompletionTokens += baiduResponse.Usage.CompletionTokens + + return nil } diff --git a/providers/baidu/embeddings.go b/providers/baidu/embeddings.go index 1b4fc5cb..5e13215e 100644 --- a/providers/baidu/embeddings.go +++ b/providers/baidu/embeddings.go @@ -6,33 +6,63 @@ import ( "one-api/types" ) -func (p *BaiduProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *BaiduEmbeddingRequest { +func (p *BaiduProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) + if errWithCode != nil { + return nil, errWithCode + } + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, request.Model) + if fullRequestURL == "" { + return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError) + } + + // 获取请求头 + headers := p.GetRequestHeaders() + + aliRequest := convertFromEmbeddingOpenai(request) + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(aliRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + defer req.Body.Close() + + baiduResponse := &BaiduEmbeddingResponse{} + + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, baiduResponse, false) + if errWithCode != nil { + return nil, errWithCode + } + + return p.convertToEmbeddingOpenai(baiduResponse, request) +} + +func convertFromEmbeddingOpenai(request *types.EmbeddingRequest) *BaiduEmbeddingRequest { return &BaiduEmbeddingRequest{ Input: request.ParseInput(), } } -func (baiduResponse *BaiduEmbeddingResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if baiduResponse.ErrorMsg != "" { - return nil, &types.OpenAIErrorWithStatusCode{ - OpenAIError: types.OpenAIError{ - Message: baiduResponse.ErrorMsg, - Type: "baidu_error", - Param: "", - Code: baiduResponse.ErrorCode, - }, - StatusCode: resp.StatusCode, +func (p *BaiduProvider) convertToEmbeddingOpenai(response *BaiduEmbeddingResponse, request *types.EmbeddingRequest) (openaiResponse *types.EmbeddingResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + error := errorHandle(&response.BaiduError) + if error != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *error, + StatusCode: http.StatusBadRequest, } + return } openAIEmbeddingResponse := &types.EmbeddingResponse{ Object: "list", - Data: make([]types.Embedding, 0, len(baiduResponse.Data)), - Model: "text-embedding-v1", - Usage: &baiduResponse.Usage, + Data: make([]types.Embedding, 0, len(response.Data)), + Model: request.Model, + Usage: &response.Usage, } - for _, item := range baiduResponse.Data { + for _, item := range response.Data { openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{ Object: item.Object, Index: item.Index, @@ -40,30 +70,7 @@ func (baiduResponse *BaiduEmbeddingResponse) ResponseHandler(resp *http.Response }) } + *p.Usage = response.Usage + return openAIEmbeddingResponse, nil } - -func (p *BaiduProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - - requestBody := p.getEmbeddingsRequestBody(request) - fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model) - if fullRequestURL == "" { - return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError) - } - - headers := p.GetRequestHeaders() - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - - baiduEmbeddingResponse := &BaiduEmbeddingResponse{} - errWithCode = p.SendRequest(req, baiduEmbeddingResponse, false) - if errWithCode != nil { - return - } - usage = &baiduEmbeddingResponse.Usage - - return usage, nil -} diff --git a/providers/base/common.go b/providers/base/common.go index ef80a397..710b49cf 100644 --- a/providers/base/common.go +++ b/providers/base/common.go @@ -3,9 +3,9 @@ package base import ( "encoding/json" "fmt" - "io" "net/http" "one-api/common" + "one-api/common/requester" "one-api/model" "one-api/types" "strings" @@ -13,11 +13,7 @@ import ( "github.com/gin-gonic/gin" ) -var StopFinishReason = "stop" -var StopFinishReasonToolFunction = "tool_calls" -var StopFinishReasonCallFunction = "function_call" - -type BaseProvider struct { +type ProviderConfig struct { BaseURL string Completions string ChatCompletions string @@ -29,8 +25,15 @@ type BaseProvider struct { ImagesGenerations string ImagesEdit string ImagesVariations string - Context *gin.Context - Channel *model.Channel +} + +type BaseProvider struct { + OriginalModel string + Usage *types.Usage + Config ProviderConfig + Context *gin.Context + Channel *model.Channel + Requester *requester.HTTPRequester } // 获取基础URL @@ -39,11 +42,7 @@ func (p *BaseProvider) GetBaseURL() string { return p.Channel.GetBaseURL() } - return p.BaseURL -} - -func (p *BaseProvider) SetChannel(channel *model.Channel) { - p.Channel = channel + return p.Config.BaseURL } // 获取完整请求URL @@ -62,104 +61,85 @@ func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) { } } -// 发送请求 -func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { - defer req.Body.Close() - - resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true, p.Channel.Proxy) - if openAIErrorWithStatusCode != nil { - return - } - - defer resp.Body.Close() - - openAIResponse, openAIErrorWithStatusCode := response.ResponseHandler(resp) - if openAIErrorWithStatusCode != nil { - return - } - - if rawOutput { - for k, v := range resp.Header { - p.Context.Writer.Header().Set(k, v[0]) - } - - p.Context.Writer.WriteHeader(resp.StatusCode) - _, err := io.Copy(p.Context.Writer, resp.Body) - if err != nil { - return common.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - } else { - jsonResponse, err := json.Marshal(openAIResponse) - if err != nil { - return common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) - } - p.Context.Writer.Header().Set("Content-Type", "application/json") - p.Context.Writer.WriteHeader(resp.StatusCode) - _, err = p.Context.Writer.Write(jsonResponse) - - if err != nil { - return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError) - } - } - - return nil +func (p *BaseProvider) GetUsage() *types.Usage { + return p.Usage } -func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { - defer req.Body.Close() - - // 发送请求 - client := common.GetHttpClient(p.Channel.Proxy) - resp, err := client.Do(req) - if err != nil { - return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) - } - common.PutHttpClient(client) - - defer resp.Body.Close() - - // 处理响应 - if common.IsFailureStatusCode(resp) { - return common.HandleErrorResp(resp) - } - - for k, v := range resp.Header { - p.Context.Writer.Header().Set(k, v[0]) - } - - p.Context.Writer.WriteHeader(resp.StatusCode) - - _, err = io.Copy(p.Context.Writer, resp.Body) - if err != nil { - return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError) - } - - return nil +func (p *BaseProvider) SetUsage(usage *types.Usage) { + p.Usage = usage } -func (p *BaseProvider) SupportAPI(relayMode int) bool { +func (p *BaseProvider) SetContext(c *gin.Context) { + p.Context = c +} + +func (p *BaseProvider) SetOriginalModel(ModelName string) { + p.OriginalModel = ModelName +} + +func (p *BaseProvider) GetOriginalModel() string { + return p.OriginalModel +} + +func (p *BaseProvider) GetChannel() *model.Channel { + return p.Channel +} + +func (p *BaseProvider) ModelMappingHandler(modelName string) (string, error) { + p.OriginalModel = modelName + + modelMapping := p.Channel.GetModelMapping() + + if modelMapping == "" || modelMapping == "{}" { + return modelName, nil + } + + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return "", err + } + + if modelMap[modelName] != "" { + return modelMap[modelName], nil + } + + return modelName, nil +} + +func (p *BaseProvider) GetAPIUri(relayMode int) string { switch relayMode { case common.RelayModeChatCompletions: - return p.ChatCompletions != "" + return p.Config.ChatCompletions case common.RelayModeCompletions: - return p.Completions != "" + return p.Config.Completions case common.RelayModeEmbeddings: - return p.Embeddings != "" + return p.Config.Embeddings case common.RelayModeAudioSpeech: - return p.AudioSpeech != "" + return p.Config.AudioSpeech case common.RelayModeAudioTranscription: - return p.AudioTranscriptions != "" + return p.Config.AudioTranscriptions case common.RelayModeAudioTranslation: - return p.AudioTranslations != "" + return p.Config.AudioTranslations case common.RelayModeModerations: - return p.Moderation != "" + return p.Config.Moderation case common.RelayModeImagesGenerations: - return p.ImagesGenerations != "" + return p.Config.ImagesGenerations case common.RelayModeImagesEdits: - return p.ImagesEdit != "" + return p.Config.ImagesEdit case common.RelayModeImagesVariations: - return p.ImagesVariations != "" + return p.Config.ImagesVariations default: - return false + return "" } } + +func (p *BaseProvider) GetSupportedAPIUri(relayMode int) (url string, err *types.OpenAIErrorWithStatusCode) { + url = p.GetAPIUri(relayMode) + if url == "" { + err = common.StringErrorWrapper("The API interface is not supported", "unsupported_api", http.StatusNotImplemented) + return + } + + return +} diff --git a/providers/base/handler.go b/providers/base/handler.go new file mode 100644 index 00000000..fb7688d8 --- /dev/null +++ b/providers/base/handler.go @@ -0,0 +1,7 @@ +package base + +import "one-api/types" + +type BaseHandler struct { + Usage *types.Usage +} diff --git a/providers/base/interface.go b/providers/base/interface.go index 584e12b0..84889acb 100644 --- a/providers/base/interface.go +++ b/providers/base/interface.go @@ -2,84 +2,108 @@ package base import ( "net/http" + "one-api/common/requester" "one-api/model" "one-api/types" + + "github.com/gin-gonic/gin" ) +type Requestable interface { + types.CompletionRequest | types.ChatCompletionRequest | types.EmbeddingRequest | types.ModerationRequest | types.SpeechAudioRequest | types.AudioRequest | types.ImageRequest | types.ImageEditRequest +} + // 基础接口 type ProviderInterface interface { - GetBaseURL() string - GetFullRequestURL(requestURL string, modelName string) string - GetRequestHeaders() (headers map[string]string) - SupportAPI(relayMode int) bool - SetChannel(channel *model.Channel) + // 获取基础URL + // GetBaseURL() string + // 获取完整请求URL + // GetFullRequestURL(requestURL string, modelName string) string + // 获取请求头 + // GetRequestHeaders() (headers map[string]string) + // 获取用量 + GetUsage() *types.Usage + // 设置用量 + SetUsage(usage *types.Usage) + // 设置Context + SetContext(c *gin.Context) + // 设置原始模型 + SetOriginalModel(ModelName string) + // 获取原始模型 + GetOriginalModel() string + + // SupportAPI(relayMode int) bool + GetChannel() *model.Channel + ModelMappingHandler(modelName string) (string, error) } // 完成接口 type CompletionInterface interface { ProviderInterface - CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) + CreateCompletion(request *types.CompletionRequest) (*types.CompletionResponse, *types.OpenAIErrorWithStatusCode) + CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[types.CompletionResponse], *types.OpenAIErrorWithStatusCode) } // 聊天接口 type ChatInterface interface { ProviderInterface - ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) + CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) + CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) } // 嵌入接口 type EmbeddingsInterface interface { ProviderInterface - EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) + CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) } // 审查接口 type ModerationInterface interface { ProviderInterface - ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) + CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode) } // 文字转语音接口 type SpeechInterface interface { ProviderInterface - SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) + CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) } // 语音转文字接口 type TranscriptionsInterface interface { ProviderInterface - TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) + CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) } // 语音翻译接口 type TranslationInterface interface { ProviderInterface - TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) + CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) } // 图片生成接口 type ImageGenerationsInterface interface { ProviderInterface - ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) + CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) } // 图片编辑接口 type ImageEditsInterface interface { ProviderInterface - ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) + CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) } type ImageVariationsInterface interface { ProviderInterface - ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) + CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) } // 余额接口 type BalanceInterface interface { - Balance(channel *model.Channel) (float64, error) + Balance() (float64, error) } -type ProviderResponseHandler interface { - // 响应处理函数 - ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) -} +// type ProviderResponseHandler interface { +// // 响应处理函数 +// ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) +// } diff --git a/providers/claude/base.go b/providers/claude/base.go index 59b819eb..c96333a7 100644 --- a/providers/claude/base.go +++ b/providers/claude/base.go @@ -1,20 +1,23 @@ package claude import ( + "encoding/json" + "net/http" + "one-api/common/requester" + "one-api/model" "one-api/providers/base" - - "github.com/gin-gonic/gin" + "one-api/types" ) type ClaudeProviderFactory struct{} // 创建 ClaudeProvider -func (f ClaudeProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f ClaudeProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &ClaudeProvider{ BaseProvider: base.BaseProvider{ - BaseURL: "https://api.anthropic.com", - ChatCompletions: "/v1/complete", - Context: c, + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle), }, } } @@ -23,6 +26,36 @@ type ClaudeProvider struct { base.BaseProvider } +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://api.anthropic.com", + ChatCompletions: "/v1/complete", + } +} + +// 请求错误处理 +func requestErrorHandle(resp *http.Response) *types.OpenAIError { + var claudeError *ClaudeResponseError + err := json.NewDecoder(resp.Body).Decode(claudeError) + if err != nil { + return nil + } + + return errorHandle(claudeError) +} + +// 错误处理 +func errorHandle(claudeError *ClaudeResponseError) *types.OpenAIError { + if claudeError.Error.Type == "" { + return nil + } + return &types.OpenAIError{ + Message: claudeError.Error.Message, + Type: claudeError.Error.Type, + Code: claudeError.Error.Type, + } +} + // 获取请求头 func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) @@ -41,9 +74,9 @@ func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) { func stopReasonClaude2OpenAI(reason string) string { switch reason { case "stop_sequence": - return "stop" + return types.FinishReasonStop case "max_tokens": - return "length" + return types.FinishReasonLength default: return reason } diff --git a/providers/claude/chat.go b/providers/claude/chat.go index d2094926..f7f5d0d2 100644 --- a/providers/claude/chat.go +++ b/providers/claude/chat.go @@ -1,56 +1,86 @@ package claude import ( - "bufio" "encoding/json" "fmt" - "io" "net/http" "one-api/common" + "one-api/common/requester" "one-api/types" "strings" ) -func (claudeResponse *ClaudeResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if claudeResponse.Error.Type != "" { - return nil, &types.OpenAIErrorWithStatusCode{ - OpenAIError: types.OpenAIError{ - Message: claudeResponse.Error.Message, - Type: claudeResponse.Error.Type, - Param: "", - Code: claudeResponse.Error.Type, - }, - StatusCode: resp.StatusCode, - } - } - - choice := types.ChatCompletionChoice{ - Index: 0, - Message: types.ChatCompletionMessage{ - Role: "assistant", - Content: strings.TrimPrefix(claudeResponse.Completion, " "), - Name: nil, - }, - FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), - } - fullTextResponse := types.ChatCompletionResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), - Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: []types.ChatCompletionChoice{choice}, - Model: claudeResponse.Model, - } - - completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model) - claudeResponse.Usage.CompletionTokens = completionTokens - claudeResponse.Usage.TotalTokens = claudeResponse.Usage.PromptTokens + completionTokens - - fullTextResponse.Usage = claudeResponse.Usage - - return fullTextResponse, nil +type claudeStreamHandler struct { + Usage *types.Usage + Request *types.ChatCompletionRequest } -func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *ClaudeRequest) { +func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + claudeResponse := &ClaudeResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, claudeResponse, false) + if errWithCode != nil { + return nil, errWithCode + } + + return p.convertToChatOpenai(claudeResponse, request) +} + +func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode + } + + chatHandler := &claudeStreamHandler{ + Usage: p.Usage, + Request: request, + } + + return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) +} + +func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + if errWithCode != nil { + return nil, errWithCode + } + + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, request.Model) + if fullRequestURL == "" { + return nil, common.ErrorWrapper(nil, "invalid_claude_config", http.StatusInternalServerError) + } + + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + } + + claudeRequest := convertFromChatOpenai(request) + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(claudeRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + return req, nil +} + +func convertFromChatOpenai(request *types.ChatCompletionRequest) *ClaudeRequest { claudeRequest := ClaudeRequest{ Model: request.Model, Prompt: "", @@ -80,138 +110,84 @@ func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest return &claudeRequest } -func (p *ClaudeProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody := p.getChatRequestBody(request) - fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) - headers := p.GetRequestHeaders() - if request.Stream { - headers["Accept"] = "text/event-stream" +func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + error := errorHandle(&response.ClaudeResponseError) + if error != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *error, + StatusCode: http.StatusBadRequest, + } + return } - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + choice := types.ChatCompletionChoice{ + Index: 0, + Message: types.ChatCompletionMessage{ + Role: "assistant", + Content: strings.TrimPrefix(response.Completion, " "), + Name: nil, + }, + FinishReason: stopReasonClaude2OpenAI(response.StopReason), + } + openaiResponse = &types.ChatCompletionResponse{ + ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []types.ChatCompletionChoice{choice}, + Model: response.Model, } - if request.Stream { - var responseText string - errWithCode, responseText = p.sendStreamRequest(req) - if errWithCode != nil { - return - } + completionTokens := common.CountTokenText(response.Completion, response.Model) + response.Usage.CompletionTokens = completionTokens + response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: common.CountTokenText(responseText, request.Model), - } - usage.TotalTokens = promptTokens + usage.CompletionTokens + openaiResponse.Usage = response.Usage - } else { - var claudeResponse = &ClaudeResponse{ - Usage: &types.Usage{ - PromptTokens: promptTokens, - }, - } - errWithCode = p.SendRequest(req, claudeResponse, false) - if errWithCode != nil { - return - } - - usage = claudeResponse.Usage - } - return + *p.Usage = *response.Usage + return openaiResponse, nil } -func (p *ClaudeProvider) streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *types.ChatCompletionStreamResponse { +// 转换为OpenAI聊天流式请求体 +func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { + // 如果rawLine 前缀不为data:,则直接返回 + if !strings.HasPrefix(string(*rawLine), `data: {"type": "completion"`) { + *rawLine = nil + return nil + } + + // 去除前缀 + *rawLine = (*rawLine)[6:] + + var claudeResponse *ClaudeResponse + err := json.Unmarshal(*rawLine, claudeResponse) + if err != nil { + return common.ErrorToOpenAIError(err) + } + + if claudeResponse.StopReason == "stop_sequence" { + *isFinished = true + } + + return h.convertToOpenaiStream(claudeResponse, response) +} + +func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, response *[]types.ChatCompletionStreamResponse) error { var choice types.ChatCompletionStreamChoice choice.Delta.Content = claudeResponse.Completion finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) if finishReason != "null" { choice.FinishReason = &finishReason } - var response types.ChatCompletionStreamResponse - response.Object = "chat.completion.chunk" - response.Model = claudeResponse.Model - response.Choices = []types.ChatCompletionStreamChoice{choice} - return &response -} - -func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) { - defer req.Body.Close() - - // 发送请求 - client := common.GetHttpClient(p.Channel.Proxy) - resp, err := client.Do(req) - if err != nil { - return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" - } - common.PutHttpClient(client) - - if common.IsFailureStatusCode(resp) { - return common.HandleErrorResp(resp), "" + chatCompletion := types.ChatCompletionStreamResponse{ + Object: "chat.completion.chunk", + Model: h.Request.Model, + Choices: []types.ChatCompletionStreamChoice{choice}, } - defer resp.Body.Close() + *response = append(*response, chatCompletion) - responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - createdTime := common.GetTimestamp() - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { - return i + 4, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if !strings.HasPrefix(data, "event: completion") { - continue - } - data = strings.TrimPrefix(data, "event: completion\r\ndata: ") - dataChan <- data - } - stopChan <- true - }() - common.SetEventStreamHeaders(p.Context) - p.Context.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var claudeResponse ClaudeResponse - err := json.Unmarshal([]byte(data), &claudeResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - responseText += claudeResponse.Completion - response := p.streamResponseClaude2OpenAI(&claudeResponse) - response.ID = responseId - response.Created = createdTime - jsonStr, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) + h.Usage.PromptTokens += common.CountTokenText(claudeResponse.Completion, h.Request.Model) - return nil, responseText + return nil } diff --git a/providers/claude/type.go b/providers/claude/type.go index 8a920c73..4cc0eefd 100644 --- a/providers/claude/type.go +++ b/providers/claude/type.go @@ -23,10 +23,13 @@ type ClaudeRequest struct { Stream bool `json:"stream,omitempty"` } +type ClaudeResponseError struct { + Error ClaudeError `json:"error,omitempty"` +} type ClaudeResponse struct { Completion string `json:"completion"` StopReason string `json:"stop_reason"` Model string `json:"model"` - Error ClaudeError `json:"error"` Usage *types.Usage `json:"usage,omitempty"` + ClaudeResponseError } diff --git a/providers/closeai/balance.go b/providers/closeai/balance.go index 82a99432..d163c123 100644 --- a/providers/closeai/balance.go +++ b/providers/closeai/balance.go @@ -2,7 +2,6 @@ package closeai import ( "errors" - "one-api/common" "one-api/model" ) @@ -10,15 +9,14 @@ func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") headers := p.GetRequestHeaders() - client := common.NewClient() - req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers)) if err != nil { return 0, err } // 发送请求 var response OpenAICreditGrants - _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) + _, errWithCode := p.Requester.SendRequest(req, &response, false) if errWithCode != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } diff --git a/providers/closeai/base.go b/providers/closeai/base.go index c5387ec5..2ef00cfa 100644 --- a/providers/closeai/base.go +++ b/providers/closeai/base.go @@ -1,18 +1,17 @@ package closeai import ( + "one-api/model" "one-api/providers/base" "one-api/providers/openai" - - "github.com/gin-gonic/gin" ) type CloseaiProviderFactory struct{} // 创建 CloseaiProvider -func (f CloseaiProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f CloseaiProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &CloseaiProxyProvider{ - OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"), + OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.closeai-proxy.xyz"), } } diff --git a/providers/gemini/base.go b/providers/gemini/base.go index 26d71c7d..aafcc376 100644 --- a/providers/gemini/base.go +++ b/providers/gemini/base.go @@ -1,22 +1,25 @@ package gemini import ( + "encoding/json" "fmt" + "net/http" + "one-api/common/requester" + "one-api/model" "one-api/providers/base" + "one-api/types" "strings" - - "github.com/gin-gonic/gin" ) type GeminiProviderFactory struct{} // 创建 ClaudeProvider -func (f GeminiProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f GeminiProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &GeminiProvider{ BaseProvider: base.BaseProvider{ - BaseURL: "https://generativelanguage.googleapis.com", - ChatCompletions: "/", - Context: c, + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle), }, } } @@ -25,6 +28,37 @@ type GeminiProvider struct { base.BaseProvider } +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://generativelanguage.googleapis.com", + ChatCompletions: "/", + } +} + +// 请求错误处理 +func requestErrorHandle(resp *http.Response) *types.OpenAIError { + var geminiError *GeminiErrorResponse + err := json.NewDecoder(resp.Body).Decode(geminiError) + if err != nil { + return nil + } + + return errorHandle(geminiError) +} + +// 错误处理 +func errorHandle(geminiError *GeminiErrorResponse) *types.OpenAIError { + if geminiError.Error.Message == "" { + return nil + } + return &types.OpenAIError{ + Message: geminiError.Error.Message, + Type: "gemini_error", + Param: geminiError.Error.Status, + Code: geminiError.Error.Code, + } +} + func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string) string { baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") version := "v1" diff --git a/providers/gemini/chat.go b/providers/gemini/chat.go index 57de101b..acaedb97 100644 --- a/providers/gemini/chat.go +++ b/providers/gemini/chat.go @@ -1,14 +1,12 @@ package gemini import ( - "bufio" "encoding/json" "fmt" - "io" "net/http" "one-api/common" "one-api/common/image" - "one-api/providers/base" + "one-api/common/requester" "one-api/types" "strings" ) @@ -17,50 +15,78 @@ const ( GeminiVisionMaxImageNum = 16 ) -func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if len(response.Candidates) == 0 { - return nil, &types.OpenAIErrorWithStatusCode{ - OpenAIError: types.OpenAIError{ - Message: "No candidates returned", - Type: "server_error", - Param: "", - Code: 500, - }, - StatusCode: resp.StatusCode, - } - } - - fullTextResponse := &types.ChatCompletionResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), - Object: "chat.completion", - Created: common.GetTimestamp(), - Model: response.Model, - Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)), - } - for i, candidate := range response.Candidates { - choice := types.ChatCompletionChoice{ - Index: i, - Message: types.ChatCompletionMessage{ - Role: "assistant", - Content: "", - }, - FinishReason: base.StopFinishReason, - } - if len(candidate.Content.Parts) > 0 { - choice.Message.Content = candidate.Content.Parts[0].Text - } - fullTextResponse.Choices = append(fullTextResponse.Choices, choice) - } - - completionTokens := common.CountTokenText(response.GetResponseText(), response.Model) - response.Usage.CompletionTokens = completionTokens - response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens - - return fullTextResponse, nil +type geminiStreamHandler struct { + Usage *types.Usage + Request *types.ChatCompletionRequest } -// Setting safety to the lowest possible values since Gemini is already powerless enough -func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest, errWithCode *types.OpenAIErrorWithStatusCode) { +func (p *GeminiProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + geminiChatResponse := &GeminiChatResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, geminiChatResponse, false) + if errWithCode != nil { + return nil, errWithCode + } + + return p.convertToChatOpenai(geminiChatResponse, request) +} + +func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode + } + + chatHandler := &geminiStreamHandler{ + Usage: p.Usage, + Request: request, + } + + return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) +} + +func (p *GeminiProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + url := "generateContent" + if request.Stream { + url = "streamGenerateContent?alt=sse" + } + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, request.Model) + + // 获取请求头 + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + } + + geminiRequest, errWithCode := convertFromChatOpenai(request) + if errWithCode != nil { + return nil, errWithCode + } + + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(geminiRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + return req, nil +} + +func convertFromChatOpenai(request *types.ChatCompletionRequest) (*GeminiChatRequest, *types.OpenAIErrorWithStatusCode) { geminiRequest := GeminiChatRequest{ Contents: make([]GeminiChatContent, 0, len(request.Messages)), SafetySettings: []GeminiChatSafetySettings{ @@ -160,144 +186,99 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest return &geminiRequest, nil } -func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody, errWithCode := p.getChatRequestBody(request) - if errWithCode != nil { +func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + error := errorHandle(&response.GeminiErrorResponse) + if error != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *error, + StatusCode: http.StatusBadRequest, + } return } - fullRequestURL := p.GetFullRequestURL("generateContent", request.Model) - headers := p.GetRequestHeaders() - if request.Stream { - headers["Accept"] = "text/event-stream" + + openaiResponse = &types.ChatCompletionResponse{ + ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Model: request.Model, + Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)), } - - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - - if request.Stream { - var responseText string - errWithCode, responseText = p.sendStreamRequest(req, request.Model) - if errWithCode != nil { - return - } - - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: common.CountTokenText(responseText, request.Model), - } - usage.TotalTokens = promptTokens + usage.CompletionTokens - - } else { - var geminiResponse = &GeminiChatResponse{ - Model: request.Model, - Usage: &types.Usage{ - PromptTokens: promptTokens, + for i, candidate := range response.Candidates { + choice := types.ChatCompletionChoice{ + Index: i, + Message: types.ChatCompletionMessage{ + Role: "assistant", + Content: "", }, + FinishReason: types.FinishReasonStop, } - errWithCode = p.SendRequest(req, geminiResponse, false) - if errWithCode != nil { - return + if len(candidate.Content.Parts) > 0 { + choice.Message.Content = candidate.Content.Parts[0].Text } - - usage = geminiResponse.Usage + openaiResponse.Choices = append(openaiResponse.Choices, choice) } + + completionTokens := common.CountTokenText(response.GetResponseText(), response.Model) + + p.Usage.CompletionTokens = completionTokens + p.Usage.TotalTokens = p.Usage.PromptTokens + completionTokens + openaiResponse.Usage = p.Usage + return - } -// func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse { -// var choice types.ChatCompletionStreamChoice -// choice.Delta.Content = geminiResponse.GetResponseText() -// choice.FinishReason = &base.StopFinishReason -// var response types.ChatCompletionStreamResponse -// response.Object = "chat.completion.chunk" -// response.Model = "gemini" -// response.Choices = []types.ChatCompletionStreamChoice{choice} -// return &response -// } +// 转换为OpenAI聊天流式请求体 +func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { + // 如果rawLine 前缀不为data:,则直接返回 + if !strings.HasPrefix(string(*rawLine), "data: ") { + *rawLine = nil + return nil + } -func (p *GeminiProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) { - defer req.Body.Close() + // 去除前缀 + *rawLine = (*rawLine)[6:] - // 发送请求 - client := common.GetHttpClient(p.Channel.Proxy) - resp, err := client.Do(req) + var geminiResponse GeminiChatResponse + err := json.Unmarshal(*rawLine, &geminiResponse) if err != nil { - return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" - } - common.PutHttpClient(client) - - if common.IsFailureStatusCode(resp) { - return common.HandleErrorResp(resp), "" + return common.ErrorToOpenAIError(err) } - defer resp.Body.Close() + error := errorHandle(&geminiResponse.GeminiErrorResponse) + if error != nil { + return error + } - responseText := "" - dataChan := make(chan string) - stopChan := make(chan bool) - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - go func() { - for scanner.Scan() { - data := scanner.Text() - data = strings.TrimSpace(data) - if !strings.HasPrefix(data, "\"text\": \"") { - continue - } - data = strings.TrimPrefix(data, "\"text\": \"") - data = strings.TrimSuffix(data, "\"") - dataChan <- data - } - stopChan <- true - }() - common.SetEventStreamHeaders(p.Context) - p.Context.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // this is used to prevent annoying \ related format bug - data = fmt.Sprintf("{\"content\": \"%s\"}", data) - type dummyStruct struct { - Content string `json:"content"` - } - var dummy dummyStruct - json.Unmarshal([]byte(data), &dummy) - responseText += dummy.Content - var choice types.ChatCompletionStreamChoice - choice.Delta.Content = dummy.Content - response := types.ChatCompletionStreamResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), - Object: "chat.completion.chunk", - Created: common.GetTimestamp(), - Model: model, - Choices: []types.ChatCompletionStreamChoice{choice}, - } - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) + return h.convertToOpenaiStream(&geminiResponse, response) - return nil, responseText +} + +func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, response *[]types.ChatCompletionStreamResponse) error { + choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates)) + + for i, candidate := range geminiResponse.Candidates { + choice := types.ChatCompletionStreamChoice{ + Index: i, + Delta: types.ChatCompletionStreamChoiceDelta{ + Content: candidate.Content.Parts[0].Text, + }, + FinishReason: types.FinishReasonStop, + } + choices = append(choices, choice) + } + + streamResponse := types.ChatCompletionStreamResponse{ + ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: h.Request.Model, + Choices: choices, + } + + *response = append(*response, streamResponse) + + h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model) + h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens + + return nil } diff --git a/providers/gemini/type.go b/providers/gemini/type.go index f6476e11..fc063bcd 100644 --- a/providers/gemini/type.go +++ b/providers/gemini/type.go @@ -42,11 +42,22 @@ type GeminiChatGenerationConfig struct { StopSequences []string `json:"stopSequences,omitempty"` } +type GeminiError struct { + Code string `json:"code"` + Message string `json:"message"` + Status string `json:"status"` +} + +type GeminiErrorResponse struct { + Error GeminiError `json:"error,omitempty"` +} + type GeminiChatResponse struct { Candidates []GeminiChatCandidate `json:"candidates"` PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` Usage *types.Usage `json:"usage,omitempty"` Model string `json:"model,omitempty"` + GeminiErrorResponse } type GeminiChatCandidate struct { diff --git a/providers/openai/balance.go b/providers/openai/balance.go index 3620e870..b17e8998 100644 --- a/providers/openai/balance.go +++ b/providers/openai/balance.go @@ -3,7 +3,6 @@ package openai import ( "errors" "fmt" - "one-api/common" "one-api/model" "time" ) @@ -16,15 +15,14 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) { fullRequestURL := p.GetFullRequestURL("/v1/dashboard/billing/subscription", "") headers := p.GetRequestHeaders() - client := common.NewClient() - req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers)) if err != nil { return 0, err } // 发送请求 var subscription OpenAISubscriptionResponse - _, errWithCode := common.SendRequest(req, &subscription, false, p.Channel.Proxy) + _, errWithCode := p.Requester.SendRequest(req, &subscription, false) if errWithCode != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } @@ -37,12 +35,15 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) { } fullRequestURL = p.GetFullRequestURL(fmt.Sprintf("/v1/dashboard/billing/usage?start_date=%s&end_date=%s", startDate, endDate), "") - req, err = client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + req, err = p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers)) if err != nil { return 0, err } usage := OpenAIUsageResponse{} - _, errWithCode = common.SendRequest(req, &usage, false, p.Channel.Proxy) + _, errWithCode = p.Requester.SendRequest(req, &usage, false) + if errWithCode != nil { + return 0, errWithCode + } balance := subscription.HardLimitUSD - usage.TotalUsage/100 channel.UpdateBalance(balance) diff --git a/providers/openai/base.go b/providers/openai/base.go index 5536871e..9a295e75 100644 --- a/providers/openai/base.go +++ b/providers/openai/base.go @@ -1,69 +1,100 @@ package openai import ( - "bufio" - "bytes" "encoding/json" "fmt" - "io" "net/http" "one-api/common" + "one-api/common/requester" + "one-api/model" "one-api/types" "strings" "one-api/providers/base" - - "github.com/gin-gonic/gin" ) type OpenAIProviderFactory struct{} -// 创建 OpenAIProvider -func (f OpenAIProviderFactory) Create(c *gin.Context) base.ProviderInterface { - openAIProvider := CreateOpenAIProvider(c, "") - openAIProvider.BalanceAction = true - return openAIProvider -} - type OpenAIProvider struct { base.BaseProvider IsAzure bool BalanceAction bool } +// 创建 OpenAIProvider +func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface { + openAIProvider := CreateOpenAIProvider(channel, "https://api.openai.com") + openAIProvider.BalanceAction = true + return openAIProvider +} + // 创建 OpenAIProvider // https://platform.openai.com/docs/api-reference/introduction -func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider { - if baseURL == "" { - baseURL = "https://api.openai.com" - } +func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider { + config := getOpenAIConfig(baseURL) return &OpenAIProvider{ BaseProvider: base.BaseProvider{ - BaseURL: baseURL, - Completions: "/v1/completions", - ChatCompletions: "/v1/chat/completions", - Embeddings: "/v1/embeddings", - Moderation: "/v1/moderations", - AudioSpeech: "/v1/audio/speech", - AudioTranscriptions: "/v1/audio/transcriptions", - AudioTranslations: "/v1/audio/translations", - ImagesGenerations: "/v1/images/generations", - ImagesEdit: "/v1/images/edits", - ImagesVariations: "/v1/images/variations", - Context: c, + Config: config, + Channel: channel, + Requester: requester.NewHTTPRequester(channel.Proxy, RequestErrorHandle), }, IsAzure: false, BalanceAction: true, } } +func getOpenAIConfig(baseURL string) base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: baseURL, + Completions: "/v1/completions", + ChatCompletions: "/v1/chat/completions", + Embeddings: "/v1/embeddings", + Moderation: "/v1/moderations", + AudioSpeech: "/v1/audio/speech", + AudioTranscriptions: "/v1/audio/transcriptions", + AudioTranslations: "/v1/audio/translations", + ImagesGenerations: "/v1/images/generations", + ImagesEdit: "/v1/images/edits", + ImagesVariations: "/v1/images/variations", + } +} + +// 请求错误处理 +func RequestErrorHandle(resp *http.Response) *types.OpenAIError { + var errorResponse *types.OpenAIErrorResponse + err := json.NewDecoder(resp.Body).Decode(errorResponse) + if err != nil { + return nil + } + + return ErrorHandle(errorResponse) +} + +// 错误处理 +func ErrorHandle(openaiError *types.OpenAIErrorResponse) *types.OpenAIError { + if openaiError.Error.Message == "" { + return nil + } + return &openaiError.Error +} + // 获取完整请求 URL func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string { baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") if p.IsAzure { apiVersion := p.Channel.Other + // 以-分割,检测modelName 最后一个元素是否为4位数字,必须是数字,如果是则删除modelName最后一个元素 + modelNameSlice := strings.Split(modelName, "-") + lastModelNameSlice := modelNameSlice[len(modelNameSlice)-1] + modelNum := common.String2Int(lastModelNameSlice) + if modelNum > 999 && modelNum < 10000 { + modelName = strings.TrimSuffix(modelName, "-"+lastModelNameSlice) + } + // 检测模型是是否包含 . 如果有则直接去掉 + modelName = strings.Replace(modelName, ".", "", -1) + if modelName == "dall-e-2" { // 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本 // 已经没有dall-e-2了,所以暂时写死 @@ -72,10 +103,6 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion) } - // 检测模型是是否包含 . 如果有则直接去掉 - if strings.Contains(requestURL, ".") { - requestURL = strings.Replace(requestURL, ".", "", -1) - } } if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { @@ -102,89 +129,21 @@ func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) { return headers } -// 获取请求体 -func (p *OpenAIProvider) GetRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) { - if isModelMapped { - jsonStr, err := json.Marshal(request) - if err != nil { - return nil, err - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = p.Context.Request.Body +func (p *OpenAIProvider) GetRequestTextBody(relayMode int, ModelName string, request any) (*http.Request, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(relayMode) + if errWithCode != nil { + return nil, errWithCode } - return -} + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, ModelName) -// 发送流式请求 -func (p *OpenAIProvider) SendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) { - defer req.Body.Close() - - client := common.GetHttpClient(p.Channel.Proxy) - resp, err := client.Do(req) + // 获取请求头 + headers := p.GetRequestHeaders() + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(request), p.Requester.WithHeader(headers)) if err != nil { - return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" - } - common.PutHttpClient(client) - - if common.IsFailureStatusCode(resp) { - return common.HandleErrorResp(resp), "" + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - defer resp.Body.Close() - - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // ignore blank line or wrong format - continue - } - if data[:6] != "data: " && data[:6] != "[DONE]" { - continue - } - dataChan <- data - data = data[6:] - if !strings.HasPrefix(data, "[DONE]") { - err := json.Unmarshal([]byte(data), response) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - continue // just ignore the error - } - responseText += response.responseStreamHandler() - } - } - stopChan <- true - }() - common.SetEventStreamHeaders(p.Context) - p.Context.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if strings.HasPrefix(data, "data: [DONE]") { - data = data[:12] - } - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - p.Context.Render(-1, common.CustomEvent{Data: data}) - return true - case <-stopChan: - return false - } - }) - - return nil, responseText + return req, nil } diff --git a/providers/openai/chat.go b/providers/openai/chat.go index 142ed58f..3186c889 100644 --- a/providers/openai/chat.go +++ b/providers/openai/chat.go @@ -1,82 +1,101 @@ package openai import ( + "encoding/json" "net/http" "one-api/common" + "one-api/common/requester" "one-api/types" + "strings" ) -func (c *OpenAIProviderChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if c.Error.Type != "" { +type OpenAIStreamHandler struct { + Usage *types.Usage + ModelName string +} + +func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + response := &OpenAIProviderChatResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, response, false) + if errWithCode != nil { + return nil, errWithCode + } + + // 检测是否错误 + openaiErr := ErrorHandle(&response.OpenAIErrorResponse) + if openaiErr != nil { errWithCode = &types.OpenAIErrorWithStatusCode{ - OpenAIError: c.Error, - StatusCode: resp.StatusCode, + OpenAIError: *openaiErr, + StatusCode: http.StatusBadRequest, } - return - } - return nil, nil -} - -func (c *OpenAIProviderChatStreamResponse) responseStreamHandler() (responseText string) { - for _, choice := range c.Choices { - responseText += choice.Delta.Content + return nil, errWithCode } - return + *p.Usage = *response.Usage + + return &response.ChatCompletionResponse, nil } -func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody, err := p.GetRequestBody(&request, isModelMapped) +func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode + } + + chatHandler := OpenAIStreamHandler{ + Usage: p.Usage, + ModelName: request.Model, + } + + return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream) +} + +func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { + // 如果rawLine 前缀不为data:,则直接返回 + if !strings.HasPrefix(string(*rawLine), "data: ") { + *rawLine = nil + return nil + } + + // 去除前缀 + *rawLine = (*rawLine)[6:] + + // 如果等于 DONE 则结束 + if string(*rawLine) == "[DONE]" { + *isFinished = true + return nil + } + + var openaiResponse OpenAIProviderChatStreamResponse + err := json.Unmarshal(*rawLine, &openaiResponse) if err != nil { - return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) + return common.ErrorToOpenAIError(err) } - fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) - headers := p.GetRequestHeaders() - if request.Stream && headers["Accept"] == "" { - headers["Accept"] = "text/event-stream" + error := ErrorHandle(&openaiResponse.OpenAIErrorResponse) + if error != nil { + return error } - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } + countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName) + h.Usage.CompletionTokens += countTokenText + h.Usage.TotalTokens += countTokenText - if request.Stream { - openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{} - var textResponse string - errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderChatStreamResponse) - if errWithCode != nil { - return - } + *response = append(*response, openaiResponse.ChatCompletionStreamResponse) - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: common.CountTokenText(textResponse, request.Model), - TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model), - } - - } else { - openAIProviderChatResponse := &OpenAIProviderChatResponse{} - errWithCode = p.SendRequest(req, openAIProviderChatResponse, true) - if errWithCode != nil { - return - } - - usage = openAIProviderChatResponse.Usage - - if usage.TotalTokens == 0 { - completionTokens := 0 - for _, choice := range openAIProviderChatResponse.Choices { - completionTokens += common.CountTokenText(choice.Message.StringContent(), openAIProviderChatResponse.Model) - } - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, - } - } - } - return + return nil } diff --git a/providers/openai/completion.go b/providers/openai/completion.go index 7064db0c..81bbd505 100644 --- a/providers/openai/completion.go +++ b/providers/openai/completion.go @@ -1,82 +1,96 @@ package openai import ( + "encoding/json" "net/http" "one-api/common" + "one-api/common/requester" "one-api/types" + "strings" ) -func (c *OpenAIProviderCompletionResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if c.Error.Type != "" { +func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (openaiResponse *types.CompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + response := &OpenAIProviderCompletionResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, response, false) + if errWithCode != nil { + return nil, errWithCode + } + + // 检测是否错误 + openaiErr := ErrorHandle(&response.OpenAIErrorResponse) + if openaiErr != nil { errWithCode = &types.OpenAIErrorWithStatusCode{ - OpenAIError: c.Error, - StatusCode: resp.StatusCode, + OpenAIError: *openaiErr, + StatusCode: http.StatusBadRequest, } - return - } - return nil, nil -} - -func (c *OpenAIProviderCompletionResponse) responseStreamHandler() (responseText string) { - for _, choice := range c.Choices { - responseText += choice.Text + return nil, errWithCode } - return + *p.Usage = *response.Usage + + return &response.CompletionResponse, nil } -func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody, err := p.GetRequestBody(&request, isModelMapped) +func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[types.CompletionResponse], errWithCode *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode + } + + chatHandler := OpenAIStreamHandler{ + Usage: p.Usage, + ModelName: request.Model, + } + + return requester.RequestStream[types.CompletionResponse](p.Requester, resp, chatHandler.handlerCompletionStream) +} + +func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinished *bool, response *[]types.CompletionResponse) error { + // 如果rawLine 前缀不为data:,则直接返回 + if !strings.HasPrefix(string(*rawLine), "data: ") { + *rawLine = nil + return nil + } + + // 去除前缀 + *rawLine = (*rawLine)[6:] + + // 如果等于 DONE 则结束 + if string(*rawLine) == "[DONE]" { + *isFinished = true + return nil + } + + var openaiResponse OpenAIProviderCompletionResponse + err := json.Unmarshal(*rawLine, &openaiResponse) if err != nil { - return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) + return common.ErrorToOpenAIError(err) } - fullRequestURL := p.GetFullRequestURL(p.Completions, request.Model) - headers := p.GetRequestHeaders() - if request.Stream && headers["Accept"] == "" { - headers["Accept"] = "text/event-stream" + error := ErrorHandle(&openaiResponse.OpenAIErrorResponse) + if error != nil { + return error } - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } + countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName) + h.Usage.CompletionTokens += countTokenText + h.Usage.TotalTokens += countTokenText - openAIProviderCompletionResponse := &OpenAIProviderCompletionResponse{} - if request.Stream { - // TODO - var textResponse string - errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderCompletionResponse) - if errWithCode != nil { - return - } + *response = append(*response, openaiResponse.CompletionResponse) - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: common.CountTokenText(textResponse, request.Model), - TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model), - } - - } else { - errWithCode = p.SendRequest(req, openAIProviderCompletionResponse, true) - if errWithCode != nil { - return - } - - usage = openAIProviderCompletionResponse.Usage - - if usage.TotalTokens == 0 { - completionTokens := 0 - for _, choice := range openAIProviderCompletionResponse.Choices { - completionTokens += common.CountTokenText(choice.Text, openAIProviderCompletionResponse.Model) - } - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, - } - } - } - return + return nil } diff --git a/providers/openai/embeddings.go b/providers/openai/embeddings.go index 4c484ea8..aa7b48a0 100644 --- a/providers/openai/embeddings.go +++ b/providers/openai/embeddings.go @@ -6,40 +6,30 @@ import ( "one-api/types" ) -func (c *OpenAIProviderEmbeddingsResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if c.Error.Type != "" { - errWithCode = &types.OpenAIErrorWithStatusCode{ - OpenAIError: c.Error, - StatusCode: resp.StatusCode, - } - return - } - return nil, nil -} - -func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - - requestBody, err := p.GetRequestBody(&request, isModelMapped) - if err != nil { - return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) - } - - fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model) - headers := p.GetRequestHeaders() - - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - - openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{} - errWithCode = p.SendRequest(req, openAIProviderEmbeddingsResponse, true) +func (p *OpenAIProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.GetRequestTextBody(common.RelayModeEmbeddings, request.Model, request) if errWithCode != nil { - return + return nil, errWithCode + } + defer req.Body.Close() + + response := &OpenAIProviderEmbeddingsResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, response, false) + if errWithCode != nil { + return nil, errWithCode } - usage = openAIProviderEmbeddingsResponse.Usage + openaiErr := ErrorHandle(&response.OpenAIErrorResponse) + if openaiErr != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *openaiErr, + StatusCode: http.StatusBadRequest, + } + return nil, errWithCode + } - return + *p.Usage = *response.Usage + + return &response.EmbeddingResponse, nil } diff --git a/providers/openai/image_edits.go b/providers/openai/image_edits.go index 60a3fa7c..376459bc 100644 --- a/providers/openai/image_edits.go +++ b/providers/openai/image_edits.go @@ -5,28 +5,71 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/requester" "one-api/types" ) -func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - fullRequestURL := p.GetFullRequestURL(p.ImagesEdit, request.Model) +func (p *OpenAIProvider) CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getRequestImageBody(common.RelayModeEdits, request.Model, request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + response := &OpenAIProviderImageResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, response, false) + if errWithCode != nil { + return nil, errWithCode + } + + openaiErr := ErrorHandle(&response.OpenAIErrorResponse) + if openaiErr != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *openaiErr, + StatusCode: http.StatusBadRequest, + } + return nil, errWithCode + } + + p.Usage.TotalTokens = p.Usage.PromptTokens + + return &response.ImageResponse, nil +} + +func (p *OpenAIProvider) getRequestImageBody(relayMode int, ModelName string, request *types.ImageEditRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(relayMode) + if errWithCode != nil { + return nil, errWithCode + } + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, ModelName) + + // 获取请求头 headers := p.GetRequestHeaders() - - client := common.NewClient() - - var formBody bytes.Buffer + // 创建请求 var req *http.Request var err error - if isModelMapped { - builder := client.CreateFormBuilder(&formBody) + if p.OriginalModel != request.Model { + var formBody bytes.Buffer + builder := p.Requester.CreateFormBuilder(&formBody) if err := imagesEditsMultipartForm(request, builder); err != nil { return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError) } - req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType())) + req, err = p.Requester.NewRequest( + http.MethodPost, + fullRequestURL, + p.Requester.WithBody(&formBody), + p.Requester.WithHeader(headers), + p.Requester.WithContentType(builder.FormDataContentType())) req.ContentLength = int64(formBody.Len()) - } else { - req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type"))) + req, err = p.Requester.NewRequest( + http.MethodPost, + fullRequestURL, + p.Requester.WithBody(p.Context.Request.Body), + p.Requester.WithHeader(headers), + p.Requester.WithContentType(p.Context.Request.Header.Get("Content-Type"))) req.ContentLength = p.Context.Request.ContentLength } @@ -34,22 +77,10 @@ func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isMod return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{} - errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true) - if errWithCode != nil { - return - } - - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: 0, - TotalTokens: promptTokens, - } - - return + return req, nil } -func imagesEditsMultipartForm(request *types.ImageEditRequest, b common.FormBuilder) error { +func imagesEditsMultipartForm(request *types.ImageEditRequest, b requester.FormBuilder) error { err := b.CreateFormFile("image", request.Image) if err != nil { return fmt.Errorf("creating form image: %w", err) diff --git a/providers/openai/image_generations.go b/providers/openai/image_generations.go index f0ba865b..53d4796e 100644 --- a/providers/openai/image_generations.go +++ b/providers/openai/image_generations.go @@ -6,53 +6,41 @@ import ( "one-api/types" ) -func (c *OpenAIProviderImageResponseResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if c.Error.Type != "" { - errWithCode = &types.OpenAIErrorWithStatusCode{ - OpenAIError: c.Error, - StatusCode: resp.StatusCode, - } - return - } - return nil, nil -} - -func (p *OpenAIProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - - if !isWithinRange(request.Model, request.N) { +func (p *OpenAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) { + if !IsWithinRange(request.Model, request.N) { return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest) } - requestBody, err := p.GetRequestBody(&request, isModelMapped) - if err != nil { - return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) - } - - fullRequestURL := p.GetFullRequestURL(p.ImagesGenerations, request.Model) - headers := p.GetRequestHeaders() - - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - - openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{} - errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true) + req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request) if errWithCode != nil { - return + return nil, errWithCode + } + defer req.Body.Close() + + response := &OpenAIProviderImageResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, response, false) + if errWithCode != nil { + return nil, errWithCode } - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: 0, - TotalTokens: promptTokens, + // 检测是否错误 + openaiErr := ErrorHandle(&response.OpenAIErrorResponse) + if openaiErr != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *openaiErr, + StatusCode: http.StatusBadRequest, + } + return nil, errWithCode } - return + p.Usage.TotalTokens = p.Usage.PromptTokens + + return &response.ImageResponse, nil + } -func isWithinRange(element string, value int) bool { +func IsWithinRange(element string, value int) bool { if _, ok := common.DalleGenerationImageAmounts[element]; !ok { return false } diff --git a/providers/openai/image_variations.go b/providers/openai/image_variations.go index f7d65c31..4160d28d 100644 --- a/providers/openai/image_variations.go +++ b/providers/openai/image_variations.go @@ -1,49 +1,35 @@ package openai import ( - "bytes" "net/http" "one-api/common" "one-api/types" ) -func (p *OpenAIProvider) ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - fullRequestURL := p.GetFullRequestURL(p.ImagesVariations, request.Model) - headers := p.GetRequestHeaders() - - client := common.NewClient() - - var formBody bytes.Buffer - var req *http.Request - var err error - if isModelMapped { - builder := client.CreateFormBuilder(&formBody) - if err := imagesEditsMultipartForm(request, builder); err != nil { - return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError) - } - req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType())) - req.ContentLength = int64(formBody.Len()) - - } else { - req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type"))) - req.ContentLength = p.Context.Request.ContentLength - } - - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - - openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{} - errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true) +func (p *OpenAIProvider) CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getRequestImageBody(common.RelayModeImagesVariations, request.Model, request) if errWithCode != nil { - return + return nil, errWithCode + } + defer req.Body.Close() + + response := &OpenAIProviderImageResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, response, false) + if errWithCode != nil { + return nil, errWithCode } - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: 0, - TotalTokens: promptTokens, + openaiErr := ErrorHandle(&response.OpenAIErrorResponse) + if openaiErr != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *openaiErr, + StatusCode: http.StatusBadRequest, + } + return nil, errWithCode } - return + p.Usage.TotalTokens = p.Usage.PromptTokens + + return &response.ImageResponse, nil } diff --git a/providers/openai/moderation.go b/providers/openai/moderation.go index 35739082..0cbd1bad 100644 --- a/providers/openai/moderation.go +++ b/providers/openai/moderation.go @@ -6,44 +6,31 @@ import ( "one-api/types" ) -func (c *OpenAIProviderModerationResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if c.Error.Type != "" { - errWithCode = &types.OpenAIErrorWithStatusCode{ - OpenAIError: c.Error, - StatusCode: resp.StatusCode, - } - return - } - return nil, nil -} +func (p *OpenAIProvider) CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode) { -func (p *OpenAIProvider) ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - - requestBody, err := p.GetRequestBody(&request, isModelMapped) - if err != nil { - return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) - } - - fullRequestURL := p.GetFullRequestURL(p.Moderation, request.Model) - headers := p.GetRequestHeaders() - - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - - openAIProviderModerationResponse := &OpenAIProviderModerationResponse{} - errWithCode = p.SendRequest(req, openAIProviderModerationResponse, true) + req, errWithCode := p.GetRequestTextBody(common.RelayModeModerations, request.Model, request) if errWithCode != nil { - return + return nil, errWithCode + } + defer req.Body.Close() + + response := &OpenAIProviderModerationResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, response, false) + if errWithCode != nil { + return nil, errWithCode } - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: 0, - TotalTokens: promptTokens, + openaiErr := ErrorHandle(&response.OpenAIErrorResponse) + if openaiErr != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *openaiErr, + StatusCode: http.StatusBadRequest, + } + return nil, errWithCode } - return + p.Usage.TotalTokens = p.Usage.PromptTokens + + return &response.ModerationResponse, nil } diff --git a/providers/openai/speech.go b/providers/openai/speech.go index b4c3eb07..1df76cad 100644 --- a/providers/openai/speech.go +++ b/providers/openai/speech.go @@ -3,35 +3,29 @@ package openai import ( "net/http" "one-api/common" + "one-api/common/requester" "one-api/types" ) -func (p *OpenAIProvider) SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - - requestBody, err := p.GetRequestBody(&request, isModelMapped) - if err != nil { - return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) - } - - fullRequestURL := p.GetFullRequestURL(p.AudioSpeech, request.Model) - headers := p.GetRequestHeaders() - - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - - errWithCode = p.SendRequestRaw(req) +func (p *OpenAIProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.GetRequestTextBody(common.RelayModeAudioSpeech, request.Model, request) if errWithCode != nil { - return + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + var resp *http.Response + resp, errWithCode = p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode } - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: 0, - TotalTokens: promptTokens, + if resp.Header.Get("Content-Type") == "application/json" { + return nil, requester.HandleErrorResp(resp, p.Requester.ErrorHandler) } - return + p.Usage.TotalTokens = p.Usage.PromptTokens + + return resp, nil } diff --git a/providers/openai/transcriptions.go b/providers/openai/transcriptions.go index e55fa903..8be9aeb4 100644 --- a/providers/openai/transcriptions.go +++ b/providers/openai/transcriptions.go @@ -4,48 +4,99 @@ import ( "bufio" "bytes" "fmt" + "io" "net/http" "one-api/common" + "one-api/common/requester" "one-api/types" "regexp" "strconv" "strings" ) -func (c *OpenAIProviderTranscriptionsResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if c.Error.Type != "" { - errWithCode = &types.OpenAIErrorWithStatusCode{ - OpenAIError: c.Error, - StatusCode: resp.StatusCode, - } - return +func (p *OpenAIProvider) CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranscription, request.Model, request) + if errWithCode != nil { + return nil, errWithCode } - return nil, nil + defer req.Body.Close() + + var textResponse string + var resp *http.Response + var err error + audioResponseWrapper := &types.AudioResponseWrapper{} + if hasJSONResponse(request) { + openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{} + resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsResponse, true) + if errWithCode != nil { + return nil, errWithCode + } + textResponse = openAIProviderTranscriptionsResponse.Text + } else { + openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse) + resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsTextResponse, true) + if errWithCode != nil { + return nil, errWithCode + } + textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat) + } + + defer resp.Body.Close() + + audioResponseWrapper.Headers = map[string]string{ + "Content-Type": resp.Header.Get("Content-Type"), + } + + audioResponseWrapper.Body, err = io.ReadAll(resp.Body) + if err != nil { + return nil, common.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + + completionTokens := common.CountTokenText(textResponse, request.Model) + + p.Usage.CompletionTokens = completionTokens + p.Usage.TotalTokens = p.Usage.PromptTokens + p.Usage.CompletionTokens + + return audioResponseWrapper, nil } -func (c *OpenAIProviderTranscriptionsTextResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - return nil, nil +func hasJSONResponse(request *types.AudioRequest) bool { + return request.ResponseFormat == "" || request.ResponseFormat == "json" || request.ResponseFormat == "verbose_json" } -func (p *OpenAIProvider) TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - fullRequestURL := p.GetFullRequestURL(p.AudioTranscriptions, request.Model) +func (p *OpenAIProvider) getRequestAudioBody(relayMode int, ModelName string, request *types.AudioRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(relayMode) + if errWithCode != nil { + return nil, errWithCode + } + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, ModelName) + + // 获取请求头 headers := p.GetRequestHeaders() - - client := common.NewClient() - - var formBody bytes.Buffer + // 创建请求 var req *http.Request var err error - if isModelMapped { - builder := client.CreateFormBuilder(&formBody) + if p.OriginalModel != request.Model { + var formBody bytes.Buffer + builder := p.Requester.CreateFormBuilder(&formBody) if err := audioMultipartForm(request, builder); err != nil { return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError) } - req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType())) + req, err = p.Requester.NewRequest( + http.MethodPost, + fullRequestURL, + p.Requester.WithBody(&formBody), + p.Requester.WithHeader(headers), + p.Requester.WithContentType(builder.FormDataContentType())) req.ContentLength = int64(formBody.Len()) - } else { - req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type"))) + req, err = p.Requester.NewRequest( + http.MethodPost, + fullRequestURL, + p.Requester.WithBody(p.Context.Request.Body), + p.Requester.WithHeader(headers), + p.Requester.WithContentType(p.Context.Request.Header.Get("Content-Type"))) req.ContentLength = p.Context.Request.ContentLength } @@ -53,37 +104,10 @@ func (p *OpenAIProvider) TranscriptionsAction(request *types.AudioRequest, isMod return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - var textResponse string - if hasJSONResponse(request) { - openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{} - errWithCode = p.SendRequest(req, openAIProviderTranscriptionsResponse, true) - if errWithCode != nil { - return - } - textResponse = openAIProviderTranscriptionsResponse.Text - } else { - openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse) - errWithCode = p.SendRequest(req, openAIProviderTranscriptionsTextResponse, true) - if errWithCode != nil { - return - } - textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat) - } - - completionTokens := common.CountTokenText(textResponse, request.Model) - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, - } - return + return req, nil } -func hasJSONResponse(request *types.AudioRequest) bool { - return request.ResponseFormat == "" || request.ResponseFormat == "json" || request.ResponseFormat == "verbose_json" -} - -func audioMultipartForm(request *types.AudioRequest, b common.FormBuilder) error { +func audioMultipartForm(request *types.AudioRequest, b requester.FormBuilder) error { err := b.CreateFormFile("file", request.File) if err != nil { diff --git a/providers/openai/translations.go b/providers/openai/translations.go index 3114be0a..1bc4b581 100644 --- a/providers/openai/translations.go +++ b/providers/openai/translations.go @@ -1,60 +1,53 @@ package openai import ( - "bytes" + "io" "net/http" "one-api/common" "one-api/types" ) -func (p *OpenAIProvider) TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - fullRequestURL := p.GetFullRequestURL(p.AudioTranslations, request.Model) - headers := p.GetRequestHeaders() - - client := common.NewClient() - - var formBody bytes.Buffer - var req *http.Request - var err error - if isModelMapped { - builder := client.CreateFormBuilder(&formBody) - if err := audioMultipartForm(request, builder); err != nil { - return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError) - } - req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType())) - req.ContentLength = int64(formBody.Len()) - - } else { - req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type"))) - req.ContentLength = p.Context.Request.ContentLength - } - - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) +func (p *OpenAIProvider) CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranslation, request.Model, request) + if errWithCode != nil { + return nil, errWithCode } + defer req.Body.Close() var textResponse string + var resp *http.Response + var err error + audioResponseWrapper := &types.AudioResponseWrapper{} if hasJSONResponse(request) { openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{} - errWithCode = p.SendRequest(req, openAIProviderTranscriptionsResponse, true) + resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsResponse, true) if errWithCode != nil { - return + return nil, errWithCode } textResponse = openAIProviderTranscriptionsResponse.Text } else { openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse) - errWithCode = p.SendRequest(req, openAIProviderTranscriptionsTextResponse, true) + resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsTextResponse, true) if errWithCode != nil { - return + return nil, errWithCode } textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat) } + defer resp.Body.Close() + + audioResponseWrapper.Headers = map[string]string{ + "Content-Type": resp.Header.Get("Content-Type"), + } + + audioResponseWrapper.Body, err = io.ReadAll(resp.Body) + if err != nil { + return nil, common.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } completionTokens := common.CountTokenText(textResponse, request.Model) - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, - } - return + + p.Usage.CompletionTokens = completionTokens + p.Usage.TotalTokens = p.Usage.PromptTokens + p.Usage.CompletionTokens + + return audioResponseWrapper, nil } diff --git a/providers/openai/type.go b/providers/openai/type.go index b17e513b..0670dca9 100644 --- a/providers/openai/type.go +++ b/providers/openai/type.go @@ -12,11 +12,27 @@ type OpenAIProviderChatStreamResponse struct { types.OpenAIErrorResponse } +func (c *OpenAIProviderChatStreamResponse) getResponseText() (responseText string) { + for _, choice := range c.Choices { + responseText += choice.Delta.Content + } + + return +} + type OpenAIProviderCompletionResponse struct { types.CompletionResponse types.OpenAIErrorResponse } +func (c *OpenAIProviderCompletionResponse) getResponseText() (responseText string) { + for _, choice := range c.Choices { + responseText += choice.Text + } + + return +} + type OpenAIProviderEmbeddingsResponse struct { types.EmbeddingResponse types.OpenAIErrorResponse @@ -38,7 +54,7 @@ func (a *OpenAIProviderTranscriptionsTextResponse) GetString() *string { return (*string)(a) } -type OpenAIProviderImageResponseResponse struct { +type OpenAIProviderImageResponse struct { types.ImageResponse types.OpenAIErrorResponse } diff --git a/providers/openaisb/balance.go b/providers/openaisb/balance.go index d67b03f4..3873a3b3 100644 --- a/providers/openaisb/balance.go +++ b/providers/openaisb/balance.go @@ -3,7 +3,6 @@ package openaisb import ( "errors" "fmt" - "one-api/common" "one-api/model" "strconv" ) @@ -13,15 +12,14 @@ func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) { fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key) headers := p.GetRequestHeaders() - client := common.NewClient() - req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers)) if err != nil { return 0, err } // 发送请求 var response OpenAISBUsageResponse - _, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) + _, errWithCode := p.Requester.SendRequest(req, &response, false) if err != nil { return 0, errors.New(errWithCode.OpenAIError.Message) } diff --git a/providers/openaisb/base.go b/providers/openaisb/base.go index c770d3e4..6e6841b1 100644 --- a/providers/openaisb/base.go +++ b/providers/openaisb/base.go @@ -1,18 +1,17 @@ package openaisb import ( + "one-api/model" "one-api/providers/base" "one-api/providers/openai" - - "github.com/gin-gonic/gin" ) type OpenaiSBProviderFactory struct{} // 创建 OpenaiSBProvider -func (f OpenaiSBProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f OpenaiSBProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &OpenaiSBProvider{ - OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.openai-sb.com"), + OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.openai-sb.com"), } } diff --git a/providers/palm/base.go b/providers/palm/base.go index f500b418..a314e655 100644 --- a/providers/palm/base.go +++ b/providers/palm/base.go @@ -1,22 +1,25 @@ package palm import ( + "encoding/json" "fmt" + "net/http" + "one-api/common/requester" + "one-api/model" "one-api/providers/base" + "one-api/types" "strings" - - "github.com/gin-gonic/gin" ) type PalmProviderFactory struct{} // 创建 PalmProvider -func (f PalmProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f PalmProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &PalmProvider{ BaseProvider: base.BaseProvider{ - BaseURL: "https://generativelanguage.googleapis.com", - ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage", - Context: c, + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle), }, } } @@ -25,6 +28,37 @@ type PalmProvider struct { base.BaseProvider } +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://generativelanguage.googleapis.com", + ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage", + } +} + +// 请求错误处理 +func requestErrorHandle(resp *http.Response) *types.OpenAIError { + var palmError *PaLMErrorResponse + err := json.NewDecoder(resp.Body).Decode(palmError) + if err != nil { + return nil + } + + return errorHandle(palmError) +} + +// 错误处理 +func errorHandle(palmError *PaLMErrorResponse) *types.OpenAIError { + if palmError.Error.Code == 0 { + return nil + } + return &types.OpenAIError{ + Message: palmError.Error.Message, + Type: "palm_error", + Param: palmError.Error.Status, + Code: palmError.Error.Code, + } +} + // 获取请求头 func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) diff --git a/providers/palm/chat.go b/providers/palm/chat.go index 81dd1777..f7fd4fdd 100644 --- a/providers/palm/chat.go +++ b/providers/palm/chat.go @@ -3,30 +3,98 @@ package palm import ( "encoding/json" "fmt" - "io" "net/http" "one-api/common" - "one-api/providers/base" + "one-api/common/requester" "one-api/types" + "strings" ) -func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { - return nil, &types.OpenAIErrorWithStatusCode{ - OpenAIError: types.OpenAIError{ - Message: palmResponse.Error.Message, - Type: palmResponse.Error.Status, - Param: "", - Code: palmResponse.Error.Code, - }, - StatusCode: resp.StatusCode, - } +type palmStreamHandler struct { + Usage *types.Usage + Request *types.ChatCompletionRequest +} + +func (p *PalmProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + palmResponse := &PaLMChatResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, palmResponse, false) + if errWithCode != nil { + return nil, errWithCode } - fullTextResponse := types.ChatCompletionResponse{ - Choices: make([]types.ChatCompletionChoice, 0, len(palmResponse.Candidates)), + return p.convertToChatOpenai(palmResponse, request) +} + +func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode } - for i, candidate := range palmResponse.Candidates { + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode + } + + chatHandler := &palmStreamHandler{ + Usage: p.Usage, + Request: request, + } + + return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) +} + +func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + if errWithCode != nil { + return nil, errWithCode + } + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, request.Model) + if fullRequestURL == "" { + return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError) + } + + // 获取请求头 + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + } + + palmRequest := convertFromChatOpenai(request) + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(palmRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + return req, nil +} + +func (p *PalmProvider) convertToChatOpenai(response *PaLMChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + error := errorHandle(&response.PaLMErrorResponse) + if error != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *error, + StatusCode: http.StatusBadRequest, + } + return + } + + openaiResponse = &types.ChatCompletionResponse{ + Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)), + Model: request.Model, + } + for i, candidate := range response.Candidates { choice := types.ChatCompletionChoice{ Index: i, Message: types.ChatCompletionMessage{ @@ -35,20 +103,21 @@ func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (Open }, FinishReason: "stop", } - fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + openaiResponse.Choices = append(openaiResponse.Choices, choice) } - completionTokens := common.CountTokenText(palmResponse.Candidates[0].Content, palmResponse.Model) - palmResponse.Usage.CompletionTokens = completionTokens - palmResponse.Usage.TotalTokens = palmResponse.Usage.PromptTokens + completionTokens + completionTokens := common.CountTokenText(response.Candidates[0].Content, request.Model) + response.Usage.CompletionTokens = completionTokens + response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens - fullTextResponse.Usage = palmResponse.Usage - fullTextResponse.Model = palmResponse.Model + openaiResponse.Usage = response.Usage - return fullTextResponse, nil + *p.Usage = *response.Usage + + return } -func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest) *PaLMChatRequest { +func convertFromChatOpenai(request *types.ChatCompletionRequest) *PaLMChatRequest { palmRequest := PaLMChatRequest{ Prompt: PaLMPrompt{ Messages: make([]PaLMChatMessage, 0, len(request.Messages)), @@ -72,132 +141,51 @@ func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest) return &palmRequest } -func (p *PalmProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody := p.getChatRequestBody(request) - fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) - headers := p.GetRequestHeaders() - if request.Stream { - headers["Accept"] = "text/event-stream" +// 转换为OpenAI聊天流式请求体 +func (h *palmStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { + // 如果rawLine 前缀不为data:,则直接返回 + if !strings.HasPrefix(string(*rawLine), "data: ") { + *rawLine = nil + return nil } - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + // 去除前缀 + *rawLine = (*rawLine)[6:] + + var palmChatResponse PaLMChatResponse + err := json.Unmarshal(*rawLine, &palmChatResponse) if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return common.ErrorToOpenAIError(err) } - if request.Stream { - var responseText string - errWithCode, responseText = p.sendStreamRequest(req) - if errWithCode != nil { - return - } - - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: common.CountTokenText(responseText, request.Model), - } - usage.TotalTokens = promptTokens + usage.CompletionTokens - - } else { - var palmChatResponse = &PaLMChatResponse{ - Model: request.Model, - Usage: &types.Usage{ - PromptTokens: promptTokens, - }, - } - errWithCode = p.SendRequest(req, palmChatResponse, false) - if errWithCode != nil { - return - } - - usage = palmChatResponse.Usage + error := errorHandle(&palmChatResponse.PaLMErrorResponse) + if error != nil { + return error } - return + + return h.convertToOpenaiStream(&palmChatResponse, response) } -func (p *PalmProvider) streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *types.ChatCompletionStreamResponse { +func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, response *[]types.ChatCompletionStreamResponse) error { var choice types.ChatCompletionStreamChoice - if len(palmResponse.Candidates) > 0 { - choice.Delta.Content = palmResponse.Candidates[0].Content + if len(palmChatResponse.Candidates) > 0 { + choice.Delta.Content = palmChatResponse.Candidates[0].Content } - choice.FinishReason = &base.StopFinishReason - var response types.ChatCompletionStreamResponse - response.Object = "chat.completion.chunk" - response.Model = "palm2" - response.Choices = []types.ChatCompletionStreamChoice{choice} - return &response -} - -func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) { - defer req.Body.Close() - - // 发送请求 - client := common.GetHttpClient(p.Channel.Proxy) - resp, err := client.Do(req) - if err != nil { - return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" - } - common.PutHttpClient(client) - - if common.IsFailureStatusCode(resp) { - return common.HandleErrorResp(resp), "" - } - - defer resp.Body.Close() - - responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - createdTime := common.GetTimestamp() - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - common.SysError("error reading stream response: " + err.Error()) - stopChan <- true - return - } - err = resp.Body.Close() - if err != nil { - common.SysError("error closing stream response: " + err.Error()) - stopChan <- true - return - } - var palmResponse PaLMChatResponse - err = json.Unmarshal(responseBody, &palmResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - stopChan <- true - return - } - fullTextResponse := p.streamResponsePaLM2OpenAI(&palmResponse) - fullTextResponse.ID = responseId - fullTextResponse.Created = createdTime - if len(palmResponse.Candidates) > 0 { - responseText = palmResponse.Candidates[0].Content - } - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - stopChan <- true - return - } - dataChan <- string(jsonResponse) - stopChan <- true - }() - common.SetEventStreamHeaders(p.Context) - p.Context.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - p.Context.Render(-1, common.CustomEvent{Data: "data: " + data}) - return true - case <-stopChan: - p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - - return nil, responseText + choice.FinishReason = types.FinishReasonStop + + streamResponse := types.ChatCompletionStreamResponse{ + ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion.chunk", + Model: h.Request.Model, + Choices: []types.ChatCompletionStreamChoice{choice}, + Created: common.GetTimestamp(), + } + + *response = append(*response, streamResponse) + + h.Usage.CompletionTokens += common.CountTokenText(palmChatResponse.Candidates[0].Content, h.Request.Model) + h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens + + return nil } diff --git a/providers/palm/type.go b/providers/palm/type.go index 76eadded..ead9a768 100644 --- a/providers/palm/type.go +++ b/providers/palm/type.go @@ -30,11 +30,15 @@ type PaLMError struct { Status string `json:"status"` } +type PaLMErrorResponse struct { + Error PaLMError `json:"error,omitempty"` +} + type PaLMChatResponse struct { Candidates []PaLMChatMessage `json:"candidates"` Messages []types.ChatCompletionMessage `json:"messages"` Filters []PaLMFilter `json:"filters"` - Error PaLMError `json:"error"` Usage *types.Usage `json:"usage,omitempty"` Model string `json:"model,omitempty"` + PaLMErrorResponse } diff --git a/providers/providers.go b/providers/providers.go index c5de8094..02122e42 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -28,7 +28,7 @@ import ( // 定义供应商工厂接口 type ProviderFactory interface { - Create(c *gin.Context) base.ProviderInterface + Create(Channel *model.Channel) base.ProviderInterface } // 创建全局的供应商工厂映射 @@ -71,11 +71,11 @@ func GetProvider(channel *model.Channel, c *gin.Context) base.ProviderInterface return nil } - provider = openai.CreateOpenAIProvider(c, baseURL) + provider = openai.CreateOpenAIProvider(channel, baseURL) } else { - provider = factory.Create(c) + provider = factory.Create(channel) } - provider.SetChannel(channel) + provider.SetContext(c) return provider } diff --git a/providers/tencent/base.go b/providers/tencent/base.go index 7c259f9f..135d8ecf 100644 --- a/providers/tencent/base.go +++ b/providers/tencent/base.go @@ -4,25 +4,28 @@ import ( "crypto/hmac" "crypto/sha1" "encoding/base64" + "encoding/json" "errors" "fmt" + "net/http" + "one-api/common/requester" + "one-api/model" "one-api/providers/base" + "one-api/types" "sort" "strconv" "strings" - - "github.com/gin-gonic/gin" ) type TencentProviderFactory struct{} // 创建 TencentProvider -func (f TencentProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f TencentProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &TencentProvider{ BaseProvider: base.BaseProvider{ - BaseURL: "https://hunyuan.cloud.tencent.com", - ChatCompletions: "/hyllm/v1/chat/completions", - Context: c, + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle), }, } } @@ -31,6 +34,36 @@ type TencentProvider struct { base.BaseProvider } +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://hunyuan.cloud.tencent.com", + ChatCompletions: "/hyllm/v1/chat/completions", + } +} + +// 请求错误处理 +func requestErrorHandle(resp *http.Response) *types.OpenAIError { + var tencentError *TencentResponseError + err := json.NewDecoder(resp.Body).Decode(tencentError) + if err != nil { + return nil + } + + return errorHandle(tencentError) +} + +// 错误处理 +func errorHandle(tencentError *TencentResponseError) *types.OpenAIError { + if tencentError.Error.Code == 0 { + return nil + } + return &types.OpenAIError{ + Message: tencentError.Error.Message, + Type: "tencent_error", + Code: tencentError.Error.Code, + } +} + // 获取请求头 func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) @@ -77,7 +110,7 @@ func (p *TencentProvider) getTencentSign(req TencentChatRequest) string { messageStr = strings.TrimSuffix(messageStr, ",") params = append(params, "messages=["+messageStr+"]") - sort.Sort(sort.StringSlice(params)) + sort.Strings(params) url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") mac := hmac.New(sha1.New, []byte(secretKey)) signURL := url diff --git a/providers/tencent/chat.go b/providers/tencent/chat.go index 1965c549..32e2e673 100644 --- a/providers/tencent/chat.go +++ b/providers/tencent/chat.go @@ -1,50 +1,127 @@ package tencent import ( - "bufio" "encoding/json" "errors" - "io" "net/http" "one-api/common" - "one-api/providers/base" + "one-api/common/requester" "one-api/types" "strings" ) -func (TencentResponse *TencentChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if TencentResponse.Error.Code != 0 { - return &types.OpenAIErrorWithStatusCode{ - OpenAIError: types.OpenAIError{ - Message: TencentResponse.Error.Message, - Code: TencentResponse.Error.Code, - }, - StatusCode: resp.StatusCode, - }, nil +type tencentStreamHandler struct { + Usage *types.Usage + Request *types.ChatCompletionRequest +} + +func (p *TencentProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + tencentChatResponse := &TencentChatResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, tencentChatResponse, false) + if errWithCode != nil { + return nil, errWithCode } - fullTextResponse := types.ChatCompletionResponse{ + return p.convertToChatOpenai(tencentChatResponse, request) +} + +func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode + } + + chatHandler := &tencentStreamHandler{ + Usage: p.Usage, + Request: request, + } + + return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) +} + +func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + if errWithCode != nil { + return nil, errWithCode + } + + tencentRequest := convertFromChatOpenai(request) + + sign := p.getTencentSign(*tencentRequest) + if sign == "" { + return nil, common.ErrorWrapper(errors.New("get tencent sign failed"), "get_tencent_sign_failed", http.StatusInternalServerError) + } + + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, request.Model) + if fullRequestURL == "" { + return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError) + } + + // 获取请求头 + headers := p.GetRequestHeaders() + headers["Authorization"] = sign + if request.Stream { + headers["Accept"] = "text/event-stream" + } + + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(tencentRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + return req, nil +} + +func (p *TencentProvider) convertToChatOpenai(response *TencentChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + error := errorHandle(&response.TencentResponseError) + if error != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *error, + StatusCode: http.StatusBadRequest, + } + return + } + + openaiResponse = &types.ChatCompletionResponse{ Object: "chat.completion", Created: common.GetTimestamp(), - Usage: TencentResponse.Usage, - Model: TencentResponse.Model, + Usage: response.Usage, + Model: request.Model, } - if len(TencentResponse.Choices) > 0 { + if len(response.Choices) > 0 { choice := types.ChatCompletionChoice{ Index: 0, Message: types.ChatCompletionMessage{ Role: "assistant", - Content: TencentResponse.Choices[0].Messages.Content, + Content: response.Choices[0].Messages.Content, }, - FinishReason: TencentResponse.Choices[0].FinishReason, + FinishReason: response.Choices[0].FinishReason, } - fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + openaiResponse.Choices = append(openaiResponse.Choices, choice) } - return fullTextResponse, nil + *p.Usage = *response.Usage + + return } -func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionRequest) *TencentChatRequest { +func convertFromChatOpenai(request *types.ChatCompletionRequest) *TencentChatRequest { messages := make([]TencentMessage, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] @@ -79,143 +156,51 @@ func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionReques } } -func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody := p.getChatRequestBody(request) - sign := p.getTencentSign(*requestBody) - if sign == "" { - return nil, common.ErrorWrapper(errors.New("get tencent sign failed"), "get_tencent_sign_failed", http.StatusInternalServerError) +// 转换为OpenAI聊天流式请求体 +func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { + // 如果rawLine 前缀不为data:,则直接返回 + if !strings.HasPrefix(string(*rawLine), "data:") { + *rawLine = nil + return nil } - fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) - headers := p.GetRequestHeaders() - headers["Authorization"] = sign - if request.Stream { - headers["Accept"] = "text/event-stream" - } + // 去除前缀 + *rawLine = (*rawLine)[5:] - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + var tencentChatResponse TencentChatResponse + err := json.Unmarshal(*rawLine, &tencentChatResponse) if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return common.ErrorToOpenAIError(err) } - if request.Stream { - var responseText string - errWithCode, responseText = p.sendStreamRequest(req, request.Model) - if errWithCode != nil { - return - } - - usage = &types.Usage{ - PromptTokens: promptTokens, - CompletionTokens: common.CountTokenText(responseText, request.Model), - } - usage.TotalTokens = promptTokens + usage.CompletionTokens - - } else { - tencentResponse := &TencentChatResponse{ - Model: request.Model, - } - errWithCode = p.SendRequest(req, tencentResponse, false) - if errWithCode != nil { - return - } - - usage = tencentResponse.Usage + error := errorHandle(&tencentChatResponse.TencentResponseError) + if error != nil { + return error } - return + + return h.convertToOpenaiStream(&tencentChatResponse, response) } -func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *types.ChatCompletionStreamResponse { - response := types.ChatCompletionStreamResponse{ +func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, response *[]types.ChatCompletionStreamResponse) error { + streamResponse := types.ChatCompletionStreamResponse{ Object: "chat.completion.chunk", Created: common.GetTimestamp(), - Model: TencentResponse.Model, + Model: h.Request.Model, } - if len(TencentResponse.Choices) > 0 { + if len(tencentChatResponse.Choices) > 0 { var choice types.ChatCompletionStreamChoice - choice.Delta.Content = TencentResponse.Choices[0].Delta.Content - if TencentResponse.Choices[0].FinishReason == "stop" { - choice.FinishReason = &base.StopFinishReason + choice.Delta.Content = tencentChatResponse.Choices[0].Delta.Content + if tencentChatResponse.Choices[0].FinishReason == "stop" { + choice.FinishReason = types.FinishReasonStop } - response.Choices = append(response.Choices, choice) + streamResponse.Choices = append(streamResponse.Choices, choice) } - return &response -} - -func (p *TencentProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) { - defer req.Body.Close() - // 发送请求 - client := common.GetHttpClient(p.Channel.Proxy) - resp, err := client.Do(req) - if err != nil { - return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" - } - common.PutHttpClient(client) - - if common.IsFailureStatusCode(resp) { - return common.HandleErrorResp(resp), "" - } - - defer resp.Body.Close() - - var responseText string - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() - common.SetEventStreamHeaders(p.Context) - p.Context.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var TencentResponse TencentChatResponse - err := json.Unmarshal([]byte(data), &TencentResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - TencentResponse.Model = model - response := p.streamResponseTencent2OpenAI(&TencentResponse) - if len(response.Choices) != 0 { - responseText += response.Choices[0].Delta.Content - } - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - return nil, responseText + + *response = append(*response, streamResponse) + + h.Usage.CompletionTokens += common.CountTokenText(tencentChatResponse.Choices[0].Delta.Content, h.Request.Model) + h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens + + return nil } diff --git a/providers/tencent/type.go b/providers/tencent/type.go index 9783b920..e92a1f92 100644 --- a/providers/tencent/type.go +++ b/providers/tencent/type.go @@ -50,13 +50,17 @@ type TencentResponseChoices struct { Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 } +type TencentResponseError struct { + Error TencentError `json:"error,omitempty"` +} + type TencentChatResponse struct { Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 Created string `json:"created,omitempty"` // unix 时间戳的字符串 Id string `json:"id,omitempty"` // 会话 id Usage *types.Usage `json:"usage,omitempty"` // token 数量 - Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 Note string `json:"note,omitempty"` // 注释 ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 Model string `json:"model,omitempty"` // 模型名称 + TencentResponseError } diff --git a/providers/xunfei/base.go b/providers/xunfei/base.go index eef99b2a..e4cd525f 100644 --- a/providers/xunfei/base.go +++ b/providers/xunfei/base.go @@ -7,31 +7,53 @@ import ( "fmt" "net/url" "one-api/common" + "one-api/common/requester" + "one-api/model" "one-api/providers/base" + "one-api/types" "strings" "time" - - "github.com/gin-gonic/gin" ) type XunfeiProviderFactory struct{} // 创建 XunfeiProvider -func (f XunfeiProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f XunfeiProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &XunfeiProvider{ BaseProvider: base.BaseProvider{ - BaseURL: "wss://spark-api.xf-yun.com", - ChatCompletions: "true", - Context: c, + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(channel.Proxy, nil), }, + wsRequester: requester.NewWSRequester(channel.Proxy), } } // https://www.xfyun.cn/doc/spark/Web.html type XunfeiProvider struct { base.BaseProvider - domain string - apiId string + domain string + apiId string + wsRequester *requester.WSRequester +} + +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "wss://spark-api.xf-yun.com", + ChatCompletions: "/", + } +} + +// 错误处理 +func errorHandle(xunfeiError *XunfeiChatResponse) *types.OpenAIError { + if xunfeiError.Header.Code == 0 { + return nil + } + return &types.OpenAIError{ + Message: xunfeiError.Header.Message, + Type: "xunfei_error", + Code: xunfeiError.Header.Code, + } } // 获取请求头 @@ -68,7 +90,7 @@ func (p *XunfeiProvider) getXunfeiAuthUrl(apiKey string, apiSecret string) (stri if apiVersion != "v1.1" { domain += strings.Split(apiVersion, ".")[0] } - authUrl := p.buildXunfeiAuthUrl(fmt.Sprintf("%s/%s/chat", p.BaseURL, apiVersion), apiKey, apiSecret) + authUrl := p.buildXunfeiAuthUrl(fmt.Sprintf("%s/%s/chat", p.Config.BaseURL, apiVersion), apiKey, apiSecret) return domain, authUrl } diff --git a/providers/xunfei/chat.go b/providers/xunfei/chat.go index 8a63f90e..b7966387 100644 --- a/providers/xunfei/chat.go +++ b/providers/xunfei/chat.go @@ -2,140 +2,93 @@ package xunfei import ( "encoding/json" - "fmt" + "errors" "io" "net/http" "one-api/common" - "one-api/providers/base" + "one-api/common/requester" "one-api/types" - "time" + "strings" "github.com/gorilla/websocket" ) -func (p *XunfeiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model) - - dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl) - if err != nil { - return nil, common.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError) - } - - if request.Stream { - return p.sendStreamRequest(dataChan, stopChan, request.GetFunctionCate()) - } else { - return p.sendRequest(dataChan, stopChan, request.GetFunctionCate()) - } +type xunfeiHandler struct { + Usage *types.Usage + Request *types.ChatCompletionRequest } -func (p *XunfeiProvider) sendRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - usage = &types.Usage{} - - var content string - var xunfeiResponse XunfeiChatResponse - - stop := false - for !stop { - select { - case xunfeiResponse = <-dataChan: - if len(xunfeiResponse.Payload.Choices.Text) == 0 { - continue - } - content += xunfeiResponse.Payload.Choices.Text[0].Content - usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens - usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens - usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens - case stop = <-stopChan: - } - } - - if xunfeiResponse.Header.Code != 0 { - return nil, common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError) - } - - if len(xunfeiResponse.Payload.Choices.Text) == 0 { - xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}} - } - - xunfeiResponse.Payload.Choices.Text[0].Content = content - - response := p.responseXunfei2OpenAI(&xunfeiResponse, functionCate) - jsonResponse, err := json.Marshal(response) - if err != nil { - return nil, common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) - } - p.Context.Writer.Header().Set("Content-Type", "application/json") - _, _ = p.Context.Writer.Write(jsonResponse) - return usage, nil -} - -func (p *XunfeiProvider) sendStreamRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - usage = &types.Usage{} - - // 等待第一个dataChan的响应 - xunfeiResponse, ok := <-dataChan - if !ok { - return nil, common.ErrorWrapper(fmt.Errorf("xunfei response channel closed"), "xunfei_response_error", http.StatusInternalServerError) - } - if xunfeiResponse.Header.Code != 0 { - errWithCode = common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError) +func (p *XunfeiProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + wsConn, errWithCode := p.getChatRequest(request) + if errWithCode != nil { return nil, errWithCode } - // 如果第一个响应没有错误,设置StreamHeaders并开始streaming - common.SetEventStreamHeaders(p.Context) - p.Context.Stream(func(w io.Writer) bool { - // 处理第一个响应 - usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens - usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens - usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens - response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + xunfeiRequest := p.convertFromChatOpenai(request) + + chatHandler := &xunfeiHandler{ + Usage: p.Usage, + Request: request, + } + + stream, errWithCode := requester.SendWSJsonRequest[XunfeiChatResponse](wsConn, xunfeiRequest, chatHandler.handlerNotStream) + if errWithCode != nil { + return nil, errWithCode + } + + return chatHandler.convertToChatOpenai(stream) - // 处理后续的响应 - for { - select { - case xunfeiResponse, ok := <-dataChan: - if !ok { - p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens - usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens - usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens - response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - case <-stopChan: - p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - } - }) - return usage, nil } -func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionRequest) *XunfeiChatRequest { +func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { + wsConn, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + + xunfeiRequest := p.convertFromChatOpenai(request) + + chatHandler := &xunfeiHandler{ + Usage: p.Usage, + Request: request, + } + + return requester.SendWSJsonRequest[types.ChatCompletionStreamResponse](wsConn, xunfeiRequest, chatHandler.handlerStream) +} + +func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + if errWithCode != nil { + return nil, errWithCode + } + + authUrl := p.GetFullRequestURL(url, request.Model) + + wsConn, err := p.wsRequester.NewRequest(authUrl, nil) + if err != nil { + return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError) + } + + return wsConn, nil +} + +func (p *XunfeiProvider) convertFromChatOpenai(request *types.ChatCompletionRequest) *XunfeiChatRequest { messages := make([]XunfeiMessage, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { messages = append(messages, XunfeiMessage{ - Role: "user", + Role: types.ChatMessageRoleUser, Content: message.StringContent(), }) messages = append(messages, XunfeiMessage{ - Role: "assistant", + Role: types.ChatMessageRoleAssistant, Content: "Okay", }) + } else if message.Role == types.ChatMessageRoleFunction { + messages = append(messages, XunfeiMessage{ + Role: types.ChatMessageRoleUser, + Content: "这是函数调用返回的内容,请回答之前的问题:\n" + message.StringContent(), + }) } else { messages = append(messages, XunfeiMessage{ Role: message.Role, @@ -143,6 +96,7 @@ func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionReque }) } } + xunfeiRequest := XunfeiChatRequest{} if request.Tools != nil { @@ -166,35 +120,57 @@ func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionReque return &xunfeiRequest } -func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse, functionCate string) *types.ChatCompletionResponse { - if len(response.Payload.Choices.Text) == 0 { - response.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}} +func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterface[XunfeiChatResponse]) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + var content string + var xunfeiResponse XunfeiChatResponse + + for { + response, err := stream.Recv() + + if err != nil && !errors.Is(err, io.EOF) { + return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError) + } + + if errors.Is(err, io.EOF) && response == nil { + break + } + + if len((*response)[0].Payload.Choices.Text) == 0 { + continue + } + xunfeiResponse = (*response)[0] + content += xunfeiResponse.Payload.Choices.Text[0].Content } + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}} + } + xunfeiResponse.Payload.Choices.Text[0].Content = content + choice := types.ChatCompletionChoice{ Index: 0, - FinishReason: base.StopFinishReason, + FinishReason: types.FinishReasonStop, } - xunfeiText := response.Payload.Choices.Text[0] + xunfeiText := xunfeiResponse.Payload.Choices.Text[0] if xunfeiText.FunctionCall != nil { choice.Message = types.ChatCompletionMessage{ Role: "assistant", } - if functionCate == "tool" { + if h.Request.Tools != nil { choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{ { - Id: response.Header.Sid, + Id: xunfeiResponse.Header.Sid, Type: "function", - Function: *xunfeiText.FunctionCall, + Function: xunfeiText.FunctionCall, }, } - choice.FinishReason = &base.StopFinishReasonToolFunction + choice.FinishReason = types.FinishReasonToolCalls } else { choice.Message.FunctionCall = xunfeiText.FunctionCall - choice.FinishReason = &base.StopFinishReasonCallFunction + choice.FinishReason = types.FinishReasonFunctionCall } } else { @@ -204,97 +180,128 @@ func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse, fun } } - fullTextResponse := types.ChatCompletionResponse{ - ID: response.Header.Sid, + fullTextResponse := &types.ChatCompletionResponse{ + ID: xunfeiResponse.Header.Sid, Object: "chat.completion", - Model: "SparkDesk", + Model: h.Request.Model, Created: common.GetTimestamp(), Choices: []types.ChatCompletionChoice{choice}, - Usage: &response.Payload.Usage.Text, + Usage: &xunfeiResponse.Payload.Usage.Text, } - return &fullTextResponse + + return fullTextResponse, nil } -func (p *XunfeiProvider) xunfeiMakeRequest(textRequest *types.ChatCompletionRequest, authUrl string) (chan XunfeiChatResponse, chan bool, error) { - d := websocket.Dialer{ - HandshakeTimeout: 5 * time.Second, +func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiChatResponse, error) { + // 如果rawLine 前缀不为data:,则直接返回 + if !strings.HasPrefix(string(*rawLine), "{") { + *rawLine = nil + return nil, nil } - conn, resp, err := d.Dial(authUrl, nil) - if err != nil || resp.StatusCode != 101 { - return nil, nil, err - } - data := p.requestOpenAI2Xunfei(textRequest) - err = conn.WriteJSON(data) + + var xunfeiChatResponse XunfeiChatResponse + err := json.Unmarshal(*rawLine, &xunfeiChatResponse) if err != nil { - return nil, nil, err + return nil, common.ErrorToOpenAIError(err) } - dataChan := make(chan XunfeiChatResponse) - stopChan := make(chan bool) - go func() { - for { - _, msg, err := conn.ReadMessage() - if err != nil { - common.SysError("error reading stream response: " + err.Error()) - break - } - var response XunfeiChatResponse - err = json.Unmarshal(msg, &response) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - break - } - dataChan <- response - if response.Payload.Choices.Status == 2 { - err := conn.Close() - if err != nil { - common.SysError("error closing websocket connection: " + err.Error()) - } - break - } - } - stopChan <- true - }() + error := errorHandle(&xunfeiChatResponse) + if error != nil { + return nil, error + } - return dataChan, stopChan, nil + if xunfeiChatResponse.Payload.Choices.Status == 2 { + *isFinished = true + } + + h.Usage.PromptTokens = xunfeiChatResponse.Payload.Usage.Text.PromptTokens + h.Usage.CompletionTokens = xunfeiChatResponse.Payload.Usage.Text.CompletionTokens + h.Usage.TotalTokens = xunfeiChatResponse.Payload.Usage.Text.TotalTokens + + return &xunfeiChatResponse, nil } -func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse, functionCate string) *types.ChatCompletionStreamResponse { - if len(xunfeiResponse.Payload.Choices.Text) == 0 { - xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}} +func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, isFinished *bool, response *[]XunfeiChatResponse) error { + xunfeiChatResponse, err := h.handlerData(rawLine, isFinished) + if err != nil { + return err } - var choice types.ChatCompletionStreamChoice - xunfeiText := xunfeiResponse.Payload.Choices.Text[0] + + if *rawLine == nil { + return nil + } + + *response = append(*response, *xunfeiChatResponse) + return nil +} + +func (h *xunfeiHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { + xunfeiChatResponse, err := h.handlerData(rawLine, isFinished) + if err != nil { + return err + } + + if *rawLine == nil { + return nil + } + + return h.convertToOpenaiStream(xunfeiChatResponse, response) +} + +func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, response *[]types.ChatCompletionStreamResponse) error { + if len(xunfeiChatResponse.Payload.Choices.Text) == 0 { + xunfeiChatResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}} + } + + choice := types.ChatCompletionStreamChoice{ + Index: 0, + Delta: types.ChatCompletionStreamChoiceDelta{ + Role: types.ChatMessageRoleAssistant, + }, + } + xunfeiText := xunfeiChatResponse.Payload.Choices.Text[0] if xunfeiText.FunctionCall != nil { - if functionCate == "tool" { + if h.Request.Tools != nil { choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{ { - Id: xunfeiResponse.Header.Sid, + Id: xunfeiChatResponse.Header.Sid, Index: 0, Type: "function", - Function: *xunfeiText.FunctionCall, + Function: xunfeiText.FunctionCall, }, } - choice.FinishReason = &base.StopFinishReasonToolFunction + choice.FinishReason = types.FinishReasonToolCalls } else { choice.Delta.FunctionCall = xunfeiText.FunctionCall - choice.FinishReason = &base.StopFinishReasonCallFunction + choice.FinishReason = types.FinishReasonFunctionCall } } else { - choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content - if xunfeiResponse.Payload.Choices.Status == 2 { - choice.FinishReason = &base.StopFinishReason + choice.Delta.Content = xunfeiChatResponse.Payload.Choices.Text[0].Content + if xunfeiChatResponse.Payload.Choices.Status == 2 { + choice.FinishReason = types.FinishReasonStop } } - response := types.ChatCompletionStreamResponse{ - ID: xunfeiResponse.Header.Sid, + chatCompletion := types.ChatCompletionStreamResponse{ + ID: xunfeiChatResponse.Header.Sid, Object: "chat.completion.chunk", Created: common.GetTimestamp(), - Model: "SparkDesk", - Choices: []types.ChatCompletionStreamChoice{choice}, + Model: h.Request.Model, } - return &response + + if xunfeiText.FunctionCall == nil { + chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice} + *response = append(*response, chatCompletion) + } else { + choices := choice.ConvertOpenaiStream() + for _, choice := range choices { + chatCompletionCopy := chatCompletion + chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice} + *response = append(*response, chatCompletionCopy) + } + } + + return nil } diff --git a/providers/zhipu/base.go b/providers/zhipu/base.go index 7d1d08ce..62f3d1a5 100644 --- a/providers/zhipu/base.go +++ b/providers/zhipu/base.go @@ -1,14 +1,18 @@ package zhipu import ( + "encoding/json" "fmt" + "net/http" "one-api/common" + "one-api/common/requester" + "one-api/model" "one-api/providers/base" + "one-api/types" "strings" "sync" "time" - "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" ) @@ -18,12 +22,12 @@ var expSeconds int64 = 24 * 3600 type ZhipuProviderFactory struct{} // 创建 ZhipuProvider -func (f ZhipuProviderFactory) Create(c *gin.Context) base.ProviderInterface { +func (f ZhipuProviderFactory) Create(channel *model.Channel) base.ProviderInterface { return &ZhipuProvider{ BaseProvider: base.BaseProvider{ - BaseURL: "https://open.bigmodel.cn", - ChatCompletions: "/api/paas/v3/model-api", - Context: c, + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle), }, } } @@ -32,6 +36,36 @@ type ZhipuProvider struct { base.BaseProvider } +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://open.bigmodel.cn", + ChatCompletions: "/api/paas/v3/model-api", + } +} + +// 请求错误处理 +func requestErrorHandle(resp *http.Response) *types.OpenAIError { + var zhipuError *ZhipuResponse + err := json.NewDecoder(resp.Body).Decode(zhipuError) + if err != nil { + return nil + } + + return errorHandle(zhipuError) +} + +// 错误处理 +func errorHandle(zhipuError *ZhipuResponse) *types.OpenAIError { + if zhipuError.Success { + return nil + } + return &types.OpenAIError{ + Message: zhipuError.Msg, + Type: "zhipu_error", + Code: zhipuError.Code, + } +} + // 获取请求头 func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go index 7e58c6e5..ab664106 100644 --- a/providers/zhipu/chat.go +++ b/providers/zhipu/chat.go @@ -1,38 +1,108 @@ package zhipu import ( - "bufio" "encoding/json" - "io" + "fmt" "net/http" "one-api/common" - "one-api/providers/base" + "one-api/common/requester" "one-api/types" "strings" ) -func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { - if !zhipuResponse.Success { - return &types.OpenAIErrorWithStatusCode{ - OpenAIError: types.OpenAIError{ - Message: zhipuResponse.Msg, - Type: "zhipu_error", - Param: "", - Code: zhipuResponse.Code, - }, - StatusCode: resp.StatusCode, - }, nil +type zhipuStreamHandler struct { + Usage *types.Usage + Request *types.ChatCompletionRequest +} + +func (p *ZhipuProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + zhipuChatResponse := &ZhipuResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, zhipuChatResponse, false) + if errWithCode != nil { + return nil, errWithCode } - fullTextResponse := types.ChatCompletionResponse{ - ID: zhipuResponse.Data.TaskId, + return p.convertToChatOpenai(zhipuChatResponse, request) +} + +func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode + } + + chatHandler := &zhipuStreamHandler{ + Usage: p.Usage, + Request: request, + } + + return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) +} + +func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + if errWithCode != nil { + return nil, errWithCode + } + + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, request.Model) + if fullRequestURL == "" { + return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError) + } + + // 获取请求头 + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + fullRequestURL += "/sse-invoke" + } else { + fullRequestURL += "/invoke" + } + + zhipuRequest := convertFromChatOpenai(request) + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(zhipuRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + return req, nil +} + +func (p *ZhipuProvider) convertToChatOpenai(response *ZhipuResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + error := errorHandle(response) + if error != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *error, + StatusCode: http.StatusBadRequest, + } + return + } + + openaiResponse = &types.ChatCompletionResponse{ + ID: response.Data.TaskId, Object: "chat.completion", Created: common.GetTimestamp(), - Model: zhipuResponse.Model, - Choices: make([]types.ChatCompletionChoice, 0, len(zhipuResponse.Data.Choices)), - Usage: &zhipuResponse.Data.Usage, + Model: request.Model, + Choices: make([]types.ChatCompletionChoice, 0, len(response.Data.Choices)), + Usage: &response.Data.Usage, } - for i, choice := range zhipuResponse.Data.Choices { + for i, choice := range response.Data.Choices { openaiChoice := types.ChatCompletionChoice{ Index: i, Message: types.ChatCompletionMessage{ @@ -41,16 +111,18 @@ func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAI }, FinishReason: "", } - if i == len(zhipuResponse.Data.Choices)-1 { + if i == len(response.Data.Choices)-1 { openaiChoice.FinishReason = "stop" } - fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice) + openaiResponse.Choices = append(openaiResponse.Choices, openaiChoice) } - return fullTextResponse, nil + *p.Usage = response.Data.Usage + + return } -func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest) *ZhipuRequest { +func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest { messages := make([]ZhipuMessage, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { @@ -77,158 +149,60 @@ func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest) } } -func (p *ZhipuProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody := p.getChatRequestBody(request) - fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) - headers := p.GetRequestHeaders() - if request.Stream { - headers["Accept"] = "text/event-stream" - fullRequestURL += "/sse-invoke" - } else { - fullRequestURL += "/invoke" +// 转换为OpenAI聊天流式请求体 +func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { + // 如果rawLine 前缀不为data: 或者 meta:,则直接返回 + if !strings.HasPrefix(string(*rawLine), "data:") && !strings.HasPrefix(string(*rawLine), "meta:") { + *rawLine = nil + return nil } - client := common.NewClient() - req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) - if err != nil { - return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + if strings.HasPrefix(string(*rawLine), "meta:") { + *rawLine = (*rawLine)[5:] + var zhipuStreamMetaResponse ZhipuStreamMetaResponse + err := json.Unmarshal(*rawLine, &zhipuStreamMetaResponse) + if err != nil { + return common.ErrorToOpenAIError(err) + } + *isFinished = true + return h.handlerMeta(&zhipuStreamMetaResponse, response) } - if request.Stream { - errWithCode, usage = p.sendStreamRequest(req, request.Model) - if errWithCode != nil { - return - } - - } else { - zhipuResponse := &ZhipuResponse{ - Model: request.Model, - } - errWithCode = p.SendRequest(req, zhipuResponse, false) - if errWithCode != nil { - return - } - - usage = &zhipuResponse.Data.Usage - } - return - + *rawLine = (*rawLine)[5:] + return h.convertToOpenaiStream(string(*rawLine), response) } -func (p *ZhipuProvider) streamResponseZhipu2OpenAI(zhipuResponse string) *types.ChatCompletionStreamResponse { +func (h *zhipuStreamHandler) convertToOpenaiStream(content string, response *[]types.ChatCompletionStreamResponse) error { var choice types.ChatCompletionStreamChoice - choice.Delta.Content = zhipuResponse - response := types.ChatCompletionStreamResponse{ + choice.Delta.Content = content + streamResponse := types.ChatCompletionStreamResponse{ + ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Object: "chat.completion.chunk", Created: common.GetTimestamp(), - Model: "chatglm", + Model: h.Request.Model, Choices: []types.ChatCompletionStreamChoice{choice}, } - return &response + + *response = append(*response, streamResponse) + + return nil } -func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*types.ChatCompletionStreamResponse, *types.Usage) { +func (h *zhipuStreamHandler) handlerMeta(zhipuResponse *ZhipuStreamMetaResponse, response *[]types.ChatCompletionStreamResponse) error { var choice types.ChatCompletionStreamChoice choice.Delta.Content = "" - choice.FinishReason = &base.StopFinishReason - response := types.ChatCompletionStreamResponse{ + choice.FinishReason = types.FinishReasonStop + streamResponse := types.ChatCompletionStreamResponse{ ID: zhipuResponse.RequestId, Object: "chat.completion.chunk", Created: common.GetTimestamp(), - Model: zhipuResponse.Model, + Model: h.Request.Model, Choices: []types.ChatCompletionStreamChoice{choice}, } - return &response, &zhipuResponse.Usage -} - -func (p *ZhipuProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, *types.Usage) { - defer req.Body.Close() - - // 发送请求 - client := common.GetHttpClient(p.Channel.Proxy) - resp, err := client.Do(req) - if err != nil { - return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil - } - common.PutHttpClient(client) - - if common.IsFailureStatusCode(resp) { - return common.HandleErrorResp(resp), nil - } - - defer resp.Body.Close() - - var usage *types.Usage - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Contains(string(data), ":") { - return i + 2, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - metaChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - lines := strings.Split(data, "\n") - for i, line := range lines { - if len(line) < 5 { - continue - } - if line[:5] == "data:" { - dataChan <- line[5:] - if i != len(lines)-1 { - dataChan <- "\n" - } - } else if line[:5] == "meta:" { - metaChan <- line[5:] - } - } - } - stopChan <- true - }() - common.SetEventStreamHeaders(p.Context) - p.Context.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - response := p.streamResponseZhipu2OpenAI(data) - response.Model = model - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case data := <-metaChan: - var zhipuResponse ZhipuStreamMetaResponse - err := json.Unmarshal([]byte(data), &zhipuResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - zhipuResponse.Model = model - response, zhipuUsage := p.streamMetaResponseZhipu2OpenAI(&zhipuResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - usage = zhipuUsage - p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - return nil, usage + + *response = append(*response, streamResponse) + + *h.Usage = zhipuResponse.Usage + + return nil } diff --git a/types/audio.go b/types/audio.go index ead8d436..d1a1ede9 100644 --- a/types/audio.go +++ b/types/audio.go @@ -26,3 +26,8 @@ type AudioResponse struct { Segments any `json:"segments,omitempty"` Text string `json:"text"` } + +type AudioResponseWrapper struct { + Headers map[string]string + Body []byte +} diff --git a/types/chat.go b/types/chat.go index f746dbad..63e2515a 100644 --- a/types/chat.go +++ b/types/chat.go @@ -5,16 +5,33 @@ const ( ContentTypeImageURL = "image_url" ) +const ( + FinishReasonStop = "stop" + FinishReasonLength = "length" + FinishReasonFunctionCall = "function_call" + FinishReasonToolCalls = "tool_calls" + FinishReasonContentFilter = "content_filter" + FinishReasonNull = "null" +) + +const ( + ChatMessageRoleSystem = "system" + ChatMessageRoleUser = "user" + ChatMessageRoleAssistant = "assistant" + ChatMessageRoleFunction = "function" + ChatMessageRoleTool = "tool" +) + type ChatCompletionToolCallsFunction struct { Name string `json:"name,omitempty"` - Arguments string `json:"arguments,omitempty"` + Arguments string `json:"arguments"` } type ChatCompletionToolCalls struct { - Id string `json:"id"` - Type string `json:"type"` - Function ChatCompletionToolCallsFunction `json:"function"` - Index int `json:"index"` + Id string `json:"id"` + Type string `json:"type"` + Function *ChatCompletionToolCallsFunction `json:"function"` + Index int `json:"index"` } type ChatCompletionMessage struct { @@ -123,6 +140,8 @@ type ChatCompletionRequest struct { Seed *int `json:"seed,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` LogitBias any `json:"logit_bias,omitempty"` + LogProbs bool `json:"logprobs,omitempty"` + TopLogProbs int `json:"top_logprobs,omitempty"` User string `json:"user,omitempty"` Functions []*ChatCompletionFunction `json:"functions,omitempty"` FunctionCall any `json:"function_call,omitempty"` @@ -151,19 +170,89 @@ type ChatCompletionTool struct { } type ChatCompletionChoice struct { - Index int `json:"index"` - Message ChatCompletionMessage `json:"message"` - FinishReason any `json:"finish_reason,omitempty"` + Index int `json:"index"` + Message ChatCompletionMessage `json:"message"` + FinishReason any `json:"finish_reason,omitempty"` + ContentFilterResults any `json:"content_filter_results,omitempty"` + FinishDetails any `json:"finish_details,omitempty"` } type ChatCompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionChoice `json:"choices"` - Usage *Usage `json:"usage,omitempty"` - SystemFingerprint string `json:"system_fingerprint,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + PromptFilterResults any `json:"prompt_filter_results,omitempty"` +} + +func (c ChatCompletionStreamChoice) ConvertOpenaiStream() []ChatCompletionStreamChoice { + var function *ChatCompletionToolCallsFunction + var functions []*ChatCompletionToolCallsFunction + var choices []ChatCompletionStreamChoice + var stopFinish string + if c.Delta.FunctionCall != nil { + function = c.Delta.FunctionCall + stopFinish = FinishReasonFunctionCall + } else { + function = c.Delta.ToolCalls[0].Function + stopFinish = FinishReasonToolCalls + } + + if function.Name == "" { + c.FinishReason = stopFinish + choices = append(choices, c) + return choices + } + + functions = append(functions, &ChatCompletionToolCallsFunction{ + Name: function.Name, + Arguments: "", + }) + + if function.Arguments == "" || function.Arguments == "{}" { + functions = append(functions, &ChatCompletionToolCallsFunction{ + Arguments: "{}", + }) + } else { + functions = append(functions, &ChatCompletionToolCallsFunction{ + Arguments: function.Arguments, + }) + } + + // 循环functions, 生成choices + for _, function := range functions { + choice := ChatCompletionStreamChoice{ + Index: 0, + Delta: ChatCompletionStreamChoiceDelta{ + Role: c.Delta.Role, + }, + } + if stopFinish == FinishReasonFunctionCall { + choice.Delta.FunctionCall = function + } else { + choice.Delta.ToolCalls = []*ChatCompletionToolCalls{ + { + Id: c.Delta.ToolCalls[0].Id, + Index: 0, + Type: "function", + Function: function, + }, + } + } + + choices = append(choices, choice) + } + + choices = append(choices, ChatCompletionStreamChoice{ + Index: c.Index, + Delta: ChatCompletionStreamChoiceDelta{}, + FinishReason: stopFinish, + }) + + return choices } type ChatCompletionStreamChoiceDelta struct { diff --git a/types/common.go b/types/common.go index 5573fe16..9bc02617 100644 --- a/types/common.go +++ b/types/common.go @@ -1,5 +1,7 @@ package types +import "encoding/json" + type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` @@ -14,6 +16,16 @@ type OpenAIError struct { InnerError any `json:"innererror,omitempty"` } +func (e *OpenAIError) Error() string { + response := &OpenAIErrorResponse{ + Error: *e, + } + + // 转换为JSON + bytes, _ := json.Marshal(response) + return string(bytes) +} + type OpenAIErrorWithStatusCode struct { OpenAIError StatusCode int `json:"status_code"` diff --git a/types/embeddings.go b/types/embeddings.go index eb36108e..83a6aa1b 100644 --- a/types/embeddings.go +++ b/types/embeddings.go @@ -8,9 +8,9 @@ type EmbeddingRequest struct { } type Embedding struct { - Object string `json:"object"` - Embedding []float64 `json:"embedding"` - Index int `json:"index"` + Object string `json:"object"` + Embedding any `json:"embedding"` + Index int `json:"index"` } type EmbeddingResponse struct {