♻️ refactor: provider refactor (#41)

* ♻️ refactor: provider refactor
* 完善百度/讯飞的函数调用,现在可以在`lobe-chat`中正常调用函数了
This commit is contained in:
Buer 2024-01-19 02:47:10 +08:00 committed by GitHub
parent 0bfe1f5779
commit ef041e28a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
96 changed files with 4339 additions and 3276 deletions

2
.gitignore vendored
View File

@ -8,5 +8,5 @@ build
logs
data
tmp/
test/
/test/
.env

View File

@ -1,299 +0,0 @@
package common
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"one-api/types"
"strconv"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/net/proxy"
)
var clientPool = &sync.Pool{
New: func() interface{} {
return &http.Client{}
},
}
func GetHttpClient(proxyAddr string) *http.Client {
client := clientPool.Get().(*http.Client)
if RelayTimeout > 0 {
client.Timeout = time.Duration(RelayTimeout) * time.Second
}
if proxyAddr != "" {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
SysError("Error parsing proxy address: " + err.Error())
return client
}
switch proxyURL.Scheme {
case "http", "https":
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
case "socks5":
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
if err != nil {
SysError("Error creating SOCKS5 dialer: " + err.Error())
return client
}
client.Transport = &http.Transport{
Dial: dialer.Dial,
}
default:
SysError("Unsupported proxy scheme: " + proxyURL.Scheme)
}
}
return client
}
func PutHttpClient(c *http.Client) {
clientPool.Put(c)
}
type Client struct {
requestBuilder RequestBuilder
CreateFormBuilder func(io.Writer) FormBuilder
}
func NewClient() *Client {
return &Client{
requestBuilder: NewRequestBuilder(),
CreateFormBuilder: func(body io.Writer) FormBuilder {
return NewFormBuilder(body)
},
}
}
type requestOptions struct {
body any
header http.Header
}
type requestOption func(*requestOptions)
type Stringer interface {
GetString() *string
}
func WithBody(body any) requestOption {
return func(args *requestOptions) {
args.body = body
}
}
func WithHeader(header map[string]string) requestOption {
return func(args *requestOptions) {
for k, v := range header {
args.header.Set(k, v)
}
}
}
func WithContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
}
}
type RequestError struct {
HTTPStatusCode int
Err error
}
func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
// Default Options
args := &requestOptions{
body: nil,
header: make(http.Header),
}
for _, setter := range setters {
setter(args)
}
req, err := c.requestBuilder.Build(method, url, args.body, args.header)
if err != nil {
return nil, err
}
return req, nil
}
func SendRequest(req *http.Request, response any, outputResp bool, proxyAddr string) (*http.Response, *types.OpenAIErrorWithStatusCode) {
// 发送请求
client := GetHttpClient(proxyAddr)
resp, err := client.Do(req)
if err != nil {
return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
PutHttpClient(client)
if !outputResp {
defer resp.Body.Close()
}
// 处理响应
if IsFailureStatusCode(resp) {
return nil, HandleErrorResp(resp)
}
// 解析响应
if outputResp {
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
err = DecodeResponse(tee, response)
// 将响应体重新写入 resp.Body
resp.Body = io.NopCloser(&buf)
} else {
err = DecodeResponse(resp.Body, response)
}
if err != nil {
return nil, ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
if outputResp {
return resp, nil
}
return nil, nil
}
type GeneralErrorResponse struct {
Error types.OpenAIError `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
Response struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} `json:"response"`
}
func (e GeneralErrorResponse) ToMessage() string {
if e.Error.Message != "" {
return e.Error.Message
}
if e.Message != "" {
return e.Message
}
if e.Msg != "" {
return e.Msg
}
if e.Err != "" {
return e.Err
}
if e.ErrorMsg != "" {
return e.ErrorMsg
}
if e.Header.Message != "" {
return e.Header.Message
}
if e.Response.Error.Message != "" {
return e.Response.Error.Message
}
return ""
}
// 处理错误响应
func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: types.OpenAIError{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
// var errorResponse types.OpenAIErrorResponse
var errorResponse GeneralErrorResponse
err = json.Unmarshal(responseBody, &errorResponse)
if err != nil {
return
}
if errorResponse.Error.Message != "" {
// OpenAI format error, so we override the default one
openAIErrorWithStatusCode.OpenAIError = errorResponse.Error
} else {
openAIErrorWithStatusCode.OpenAIError.Message = errorResponse.ToMessage()
}
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return
}
func (c *Client) SendRequestRaw(req *http.Request, proxyAddr string) (body io.ReadCloser, err error) {
client := GetHttpClient(proxyAddr)
resp, err := client.Do(req)
PutHttpClient(client)
if err != nil {
return
}
return resp.Body, nil
}
func IsFailureStatusCode(resp *http.Response) bool {
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
}
func DecodeResponse(body io.Reader, v any) error {
if v == nil {
return nil
}
if result, ok := v.(*string); ok {
return DecodeString(body, result)
}
if stringer, ok := v.(Stringer); ok {
return DecodeString(body, stringer.GetString())
}
return json.NewDecoder(body).Decode(v)
}
func DecodeString(body io.Reader, output *string) error {
b, err := io.ReadAll(body)
if err != nil {
return err
}
*output = string(b)
return nil
}
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}

View File

@ -37,6 +37,14 @@ func ErrorWrapper(err error, code string, statusCode int) *types.OpenAIErrorWith
return StringErrorWrapper(err.Error(), code, statusCode)
}
func ErrorToOpenAIError(err error) *types.OpenAIError {
return &types.OpenAIError{
Code: "system error",
Message: err.Error(),
Type: "one_api_error",
}
}
func StringErrorWrapper(err string, code string, statusCode int) *types.OpenAIErrorWithStatusCode {
openAIError := types.OpenAIError{
Message: err,

View File

@ -1,59 +0,0 @@
package common
// type Quota struct {
// ModelName string
// ModelRatio float64
// GroupRatio float64
// Ratio float64
// UserQuota int
// }
// func CreateQuota(modelName string, userQuota int, group string) *Quota {
// modelRatio := GetModelRatio(modelName)
// groupRatio := GetGroupRatio(group)
// return &Quota{
// ModelName: modelName,
// ModelRatio: modelRatio,
// GroupRatio: groupRatio,
// Ratio: modelRatio * groupRatio,
// UserQuota: userQuota,
// }
// }
// func (q *Quota) getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
// if ApproximateTokenEnabled {
// return int(float64(len(text)) * 0.38)
// }
// return len(tokenEncoder.Encode(text, nil, nil))
// }
// func (q *Quota) CountTokenMessages(messages []Message, model string) int {
// tokenEncoder := q.getTokenEncoder(model)
// // Reference:
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// // https://github.com/pkoukk/tiktoken-go/issues/6
// //
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
// var tokensPerMessage int
// var tokensPerName int
// if model == "gpt-3.5-turbo-0301" {
// tokensPerMessage = 4
// tokensPerName = -1 // If there's a name, the role is omitted
// } else {
// tokensPerMessage = 3
// tokensPerName = 1
// }
// tokenNum := 0
// for _, message := range messages {
// tokenNum += tokensPerMessage
// tokenNum += q.getTokenNum(tokenEncoder, message.StringContent())
// tokenNum += q.getTokenNum(tokenEncoder, message.Role)
// if message.Name != nil {
// tokenNum += tokensPerName
// tokenNum += q.getTokenNum(tokenEncoder, *message.Name)
// }
// }
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
// return tokenNum
// }

View File

@ -1,4 +1,4 @@
package common
package requester
import (
"fmt"

View File

@ -0,0 +1,68 @@
package requester
import (
"fmt"
"net/http"
"net/url"
"one-api/common"
"sync"
"time"
"golang.org/x/net/proxy"
)
type HTTPClient struct{}
var clientPool = &sync.Pool{
New: func() interface{} {
return &http.Client{}
},
}
func (h *HTTPClient) getClientFromPool(proxyAddr string) *http.Client {
client := clientPool.Get().(*http.Client)
if common.RelayTimeout > 0 {
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
}
if proxyAddr != "" {
err := h.setProxy(client, proxyAddr)
if err != nil {
common.SysError(err.Error())
return client
}
}
return client
}
func (h *HTTPClient) returnClientToPool(client *http.Client) {
clientPool.Put(client)
}
func (h *HTTPClient) setProxy(client *http.Client, proxyAddr string) error {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return fmt.Errorf("error parsing proxy address: %w", err)
}
switch proxyURL.Scheme {
case "http", "https":
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
case "socks5":
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
if err != nil {
return fmt.Errorf("error creating socks5 dialer: %w", err)
}
client.Transport = &http.Transport{
Dial: dialer.Dial,
}
default:
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
}
return nil
}

View File

@ -0,0 +1,229 @@
package requester
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/types"
"strconv"
"github.com/gin-gonic/gin"
)
type HttpErrorHandler func(*http.Response) *types.OpenAIError
type HTTPRequester struct {
HTTPClient HTTPClient
requestBuilder RequestBuilder
CreateFormBuilder func(io.Writer) FormBuilder
ErrorHandler HttpErrorHandler
proxyAddr string
}
// NewHTTPRequester 创建一个新的 HTTPRequester 实例。
// proxyAddr: 是代理服务器的地址。
// errorHandler: 是一个错误处理函数,它接收一个 *http.Response 参数并返回一个 *types.OpenAIErrorResponse。
// 如果 errorHandler 为 nil那么会使用一个默认的错误处理函数。
func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequester {
return &HTTPRequester{
HTTPClient: HTTPClient{},
requestBuilder: NewRequestBuilder(),
CreateFormBuilder: func(body io.Writer) FormBuilder {
return NewFormBuilder(body)
},
ErrorHandler: errorHandler,
proxyAddr: proxyAddr,
}
}
type requestOptions struct {
body any
header http.Header
}
type requestOption func(*requestOptions)
// 创建请求
func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
args := &requestOptions{
body: nil,
header: make(http.Header),
}
for _, setter := range setters {
setter(args)
}
req, err := r.requestBuilder.Build(method, url, args.body, args.header)
if err != nil {
return nil, err
}
return req, nil
}
// 发送请求
func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
client := r.HTTPClient.getClientFromPool(r.proxyAddr)
resp, err := client.Do(req)
r.HTTPClient.returnClientToPool(client)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
if !outputResp {
defer resp.Body.Close()
}
// 处理响应
if r.IsFailureStatusCode(resp) {
return nil, HandleErrorResp(resp, r.ErrorHandler)
}
// 解析响应
if outputResp {
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
err = DecodeResponse(tee, response)
// 将响应体重新写入 resp.Body
resp.Body = io.NopCloser(&buf)
} else {
err = json.NewDecoder(resp.Body).Decode(response)
}
if err != nil {
return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
return resp, nil
}
// 发送请求 RAW
func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *types.OpenAIErrorWithStatusCode) {
// 发送请求
client := r.HTTPClient.getClientFromPool(r.proxyAddr)
resp, err := client.Do(req)
r.HTTPClient.returnClientToPool(client)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
// 处理响应
if r.IsFailureStatusCode(resp) {
return nil, HandleErrorResp(resp, r.ErrorHandler)
}
return resp, nil
}
// 获取流式响应
func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response, handlerPrefix HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) {
// 如果返回的头是json格式 说明有错误
if resp.Header.Get("Content-Type") == "application/json" {
return nil, HandleErrorResp(resp, requester.ErrorHandler)
}
return &streamReader[T]{
reader: bufio.NewReader(resp.Body),
response: resp,
handlerPrefix: handlerPrefix,
}, nil
}
// 设置请求体
func (r *HTTPRequester) WithBody(body any) requestOption {
return func(args *requestOptions) {
args.body = body
}
}
// 设置请求头
func (r *HTTPRequester) WithHeader(header map[string]string) requestOption {
return func(args *requestOptions) {
for k, v := range header {
args.header.Set(k, v)
}
}
}
// 设置Content-Type
func (r *HTTPRequester) WithContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
}
}
// 判断是否为失败状态码
func (r *HTTPRequester) IsFailureStatusCode(resp *http.Response) bool {
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
}
// 处理错误响应
func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler) *types.OpenAIErrorWithStatusCode {
openAIErrorWithStatusCode := &types.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: types.OpenAIError{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
defer resp.Body.Close()
if toOpenAIError != nil {
errorResponse := toOpenAIError(resp)
if errorResponse != nil && errorResponse.Message != "" {
openAIErrorWithStatusCode.OpenAIError = *errorResponse
}
}
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return openAIErrorWithStatusCode
}
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
type Stringer interface {
GetString() *string
}
func DecodeResponse(body io.Reader, v any) error {
if v == nil {
return nil
}
if result, ok := v.(*string); ok {
return DecodeString(body, result)
}
if stringer, ok := v.(Stringer); ok {
return DecodeString(body, stringer.GetString())
}
return json.NewDecoder(body).Decode(v)
}
func DecodeString(body io.Reader, output *string) error {
b, err := io.ReadAll(body)
if err != nil {
return err
}
*output = string(b)
return nil
}

View File

@ -0,0 +1,79 @@
package requester
import (
"bufio"
"bytes"
"io"
"net/http"
)
// 流处理函数,判断依据如下:
// 1.如果有错误信息,则直接返回错误信息
// 2.如果isFinished=true则返回io.EOF并且如果response不为空还将返回response
// 3.如果rawLine=nil 或者 response长度为0则直接跳过
// 4.如果以上条件都不满足则返回response
type HandlerPrefix[T streamable] func(rawLine *[]byte, isFinished *bool, response *[]T) error
type streamable interface {
// types.ChatCompletionStreamResponse | types.CompletionResponse
any
}
type StreamReaderInterface[T streamable] interface {
Recv() (*[]T, error)
Close()
}
type streamReader[T streamable] struct {
isFinished bool
reader *bufio.Reader
response *http.Response
handlerPrefix HandlerPrefix[T]
}
func (stream *streamReader[T]) Recv() (response *[]T, err error) {
if stream.isFinished {
err = io.EOF
return
}
response, err = stream.processLines()
return
}
//nolint:gocognit
func (stream *streamReader[T]) processLines() (*[]T, error) {
for {
rawLine, readErr := stream.reader.ReadBytes('\n')
if readErr != nil {
return nil, readErr
}
noSpaceLine := bytes.TrimSpace(rawLine)
var response []T
err := stream.handlerPrefix(&noSpaceLine, &stream.isFinished, &response)
if err != nil {
return nil, err
}
if stream.isFinished {
if len(response) > 0 {
return &response, io.EOF
}
return nil, io.EOF
}
if noSpaceLine == nil || len(response) == 0 {
continue
}
return &response, nil
}
}
func (stream *streamReader[T]) Close() {
stream.response.Body.Close()
}

View File

@ -1,9 +1,10 @@
package common
package requester
import (
"bytes"
"io"
"net/http"
"one-api/common"
)
type RequestBuilder interface {
@ -11,12 +12,12 @@ type RequestBuilder interface {
}
type HTTPRequestBuilder struct {
marshaller Marshaller
marshaller common.Marshaller
}
func NewRequestBuilder() *HTTPRequestBuilder {
return &HTTPRequestBuilder{
marshaller: &JSONMarshaller{},
marshaller: &common.JSONMarshaller{},
}
}

View File

@ -0,0 +1,53 @@
package requester
import (
"fmt"
"net"
"net/http"
"net/url"
"one-api/common"
"time"
"github.com/gorilla/websocket"
"golang.org/x/net/proxy"
)
func GetWSClient(proxyAddr string) *websocket.Dialer {
dialer := &websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
if proxyAddr != "" {
err := setWSProxy(dialer, proxyAddr)
if err != nil {
common.SysError(err.Error())
return dialer
}
}
return dialer
}
func setWSProxy(dialer *websocket.Dialer, proxyAddr string) error {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return fmt.Errorf("error parsing proxy address: %w", err)
}
switch proxyURL.Scheme {
case "http", "https":
dialer.Proxy = http.ProxyURL(proxyURL)
case "socks5":
socks5Proxy, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
if err != nil {
return fmt.Errorf("error creating socks5 dialer: %w", err)
}
dialer.NetDial = func(network, addr string) (net.Conn, error) {
return socks5Proxy.Dial(network, addr)
}
default:
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
}
return nil
}

View File

@ -0,0 +1,58 @@
package requester
import (
"io"
"github.com/gorilla/websocket"
)
type wsReader[T streamable] struct {
isFinished bool
reader *websocket.Conn
handlerPrefix HandlerPrefix[T]
}
func (stream *wsReader[T]) Recv() (response *[]T, err error) {
if stream.isFinished {
err = io.EOF
return
}
response, err = stream.processLines()
return
}
func (stream *wsReader[T]) processLines() (*[]T, error) {
for {
_, msg, err := stream.reader.ReadMessage()
if err != nil {
return nil, err
}
var response []T
err = stream.handlerPrefix(&msg, &stream.isFinished, &response)
if err != nil {
return nil, err
}
if stream.isFinished {
if len(response) > 0 {
return &response, io.EOF
}
return nil, io.EOF
}
if msg == nil || len(response) == 0 {
continue
}
return &response, nil
}
}
func (stream *wsReader[T]) Close() {
stream.reader.Close()
}

View File

@ -0,0 +1,54 @@
package requester
import (
"errors"
"net/http"
"one-api/common"
"one-api/types"
"github.com/gorilla/websocket"
)
type WSRequester struct {
WSClient *websocket.Dialer
}
func NewWSRequester(proxyAddr string) *WSRequester {
return &WSRequester{
WSClient: GetWSClient(proxyAddr),
}
}
func (w *WSRequester) NewRequest(url string, header http.Header) (*websocket.Conn, error) {
conn, resp, err := w.WSClient.Dial(url, header)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, errors.New("ws unexpected status code")
}
return conn, nil
}
func SendWSJsonRequest[T streamable](conn *websocket.Conn, data any, handlerPrefix HandlerPrefix[T]) (*wsReader[T], *types.OpenAIErrorWithStatusCode) {
err := conn.WriteJSON(data)
if err != nil {
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
}
return &wsReader[T]{
reader: conn,
handlerPrefix: handlerPrefix,
}, nil
}
// 设置请求头
func (r *WSRequester) WithHeader(headers map[string]string) http.Header {
header := make(http.Header)
for k, v := range headers {
header.Set(k, v)
}
return header
}

55
common/test/api.go Normal file
View File

@ -0,0 +1,55 @@
package test
import (
"io"
"net/http"
"net/http/httptest"
"one-api/model"
"github.com/gin-gonic/gin"
)
func RequestJSONConfig() map[string]string {
return map[string]string{
"Content-Type": "application/json",
"Accept": "application/json",
}
}
func GetContext(method, path string, headers map[string]string, body io.Reader) (*gin.Context, *httptest.ResponseRecorder) {
var req *http.Request
req, _ = http.NewRequest(method, path, body)
for k, v := range headers {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
return c, w
}
func GetGinRouter(method, path string, headers map[string]string, body *io.Reader) *httptest.ResponseRecorder {
var req *http.Request
r := gin.Default()
w := httptest.NewRecorder()
req, _ = http.NewRequest(method, path, *body)
for k, v := range headers {
req.Header.Set(k, v)
}
r.ServeHTTP(w, req)
return w
}
func GetChannel(channelType int, baseUrl, other, porxy, modelMapping string) model.Channel {
return model.Channel{
Type: channelType,
BaseURL: &baseUrl,
Other: other,
Proxy: porxy,
ModelMapping: &modelMapping,
Key: GetTestToken(),
}
}

132
common/test/chat_config.go Normal file
View File

@ -0,0 +1,132 @@
package test
import (
"encoding/json"
"one-api/types"
"strings"
)
func GetChatCompletionRequest(chatType, modelName, stream string) *types.ChatCompletionRequest {
chatJSON := GetChatRequest(chatType, modelName, stream)
chatCompletionRequest := &types.ChatCompletionRequest{}
json.NewDecoder(chatJSON).Decode(chatCompletionRequest)
return chatCompletionRequest
}
func GetChatRequest(chatType, modelName, stream string) *strings.Reader {
var chatJSON string
switch chatType {
case "image":
chatJSON = `{
"model": "` + modelName + `",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
}
]
}
],
"max_tokens": 300,
"stream": ` + stream + `
}`
case "default":
chatJSON = `{
"model": "` + modelName + `",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello!"
}
],
"stream": ` + stream + `
}`
case "function":
chatJSON = `{
"model": "` + modelName + `",
"stream": ` + stream + `,
"messages": [
{
"role": "user",
"content": "What is the weather like in Boston?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
}
],
"tool_choice": "auto"
}`
case "tools":
chatJSON = `{
"model": "` + modelName + `",
"stream": ` + stream + `,
"messages": [
{
"role": "user",
"content": "What is the weather like in Boston?"
}
],
"functions": [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
]
}
},
"required": [
"location"
]
}
}
]
}`
}
return strings.NewReader(chatJSON)
}

65
common/test/check_chat.go Normal file
View File

@ -0,0 +1,65 @@
package test
import (
"one-api/types"
"testing"
"github.com/stretchr/testify/assert"
)
func CheckChat(t *testing.T, response *types.ChatCompletionResponse, modelName string, usage *types.Usage) {
assert.NotEmpty(t, response.ID)
assert.NotEmpty(t, response.Object)
assert.NotEmpty(t, response.Created)
assert.Equal(t, response.Model, modelName)
assert.IsType(t, []types.ChatCompletionChoice{}, response.Choices)
// check choices 长度大于1
assert.True(t, len(response.Choices) > 0)
for _, choice := range response.Choices {
assert.NotNil(t, choice.Index)
assert.IsType(t, types.ChatCompletionMessage{}, choice.Message)
assert.NotEmpty(t, choice.Message.Role)
assert.NotEmpty(t, choice.FinishReason)
// check message
if choice.Message.Content != nil {
multiContents, ok := choice.Message.Content.([]types.ChatMessagePart)
if ok {
for _, content := range multiContents {
assert.NotEmpty(t, content.Type)
if content.Type == "text" {
assert.NotEmpty(t, content.Text)
} else if content.Type == "image_url" {
assert.IsType(t, types.ChatMessageImageURL{}, content.ImageURL)
}
}
} else {
content, ok := choice.Message.Content.(string)
assert.True(t, ok)
assert.NotEmpty(t, content)
}
} else if choice.Message.FunctionCall != nil {
assert.NotEmpty(t, choice.Message.FunctionCall.Name)
assert.Equal(t, choice.FinishReason, types.FinishReasonFunctionCall)
} else if choice.Message.ToolCalls != nil {
assert.IsType(t, []types.ChatCompletionToolCalls{}, choice.Message.ToolCalls)
assert.NotEmpty(t, choice.Message.ToolCalls[0].Id)
assert.NotEmpty(t, choice.Message.ToolCalls[0].Function)
assert.Equal(t, choice.Message.ToolCalls[0].Function, "function")
assert.IsType(t, types.ChatCompletionToolCallsFunction{}, choice.Message.ToolCalls[0].Function)
assert.NotEmpty(t, choice.Message.ToolCalls[0].Function.Name)
assert.Equal(t, choice.FinishReason, types.FinishReasonToolCalls)
} else {
assert.Fail(t, "message content is nil")
}
}
// check usage
assert.IsType(t, &types.Usage{}, response.Usage)
assert.Equal(t, response.Usage.PromptTokens, usage.PromptTokens)
assert.Equal(t, response.Usage.CompletionTokens, usage.CompletionTokens)
assert.Equal(t, response.Usage.TotalTokens, usage.TotalTokens)
}

48
common/test/checks.go Normal file
View File

@ -0,0 +1,48 @@
package test
import (
"errors"
"testing"
)
func NoError(t *testing.T, err error, message ...string) {
t.Helper()
if err != nil {
t.Error(err, message)
}
}
func HasError(t *testing.T, err error, message ...string) {
t.Helper()
if err == nil {
t.Error(err, message)
}
}
func ErrorIs(t *testing.T, err, target error, msg ...string) {
t.Helper()
if !errors.Is(err, target) {
t.Fatal(msg)
}
}
func ErrorIsF(t *testing.T, err, target error, format string, msg ...string) {
t.Helper()
if !errors.Is(err, target) {
t.Fatalf(format, msg)
}
}
func ErrorIsNot(t *testing.T, err, target error, msg ...string) {
t.Helper()
if errors.Is(err, target) {
t.Fatal(msg)
}
}
func ErrorIsNotf(t *testing.T, err, target error, format string, msg ...string) {
t.Helper()
if errors.Is(err, target) {
t.Fatalf(format, msg)
}
}

7
common/test/init/init.go Normal file
View File

@ -0,0 +1,7 @@
package init
import "testing"
func init() {
testing.Init()
}

63
common/test/server.go Normal file
View File

@ -0,0 +1,63 @@
package test
import (
"log"
"net/http"
"net/http/httptest"
"regexp"
"strings"
)
const testAPI = "this-is-my-secure-token-do-not-steal!!"
func GetTestToken() string {
return testAPI
}
type ServerTest struct {
handlers map[string]handler
}
type handler func(w http.ResponseWriter, r *http.Request)
func NewTestServer() *ServerTest {
return &ServerTest{handlers: make(map[string]handler)}
}
func OpenAICheck(w http.ResponseWriter, r *http.Request) bool {
if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && r.Header.Get("api-key") != GetTestToken() {
w.WriteHeader(http.StatusUnauthorized)
return false
}
return true
}
func (ts *ServerTest) RegisterHandler(path string, handler handler) {
// to make the registered paths friendlier to a regex match in the route handler
// in OpenAITestServer
path = strings.ReplaceAll(path, "*", ".*")
ts.handlers[path] = handler
}
// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
func (ts *ServerTest) TestServer(headerCheck func(w http.ResponseWriter, r *http.Request) bool) *httptest.Server {
return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("received a %s request at path %q\n", r.Method, r.URL.Path)
// check auth
if headerCheck != nil && !headerCheck(w, r) {
return
}
// Handle /path/* routes.
// Note: the * is converted to a .* in register handler for proper regex handling
for route, handler := range ts.handlers {
// Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered
pattern, _ := regexp.Compile("^" + route + "$")
if pattern.MatchString(r.URL.Path) {
handler(w, r)
return
}
}
http.Error(w, "the resource path doesn't exist", http.StatusNotFound)
}))
}

View File

@ -67,7 +67,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return 0, errors.New("provider not implemented")
}
return balanceProvider.Balance(channel)
return balanceProvider.Balance()
}

View File

@ -1,6 +1,7 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
@ -38,30 +39,36 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
if provider == nil {
return errors.New("channel not implemented"), nil
}
newModelName, err := provider.ModelMappingHandler(request.Model)
if err != nil {
return err, nil
}
request.Model = newModelName
chatProvider, ok := provider.(providers_base.ChatInterface)
if !ok {
return errors.New("channel not implemented"), nil
}
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
return err, nil
}
if modelMap != nil && modelMap[request.Model] != "" {
request.Model = modelMap[request.Model]
}
chatProvider.SetUsage(&types.Usage{})
response, openAIErrorWithStatusCode := chatProvider.CreateChatCompletion(&request)
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
if openAIErrorWithStatusCode != nil {
return errors.New(openAIErrorWithStatusCode.Message), &openAIErrorWithStatusCode.OpenAIError
}
if Usage.CompletionTokens == 0 {
usage := chatProvider.GetUsage()
if usage.CompletionTokens == 0 {
return fmt.Errorf("channel %s, message 补全 tokens 非预期返回 0", channel.Name), nil
}
common.SysLog(fmt.Sprintf("测试模型 %s 返回内容为:%s", channel.Name, w.Body.String()))
// 转换为JSON字符串
jsonBytes, _ := json.Marshal(response)
common.SysLog(fmt.Sprintf("测试模型 %s 返回内容为:%s", channel.Name, string(jsonBytes)))
return nil, nil
}
@ -74,9 +81,9 @@ func buildTestRequest() *types.ChatCompletionRequest {
Content: "You just need to output 'hi' next.",
},
},
Model: "",
MaxTokens: 1,
Stream: false,
Model: "",
// MaxTokens: 1,
Stream: false,
}
return testRequest
}

164
controller/quota.go Normal file
View File

@ -0,0 +1,164 @@
package controller
import (
"context"
"errors"
"fmt"
"math"
"net/http"
"one-api/common"
"one-api/model"
"one-api/types"
"time"
"github.com/gin-gonic/gin"
)
type QuotaInfo struct {
modelName string
promptTokens int
preConsumedTokens int
modelRatio float64
groupRatio float64
ratio float64
preConsumedQuota int
userId int
channelId int
tokenId int
HandelStatus bool
}
func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) {
quotaInfo := &QuotaInfo{
modelName: modelName,
promptTokens: promptTokens,
userId: c.GetInt("id"),
channelId: c.GetInt("channel_id"),
tokenId: c.GetInt("token_id"),
HandelStatus: false,
}
quotaInfo.initQuotaInfo(c.GetString("group"))
errWithCode := quotaInfo.preQuotaConsumption()
if errWithCode != nil {
return nil, errWithCode
}
return quotaInfo, nil
}
func (q *QuotaInfo) initQuotaInfo(groupName string) {
modelRatio := common.GetModelRatio(q.modelName)
groupRatio := common.GetGroupRatio(groupName)
preConsumedTokens := common.PreConsumedQuota
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio)
q.preConsumedTokens = preConsumedTokens
q.modelRatio = modelRatio
q.groupRatio = groupRatio
q.ratio = ratio
q.preConsumedQuota = preConsumedQuota
}
func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
userQuota, err := model.CacheGetUserQuota(q.userId)
if err != nil {
return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota < q.preConsumedQuota {
return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*q.preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
q.preConsumedQuota = 0
// common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if q.preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
q.HandelStatus = true
}
return nil
}
func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
quota := 0
completionRatio := common.GetCompletionRatio(q.modelName)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * q.ratio))
if q.ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - q.preConsumedQuota
err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(q.userId)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
requestTime := 0
requestStartTimeValue := ctx.Value("requestStartTime")
if requestStartTimeValue != nil {
requestStartTime, ok := requestStartTimeValue.(time.Time)
if ok {
requestTime = int(time.Since(requestStartTime).Milliseconds())
}
}
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", q.modelRatio, q.groupRatio)
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime)
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
model.UpdateChannelUsedQuota(q.channelId, quota)
}
return nil
}
func (q *QuotaInfo) undo(c *gin.Context, errWithCode *types.OpenAIErrorWithStatusCode) {
tokenId := c.GetInt("token_id")
if q.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -q.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
}
func (q *QuotaInfo) consume(c *gin.Context, usage *types.Usage) {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err := q.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}

View File

@ -1,11 +1,11 @@
package controller
import (
"context"
"fmt"
"math"
"net/http"
"one-api/common"
"one-api/model"
"one-api/common/requester"
providersBase "one-api/providers/base"
"one-api/types"
@ -20,33 +20,18 @@ func RelayChat(c *gin.Context) {
return
}
channel, pass := fetchChannel(c, chatRequest.Model)
if pass {
return
}
if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 {
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[chatRequest.Model] != "" {
chatRequest.Model = modelMap[chatRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeChatCompletions)
if pass {
provider, modelName, fail := getProvider(c, chatRequest.Model)
if fail {
return
}
chatRequest.Model = modelName
chatProvider, ok := provider.(providersBase.ChatInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@ -56,39 +41,42 @@ func RelayChat(c *gin.Context) {
// 获取Input Tokens
promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, chatRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, chatRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = chatProvider.ChatAction(&chatRequest, isModelMapped, promptTokens)
if chatRequest.Stream {
var response requester.StreamReaderInterface[types.ChatCompletionStreamResponse]
response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseStreamClient[types.ChatCompletionStreamResponse](c, response)
} else {
var response *types.ChatCompletionResponse
response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
}
fmt.Println(usage)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@ -1,11 +1,9 @@
package controller
import (
"context"
"math"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@ -20,33 +18,18 @@ func RelayCompletions(c *gin.Context) {
return
}
channel, pass := fetchChannel(c, completionRequest.Model)
if pass {
return
}
if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 {
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[completionRequest.Model] != "" {
completionRequest.Model = modelMap[completionRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeCompletions)
if pass {
provider, modelName, fail := getProvider(c, completionRequest.Model)
if fail {
return
}
completionRequest.Model = modelName
completionProvider, ok := provider.(providersBase.CompletionInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@ -56,39 +39,38 @@ func RelayCompletions(c *gin.Context) {
// 获取Input Tokens
promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, completionRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, completionRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = completionProvider.CompleteAction(&completionRequest, isModelMapped, promptTokens)
if completionRequest.Stream {
response, errWithCode := completionProvider.CreateCompletionStream(&completionRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseStreamClient[types.CompletionResponse](c, response)
} else {
response, errWithCode := completionProvider.CreateCompletion(&completionRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
}
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
"strings"
@ -24,28 +22,13 @@ func RelayEmbeddings(c *gin.Context) {
return
}
channel, pass := fetchChannel(c, embeddingsRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeEmbeddings)
if pass {
provider, modelName, fail := getProvider(c, embeddingsRequest.Model)
if fail {
return
}
embeddingsRequest.Model = modelName
embeddingsProvider, ok := provider.(providersBase.EmbeddingsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@ -55,39 +38,29 @@ func RelayEmbeddings(c *gin.Context) {
// 获取Input Tokens
promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, embeddingsRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, embeddingsRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens)
response, errWithCode := embeddingsProvider.CreateEmbeddings(&embeddingsRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@ -33,28 +31,13 @@ func RelayImageEdits(c *gin.Context) {
imageEditRequest.Size = "1024x1024"
}
channel, pass := fetchChannel(c, imageEditRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[imageEditRequest.Model] != "" {
imageEditRequest.Model = modelMap[imageEditRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeImagesEdits)
if pass {
provider, modelName, fail := getProvider(c, imageEditRequest.Model)
if fail {
return
}
imageEditRequest.Model = modelName
imageEditsProvider, ok := provider.(providersBase.ImageEditsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@ -68,39 +51,29 @@ func RelayImageEdits(c *gin.Context) {
return
}
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = imageEditsProvider.ImageEditsAction(&imageEditRequest, isModelMapped, promptTokens)
response, errWithCode := imageEditsProvider.CreateImageEdits(&imageEditRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@ -36,28 +34,13 @@ func RelayImageGenerations(c *gin.Context) {
imageRequest.Quality = "standard"
}
channel, pass := fetchChannel(c, imageRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeImagesGenerations)
if pass {
provider, modelName, fail := getProvider(c, imageRequest.Model)
if fail {
return
}
imageRequest.Model = modelName
imageGenerationsProvider, ok := provider.(providersBase.ImageGenerationsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@ -71,39 +54,29 @@ func RelayImageGenerations(c *gin.Context) {
return
}
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, imageRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, imageRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = imageGenerationsProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens)
response, errWithCode := imageGenerationsProvider.CreateImageGenerations(&imageRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@ -28,28 +26,13 @@ func RelayImageVariations(c *gin.Context) {
imageEditRequest.Size = "1024x1024"
}
channel, pass := fetchChannel(c, imageEditRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[imageEditRequest.Model] != "" {
imageEditRequest.Model = modelMap[imageEditRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeImagesVariations)
if pass {
provider, modelName, fail := getProvider(c, imageEditRequest.Model)
if fail {
return
}
imageEditRequest.Model = modelName
imageVariations, ok := provider.(providersBase.ImageVariationsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@ -63,39 +46,29 @@ func RelayImageVariations(c *gin.Context) {
return
}
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = imageVariations.ImageVariationsAction(&imageEditRequest, isModelMapped, promptTokens)
response, errWithCode := imageVariations.CreateImageVariations(&imageEditRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@ -24,28 +22,13 @@ func RelayModerations(c *gin.Context) {
moderationRequest.Model = "text-moderation-stable"
}
channel, pass := fetchChannel(c, moderationRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[moderationRequest.Model] != "" {
moderationRequest.Model = modelMap[moderationRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeModerations)
if pass {
provider, modelName, fail := getProvider(c, moderationRequest.Model)
if fail {
return
}
moderationRequest.Model = modelName
moderationProvider, ok := provider.(providersBase.ModerationInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@ -55,39 +38,29 @@ func RelayModerations(c *gin.Context) {
// 获取Input Tokens
promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, moderationRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, moderationRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = moderationProvider.ModerationAction(&moderationRequest, isModelMapped, promptTokens)
response, errWithCode := moderationProvider.CreateModeration(&moderationRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@ -20,28 +18,13 @@ func RelaySpeech(c *gin.Context) {
return
}
channel, pass := fetchChannel(c, speechRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[speechRequest.Model] != "" {
speechRequest.Model = modelMap[speechRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeAudioSpeech)
if pass {
provider, modelName, fail := getProvider(c, speechRequest.Model)
if fail {
return
}
speechRequest.Model = modelName
speechProvider, ok := provider.(providersBase.SpeechInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@ -51,39 +34,29 @@ func RelaySpeech(c *gin.Context) {
// 获取Input Tokens
promptTokens := len(speechRequest.Input)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, speechRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, speechRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = speechProvider.SpeechAction(&speechRequest, isModelMapped, promptTokens)
response, errWithCode := speechProvider.CreateSpeech(&speechRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseMultipart(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@ -20,28 +18,13 @@ func RelayTranscriptions(c *gin.Context) {
return
}
channel, pass := fetchChannel(c, audioRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeAudioTranscription)
if pass {
provider, modelName, fail := getProvider(c, audioRequest.Model)
if fail {
return
}
audioRequest.Model = modelName
transcriptionsProvider, ok := provider.(providersBase.TranscriptionsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@ -51,39 +34,29 @@ func RelayTranscriptions(c *gin.Context) {
// 获取Input Tokens
promptTokens := 0
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = transcriptionsProvider.TranscriptionsAction(&audioRequest, isModelMapped, promptTokens)
response, errWithCode := transcriptionsProvider.CreateTranscriptions(&audioRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseCustom(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@ -20,28 +18,13 @@ func RelayTranslations(c *gin.Context) {
return
}
channel, pass := fetchChannel(c, audioRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeAudioTranslation)
if pass {
provider, modelName, fail := getProvider(c, audioRequest.Model)
if fail {
return
}
audioRequest.Model = modelName
translationProvider, ok := provider.(providersBase.TranslationInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@ -51,39 +34,29 @@ func RelayTranslations(c *gin.Context) {
// 获取Input Tokens
promptTokens := 0
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = translationProvider.TranslationAction(&audioRequest, isModelMapped, promptTokens)
response, errWithCode := translationProvider.CreateTranslation(&audioRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseCustom(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@ -1,25 +1,47 @@
package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"math"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/model"
"one-api/providers"
providersBase "one-api/providers/base"
"one-api/types"
"reflect"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
)
func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail bool) {
channel, fail := fetchChannel(c, modeName)
if fail {
return
}
provider = providers.GetProvider(channel, c)
if provider == nil {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
fail = true
return
}
newModelName, err := provider.ModelMappingHandler(modeName)
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
fail = true
return
}
return
}
func GetValidFieldName(err error, obj interface{}) string {
getObj := reflect.TypeOf(obj)
if errs, ok := err.(validator.ValidationErrors); ok {
@ -32,17 +54,17 @@ func GetValidFieldName(err error, obj interface{}) string {
return err.Error()
}
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, pass bool) {
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail bool) {
channelId, ok := c.Get("channelId")
if ok {
channel, pass = fetchChannelById(c, channelId.(int))
if pass {
channel, fail = fetchChannelById(c, channelId.(int))
if fail {
return
}
}
channel, pass = fetchChannelByModel(c, modelName)
if pass {
channel, fail = fetchChannelByModel(c, modelName)
if fail {
return
}
@ -86,21 +108,6 @@ func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool
return channel, false
}
func getProvider(c *gin.Context, channel *model.Channel, relayMode int) (providersBase.ProviderInterface, bool) {
provider := providers.GetProvider(channel, c)
if provider == nil {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
return nil, true
}
if !provider.SupportAPI(relayMode) {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel does not support this API")
return nil, true
}
return provider, false
}
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
if !common.AutomaticDisableChannelEnabled {
return false
@ -130,138 +137,81 @@ func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
return true
}
func parseModelMapping(modelMapping string) (map[string]string, error) {
if modelMapping == "" || modelMapping == "{}" {
return nil, nil
}
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWithStatusCode {
// 将data转换为 JSON
responseBody, err := json.Marshal(data)
if err != nil {
return nil, err
}
return modelMap, nil
}
type QuotaInfo struct {
modelName string
promptTokens int
preConsumedTokens int
modelRatio float64
groupRatio float64
ratio float64
preConsumedQuota int
userId int
channelId int
tokenId int
HandelStatus bool
}
func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) {
quotaInfo := &QuotaInfo{
modelName: modelName,
promptTokens: promptTokens,
userId: c.GetInt("id"),
channelId: c.GetInt("channel_id"),
tokenId: c.GetInt("token_id"),
HandelStatus: false,
}
quotaInfo.initQuotaInfo(c.GetString("group"))
errWithCode := quotaInfo.preQuotaConsumption()
if errWithCode != nil {
return nil, errWithCode
return common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
return quotaInfo, nil
}
func (q *QuotaInfo) initQuotaInfo(groupName string) {
modelRatio := common.GetModelRatio(q.modelName)
groupRatio := common.GetGroupRatio(groupName)
preConsumedTokens := common.PreConsumedQuota
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio)
q.preConsumedTokens = preConsumedTokens
q.modelRatio = modelRatio
q.groupRatio = groupRatio
q.ratio = ratio
q.preConsumedQuota = preConsumedQuota
return
}
func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
userQuota, err := model.CacheGetUserQuota(q.userId)
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(http.StatusOK)
_, err = c.Writer.Write(responseBody)
if err != nil {
return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota < q.preConsumedQuota {
return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*q.preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
q.preConsumedQuota = 0
// common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if q.preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
q.HandelStatus = true
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
return nil
}
func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
quota := 0
completionRatio := common.GetCompletionRatio(q.modelName)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * q.ratio))
if q.ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - q.preConsumedQuota
err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(q.userId)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
requestTime := 0
requestStartTimeValue := ctx.Value("requestStartTime")
if requestStartTimeValue != nil {
requestStartTime, ok := requestStartTimeValue.(time.Time)
if ok {
requestTime = int(time.Since(requestStartTime).Milliseconds())
func responseStreamClient[T any](c *gin.Context, stream requester.StreamReaderInterface[T]) *types.OpenAIErrorWithStatusCode {
requester.SetEventStreamHeaders(c)
defer stream.Close()
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
if response != nil && len(*response) > 0 {
for _, streamResponse := range *response {
responseBody, _ := json.Marshal(streamResponse)
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)})
}
}
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
break
}
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", q.modelRatio, q.groupRatio)
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime)
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
model.UpdateChannelUsedQuota(q.channelId, quota)
if err != nil {
c.Render(-1, common.CustomEvent{Data: "data: " + err.Error()})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
break
}
for _, streamResponse := range *response {
responseBody, _ := json.Marshal(streamResponse)
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)})
}
}
return nil
}
func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode {
defer resp.Body.Close()
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
return nil
}
func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types.OpenAIErrorWithStatusCode {
for k, v := range response.Headers {
c.Writer.Header().Set(k, v)
}
c.Writer.WriteHeader(http.StatusOK)
_, err := c.Writer.Write(response.Body)
if err != nil {
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
return nil

View File

@ -2,29 +2,26 @@ package aigc2d
import (
"errors"
"one-api/common"
"one-api/model"
"one-api/providers/base"
)
func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) {
func (p *Aigc2dProvider) Balance() (float64, error) {
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response base.BalanceResponse
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
_, errWithCode := p.Requester.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
channel.UpdateBalance(response.TotalAvailable)
p.Channel.UpdateBalance(response.TotalAvailable)
return response.TotalAvailable, nil
}

View File

@ -1,20 +1,19 @@
package aigc2d
import (
"one-api/model"
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type Aigc2dProviderFactory struct{}
func (f Aigc2dProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &Aigc2dProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aigc2d.com"),
}
}
type Aigc2dProvider struct {
*openai.OpenAIProvider
}
func (f Aigc2dProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &Aigc2dProvider{
OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.aigc2d.com"),
}
}

View File

@ -3,24 +3,21 @@ package aiproxy
import (
"errors"
"fmt"
"one-api/common"
"one-api/model"
)
func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) {
func (p *AIProxyProvider) Balance() (float64, error) {
fullRequestURL := "https://aiproxy.io/api/report/getUserOverview"
headers := make(map[string]string)
headers["Api-Key"] = channel.Key
headers["Api-Key"] = p.Channel.Key
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response AIProxyUserOverviewResponse
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
_, errWithCode := p.Requester.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
@ -29,7 +26,7 @@ func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) {
return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
}
channel.UpdateBalance(response.Data.TotalPoints)
p.Channel.UpdateBalance(response.Data.TotalPoints)
return response.Data.TotalPoints, nil
}

View File

@ -1,20 +1,19 @@
package aiproxy
import (
"one-api/model"
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type AIProxyProviderFactory struct{}
func (f AIProxyProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &AIProxyProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aiproxy.io"),
}
}
type AIProxyProvider struct {
*openai.OpenAIProvider
}
func (f AIProxyProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &AIProxyProvider{
OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.aiproxy.io"),
}
}

24
providers/ali/ali_test.go Normal file
View File

@ -0,0 +1,24 @@
package ali_test
import (
"net/http"
"one-api/common"
"one-api/common/test"
"one-api/model"
)
func setupAliTestServer() (baseUrl string, server *test.ServerTest, teardown func()) {
server = test.NewTestServer()
ts := server.TestServer(func(w http.ResponseWriter, r *http.Request) bool {
return test.OpenAICheck(w, r)
})
ts.Start()
teardown = ts.Close
baseUrl = ts.URL
return
}
func getAliChannel(baseUrl string) model.Channel {
return test.GetChannel(common.ChannelTypeAli, baseUrl, "", "", "")
}

View File

@ -1,32 +1,66 @@
package ali
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"github.com/gin-gonic/gin"
"one-api/types"
)
// 定义供应商工厂
type AliProviderFactory struct{}
type AliProvider struct {
base.BaseProvider
}
// 创建 AliProvider
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
func (f AliProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f AliProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &AliProvider{
BaseProvider: base.BaseProvider{
BaseURL: "https://dashscope.aliyuncs.com",
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
Context: c,
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
},
}
}
type AliProvider struct {
base.BaseProvider
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://dashscope.aliyuncs.com",
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
}
}
// 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
var aliError *AliError
err := json.NewDecoder(resp.Body).Decode(aliError)
if err != nil {
return nil
}
return errorHandle(aliError)
}
// 错误处理
func errorHandle(aliError *AliError) *types.OpenAIError {
if aliError.Code == "" {
return nil
}
return &types.OpenAIError{
Message: aliError.Message,
Type: aliError.Code,
Param: aliError.RequestId,
Code: aliError.Code,
}
}
func (p *AliProvider) GetFullRequestURL(requestURL string, modelName string) string {

View File

@ -1,51 +1,116 @@
package ali
import (
"bufio"
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"strings"
)
// 阿里云响应处理
func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if aliResponse.Code != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}
return
}
OpenAIResponse = types.ChatCompletionResponse{
ID: aliResponse.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Model: aliResponse.Model,
Choices: aliResponse.Output.ToChatCompletionChoices(),
Usage: &types.Usage{
PromptTokens: aliResponse.Usage.InputTokens,
CompletionTokens: aliResponse.Usage.OutputTokens,
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
},
}
return
type aliStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
lastStreamResponse string
}
const AliEnableSearchModelSuffix = "-internet"
// 获取聊天请求体
func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
func (p *AliProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getAliChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
aliResponse := &AliChatResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, aliResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
return p.convertToChatOpenai(aliResponse, request)
}
func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getAliChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := &aliStreamHandler{
Usage: p.Usage,
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
}
func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, request.Model)
// 获取请求头
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
headers["X-DashScope-SSE"] = "enable"
}
aliRequest := convertFromChatOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(aliRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
return req, nil
}
// 转换为OpenAI聊天请求体
func (p *AliProvider) convertToChatOpenai(response *AliChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
error := errorHandle(&response.AliError)
if error != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *error,
StatusCode: http.StatusBadRequest,
}
return
}
openaiResponse = &types.ChatCompletionResponse{
ID: response.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Model: request.Model,
Choices: response.Output.ToChatCompletionChoices(),
Usage: &types.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
*p.Usage = *openaiResponse.Usage
return
}
// 阿里云聊天请求体
func convertFromChatOpenai(request *types.ChatCompletionRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
@ -96,163 +161,68 @@ func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *
}
}
// 聊天
func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
headers["X-DashScope-SSE"] = "enable"
// 转换为OpenAI聊天流式请求体
func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data:") {
*rawLine = nil
return nil
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
// 去除前缀
*rawLine = (*rawLine)[5:]
var aliResponse AliChatResponse
err := json.Unmarshal(*rawLine, &aliResponse)
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
return common.ErrorToOpenAIError(err)
}
if request.Stream {
usage, errWithCode = p.sendStreamRequest(req, request.Model)
if errWithCode != nil {
return
}
if usage == nil {
usage = &types.Usage{
PromptTokens: 0,
CompletionTokens: 0,
TotalTokens: 0,
}
}
} else {
aliResponse := &AliChatResponse{
Model: request.Model,
}
errWithCode = p.SendRequest(req, aliResponse, false)
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: aliResponse.Usage.InputTokens,
CompletionTokens: aliResponse.Usage.OutputTokens,
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
}
error := errorHandle(&aliResponse.AliError)
if error != nil {
return error
}
return
return h.convertToOpenaiStream(&aliResponse, response)
}
// 阿里云响应转OpenAI响应
func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
// chatChoice := aliResponse.Output.ToChatCompletionChoices()
// jsonBody, _ := json.MarshalIndent(chatChoice, "", " ")
// fmt.Println("requestBody:", string(jsonBody))
func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, response *[]types.ChatCompletionStreamResponse) error {
content := aliResponse.Output.Choices[0].Message.StringContent()
var choice types.ChatCompletionStreamChoice
choice.Index = aliResponse.Output.Choices[0].Index
choice.Delta.Content = aliResponse.Output.Choices[0].Message.StringContent()
// fmt.Println("choice.Delta.Content:", chatChoice[0].Message)
if aliResponse.Output.Choices[0].FinishReason != "null" {
finishReason := aliResponse.Output.Choices[0].FinishReason
choice.FinishReason = &finishReason
choice.Delta.Content = strings.TrimPrefix(content, h.lastStreamResponse)
if aliResponse.Output.Choices[0].FinishReason != "" {
if aliResponse.Output.Choices[0].FinishReason != "null" {
finishReason := aliResponse.Output.Choices[0].FinishReason
choice.FinishReason = &finishReason
}
}
response := types.ChatCompletionStreamResponse{
if aliResponse.Output.FinishReason != "" {
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
}
h.lastStreamResponse = content
streamResponse := types.ChatCompletionStreamResponse{
ID: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: aliResponse.Model,
Model: h.Request.Model,
Choices: []types.ChatCompletionStreamChoice{choice},
}
return &response
}
// 发送流请求
func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close()
usage = &types.Usage{}
// 发送请求
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return nil, common.HandleErrorResp(resp)
if aliResponse.Usage.OutputTokens != 0 {
h.Usage.PromptTokens = aliResponse.Usage.InputTokens
h.Usage.CompletionTokens = aliResponse.Usage.OutputTokens
h.Usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
defer resp.Body.Close()
*response = append(*response, streamResponse)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
lastResponseText := ""
index := 0
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
aliResponse.Model = model
aliResponse.Output.Choices[0].Index = index
index++
response := p.streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Choices[0].Message.StringContent()
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return
return nil
}

330
providers/ali/chat_test.go Normal file
View File

@ -0,0 +1,330 @@
package ali_test
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common/test"
_ "one-api/common/test/init"
"one-api/providers"
providers_base "one-api/providers/base"
"one-api/types"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func getChatProvider(url string, context *gin.Context) providers_base.ChatInterface {
channel := getAliChannel(url)
provider := providers.GetProvider(&channel, context)
chatProvider, _ := provider.(providers_base.ChatInterface)
return chatProvider
}
func TestChatCompletions(t *testing.T) {
url, server, teardown := setupAliTestServer()
context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
defer teardown()
server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionEndpoint)
chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "false")
chatProvider := getChatProvider(url, context)
usage := &types.Usage{}
chatProvider.SetUsage(usage)
response, errWithCode := chatProvider.CreateChatCompletion(chatRequest)
assert.Nil(t, errWithCode)
assert.IsType(t, &types.Usage{}, usage)
assert.Equal(t, 33, usage.TotalTokens)
assert.Equal(t, 14, usage.PromptTokens)
assert.Equal(t, 19, usage.CompletionTokens)
// 转换成JSON字符串
responseBody, err := json.Marshal(response)
if err != nil {
fmt.Println(err)
assert.Fail(t, "json marshal error")
}
fmt.Println(string(responseBody))
test.CheckChat(t, response, "qwen-turbo", usage)
}
func TestChatCompletionsError(t *testing.T) {
url, server, teardown := setupAliTestServer()
context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
defer teardown()
server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionErrorEndpoint)
chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "false")
chatProvider := getChatProvider(url, context)
_, err := chatProvider.CreateChatCompletion(chatRequest)
usage := chatProvider.GetUsage()
assert.NotNil(t, err)
assert.Nil(t, usage)
assert.Equal(t, "InvalidParameter", err.Code)
}
// func TestChatCompletionsStream(t *testing.T) {
// url, server, teardown := setupAliTestServer()
// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
// defer teardown()
// server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionStreamEndpoint)
// channel := getAliChannel(url)
// provider := providers.GetProvider(&channel, context)
// chatProvider, _ := provider.(providers_base.ChatInterface)
// chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "true")
// usage := &types.Usage{}
// chatProvider.SetUsage(usage)
// response, errWithCode := chatProvider.CreateChatCompletionStream(chatRequest)
// assert.Nil(t, errWithCode)
// assert.IsType(t, &types.Usage{}, usage)
// assert.Equal(t, 16, usage.TotalTokens)
// assert.Equal(t, 8, usage.PromptTokens)
// assert.Equal(t, 8, usage.CompletionTokens)
// streamResponseCheck(t, w.Body.String())
// }
// func TestChatCompletionsStreamError(t *testing.T) {
// url, server, teardown := setupAliTestServer()
// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
// defer teardown()
// server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionStreamErrorEndpoint)
// channel := getAliChannel(url)
// provider := providers.GetProvider(&channel, context)
// chatProvider, _ := provider.(providers_base.ChatInterface)
// chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "true")
// usage, err := chatProvider.ChatAction(chatRequest, 0)
// // 打印 context 写入的内容
// fmt.Println(w.Body.String())
// assert.NotNil(t, err)
// assert.Nil(t, usage)
// }
// func TestChatImageCompletions(t *testing.T) {
// url, server, teardown := setupAliTestServer()
// context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
// defer teardown()
// server.RegisterHandler("/api/v1/services/aigc/multimodal-generation/generation", handleChatImageCompletionEndpoint)
// channel := getAliChannel(url)
// provider := providers.GetProvider(&channel, context)
// chatProvider, _ := provider.(providers_base.ChatInterface)
// chatRequest := test.GetChatCompletionRequest("image", "qwen-vl-plus", "false")
// usage, err := chatProvider.ChatAction(chatRequest, 0)
// assert.Nil(t, err)
// assert.IsType(t, &types.Usage{}, usage)
// assert.Equal(t, 1306, usage.TotalTokens)
// assert.Equal(t, 1279, usage.PromptTokens)
// assert.Equal(t, 27, usage.CompletionTokens)
// }
// func TestChatImageCompletionsStream(t *testing.T) {
// url, server, teardown := setupAliTestServer()
// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
// defer teardown()
// server.RegisterHandler("/api/v1/services/aigc/multimodal-generation/generation", handleChatImageCompletionStreamEndpoint)
// channel := getAliChannel(url)
// provider := providers.GetProvider(&channel, context)
// chatProvider, _ := provider.(providers_base.ChatInterface)
// chatRequest := test.GetChatCompletionRequest("image", "qwen-vl-plus", "true")
// usage, err := chatProvider.ChatAction(chatRequest, 0)
// fmt.Println(w.Body.String())
// assert.Nil(t, err)
// assert.IsType(t, &types.Usage{}, usage)
// assert.Equal(t, 1342, usage.TotalTokens)
// assert.Equal(t, 1279, usage.PromptTokens)
// assert.Equal(t, 63, usage.CompletionTokens)
// streamResponseCheck(t, w.Body.String())
// }
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
response := `{"output":{"choices":[{"finish_reason":"stop","message":{"role":"assistant","content":"您好!我可以帮您查询最近的公园,请问您现在所在的位置是哪里呢?"}}]},"usage":{"total_tokens":33,"output_tokens":19,"input_tokens":14},"request_id":"2479f818-9717-9b0b-9769-0d26e873a3f6"}`
w.Header().Set("Content-Type", "application/json")
fmt.Fprintln(w, response)
}
func handleChatCompletionErrorEndpoint(w http.ResponseWriter, r *http.Request) {
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
response := `{"code":"InvalidParameter","message":"Role must be user or assistant and Content length must be greater than 0","request_id":"4883ee8d-f095-94ff-a94a-5ce0a94bc81f"}`
w.Header().Set("Content-Type", "application/json")
fmt.Fprintln(w, response)
}
func handleChatCompletionStreamEndpoint(w http.ResponseWriter, r *http.Request) {
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
// 检测头部是否有X-DashScope-SSE: enable
if r.Header.Get("X-DashScope-SSE") != "enable" {
http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest)
}
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("id:1\n")...)
dataBytes = append(dataBytes, []byte("event:result\n")...)
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
//nolint:lll
data := `{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":10,"input_tokens":8,"output_tokens":2},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}`
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
dataBytes = append(dataBytes, []byte("id:2\n")...)
dataBytes = append(dataBytes, []byte("event:result\n")...)
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
//nolint:lll
data = `{"output":{"choices":[{"message":{"content":"有什么我可以帮助你的吗?","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":16,"input_tokens":8,"output_tokens":8},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}`
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
dataBytes = append(dataBytes, []byte("id:3\n")...)
dataBytes = append(dataBytes, []byte("event:result\n")...)
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
//nolint:lll
data = `{"output":{"choices":[{"message":{"content":"","role":"assistant"},"finish_reason":"stop"}]},"usage":{"total_tokens":16,"input_tokens":8,"output_tokens":8},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}`
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
_, err := w.Write(dataBytes)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func handleChatCompletionStreamErrorEndpoint(w http.ResponseWriter, r *http.Request) {
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
// 检测头部是否有X-DashScope-SSE: enable
if r.Header.Get("X-DashScope-SSE") != "enable" {
http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest)
}
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("id:1\n")...)
dataBytes = append(dataBytes, []byte("event:error\n")...)
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/400\n")...)
//nolint:lll
data := `{"code":"InvalidParameter","message":"Role must be user or assistant and Content length must be greater than 0","request_id":"6b932ba9-41bd-9ad3-b430-24bc1e125880"}`
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
_, err := w.Write(dataBytes)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func handleChatImageCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
response := `{"output":{"finish_reason":"stop","choices":[{"message":{"role":"assistant","content":[{"text":"这张照片展示的是一个海滩的场景,但是并没有明确指出具体的位置。可以看到海浪和日落背景下的沙滩景色。"}]}}]},"usage":{"output_tokens":27,"input_tokens":1279,"image_tokens":1247},"request_id":"a360d53b-b993-927f-9a68-bef6b2b2042e"}`
w.Header().Set("Content-Type", "application/json")
fmt.Fprintln(w, response)
}
func handleChatImageCompletionStreamEndpoint(w http.ResponseWriter, r *http.Request) {
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
// 检测头部是否有X-DashScope-SSE: enable
if r.Header.Get("X-DashScope-SSE") != "enable" {
http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest)
}
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("id:1\n")...)
dataBytes = append(dataBytes, []byte("event:result\n")...)
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
//nolint:lll
data := `{"output":{"choices":[{"message":{"content":[{"text":"这张"}],"role":"assistant"}}],"finish_reason":"null"},"usage":{"input_tokens":1279,"output_tokens":1,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}`
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
dataBytes = append(dataBytes, []byte("id:2\n")...)
dataBytes = append(dataBytes, []byte("event:result\n")...)
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
//nolint:lll
data = `{"output":{"choices":[{"message":{"content":[{"text":"这张照片"}],"role":"assistant"}}],"finish_reason":"null"},"usage":{"input_tokens":1279,"output_tokens":2,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}`
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
dataBytes = append(dataBytes, []byte("id:3\n")...)
dataBytes = append(dataBytes, []byte("event:result\n")...)
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
//nolint:lll
data = `{"output":{"choices":[{"message":{"content":[{"text":"这张照片展示的是一个海滩的场景,具体来说是在日落时分。由于没有明显的地标或建筑物等特征可以辨认出具体的地点信息,所以无法确定这是哪个地方的海滩。但是根据图像中的元素和环境特点,我们可以推测这可能是一个位于沿海地区的沙滩海岸线。"}],"role":"assistant"}}],"finish_reason":"stop"},"usage":{"input_tokens":1279,"output_tokens":63,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}`
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
_, err := w.Write(dataBytes)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func streamResponseCheck(t *testing.T, response string) {
// 以换行符分割response
lines := strings.Split(response, "\n\n")
// 如果最后一行为空,则删除最后一行
if lines[len(lines)-1] == "" {
lines = lines[:len(lines)-1]
}
// 循环遍历每一行
for _, line := range lines {
if line == "" {
continue
}
// assert判断 是否以data: 开头
assert.True(t, strings.HasPrefix(line, "data: "))
}
// 检测最后一行是否以data: [DONE] 结尾
assert.True(t, strings.HasSuffix(lines[len(lines)-1], "data: [DONE]"))
// 检测倒数第二行是否存在 `"finish_reason":"stop"`
assert.True(t, strings.Contains(lines[len(lines)-2], `"finish_reason":"stop"`))
}

View File

@ -6,40 +6,37 @@ import (
"one-api/types"
)
// 嵌入请求处理
func (aliResponse *AliEmbeddingResponse) ResponseHandler(resp *http.Response) (any, *types.OpenAIErrorWithStatusCode) {
if aliResponse.Code != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}
func (p *AliProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, request.Model)
// 获取请求头
headers := p.GetRequestHeaders()
aliRequest := convertFromEmbeddingOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(aliRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
defer req.Body.Close()
aliResponse := &AliEmbeddingResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, aliResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
openAIEmbeddingResponse := &types.EmbeddingResponse{
Object: "list",
Data: make([]types.Embedding, 0, len(aliResponse.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: &types.Usage{TotalTokens: aliResponse.Usage.TotalTokens},
}
for _, item := range aliResponse.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return openAIEmbeddingResponse, nil
return p.convertToEmbeddingOpenai(aliResponse, request)
}
// 获取嵌入请求体
func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest {
func convertFromEmbeddingOpenai(request *types.EmbeddingRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
@ -50,24 +47,36 @@ func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest)
}
}
func (p *AliProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getEmbeddingsRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
aliEmbeddingResponse := &AliEmbeddingResponse{}
errWithCode = p.SendRequest(req, aliEmbeddingResponse, false)
if errWithCode != nil {
func (p *AliProvider) convertToEmbeddingOpenai(response *AliEmbeddingResponse, request *types.EmbeddingRequest) (openaiResponse *types.EmbeddingResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
error := errorHandle(&response.AliError)
if error != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *error,
StatusCode: http.StatusBadRequest,
}
return
}
usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens}
return usage, nil
openaiResponse = &types.EmbeddingResponse{
Object: "list",
Data: make([]types.Embedding, 0, len(response.Output.Embeddings)),
Model: request.Model,
Usage: &types.Usage{
PromptTokens: response.Usage.TotalTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.TotalTokens,
},
}
for _, item := range response.Output.Embeddings {
openaiResponse.Data = append(openaiResponse.Data, types.Embedding{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
*p.Usage = *openaiResponse.Usage
return
}

View File

@ -52,7 +52,8 @@ type AliChoice struct {
}
type AliOutput struct {
Choices []types.ChatCompletionChoice `json:"choices"`
Choices []types.ChatCompletionChoice `json:"choices"`
FinishReason string `json:"finish_reason,omitempty"`
}
func (o *AliOutput) ToChatCompletionChoices() []types.ChatCompletionChoice {
@ -70,7 +71,6 @@ func (o *AliOutput) ToChatCompletionChoices() []types.ChatCompletionChoice {
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
Model string `json:"model,omitempty"`
AliError
}

View File

@ -2,7 +2,6 @@ package api2d
import (
"errors"
"one-api/common"
"one-api/model"
"one-api/providers/base"
)
@ -11,15 +10,14 @@ func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response base.BalanceResponse
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
_, errWithCode := p.Requester.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}

View File

@ -1,18 +1,17 @@
package api2d
import (
"one-api/model"
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type Api2dProviderFactory struct{}
// 创建 Api2dProvider
func (f Api2dProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f Api2dProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &Api2dProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://oa.api2d.net"),
OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://oa.api2d.net"),
}
}

View File

@ -2,7 +2,6 @@ package api2gpt
import (
"errors"
"one-api/common"
"one-api/model"
"one-api/providers/base"
)
@ -11,15 +10,14 @@ func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response base.BalanceResponse
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
_, errWithCode := p.Requester.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}

View File

@ -1,17 +1,16 @@
package api2gpt
import (
"one-api/model"
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type Api2gptProviderFactory struct{}
func (f Api2gptProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f Api2gptProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &Api2gptProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.api2gpt.com"),
OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.api2gpt.com"),
}
}

View File

@ -1,30 +1,23 @@
package azure
import (
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type AzureProviderFactory struct{}
// 创建 AzureProvider
func (f AzureProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f AzureProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
config := getAzureConfig()
return &AzureProvider{
OpenAIProvider: openai.OpenAIProvider{
BaseProvider: base.BaseProvider{
BaseURL: "",
Completions: "/completions",
ChatCompletions: "/chat/completions",
Embeddings: "/embeddings",
AudioTranscriptions: "/audio/transcriptions",
AudioTranslations: "/audio/translations",
ImagesGenerations: "/images/generations",
// ImagesEdit: "/images/edit",
// ImagesVariations: "/images/variations",
Context: c,
// AudioSpeech: "/audio/speech",
Config: config,
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, openai.RequestErrorHandle),
},
IsAzure: true,
BalanceAction: false,
@ -32,6 +25,18 @@ func (f AzureProviderFactory) Create(c *gin.Context) base.ProviderInterface {
}
}
func getAzureConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "",
Completions: "/completions",
ChatCompletions: "/chat/completions",
Embeddings: "/embeddings",
AudioTranscriptions: "/audio/transcriptions",
AudioTranslations: "/audio/translations",
ImagesGenerations: "/images/generations",
}
}
type AzureProvider struct {
openai.OpenAIProvider
}

View File

@ -10,13 +10,60 @@ import (
"time"
)
func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Status == "canceled" || c.Status == "failed" {
func (p *AzureProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
if !openai.IsWithinRange(request.Model, request.N) {
return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest)
}
req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
var response *types.ImageResponse
var resp *http.Response
if request.Model == "dall-e-2" {
imageAzureResponse := &ImageAzureResponse{}
resp, errWithCode = p.Requester.SendRequest(req, imageAzureResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
response, errWithCode = p.ResponseAzureImageHandler(resp, imageAzureResponse)
if errWithCode != nil {
return nil, errWithCode
}
} else {
var openaiResponse openai.OpenAIProviderImageResponse
_, errWithCode = p.Requester.SendRequest(req, &openaiResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
// 检测是否错误
openaiErr := openai.ErrorHandle(&openaiResponse.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return nil, errWithCode
}
response = &openaiResponse.ImageResponse
}
p.Usage.TotalTokens = p.Usage.PromptTokens
return response, nil
}
func (p *AzureProvider) ResponseAzureImageHandler(resp *http.Response, azure *ImageAzureResponse) (OpenAIResponse *types.ImageResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
if azure.Status == "canceled" || azure.Status == "failed" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: c.Error.Message,
Message: azure.Error.Message,
Type: "one_api_error",
Code: c.Error.Code,
Code: azure.Error.Code,
},
StatusCode: resp.StatusCode,
}
@ -28,8 +75,7 @@ func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIRespons
return nil, common.ErrorWrapper(errors.New("image url is empty"), "get_images_url_failed", http.StatusInternalServerError)
}
client := common.NewClient()
req, err := client.NewRequest("GET", operation_location, common.WithHeader(c.Header))
req, err := p.Requester.NewRequest("GET", operation_location, p.Requester.WithHeader(p.GetRequestHeaders()))
if err != nil {
return nil, common.ErrorWrapper(err, "get_images_request_failed", http.StatusInternalServerError)
}
@ -38,7 +84,7 @@ func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIRespons
for i := 0; i < 3; i++ {
// 休眠 2 秒
time.Sleep(2 * time.Second)
_, errWithCode = common.SendRequest(req, &getImageAzureResponse, false, c.Proxy)
_, errWithCode = p.Requester.SendRequest(req, &getImageAzureResponse, false)
fmt.Println("getImageAzureResponse", getImageAzureResponse)
if errWithCode != nil {
return
@ -47,57 +93,17 @@ func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIRespons
if getImageAzureResponse.Status == "canceled" || getImageAzureResponse.Status == "failed" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: c.Error.Message,
Message: getImageAzureResponse.Error.Message,
Type: "get_images_request_failed",
Code: c.Error.Code,
Code: getImageAzureResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}
}
if getImageAzureResponse.Status == "succeeded" {
return getImageAzureResponse.Result, nil
return &getImageAzureResponse.Result, nil
}
}
return nil, common.ErrorWrapper(errors.New("get image Timeout"), "get_images_url_failed", http.StatusInternalServerError)
}
func (p *AzureProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.GetRequestBody(&request, isModelMapped)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.ImagesGenerations, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if request.Model == "dall-e-2" {
imageAzureResponse := &ImageAzureResponse{
Header: headers,
Proxy: p.Channel.Proxy,
}
errWithCode = p.SendRequest(req, imageAzureResponse, false)
} else {
openAIProviderImageResponseResponse := &openai.OpenAIProviderImageResponseResponse{}
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
}
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
}
return
}

View File

@ -1,21 +1,23 @@
package azureSpeech
import (
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"github.com/gin-gonic/gin"
)
// 定义供应商工厂
type AzureSpeechProviderFactory struct{}
// 创建 AliProvider
func (f AzureSpeechProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f AzureSpeechProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &AzureSpeechProvider{
BaseProvider: base.BaseProvider{
BaseURL: "",
AudioSpeech: "/cognitiveservices/v1",
Context: c,
Config: base.ProviderConfig{
AudioSpeech: "/cognitiveservices/v1",
},
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, nil),
},
}
}

View File

@ -55,9 +55,12 @@ func (p *AzureSpeechProvider) getRequestBody(request *types.SpeechAudioRequest)
}
func (p *AzureSpeechProvider) SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
fullRequestURL := p.GetFullRequestURL(p.AudioSpeech, request.Model)
func (p *AzureSpeechProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeAudioSpeech)
if errWithCode != nil {
return nil, errWithCode
}
fullRequestURL := p.GetFullRequestURL(url, request.Model)
headers := p.GetRequestHeaders()
responseFormatr := outputFormatMap[request.ResponseFormat]
if responseFormatr == "" {
@ -67,22 +70,19 @@ func (p *AzureSpeechProvider) SpeechAction(request *types.SpeechAudioRequest, is
requestBody := p.getRequestBody(request)
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(requestBody), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
defer req.Body.Close()
errWithCode = p.SendRequestRaw(req)
var resp *http.Response
resp, errWithCode = p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return
return nil, errWithCode
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
}
p.Usage.TotalTokens = p.Usage.PromptTokens
return
return resp, nil
}

View File

@ -1,10 +1,10 @@
package baichuan
import (
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
// 定义供应商工厂
@ -12,19 +12,26 @@ type BaichuanProviderFactory struct{}
// 创建 BaichuanProvider
// https://platform.baichuan-ai.com/docs/api
func (f BaichuanProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f BaichuanProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &BaichuanProvider{
OpenAIProvider: openai.OpenAIProvider{
BaseProvider: base.BaseProvider{
BaseURL: "https://api.baichuan-ai.com",
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
Context: c,
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, openai.RequestErrorHandle),
},
},
}
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://api.baichuan-ai.com",
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
}
}
type BaichuanProvider struct {
openai.OpenAIProvider
}

View File

@ -3,31 +3,61 @@ package baichuan
import (
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/providers/openai"
"one-api/types"
"strings"
)
func (baichuanResponse *BaichuanChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if baichuanResponse.Error.Message != "" {
func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, requestBody)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
response := &BaichuanChatResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
// 检测是否错误
openaiErr := openai.ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: baichuanResponse.Error,
StatusCode: resp.StatusCode,
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return
return nil, errWithCode
}
OpenAIResponse = types.ChatCompletionResponse{
ID: baichuanResponse.ID,
Object: baichuanResponse.Object,
Created: baichuanResponse.Created,
Model: baichuanResponse.Model,
Choices: baichuanResponse.Choices,
Usage: baichuanResponse.Usage,
*p.Usage = *response.Usage
return &response.ChatCompletionResponse, nil
}
func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
return
chatHandler := openai.OpenAIStreamHandler{
Usage: p.Usage,
ModelName: request.Model,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream)
}
// 获取聊天请求体
@ -55,46 +85,3 @@ func (p *BaichuanProvider) getChatRequestBody(request *types.ChatCompletionReque
TopK: request.N,
}
}
// 聊天
func (p *BaichuanProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if request.Stream {
openAIProviderChatStreamResponse := &openai.OpenAIProviderChatStreamResponse{}
var textResponse string
errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderChatStreamResponse)
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: common.CountTokenText(textResponse, request.Model),
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
}
} else {
baichuanResponse := &BaichuanChatResponse{}
errWithCode = p.SendRequest(req, baichuanResponse, false)
if errWithCode != nil {
return
}
usage = baichuanResponse.Usage
}
return
}

View File

@ -4,35 +4,65 @@ import (
"encoding/json"
"errors"
"fmt"
"one-api/common"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/types"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// 定义供应商工厂
type BaiduProviderFactory struct{}
// 创建 BaiduProvider
var baiduTokenStore sync.Map
func (f BaiduProviderFactory) Create(c *gin.Context) base.ProviderInterface {
// 创建 BaiduProvider
type BaiduProvider struct {
base.BaseProvider
}
func (f BaiduProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &BaiduProvider{
BaseProvider: base.BaseProvider{
BaseURL: "https://aip.baidubce.com",
ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings",
Context: c,
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
},
}
}
var baiduTokenStore sync.Map
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://aip.baidubce.com",
ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings",
}
}
type BaiduProvider struct {
base.BaseProvider
// 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
var baiduError *BaiduError
err := json.NewDecoder(resp.Body).Decode(baiduError)
if err != nil {
return nil
}
return errorHandle(baiduError)
}
// 错误处理
func errorHandle(baiduError *BaiduError) *types.OpenAIError {
if baiduError.ErrorMsg == "" {
return nil
}
return &types.OpenAIError{
Message: baiduError.ErrorMsg,
Type: "baidu_error",
Code: baiduError.ErrorCode,
}
}
// 获取完整请求 URL
@ -92,32 +122,21 @@ func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessTo
return nil, errors.New("invalid baidu apikey")
}
client := common.NewClient()
url := fmt.Sprintf(p.BaseURL+"/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1])
url := fmt.Sprintf(p.Config.BaseURL+"/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1])
var headers = map[string]string{
"Content-Type": "application/json",
"Accept": "application/json",
}
req, err := client.NewRequest("POST", url, common.WithHeader(headers))
req, err := p.Requester.NewRequest("POST", url, p.Requester.WithHeader(headers))
if err != nil {
return nil, err
}
httpClient := common.GetHttpClient(p.Channel.Proxy)
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
common.PutHttpClient(httpClient)
defer resp.Body.Close()
var accessToken BaiduAccessToken
err = json.NewDecoder(resp.Body).Decode(&accessToken)
if err != nil {
return nil, err
_, errWithCode := p.Requester.SendRequest(req, &accessToken, false)
if errWithCode != nil {
return nil, errors.New(errWithCode.OpenAIError.Message)
}
if accessToken.Error != "" {
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)

View File

@ -1,79 +1,155 @@
package baidu
import (
"bufio"
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/providers/base"
"one-api/common/requester"
"one-api/types"
"strings"
)
func (baiduResponse *BaiduChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if baiduResponse.ErrorMsg != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
Code: baiduResponse.ErrorCode,
},
StatusCode: resp.StatusCode,
type baiduStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
}
func (p *BaiduProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getBaiduChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
baiduResponse := &BaiduChatResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, baiduResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
return p.convertToChatOpenai(baiduResponse, request)
}
func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getBaiduChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := &baiduStreamHandler{
Usage: p.Usage,
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
}
func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, request.Model)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
}
// 获取请求头
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
}
baiduRequest := convertFromChatOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(baiduRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
return req, nil
}
func (p *BaiduProvider) convertToChatOpenai(response *BaiduChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
error := errorHandle(&response.BaiduError)
if error != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *error,
StatusCode: http.StatusBadRequest,
}
return
}
choice := types.ChatCompletionChoice{
Index: 0,
Message: types.ChatCompletionMessage{
Role: "assistant",
// Content: baiduResponse.Result,
},
FinishReason: base.StopFinishReason,
FinishReason: types.FinishReasonStop,
}
if baiduResponse.FunctionCall != nil {
if baiduResponse.FunctionCate == "tool" {
if response.FunctionCall != nil {
if request.Tools != nil {
choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: baiduResponse.Id,
Id: response.Id,
Type: "function",
Function: *baiduResponse.FunctionCall,
Function: response.FunctionCall,
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
choice.FinishReason = types.FinishReasonToolCalls
} else {
choice.Message.FunctionCall = baiduResponse.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
choice.Message.FunctionCall = response.FunctionCall
choice.FinishReason = types.FinishReasonFunctionCall
}
} else {
choice.Message.Content = baiduResponse.Result
choice.Message.Content = response.Result
}
OpenAIResponse = types.ChatCompletionResponse{
ID: baiduResponse.Id,
openaiResponse = &types.ChatCompletionResponse{
ID: response.Id,
Object: "chat.completion",
Created: baiduResponse.Created,
Model: request.Model,
Created: response.Created,
Choices: []types.ChatCompletionChoice{choice},
Usage: baiduResponse.Usage,
Usage: response.Usage,
}
*p.Usage = *openaiResponse.Usage
return
}
func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaiduChatRequest {
func convertFromChatOpenai(request *types.ChatCompletionRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
if message.Role == types.ChatMessageRoleSystem {
messages = append(messages, BaiduMessage{
Role: "user",
Role: types.ChatMessageRoleUser,
Content: message.StringContent(),
})
messages = append(messages, BaiduMessage{
Role: "assistant",
Role: types.ChatMessageRoleAssistant,
Content: "Okay",
})
} else if message.Role == types.ChatMessageRoleFunction {
messages = append(messages, BaiduMessage{
Role: types.ChatMessageRoleAssistant,
Content: "Okay",
})
messages = append(messages, BaiduMessage{
Role: types.ChatMessageRoleUser,
Content: "这是函数调用返回的内容,请回答之前的问题:\n" + message.StringContent(),
})
} else {
messages = append(messages, BaiduMessage{
Role: message.Role,
@ -101,154 +177,82 @@ func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest)
return baiduChatRequest
}
func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
// 转换为OpenAI聊天流式请求体
func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
}
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
}
// 去除前缀
*rawLine = (*rawLine)[6:]
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
var baiduResponse BaiduChatStreamResponse
err := json.Unmarshal(*rawLine, &baiduResponse)
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
return common.ErrorToOpenAIError(err)
}
if request.Stream {
usage, errWithCode = p.sendStreamRequest(req, request.Model, request.GetFunctionCate())
if errWithCode != nil {
return
}
} else {
baiduChatRequest := &BaiduChatResponse{
Model: request.Model,
FunctionCate: request.GetFunctionCate(),
}
errWithCode = p.SendRequest(req, baiduChatRequest, false)
if errWithCode != nil {
return
}
usage = baiduChatRequest.Usage
if baiduResponse.IsEnd {
*isFinished = true
}
return
return h.convertToOpenaiStream(&baiduResponse, response)
}
func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse {
var choice types.ChatCompletionStreamChoice
func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, response *[]types.ChatCompletionStreamResponse) error {
choice := types.ChatCompletionStreamChoice{
Index: 0,
Delta: types.ChatCompletionStreamChoiceDelta{
Role: "assistant",
},
}
if baiduResponse.FunctionCall != nil {
if baiduResponse.FunctionCate == "tool" {
if h.Request.Tools != nil {
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: baiduResponse.Id,
Type: "function",
Function: *baiduResponse.FunctionCall,
Function: baiduResponse.FunctionCall,
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
choice.FinishReason = types.FinishReasonToolCalls
} else {
choice.Delta.FunctionCall = baiduResponse.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
choice.FinishReason = types.FinishReasonFunctionCall
}
} else {
choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd {
choice.FinishReason = &base.StopFinishReason
choice.FinishReason = types.FinishReasonStop
}
}
response := types.ChatCompletionStreamResponse{
chatCompletion := types.ChatCompletionStreamResponse{
ID: baiduResponse.Id,
Object: "chat.completion.chunk",
Created: baiduResponse.Created,
Model: baiduResponse.Model,
Choices: []types.ChatCompletionStreamChoice{choice},
Model: h.Request.Model,
}
return &response
}
func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close()
usage = &types.Usage{}
// 发送请求
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return nil, common.HandleErrorResp(resp)
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
data = data[6:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var baiduResponse BaiduChatStreamResponse
baiduResponse.FunctionCate = functionCate
err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
usage.TotalTokens = baiduResponse.Usage.TotalTokens
usage.PromptTokens = baiduResponse.Usage.PromptTokens
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
}
baiduResponse.Model = model
response := p.streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return usage, nil
if baiduResponse.FunctionCall == nil {
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletion)
} else {
choices := choice.ConvertOpenaiStream()
for _, choice := range choices {
chatCompletionCopy := chatCompletion
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletionCopy)
}
}
h.Usage.TotalTokens = baiduResponse.Usage.TotalTokens
h.Usage.PromptTokens = baiduResponse.Usage.PromptTokens
h.Usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
return nil
}

View File

@ -6,33 +6,63 @@ import (
"one-api/types"
)
func (p *BaiduProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *BaiduEmbeddingRequest {
func (p *BaiduProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, request.Model)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
}
// 获取请求头
headers := p.GetRequestHeaders()
aliRequest := convertFromEmbeddingOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(aliRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
defer req.Body.Close()
baiduResponse := &BaiduEmbeddingResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, baiduResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
return p.convertToEmbeddingOpenai(baiduResponse, request)
}
func convertFromEmbeddingOpenai(request *types.EmbeddingRequest) *BaiduEmbeddingRequest {
return &BaiduEmbeddingRequest{
Input: request.ParseInput(),
}
}
func (baiduResponse *BaiduEmbeddingResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if baiduResponse.ErrorMsg != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
Code: baiduResponse.ErrorCode,
},
StatusCode: resp.StatusCode,
func (p *BaiduProvider) convertToEmbeddingOpenai(response *BaiduEmbeddingResponse, request *types.EmbeddingRequest) (openaiResponse *types.EmbeddingResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
error := errorHandle(&response.BaiduError)
if error != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *error,
StatusCode: http.StatusBadRequest,
}
return
}
openAIEmbeddingResponse := &types.EmbeddingResponse{
Object: "list",
Data: make([]types.Embedding, 0, len(baiduResponse.Data)),
Model: "text-embedding-v1",
Usage: &baiduResponse.Usage,
Data: make([]types.Embedding, 0, len(response.Data)),
Model: request.Model,
Usage: &response.Usage,
}
for _, item := range baiduResponse.Data {
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{
Object: item.Object,
Index: item.Index,
@ -40,30 +70,7 @@ func (baiduResponse *BaiduEmbeddingResponse) ResponseHandler(resp *http.Response
})
}
*p.Usage = response.Usage
return openAIEmbeddingResponse, nil
}
func (p *BaiduProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getEmbeddingsRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
}
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
baiduEmbeddingResponse := &BaiduEmbeddingResponse{}
errWithCode = p.SendRequest(req, baiduEmbeddingResponse, false)
if errWithCode != nil {
return
}
usage = &baiduEmbeddingResponse.Usage
return usage, nil
}

View File

@ -3,9 +3,9 @@ package base
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/model"
"one-api/types"
"strings"
@ -13,11 +13,7 @@ import (
"github.com/gin-gonic/gin"
)
var StopFinishReason = "stop"
var StopFinishReasonToolFunction = "tool_calls"
var StopFinishReasonCallFunction = "function_call"
type BaseProvider struct {
type ProviderConfig struct {
BaseURL string
Completions string
ChatCompletions string
@ -29,8 +25,15 @@ type BaseProvider struct {
ImagesGenerations string
ImagesEdit string
ImagesVariations string
Context *gin.Context
Channel *model.Channel
}
type BaseProvider struct {
OriginalModel string
Usage *types.Usage
Config ProviderConfig
Context *gin.Context
Channel *model.Channel
Requester *requester.HTTPRequester
}
// 获取基础URL
@ -39,11 +42,7 @@ func (p *BaseProvider) GetBaseURL() string {
return p.Channel.GetBaseURL()
}
return p.BaseURL
}
func (p *BaseProvider) SetChannel(channel *model.Channel) {
p.Channel = channel
return p.Config.BaseURL
}
// 获取完整请求URL
@ -62,104 +61,85 @@ func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) {
}
}
// 发送请求
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close()
resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true, p.Channel.Proxy)
if openAIErrorWithStatusCode != nil {
return
}
defer resp.Body.Close()
openAIResponse, openAIErrorWithStatusCode := response.ResponseHandler(resp)
if openAIErrorWithStatusCode != nil {
return
}
if rawOutput {
for k, v := range resp.Header {
p.Context.Writer.Header().Set(k, v[0])
}
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err := io.Copy(p.Context.Writer, resp.Body)
if err != nil {
return common.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
} else {
jsonResponse, err := json.Marshal(openAIResponse)
if err != nil {
return common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
p.Context.Writer.Header().Set("Content-Type", "application/json")
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err = p.Context.Writer.Write(jsonResponse)
if err != nil {
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
}
return nil
func (p *BaseProvider) GetUsage() *types.Usage {
return p.Usage
}
func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close()
// 发送请求
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
common.PutHttpClient(client)
defer resp.Body.Close()
// 处理响应
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp)
}
for k, v := range resp.Header {
p.Context.Writer.Header().Set(k, v[0])
}
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(p.Context.Writer, resp.Body)
if err != nil {
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
return nil
func (p *BaseProvider) SetUsage(usage *types.Usage) {
p.Usage = usage
}
func (p *BaseProvider) SupportAPI(relayMode int) bool {
func (p *BaseProvider) SetContext(c *gin.Context) {
p.Context = c
}
func (p *BaseProvider) SetOriginalModel(ModelName string) {
p.OriginalModel = ModelName
}
func (p *BaseProvider) GetOriginalModel() string {
return p.OriginalModel
}
func (p *BaseProvider) GetChannel() *model.Channel {
return p.Channel
}
func (p *BaseProvider) ModelMappingHandler(modelName string) (string, error) {
p.OriginalModel = modelName
modelMapping := p.Channel.GetModelMapping()
if modelMapping == "" || modelMapping == "{}" {
return modelName, nil
}
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return "", err
}
if modelMap[modelName] != "" {
return modelMap[modelName], nil
}
return modelName, nil
}
func (p *BaseProvider) GetAPIUri(relayMode int) string {
switch relayMode {
case common.RelayModeChatCompletions:
return p.ChatCompletions != ""
return p.Config.ChatCompletions
case common.RelayModeCompletions:
return p.Completions != ""
return p.Config.Completions
case common.RelayModeEmbeddings:
return p.Embeddings != ""
return p.Config.Embeddings
case common.RelayModeAudioSpeech:
return p.AudioSpeech != ""
return p.Config.AudioSpeech
case common.RelayModeAudioTranscription:
return p.AudioTranscriptions != ""
return p.Config.AudioTranscriptions
case common.RelayModeAudioTranslation:
return p.AudioTranslations != ""
return p.Config.AudioTranslations
case common.RelayModeModerations:
return p.Moderation != ""
return p.Config.Moderation
case common.RelayModeImagesGenerations:
return p.ImagesGenerations != ""
return p.Config.ImagesGenerations
case common.RelayModeImagesEdits:
return p.ImagesEdit != ""
return p.Config.ImagesEdit
case common.RelayModeImagesVariations:
return p.ImagesVariations != ""
return p.Config.ImagesVariations
default:
return false
return ""
}
}
func (p *BaseProvider) GetSupportedAPIUri(relayMode int) (url string, err *types.OpenAIErrorWithStatusCode) {
url = p.GetAPIUri(relayMode)
if url == "" {
err = common.StringErrorWrapper("The API interface is not supported", "unsupported_api", http.StatusNotImplemented)
return
}
return
}

View File

@ -0,0 +1,7 @@
package base
import "one-api/types"
type BaseHandler struct {
Usage *types.Usage
}

View File

@ -2,84 +2,108 @@ package base
import (
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/types"
"github.com/gin-gonic/gin"
)
type Requestable interface {
types.CompletionRequest | types.ChatCompletionRequest | types.EmbeddingRequest | types.ModerationRequest | types.SpeechAudioRequest | types.AudioRequest | types.ImageRequest | types.ImageEditRequest
}
// 基础接口
type ProviderInterface interface {
GetBaseURL() string
GetFullRequestURL(requestURL string, modelName string) string
GetRequestHeaders() (headers map[string]string)
SupportAPI(relayMode int) bool
SetChannel(channel *model.Channel)
// 获取基础URL
// GetBaseURL() string
// 获取完整请求URL
// GetFullRequestURL(requestURL string, modelName string) string
// 获取请求头
// GetRequestHeaders() (headers map[string]string)
// 获取用量
GetUsage() *types.Usage
// 设置用量
SetUsage(usage *types.Usage)
// 设置Context
SetContext(c *gin.Context)
// 设置原始模型
SetOriginalModel(ModelName string)
// 获取原始模型
GetOriginalModel() string
// SupportAPI(relayMode int) bool
GetChannel() *model.Channel
ModelMappingHandler(modelName string) (string, error)
}
// 完成接口
type CompletionInterface interface {
ProviderInterface
CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateCompletion(request *types.CompletionRequest) (*types.CompletionResponse, *types.OpenAIErrorWithStatusCode)
CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[types.CompletionResponse], *types.OpenAIErrorWithStatusCode)
}
// 聊天接口
type ChatInterface interface {
ProviderInterface
ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode)
CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode)
}
// 嵌入接口
type EmbeddingsInterface interface {
ProviderInterface
EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode)
}
// 审查接口
type ModerationInterface interface {
ProviderInterface
ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode)
}
// 文字转语音接口
type SpeechInterface interface {
ProviderInterface
SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode)
}
// 语音转文字接口
type TranscriptionsInterface interface {
ProviderInterface
TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode)
}
// 语音翻译接口
type TranslationInterface interface {
ProviderInterface
TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode)
}
// 图片生成接口
type ImageGenerationsInterface interface {
ProviderInterface
ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
}
// 图片编辑接口
type ImageEditsInterface interface {
ProviderInterface
ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
}
type ImageVariationsInterface interface {
ProviderInterface
ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
}
// 余额接口
type BalanceInterface interface {
Balance(channel *model.Channel) (float64, error)
Balance() (float64, error)
}
type ProviderResponseHandler interface {
// 响应处理函数
ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode)
}
// type ProviderResponseHandler interface {
// // 响应处理函数
// ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode)
// }

View File

@ -1,20 +1,23 @@
package claude
import (
"encoding/json"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"github.com/gin-gonic/gin"
"one-api/types"
)
type ClaudeProviderFactory struct{}
// 创建 ClaudeProvider
func (f ClaudeProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f ClaudeProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &ClaudeProvider{
BaseProvider: base.BaseProvider{
BaseURL: "https://api.anthropic.com",
ChatCompletions: "/v1/complete",
Context: c,
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
},
}
}
@ -23,6 +26,36 @@ type ClaudeProvider struct {
base.BaseProvider
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://api.anthropic.com",
ChatCompletions: "/v1/complete",
}
}
// 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
var claudeError *ClaudeResponseError
err := json.NewDecoder(resp.Body).Decode(claudeError)
if err != nil {
return nil
}
return errorHandle(claudeError)
}
// 错误处理
func errorHandle(claudeError *ClaudeResponseError) *types.OpenAIError {
if claudeError.Error.Type == "" {
return nil
}
return &types.OpenAIError{
Message: claudeError.Error.Message,
Type: claudeError.Error.Type,
Code: claudeError.Error.Type,
}
}
// 获取请求头
func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
@ -41,9 +74,9 @@ func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
return "stop"
return types.FinishReasonStop
case "max_tokens":
return "length"
return types.FinishReasonLength
default:
return reason
}

View File

@ -1,56 +1,86 @@
package claude
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"strings"
)
func (claudeResponse *ClaudeResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if claudeResponse.Error.Type != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}
}
choice := types.ChatCompletionChoice{
Index: 0,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := types.ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice},
Model: claudeResponse.Model,
}
completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model)
claudeResponse.Usage.CompletionTokens = completionTokens
claudeResponse.Usage.TotalTokens = claudeResponse.Usage.PromptTokens + completionTokens
fullTextResponse.Usage = claudeResponse.Usage
return fullTextResponse, nil
type claudeStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
}
func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *ClaudeRequest) {
func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
claudeResponse := &ClaudeResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, claudeResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
return p.convertToChatOpenai(claudeResponse, request)
}
func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := &claudeStreamHandler{
Usage: p.Usage,
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
}
func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, request.Model)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_claude_config", http.StatusInternalServerError)
}
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
}
claudeRequest := convertFromChatOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(claudeRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
return req, nil
}
func convertFromChatOpenai(request *types.ChatCompletionRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{
Model: request.Model,
Prompt: "",
@ -80,138 +110,84 @@ func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest
return &claudeRequest
}
func (p *ClaudeProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
error := errorHandle(&response.ClaudeResponseError)
if error != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *error,
StatusCode: http.StatusBadRequest,
}
return
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
choice := types.ChatCompletionChoice{
Index: 0,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: strings.TrimPrefix(response.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(response.StopReason),
}
openaiResponse = &types.ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice},
Model: response.Model,
}
if request.Stream {
var responseText string
errWithCode, responseText = p.sendStreamRequest(req)
if errWithCode != nil {
return
}
completionTokens := common.CountTokenText(response.Completion, response.Model)
response.Usage.CompletionTokens = completionTokens
response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: common.CountTokenText(responseText, request.Model),
}
usage.TotalTokens = promptTokens + usage.CompletionTokens
openaiResponse.Usage = response.Usage
} else {
var claudeResponse = &ClaudeResponse{
Usage: &types.Usage{
PromptTokens: promptTokens,
},
}
errWithCode = p.SendRequest(req, claudeResponse, false)
if errWithCode != nil {
return
}
usage = claudeResponse.Usage
}
return
*p.Usage = *response.Usage
return openaiResponse, nil
}
func (p *ClaudeProvider) streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *types.ChatCompletionStreamResponse {
// 转换为OpenAI聊天流式请求体
func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), `data: {"type": "completion"`) {
*rawLine = nil
return nil
}
// 去除前缀
*rawLine = (*rawLine)[6:]
var claudeResponse *ClaudeResponse
err := json.Unmarshal(*rawLine, claudeResponse)
if err != nil {
return common.ErrorToOpenAIError(err)
}
if claudeResponse.StopReason == "stop_sequence" {
*isFinished = true
}
return h.convertToOpenaiStream(claudeResponse, response)
}
func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, response *[]types.ChatCompletionStreamResponse) error {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var response types.ChatCompletionStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = []types.ChatCompletionStreamChoice{choice}
return &response
}
func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
defer req.Body.Close()
// 发送请求
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""
chatCompletion := types.ChatCompletionStreamResponse{
Object: "chat.completion.chunk",
Model: h.Request.Model,
Choices: []types.ChatCompletionStreamChoice{choice},
}
defer resp.Body.Close()
*response = append(*response, chatCompletion)
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
return i + 4, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") {
continue
}
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += claudeResponse.Completion
response := p.streamResponseClaude2OpenAI(&claudeResponse)
response.ID = responseId
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
h.Usage.PromptTokens += common.CountTokenText(claudeResponse.Completion, h.Request.Model)
return nil, responseText
return nil
}

View File

@ -23,10 +23,13 @@ type ClaudeRequest struct {
Stream bool `json:"stream,omitempty"`
}
type ClaudeResponseError struct {
Error ClaudeError `json:"error,omitempty"`
}
type ClaudeResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
Usage *types.Usage `json:"usage,omitempty"`
ClaudeResponseError
}

View File

@ -2,7 +2,6 @@ package closeai
import (
"errors"
"one-api/common"
"one-api/model"
)
@ -10,15 +9,14 @@ func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error)
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response OpenAICreditGrants
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
_, errWithCode := p.Requester.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}

View File

@ -1,18 +1,17 @@
package closeai
import (
"one-api/model"
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type CloseaiProviderFactory struct{}
// 创建 CloseaiProvider
func (f CloseaiProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f CloseaiProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &CloseaiProxyProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"),
OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.closeai-proxy.xyz"),
}
}

View File

@ -1,22 +1,25 @@
package gemini
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type GeminiProviderFactory struct{}
// 创建 ClaudeProvider
func (f GeminiProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f GeminiProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &GeminiProvider{
BaseProvider: base.BaseProvider{
BaseURL: "https://generativelanguage.googleapis.com",
ChatCompletions: "/",
Context: c,
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
},
}
}
@ -25,6 +28,37 @@ type GeminiProvider struct {
base.BaseProvider
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://generativelanguage.googleapis.com",
ChatCompletions: "/",
}
}
// 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
var geminiError *GeminiErrorResponse
err := json.NewDecoder(resp.Body).Decode(geminiError)
if err != nil {
return nil
}
return errorHandle(geminiError)
}
// 错误处理
func errorHandle(geminiError *GeminiErrorResponse) *types.OpenAIError {
if geminiError.Error.Message == "" {
return nil
}
return &types.OpenAIError{
Message: geminiError.Error.Message,
Type: "gemini_error",
Param: geminiError.Error.Status,
Code: geminiError.Error.Code,
}
}
func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
version := "v1"

View File

@ -1,14 +1,12 @@
package gemini
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/image"
"one-api/providers/base"
"one-api/common/requester"
"one-api/types"
"strings"
)
@ -17,50 +15,78 @@ const (
GeminiVisionMaxImageNum = 16
)
func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if len(response.Candidates) == 0 {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}
}
fullTextResponse := &types.ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Model: response.Model,
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := types.ChatCompletionChoice{
Index: i,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: "",
},
FinishReason: base.StopFinishReason,
}
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = candidate.Content.Parts[0].Text
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
completionTokens := common.CountTokenText(response.GetResponseText(), response.Model)
response.Usage.CompletionTokens = completionTokens
response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens
return fullTextResponse, nil
type geminiStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest, errWithCode *types.OpenAIErrorWithStatusCode) {
func (p *GeminiProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
geminiChatResponse := &GeminiChatResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, geminiChatResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
return p.convertToChatOpenai(geminiChatResponse, request)
}
func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := &geminiStreamHandler{
Usage: p.Usage,
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
}
func (p *GeminiProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url := "generateContent"
if request.Stream {
url = "streamGenerateContent?alt=sse"
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, request.Model)
// 获取请求头
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
}
geminiRequest, errWithCode := convertFromChatOpenai(request)
if errWithCode != nil {
return nil, errWithCode
}
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(geminiRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
return req, nil
}
func convertFromChatOpenai(request *types.ChatCompletionRequest) (*GeminiChatRequest, *types.OpenAIErrorWithStatusCode) {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(request.Messages)),
SafetySettings: []GeminiChatSafetySettings{
@ -160,144 +186,99 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest
return &geminiRequest, nil
}
func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, errWithCode := p.getChatRequestBody(request)
if errWithCode != nil {
func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
error := errorHandle(&response.GeminiErrorResponse)
if error != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *error,
StatusCode: http.StatusBadRequest,
}
return
}
fullRequestURL := p.GetFullRequestURL("generateContent", request.Model)
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
openaiResponse = &types.ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Model: request.Model,
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if request.Stream {
var responseText string
errWithCode, responseText = p.sendStreamRequest(req, request.Model)
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: common.CountTokenText(responseText, request.Model),
}
usage.TotalTokens = promptTokens + usage.CompletionTokens
} else {
var geminiResponse = &GeminiChatResponse{
Model: request.Model,
Usage: &types.Usage{
PromptTokens: promptTokens,
for i, candidate := range response.Candidates {
choice := types.ChatCompletionChoice{
Index: i,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: "",
},
FinishReason: types.FinishReasonStop,
}
errWithCode = p.SendRequest(req, geminiResponse, false)
if errWithCode != nil {
return
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = candidate.Content.Parts[0].Text
}
usage = geminiResponse.Usage
openaiResponse.Choices = append(openaiResponse.Choices, choice)
}
completionTokens := common.CountTokenText(response.GetResponseText(), response.Model)
p.Usage.CompletionTokens = completionTokens
p.Usage.TotalTokens = p.Usage.PromptTokens + completionTokens
openaiResponse.Usage = p.Usage
return
}
// func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse {
// var choice types.ChatCompletionStreamChoice
// choice.Delta.Content = geminiResponse.GetResponseText()
// choice.FinishReason = &base.StopFinishReason
// var response types.ChatCompletionStreamResponse
// response.Object = "chat.completion.chunk"
// response.Model = "gemini"
// response.Choices = []types.ChatCompletionStreamChoice{choice}
// return &response
// }
// 转换为OpenAI聊天流式请求体
func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
}
func (p *GeminiProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) {
defer req.Body.Close()
// 去除前缀
*rawLine = (*rawLine)[6:]
// 发送请求
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
var geminiResponse GeminiChatResponse
err := json.Unmarshal(*rawLine, &geminiResponse)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""
return common.ErrorToOpenAIError(err)
}
defer resp.Body.Close()
error := errorHandle(&geminiResponse.GeminiErrorResponse)
if error != nil {
return error
}
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {
Content string `json:"content"`
}
var dummy dummyStruct
json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = dummy.Content
response := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: model,
Choices: []types.ChatCompletionStreamChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return h.convertToOpenaiStream(&geminiResponse, response)
return nil, responseText
}
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, response *[]types.ChatCompletionStreamResponse) error {
choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates))
for i, candidate := range geminiResponse.Candidates {
choice := types.ChatCompletionStreamChoice{
Index: i,
Delta: types.ChatCompletionStreamChoiceDelta{
Content: candidate.Content.Parts[0].Text,
},
FinishReason: types.FinishReasonStop,
}
choices = append(choices, choice)
}
streamResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: h.Request.Model,
Choices: choices,
}
*response = append(*response, streamResponse)
h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
return nil
}

View File

@ -42,11 +42,22 @@ type GeminiChatGenerationConfig struct {
StopSequences []string `json:"stopSequences,omitempty"`
}
type GeminiError struct {
Code string `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type GeminiErrorResponse struct {
Error GeminiError `json:"error,omitempty"`
}
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
Usage *types.Usage `json:"usage,omitempty"`
Model string `json:"model,omitempty"`
GeminiErrorResponse
}
type GeminiChatCandidate struct {

View File

@ -3,7 +3,6 @@ package openai
import (
"errors"
"fmt"
"one-api/common"
"one-api/model"
"time"
)
@ -16,15 +15,14 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/v1/dashboard/billing/subscription", "")
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var subscription OpenAISubscriptionResponse
_, errWithCode := common.SendRequest(req, &subscription, false, p.Channel.Proxy)
_, errWithCode := p.Requester.SendRequest(req, &subscription, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
@ -37,12 +35,15 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
}
fullRequestURL = p.GetFullRequestURL(fmt.Sprintf("/v1/dashboard/billing/usage?start_date=%s&end_date=%s", startDate, endDate), "")
req, err = client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
req, err = p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
if err != nil {
return 0, err
}
usage := OpenAIUsageResponse{}
_, errWithCode = common.SendRequest(req, &usage, false, p.Channel.Proxy)
_, errWithCode = p.Requester.SendRequest(req, &usage, false)
if errWithCode != nil {
return 0, errWithCode
}
balance := subscription.HardLimitUSD - usage.TotalUsage/100
channel.UpdateBalance(balance)

View File

@ -1,69 +1,100 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/model"
"one-api/types"
"strings"
"one-api/providers/base"
"github.com/gin-gonic/gin"
)
type OpenAIProviderFactory struct{}
// 创建 OpenAIProvider
func (f OpenAIProviderFactory) Create(c *gin.Context) base.ProviderInterface {
openAIProvider := CreateOpenAIProvider(c, "")
openAIProvider.BalanceAction = true
return openAIProvider
}
type OpenAIProvider struct {
base.BaseProvider
IsAzure bool
BalanceAction bool
}
// 创建 OpenAIProvider
func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
openAIProvider := CreateOpenAIProvider(channel, "https://api.openai.com")
openAIProvider.BalanceAction = true
return openAIProvider
}
// 创建 OpenAIProvider
// https://platform.openai.com/docs/api-reference/introduction
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
if baseURL == "" {
baseURL = "https://api.openai.com"
}
func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider {
config := getOpenAIConfig(baseURL)
return &OpenAIProvider{
BaseProvider: base.BaseProvider{
BaseURL: baseURL,
Completions: "/v1/completions",
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
Moderation: "/v1/moderations",
AudioSpeech: "/v1/audio/speech",
AudioTranscriptions: "/v1/audio/transcriptions",
AudioTranslations: "/v1/audio/translations",
ImagesGenerations: "/v1/images/generations",
ImagesEdit: "/v1/images/edits",
ImagesVariations: "/v1/images/variations",
Context: c,
Config: config,
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, RequestErrorHandle),
},
IsAzure: false,
BalanceAction: true,
}
}
func getOpenAIConfig(baseURL string) base.ProviderConfig {
return base.ProviderConfig{
BaseURL: baseURL,
Completions: "/v1/completions",
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
Moderation: "/v1/moderations",
AudioSpeech: "/v1/audio/speech",
AudioTranscriptions: "/v1/audio/transcriptions",
AudioTranslations: "/v1/audio/translations",
ImagesGenerations: "/v1/images/generations",
ImagesEdit: "/v1/images/edits",
ImagesVariations: "/v1/images/variations",
}
}
// 请求错误处理
func RequestErrorHandle(resp *http.Response) *types.OpenAIError {
var errorResponse *types.OpenAIErrorResponse
err := json.NewDecoder(resp.Body).Decode(errorResponse)
if err != nil {
return nil
}
return ErrorHandle(errorResponse)
}
// 错误处理
func ErrorHandle(openaiError *types.OpenAIErrorResponse) *types.OpenAIError {
if openaiError.Error.Message == "" {
return nil
}
return &openaiError.Error
}
// 获取完整请求 URL
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
if p.IsAzure {
apiVersion := p.Channel.Other
// 以-分割检测modelName 最后一个元素是否为4位数字,必须是数字如果是则删除modelName最后一个元素
modelNameSlice := strings.Split(modelName, "-")
lastModelNameSlice := modelNameSlice[len(modelNameSlice)-1]
modelNum := common.String2Int(lastModelNameSlice)
if modelNum > 999 && modelNum < 10000 {
modelName = strings.TrimSuffix(modelName, "-"+lastModelNameSlice)
}
// 检测模型是是否包含 . 如果有则直接去掉
modelName = strings.Replace(modelName, ".", "", -1)
if modelName == "dall-e-2" {
// 因为dall-e-3需要api-version=2023-12-01-preview但是该版本
// 已经没有dall-e-2了所以暂时写死
@ -72,10 +103,6 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
}
// 检测模型是是否包含 . 如果有则直接去掉
if strings.Contains(requestURL, ".") {
requestURL = strings.Replace(requestURL, ".", "", -1)
}
}
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
@ -102,89 +129,21 @@ func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
return headers
}
// 获取请求体
func (p *OpenAIProvider) GetRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) {
if isModelMapped {
jsonStr, err := json.Marshal(request)
if err != nil {
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = p.Context.Request.Body
func (p *OpenAIProvider) GetRequestTextBody(relayMode int, ModelName string, request any) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(relayMode)
if errWithCode != nil {
return nil, errWithCode
}
return
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, ModelName)
// 发送流式请求
func (p *OpenAIProvider) SendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
defer req.Body.Close()
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
// 获取请求头
headers := p.GetRequestHeaders()
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(request), p.Requester.WithHeader(headers))
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
err := json.Unmarshal([]byte(data), response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue // just ignore the error
}
responseText += response.responseStreamHandler()
}
}
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
p.Context.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
}
})
return nil, responseText
return req, nil
}

View File

@ -1,82 +1,101 @@
package openai
import (
"encoding/json"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"strings"
)
func (c *OpenAIProviderChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
type OpenAIStreamHandler struct {
Usage *types.Usage
ModelName string
}
func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderChatResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
// 检测是否错误
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return
}
return nil, nil
}
func (c *OpenAIProviderChatStreamResponse) responseStreamHandler() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Delta.Content
return nil, errWithCode
}
return
*p.Usage = *response.Usage
return &response.ChatCompletionResponse, nil
}
func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.GetRequestBody(&request, isModelMapped)
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := OpenAIStreamHandler{
Usage: p.Usage,
ModelName: request.Model,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream)
}
func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
}
// 去除前缀
*rawLine = (*rawLine)[6:]
// 如果等于 DONE 则结束
if string(*rawLine) == "[DONE]" {
*isFinished = true
return nil
}
var openaiResponse OpenAIProviderChatStreamResponse
err := json.Unmarshal(*rawLine, &openaiResponse)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
return common.ErrorToOpenAIError(err)
}
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream && headers["Accept"] == "" {
headers["Accept"] = "text/event-stream"
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
if error != nil {
return error
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
if request.Stream {
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
var textResponse string
errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderChatStreamResponse)
if errWithCode != nil {
return
}
*response = append(*response, openaiResponse.ChatCompletionStreamResponse)
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: common.CountTokenText(textResponse, request.Model),
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
}
} else {
openAIProviderChatResponse := &OpenAIProviderChatResponse{}
errWithCode = p.SendRequest(req, openAIProviderChatResponse, true)
if errWithCode != nil {
return
}
usage = openAIProviderChatResponse.Usage
if usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range openAIProviderChatResponse.Choices {
completionTokens += common.CountTokenText(choice.Message.StringContent(), openAIProviderChatResponse.Model)
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
}
return
return nil
}

View File

@ -1,82 +1,96 @@
package openai
import (
"encoding/json"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"strings"
)
func (c *OpenAIProviderCompletionResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (openaiResponse *types.CompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderCompletionResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
// 检测是否错误
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return
}
return nil, nil
}
func (c *OpenAIProviderCompletionResponse) responseStreamHandler() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Text
return nil, errWithCode
}
return
*p.Usage = *response.Usage
return &response.CompletionResponse, nil
}
func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.GetRequestBody(&request, isModelMapped)
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[types.CompletionResponse], errWithCode *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := OpenAIStreamHandler{
Usage: p.Usage,
ModelName: request.Model,
}
return requester.RequestStream[types.CompletionResponse](p.Requester, resp, chatHandler.handlerCompletionStream)
}
func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinished *bool, response *[]types.CompletionResponse) error {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
}
// 去除前缀
*rawLine = (*rawLine)[6:]
// 如果等于 DONE 则结束
if string(*rawLine) == "[DONE]" {
*isFinished = true
return nil
}
var openaiResponse OpenAIProviderCompletionResponse
err := json.Unmarshal(*rawLine, &openaiResponse)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
return common.ErrorToOpenAIError(err)
}
fullRequestURL := p.GetFullRequestURL(p.Completions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream && headers["Accept"] == "" {
headers["Accept"] = "text/event-stream"
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
if error != nil {
return error
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
openAIProviderCompletionResponse := &OpenAIProviderCompletionResponse{}
if request.Stream {
// TODO
var textResponse string
errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderCompletionResponse)
if errWithCode != nil {
return
}
*response = append(*response, openaiResponse.CompletionResponse)
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: common.CountTokenText(textResponse, request.Model),
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
}
} else {
errWithCode = p.SendRequest(req, openAIProviderCompletionResponse, true)
if errWithCode != nil {
return
}
usage = openAIProviderCompletionResponse.Usage
if usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range openAIProviderCompletionResponse.Choices {
completionTokens += common.CountTokenText(choice.Text, openAIProviderCompletionResponse.Model)
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
}
return
return nil
}

View File

@ -6,40 +6,30 @@ import (
"one-api/types"
)
func (c *OpenAIProviderEmbeddingsResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
return
}
return nil, nil
}
func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.GetRequestBody(&request, isModelMapped)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{}
errWithCode = p.SendRequest(req, openAIProviderEmbeddingsResponse, true)
func (p *OpenAIProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeEmbeddings, request.Model, request)
if errWithCode != nil {
return
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderEmbeddingsResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
usage = openAIProviderEmbeddingsResponse.Usage
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return nil, errWithCode
}
return
*p.Usage = *response.Usage
return &response.EmbeddingResponse, nil
}

View File

@ -5,28 +5,71 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
)
func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
fullRequestURL := p.GetFullRequestURL(p.ImagesEdit, request.Model)
func (p *OpenAIProvider) CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getRequestImageBody(common.RelayModeEdits, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderImageResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return nil, errWithCode
}
p.Usage.TotalTokens = p.Usage.PromptTokens
return &response.ImageResponse, nil
}
func (p *OpenAIProvider) getRequestImageBody(relayMode int, ModelName string, request *types.ImageEditRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(relayMode)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, ModelName)
// 获取请求头
headers := p.GetRequestHeaders()
client := common.NewClient()
var formBody bytes.Buffer
// 创建请求
var req *http.Request
var err error
if isModelMapped {
builder := client.CreateFormBuilder(&formBody)
if p.OriginalModel != request.Model {
var formBody bytes.Buffer
builder := p.Requester.CreateFormBuilder(&formBody)
if err := imagesEditsMultipartForm(request, builder); err != nil {
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
}
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
req, err = p.Requester.NewRequest(
http.MethodPost,
fullRequestURL,
p.Requester.WithBody(&formBody),
p.Requester.WithHeader(headers),
p.Requester.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
} else {
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req, err = p.Requester.NewRequest(
http.MethodPost,
fullRequestURL,
p.Requester.WithBody(p.Context.Request.Body),
p.Requester.WithHeader(headers),
p.Requester.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req.ContentLength = p.Context.Request.ContentLength
}
@ -34,22 +77,10 @@ func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isMod
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
}
return
return req, nil
}
func imagesEditsMultipartForm(request *types.ImageEditRequest, b common.FormBuilder) error {
func imagesEditsMultipartForm(request *types.ImageEditRequest, b requester.FormBuilder) error {
err := b.CreateFormFile("image", request.Image)
if err != nil {
return fmt.Errorf("creating form image: %w", err)

View File

@ -6,53 +6,41 @@ import (
"one-api/types"
)
func (c *OpenAIProviderImageResponseResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
return
}
return nil, nil
}
func (p *OpenAIProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
if !isWithinRange(request.Model, request.N) {
func (p *OpenAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
if !IsWithinRange(request.Model, request.N) {
return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest)
}
requestBody, err := p.GetRequestBody(&request, isModelMapped)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.ImagesGenerations, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request)
if errWithCode != nil {
return
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderImageResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
// 检测是否错误
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return nil, errWithCode
}
return
p.Usage.TotalTokens = p.Usage.PromptTokens
return &response.ImageResponse, nil
}
func isWithinRange(element string, value int) bool {
func IsWithinRange(element string, value int) bool {
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
return false
}

View File

@ -1,49 +1,35 @@
package openai
import (
"bytes"
"net/http"
"one-api/common"
"one-api/types"
)
func (p *OpenAIProvider) ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
fullRequestURL := p.GetFullRequestURL(p.ImagesVariations, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
var formBody bytes.Buffer
var req *http.Request
var err error
if isModelMapped {
builder := client.CreateFormBuilder(&formBody)
if err := imagesEditsMultipartForm(request, builder); err != nil {
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
}
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
} else {
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req.ContentLength = p.Context.Request.ContentLength
}
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
func (p *OpenAIProvider) CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getRequestImageBody(common.RelayModeImagesVariations, request.Model, request)
if errWithCode != nil {
return
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderImageResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return nil, errWithCode
}
return
p.Usage.TotalTokens = p.Usage.PromptTokens
return &response.ImageResponse, nil
}

View File

@ -6,44 +6,31 @@ import (
"one-api/types"
)
func (c *OpenAIProviderModerationResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
return
}
return nil, nil
}
func (p *OpenAIProvider) CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode) {
func (p *OpenAIProvider) ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.GetRequestBody(&request, isModelMapped)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.Moderation, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderModerationResponse := &OpenAIProviderModerationResponse{}
errWithCode = p.SendRequest(req, openAIProviderModerationResponse, true)
req, errWithCode := p.GetRequestTextBody(common.RelayModeModerations, request.Model, request)
if errWithCode != nil {
return
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderModerationResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return nil, errWithCode
}
return
p.Usage.TotalTokens = p.Usage.PromptTokens
return &response.ModerationResponse, nil
}

View File

@ -3,35 +3,29 @@ package openai
import (
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
)
func (p *OpenAIProvider) SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.GetRequestBody(&request, isModelMapped)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.AudioSpeech, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
errWithCode = p.SendRequestRaw(req)
func (p *OpenAIProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeAudioSpeech, request.Model, request)
if errWithCode != nil {
return
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
var resp *http.Response
resp, errWithCode = p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
if resp.Header.Get("Content-Type") == "application/json" {
return nil, requester.HandleErrorResp(resp, p.Requester.ErrorHandler)
}
return
p.Usage.TotalTokens = p.Usage.PromptTokens
return resp, nil
}

View File

@ -4,48 +4,99 @@ import (
"bufio"
"bytes"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"regexp"
"strconv"
"strings"
)
func (c *OpenAIProviderTranscriptionsResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
return
func (p *OpenAIProvider) CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranscription, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
return nil, nil
defer req.Body.Close()
var textResponse string
var resp *http.Response
var err error
audioResponseWrapper := &types.AudioResponseWrapper{}
if hasJSONResponse(request) {
openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{}
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsResponse, true)
if errWithCode != nil {
return nil, errWithCode
}
textResponse = openAIProviderTranscriptionsResponse.Text
} else {
openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse)
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
if errWithCode != nil {
return nil, errWithCode
}
textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat)
}
defer resp.Body.Close()
audioResponseWrapper.Headers = map[string]string{
"Content-Type": resp.Header.Get("Content-Type"),
}
audioResponseWrapper.Body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, common.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
completionTokens := common.CountTokenText(textResponse, request.Model)
p.Usage.CompletionTokens = completionTokens
p.Usage.TotalTokens = p.Usage.PromptTokens + p.Usage.CompletionTokens
return audioResponseWrapper, nil
}
func (c *OpenAIProviderTranscriptionsTextResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
return nil, nil
func hasJSONResponse(request *types.AudioRequest) bool {
return request.ResponseFormat == "" || request.ResponseFormat == "json" || request.ResponseFormat == "verbose_json"
}
func (p *OpenAIProvider) TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
fullRequestURL := p.GetFullRequestURL(p.AudioTranscriptions, request.Model)
func (p *OpenAIProvider) getRequestAudioBody(relayMode int, ModelName string, request *types.AudioRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(relayMode)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, ModelName)
// 获取请求头
headers := p.GetRequestHeaders()
client := common.NewClient()
var formBody bytes.Buffer
// 创建请求
var req *http.Request
var err error
if isModelMapped {
builder := client.CreateFormBuilder(&formBody)
if p.OriginalModel != request.Model {
var formBody bytes.Buffer
builder := p.Requester.CreateFormBuilder(&formBody)
if err := audioMultipartForm(request, builder); err != nil {
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
}
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
req, err = p.Requester.NewRequest(
http.MethodPost,
fullRequestURL,
p.Requester.WithBody(&formBody),
p.Requester.WithHeader(headers),
p.Requester.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
} else {
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req, err = p.Requester.NewRequest(
http.MethodPost,
fullRequestURL,
p.Requester.WithBody(p.Context.Request.Body),
p.Requester.WithHeader(headers),
p.Requester.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req.ContentLength = p.Context.Request.ContentLength
}
@ -53,37 +104,10 @@ func (p *OpenAIProvider) TranscriptionsAction(request *types.AudioRequest, isMod
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
var textResponse string
if hasJSONResponse(request) {
openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{}
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsResponse, true)
if errWithCode != nil {
return
}
textResponse = openAIProviderTranscriptionsResponse.Text
} else {
openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse)
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
if errWithCode != nil {
return
}
textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat)
}
completionTokens := common.CountTokenText(textResponse, request.Model)
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
return
return req, nil
}
func hasJSONResponse(request *types.AudioRequest) bool {
return request.ResponseFormat == "" || request.ResponseFormat == "json" || request.ResponseFormat == "verbose_json"
}
func audioMultipartForm(request *types.AudioRequest, b common.FormBuilder) error {
func audioMultipartForm(request *types.AudioRequest, b requester.FormBuilder) error {
err := b.CreateFormFile("file", request.File)
if err != nil {

View File

@ -1,60 +1,53 @@
package openai
import (
"bytes"
"io"
"net/http"
"one-api/common"
"one-api/types"
)
func (p *OpenAIProvider) TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
fullRequestURL := p.GetFullRequestURL(p.AudioTranslations, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
var formBody bytes.Buffer
var req *http.Request
var err error
if isModelMapped {
builder := client.CreateFormBuilder(&formBody)
if err := audioMultipartForm(request, builder); err != nil {
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
}
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
} else {
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req.ContentLength = p.Context.Request.ContentLength
}
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
func (p *OpenAIProvider) CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranslation, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
var textResponse string
var resp *http.Response
var err error
audioResponseWrapper := &types.AudioResponseWrapper{}
if hasJSONResponse(request) {
openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{}
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsResponse, true)
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsResponse, true)
if errWithCode != nil {
return
return nil, errWithCode
}
textResponse = openAIProviderTranscriptionsResponse.Text
} else {
openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse)
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
if errWithCode != nil {
return
return nil, errWithCode
}
textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat)
}
defer resp.Body.Close()
audioResponseWrapper.Headers = map[string]string{
"Content-Type": resp.Header.Get("Content-Type"),
}
audioResponseWrapper.Body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, common.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
completionTokens := common.CountTokenText(textResponse, request.Model)
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
return
p.Usage.CompletionTokens = completionTokens
p.Usage.TotalTokens = p.Usage.PromptTokens + p.Usage.CompletionTokens
return audioResponseWrapper, nil
}

View File

@ -12,11 +12,27 @@ type OpenAIProviderChatStreamResponse struct {
types.OpenAIErrorResponse
}
func (c *OpenAIProviderChatStreamResponse) getResponseText() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Delta.Content
}
return
}
type OpenAIProviderCompletionResponse struct {
types.CompletionResponse
types.OpenAIErrorResponse
}
func (c *OpenAIProviderCompletionResponse) getResponseText() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Text
}
return
}
type OpenAIProviderEmbeddingsResponse struct {
types.EmbeddingResponse
types.OpenAIErrorResponse
@ -38,7 +54,7 @@ func (a *OpenAIProviderTranscriptionsTextResponse) GetString() *string {
return (*string)(a)
}
type OpenAIProviderImageResponseResponse struct {
type OpenAIProviderImageResponse struct {
types.ImageResponse
types.OpenAIErrorResponse
}

View File

@ -3,7 +3,6 @@ package openaisb
import (
"errors"
"fmt"
"one-api/common"
"one-api/model"
"strconv"
)
@ -13,15 +12,14 @@ func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response OpenAISBUsageResponse
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
_, errWithCode := p.Requester.SendRequest(req, &response, false)
if err != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}

View File

@ -1,18 +1,17 @@
package openaisb
import (
"one-api/model"
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type OpenaiSBProviderFactory struct{}
// 创建 OpenaiSBProvider
func (f OpenaiSBProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f OpenaiSBProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &OpenaiSBProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.openai-sb.com"),
OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.openai-sb.com"),
}
}

View File

@ -1,22 +1,25 @@
package palm
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type PalmProviderFactory struct{}
// 创建 PalmProvider
func (f PalmProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f PalmProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &PalmProvider{
BaseProvider: base.BaseProvider{
BaseURL: "https://generativelanguage.googleapis.com",
ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage",
Context: c,
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
},
}
}
@ -25,6 +28,37 @@ type PalmProvider struct {
base.BaseProvider
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://generativelanguage.googleapis.com",
ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage",
}
}
// 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
var palmError *PaLMErrorResponse
err := json.NewDecoder(resp.Body).Decode(palmError)
if err != nil {
return nil
}
return errorHandle(palmError)
}
// 错误处理
func errorHandle(palmError *PaLMErrorResponse) *types.OpenAIError {
if palmError.Error.Code == 0 {
return nil
}
return &types.OpenAIError{
Message: palmError.Error.Message,
Type: "palm_error",
Param: palmError.Error.Status,
Code: palmError.Error.Code,
}
}
// 获取请求头
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)

View File

@ -3,30 +3,98 @@ package palm
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/providers/base"
"one-api/common/requester"
"one-api/types"
"strings"
)
func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",
Code: palmResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}
type palmStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
}
func (p *PalmProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
palmResponse := &PaLMChatResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, palmResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
fullTextResponse := types.ChatCompletionResponse{
Choices: make([]types.ChatCompletionChoice, 0, len(palmResponse.Candidates)),
return p.convertToChatOpenai(palmResponse, request)
}
func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
for i, candidate := range palmResponse.Candidates {
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := &palmStreamHandler{
Usage: p.Usage,
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
}
func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, request.Model)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
}
// 获取请求头
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
}
palmRequest := convertFromChatOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(palmRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
return req, nil
}
func (p *PalmProvider) convertToChatOpenai(response *PaLMChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
error := errorHandle(&response.PaLMErrorResponse)
if error != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *error,
StatusCode: http.StatusBadRequest,
}
return
}
openaiResponse = &types.ChatCompletionResponse{
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
Model: request.Model,
}
for i, candidate := range response.Candidates {
choice := types.ChatCompletionChoice{
Index: i,
Message: types.ChatCompletionMessage{
@ -35,20 +103,21 @@ func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (Open
},
FinishReason: "stop",
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
openaiResponse.Choices = append(openaiResponse.Choices, choice)
}
completionTokens := common.CountTokenText(palmResponse.Candidates[0].Content, palmResponse.Model)
palmResponse.Usage.CompletionTokens = completionTokens
palmResponse.Usage.TotalTokens = palmResponse.Usage.PromptTokens + completionTokens
completionTokens := common.CountTokenText(response.Candidates[0].Content, request.Model)
response.Usage.CompletionTokens = completionTokens
response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens
fullTextResponse.Usage = palmResponse.Usage
fullTextResponse.Model = palmResponse.Model
openaiResponse.Usage = response.Usage
return fullTextResponse, nil
*p.Usage = *response.Usage
return
}
func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest) *PaLMChatRequest {
func convertFromChatOpenai(request *types.ChatCompletionRequest) *PaLMChatRequest {
palmRequest := PaLMChatRequest{
Prompt: PaLMPrompt{
Messages: make([]PaLMChatMessage, 0, len(request.Messages)),
@ -72,132 +141,51 @@ func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest)
return &palmRequest
}
func (p *PalmProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
// 转换为OpenAI聊天流式请求体
func (h *palmStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
// 去除前缀
*rawLine = (*rawLine)[6:]
var palmChatResponse PaLMChatResponse
err := json.Unmarshal(*rawLine, &palmChatResponse)
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
return common.ErrorToOpenAIError(err)
}
if request.Stream {
var responseText string
errWithCode, responseText = p.sendStreamRequest(req)
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: common.CountTokenText(responseText, request.Model),
}
usage.TotalTokens = promptTokens + usage.CompletionTokens
} else {
var palmChatResponse = &PaLMChatResponse{
Model: request.Model,
Usage: &types.Usage{
PromptTokens: promptTokens,
},
}
errWithCode = p.SendRequest(req, palmChatResponse, false)
if errWithCode != nil {
return
}
usage = palmChatResponse.Usage
error := errorHandle(&palmChatResponse.PaLMErrorResponse)
if error != nil {
return error
}
return
return h.convertToOpenaiStream(&palmChatResponse, response)
}
func (p *PalmProvider) streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *types.ChatCompletionStreamResponse {
func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, response *[]types.ChatCompletionStreamResponse) error {
var choice types.ChatCompletionStreamChoice
if len(palmResponse.Candidates) > 0 {
choice.Delta.Content = palmResponse.Candidates[0].Content
if len(palmChatResponse.Candidates) > 0 {
choice.Delta.Content = palmChatResponse.Candidates[0].Content
}
choice.FinishReason = &base.StopFinishReason
var response types.ChatCompletionStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "palm2"
response.Choices = []types.ChatCompletionStreamChoice{choice}
return &response
}
func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
defer req.Body.Close()
// 发送请求
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""
}
defer resp.Body.Close()
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
}
err = resp.Body.Close()
if err != nil {
common.SysError("error closing stream response: " + err.Error())
stopChan <- true
return
}
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
fullTextResponse := p.streamResponsePaLM2OpenAI(&palmResponse)
fullTextResponse.ID = responseId
fullTextResponse.Created = createdTime
if len(palmResponse.Candidates) > 0 {
responseText = palmResponse.Candidates[0].Content
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
dataChan <- string(jsonResponse)
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: " + data})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, responseText
choice.FinishReason = types.FinishReasonStop
streamResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Model: h.Request.Model,
Choices: []types.ChatCompletionStreamChoice{choice},
Created: common.GetTimestamp(),
}
*response = append(*response, streamResponse)
h.Usage.CompletionTokens += common.CountTokenText(palmChatResponse.Candidates[0].Content, h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
return nil
}

View File

@ -30,11 +30,15 @@ type PaLMError struct {
Status string `json:"status"`
}
type PaLMErrorResponse struct {
Error PaLMError `json:"error,omitempty"`
}
type PaLMChatResponse struct {
Candidates []PaLMChatMessage `json:"candidates"`
Messages []types.ChatCompletionMessage `json:"messages"`
Filters []PaLMFilter `json:"filters"`
Error PaLMError `json:"error"`
Usage *types.Usage `json:"usage,omitempty"`
Model string `json:"model,omitempty"`
PaLMErrorResponse
}

View File

@ -28,7 +28,7 @@ import (
// 定义供应商工厂接口
type ProviderFactory interface {
Create(c *gin.Context) base.ProviderInterface
Create(Channel *model.Channel) base.ProviderInterface
}
// 创建全局的供应商工厂映射
@ -71,11 +71,11 @@ func GetProvider(channel *model.Channel, c *gin.Context) base.ProviderInterface
return nil
}
provider = openai.CreateOpenAIProvider(c, baseURL)
provider = openai.CreateOpenAIProvider(channel, baseURL)
} else {
provider = factory.Create(c)
provider = factory.Create(channel)
}
provider.SetChannel(channel)
provider.SetContext(c)
return provider
}

View File

@ -4,25 +4,28 @@ import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/types"
"sort"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type TencentProviderFactory struct{}
// 创建 TencentProvider
func (f TencentProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f TencentProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &TencentProvider{
BaseProvider: base.BaseProvider{
BaseURL: "https://hunyuan.cloud.tencent.com",
ChatCompletions: "/hyllm/v1/chat/completions",
Context: c,
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
},
}
}
@ -31,6 +34,36 @@ type TencentProvider struct {
base.BaseProvider
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://hunyuan.cloud.tencent.com",
ChatCompletions: "/hyllm/v1/chat/completions",
}
}
// 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
var tencentError *TencentResponseError
err := json.NewDecoder(resp.Body).Decode(tencentError)
if err != nil {
return nil
}
return errorHandle(tencentError)
}
// 错误处理
func errorHandle(tencentError *TencentResponseError) *types.OpenAIError {
if tencentError.Error.Code == 0 {
return nil
}
return &types.OpenAIError{
Message: tencentError.Error.Message,
Type: "tencent_error",
Code: tencentError.Error.Code,
}
}
// 获取请求头
func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
@ -77,7 +110,7 @@ func (p *TencentProvider) getTencentSign(req TencentChatRequest) string {
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")
sort.Sort(sort.StringSlice(params))
sort.Strings(params)
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url

View File

@ -1,50 +1,127 @@
package tencent
import (
"bufio"
"encoding/json"
"errors"
"io"
"net/http"
"one-api/common"
"one-api/providers/base"
"one-api/common/requester"
"one-api/types"
"strings"
)
func (TencentResponse *TencentChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if TencentResponse.Error.Code != 0 {
return &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
type tencentStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
}
func (p *TencentProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
tencentChatResponse := &TencentChatResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, tencentChatResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
fullTextResponse := types.ChatCompletionResponse{
return p.convertToChatOpenai(tencentChatResponse, request)
}
func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := &tencentStreamHandler{
Usage: p.Usage,
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
}
func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
tencentRequest := convertFromChatOpenai(request)
sign := p.getTencentSign(*tencentRequest)
if sign == "" {
return nil, common.ErrorWrapper(errors.New("get tencent sign failed"), "get_tencent_sign_failed", http.StatusInternalServerError)
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, request.Model)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
}
// 获取请求头
headers := p.GetRequestHeaders()
headers["Authorization"] = sign
if request.Stream {
headers["Accept"] = "text/event-stream"
}
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(tencentRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
return req, nil
}
func (p *TencentProvider) convertToChatOpenai(response *TencentChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
error := errorHandle(&response.TencentResponseError)
if error != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *error,
StatusCode: http.StatusBadRequest,
}
return
}
openaiResponse = &types.ChatCompletionResponse{
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: TencentResponse.Usage,
Model: TencentResponse.Model,
Usage: response.Usage,
Model: request.Model,
}
if len(TencentResponse.Choices) > 0 {
if len(response.Choices) > 0 {
choice := types.ChatCompletionChoice{
Index: 0,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: TencentResponse.Choices[0].Messages.Content,
Content: response.Choices[0].Messages.Content,
},
FinishReason: TencentResponse.Choices[0].FinishReason,
FinishReason: response.Choices[0].FinishReason,
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
openaiResponse.Choices = append(openaiResponse.Choices, choice)
}
return fullTextResponse, nil
*p.Usage = *response.Usage
return
}
func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionRequest) *TencentChatRequest {
func convertFromChatOpenai(request *types.ChatCompletionRequest) *TencentChatRequest {
messages := make([]TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
@ -79,143 +156,51 @@ func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionReques
}
}
func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
sign := p.getTencentSign(*requestBody)
if sign == "" {
return nil, common.ErrorWrapper(errors.New("get tencent sign failed"), "get_tencent_sign_failed", http.StatusInternalServerError)
// 转换为OpenAI聊天流式请求体
func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data:") {
*rawLine = nil
return nil
}
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
headers["Authorization"] = sign
if request.Stream {
headers["Accept"] = "text/event-stream"
}
// 去除前缀
*rawLine = (*rawLine)[5:]
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
var tencentChatResponse TencentChatResponse
err := json.Unmarshal(*rawLine, &tencentChatResponse)
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
return common.ErrorToOpenAIError(err)
}
if request.Stream {
var responseText string
errWithCode, responseText = p.sendStreamRequest(req, request.Model)
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: common.CountTokenText(responseText, request.Model),
}
usage.TotalTokens = promptTokens + usage.CompletionTokens
} else {
tencentResponse := &TencentChatResponse{
Model: request.Model,
}
errWithCode = p.SendRequest(req, tencentResponse, false)
if errWithCode != nil {
return
}
usage = tencentResponse.Usage
error := errorHandle(&tencentChatResponse.TencentResponseError)
if error != nil {
return error
}
return
return h.convertToOpenaiStream(&tencentChatResponse, response)
}
func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *types.ChatCompletionStreamResponse {
response := types.ChatCompletionStreamResponse{
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, response *[]types.ChatCompletionStreamResponse) error {
streamResponse := types.ChatCompletionStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: TencentResponse.Model,
Model: h.Request.Model,
}
if len(TencentResponse.Choices) > 0 {
if len(tencentChatResponse.Choices) > 0 {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &base.StopFinishReason
choice.Delta.Content = tencentChatResponse.Choices[0].Delta.Content
if tencentChatResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = types.FinishReasonStop
}
response.Choices = append(response.Choices, choice)
streamResponse.Choices = append(streamResponse.Choices, choice)
}
return &response
}
func (p *TencentProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) {
defer req.Body.Close()
// 发送请求
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""
}
defer resp.Body.Close()
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var TencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
TencentResponse.Model = model
response := p.streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, responseText
*response = append(*response, streamResponse)
h.Usage.CompletionTokens += common.CountTokenText(tencentChatResponse.Choices[0].Delta.Content, h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
return nil
}

View File

@ -50,13 +50,17 @@ type TencentResponseChoices struct {
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentResponseError struct {
Error TencentError `json:"error,omitempty"`
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage *types.Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
Model string `json:"model,omitempty"` // 模型名称
TencentResponseError
}

View File

@ -7,31 +7,53 @@ import (
"fmt"
"net/url"
"one-api/common"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/types"
"strings"
"time"
"github.com/gin-gonic/gin"
)
type XunfeiProviderFactory struct{}
// 创建 XunfeiProvider
func (f XunfeiProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f XunfeiProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &XunfeiProvider{
BaseProvider: base.BaseProvider{
BaseURL: "wss://spark-api.xf-yun.com",
ChatCompletions: "true",
Context: c,
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, nil),
},
wsRequester: requester.NewWSRequester(channel.Proxy),
}
}
// https://www.xfyun.cn/doc/spark/Web.html
type XunfeiProvider struct {
base.BaseProvider
domain string
apiId string
domain string
apiId string
wsRequester *requester.WSRequester
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "wss://spark-api.xf-yun.com",
ChatCompletions: "/",
}
}
// 错误处理
func errorHandle(xunfeiError *XunfeiChatResponse) *types.OpenAIError {
if xunfeiError.Header.Code == 0 {
return nil
}
return &types.OpenAIError{
Message: xunfeiError.Header.Message,
Type: "xunfei_error",
Code: xunfeiError.Header.Code,
}
}
// 获取请求头
@ -68,7 +90,7 @@ func (p *XunfeiProvider) getXunfeiAuthUrl(apiKey string, apiSecret string) (stri
if apiVersion != "v1.1" {
domain += strings.Split(apiVersion, ".")[0]
}
authUrl := p.buildXunfeiAuthUrl(fmt.Sprintf("%s/%s/chat", p.BaseURL, apiVersion), apiKey, apiSecret)
authUrl := p.buildXunfeiAuthUrl(fmt.Sprintf("%s/%s/chat", p.Config.BaseURL, apiVersion), apiKey, apiSecret)
return domain, authUrl
}

View File

@ -2,140 +2,93 @@ package xunfei
import (
"encoding/json"
"fmt"
"errors"
"io"
"net/http"
"one-api/common"
"one-api/providers/base"
"one-api/common/requester"
"one-api/types"
"time"
"strings"
"github.com/gorilla/websocket"
)
func (p *XunfeiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model)
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
if err != nil {
return nil, common.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
}
if request.Stream {
return p.sendStreamRequest(dataChan, stopChan, request.GetFunctionCate())
} else {
return p.sendRequest(dataChan, stopChan, request.GetFunctionCate())
}
type xunfeiHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
}
func (p *XunfeiProvider) sendRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
var content string
var xunfeiResponse XunfeiChatResponse
stop := false
for !stop {
select {
case xunfeiResponse = <-dataChan:
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
continue
}
content += xunfeiResponse.Payload.Choices.Text[0].Content
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
case stop = <-stopChan:
}
}
if xunfeiResponse.Header.Code != 0 {
return nil, common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError)
}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}
xunfeiResponse.Payload.Choices.Text[0].Content = content
response := p.responseXunfei2OpenAI(&xunfeiResponse, functionCate)
jsonResponse, err := json.Marshal(response)
if err != nil {
return nil, common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
p.Context.Writer.Header().Set("Content-Type", "application/json")
_, _ = p.Context.Writer.Write(jsonResponse)
return usage, nil
}
func (p *XunfeiProvider) sendStreamRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
// 等待第一个dataChan的响应
xunfeiResponse, ok := <-dataChan
if !ok {
return nil, common.ErrorWrapper(fmt.Errorf("xunfei response channel closed"), "xunfei_response_error", http.StatusInternalServerError)
}
if xunfeiResponse.Header.Code != 0 {
errWithCode = common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError)
func (p *XunfeiProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
wsConn, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
// 如果第一个响应没有错误设置StreamHeaders并开始streaming
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
// 处理第一个响应
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
xunfeiRequest := p.convertFromChatOpenai(request)
chatHandler := &xunfeiHandler{
Usage: p.Usage,
Request: request,
}
stream, errWithCode := requester.SendWSJsonRequest[XunfeiChatResponse](wsConn, xunfeiRequest, chatHandler.handlerNotStream)
if errWithCode != nil {
return nil, errWithCode
}
return chatHandler.convertToChatOpenai(stream)
// 处理后续的响应
for {
select {
case xunfeiResponse, ok := <-dataChan:
if !ok {
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
}
})
return usage, nil
}
func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionRequest) *XunfeiChatRequest {
func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
wsConn, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
xunfeiRequest := p.convertFromChatOpenai(request)
chatHandler := &xunfeiHandler{
Usage: p.Usage,
Request: request,
}
return requester.SendWSJsonRequest[types.ChatCompletionStreamResponse](wsConn, xunfeiRequest, chatHandler.handlerStream)
}
func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
authUrl := p.GetFullRequestURL(url, request.Model)
wsConn, err := p.wsRequester.NewRequest(authUrl, nil)
if err != nil {
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
}
return wsConn, nil
}
func (p *XunfeiProvider) convertFromChatOpenai(request *types.ChatCompletionRequest) *XunfeiChatRequest {
messages := make([]XunfeiMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, XunfeiMessage{
Role: "user",
Role: types.ChatMessageRoleUser,
Content: message.StringContent(),
})
messages = append(messages, XunfeiMessage{
Role: "assistant",
Role: types.ChatMessageRoleAssistant,
Content: "Okay",
})
} else if message.Role == types.ChatMessageRoleFunction {
messages = append(messages, XunfeiMessage{
Role: types.ChatMessageRoleUser,
Content: "这是函数调用返回的内容,请回答之前的问题:\n" + message.StringContent(),
})
} else {
messages = append(messages, XunfeiMessage{
Role: message.Role,
@ -143,6 +96,7 @@ func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionReque
})
}
}
xunfeiRequest := XunfeiChatRequest{}
if request.Tools != nil {
@ -166,35 +120,57 @@ func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionReque
return &xunfeiRequest
}
func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse, functionCate string) *types.ChatCompletionResponse {
if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterface[XunfeiChatResponse]) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
var content string
var xunfeiResponse XunfeiChatResponse
for {
response, err := stream.Recv()
if err != nil && !errors.Is(err, io.EOF) {
return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError)
}
if errors.Is(err, io.EOF) && response == nil {
break
}
if len((*response)[0].Payload.Choices.Text) == 0 {
continue
}
xunfeiResponse = (*response)[0]
content += xunfeiResponse.Payload.Choices.Text[0].Content
}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}
xunfeiResponse.Payload.Choices.Text[0].Content = content
choice := types.ChatCompletionChoice{
Index: 0,
FinishReason: base.StopFinishReason,
FinishReason: types.FinishReasonStop,
}
xunfeiText := response.Payload.Choices.Text[0]
xunfeiText := xunfeiResponse.Payload.Choices.Text[0]
if xunfeiText.FunctionCall != nil {
choice.Message = types.ChatCompletionMessage{
Role: "assistant",
}
if functionCate == "tool" {
if h.Request.Tools != nil {
choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: response.Header.Sid,
Id: xunfeiResponse.Header.Sid,
Type: "function",
Function: *xunfeiText.FunctionCall,
Function: xunfeiText.FunctionCall,
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
choice.FinishReason = types.FinishReasonToolCalls
} else {
choice.Message.FunctionCall = xunfeiText.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
choice.FinishReason = types.FinishReasonFunctionCall
}
} else {
@ -204,97 +180,128 @@ func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse, fun
}
}
fullTextResponse := types.ChatCompletionResponse{
ID: response.Header.Sid,
fullTextResponse := &types.ChatCompletionResponse{
ID: xunfeiResponse.Header.Sid,
Object: "chat.completion",
Model: "SparkDesk",
Model: h.Request.Model,
Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice},
Usage: &response.Payload.Usage.Text,
Usage: &xunfeiResponse.Payload.Usage.Text,
}
return &fullTextResponse
return fullTextResponse, nil
}
func (p *XunfeiProvider) xunfeiMakeRequest(textRequest *types.ChatCompletionRequest, authUrl string) (chan XunfeiChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiChatResponse, error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "{") {
*rawLine = nil
return nil, nil
}
conn, resp, err := d.Dial(authUrl, nil)
if err != nil || resp.StatusCode != 101 {
return nil, nil, err
}
data := p.requestOpenAI2Xunfei(textRequest)
err = conn.WriteJSON(data)
var xunfeiChatResponse XunfeiChatResponse
err := json.Unmarshal(*rawLine, &xunfeiChatResponse)
if err != nil {
return nil, nil, err
return nil, common.ErrorToOpenAIError(err)
}
dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool)
go func() {
for {
_, msg, err := conn.ReadMessage()
if err != nil {
common.SysError("error reading stream response: " + err.Error())
break
}
var response XunfeiChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil {
common.SysError("error closing websocket connection: " + err.Error())
}
break
}
}
stopChan <- true
}()
error := errorHandle(&xunfeiChatResponse)
if error != nil {
return nil, error
}
return dataChan, stopChan, nil
if xunfeiChatResponse.Payload.Choices.Status == 2 {
*isFinished = true
}
h.Usage.PromptTokens = xunfeiChatResponse.Payload.Usage.Text.PromptTokens
h.Usage.CompletionTokens = xunfeiChatResponse.Payload.Usage.Text.CompletionTokens
h.Usage.TotalTokens = xunfeiChatResponse.Payload.Usage.Text.TotalTokens
return &xunfeiChatResponse, nil
}
func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse, functionCate string) *types.ChatCompletionStreamResponse {
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, isFinished *bool, response *[]XunfeiChatResponse) error {
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished)
if err != nil {
return err
}
var choice types.ChatCompletionStreamChoice
xunfeiText := xunfeiResponse.Payload.Choices.Text[0]
if *rawLine == nil {
return nil
}
*response = append(*response, *xunfeiChatResponse)
return nil
}
func (h *xunfeiHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished)
if err != nil {
return err
}
if *rawLine == nil {
return nil
}
return h.convertToOpenaiStream(xunfeiChatResponse, response)
}
func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, response *[]types.ChatCompletionStreamResponse) error {
if len(xunfeiChatResponse.Payload.Choices.Text) == 0 {
xunfeiChatResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}
choice := types.ChatCompletionStreamChoice{
Index: 0,
Delta: types.ChatCompletionStreamChoiceDelta{
Role: types.ChatMessageRoleAssistant,
},
}
xunfeiText := xunfeiChatResponse.Payload.Choices.Text[0]
if xunfeiText.FunctionCall != nil {
if functionCate == "tool" {
if h.Request.Tools != nil {
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: xunfeiResponse.Header.Sid,
Id: xunfeiChatResponse.Header.Sid,
Index: 0,
Type: "function",
Function: *xunfeiText.FunctionCall,
Function: xunfeiText.FunctionCall,
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
choice.FinishReason = types.FinishReasonToolCalls
} else {
choice.Delta.FunctionCall = xunfeiText.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
choice.FinishReason = types.FinishReasonFunctionCall
}
} else {
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &base.StopFinishReason
choice.Delta.Content = xunfeiChatResponse.Payload.Choices.Text[0].Content
if xunfeiChatResponse.Payload.Choices.Status == 2 {
choice.FinishReason = types.FinishReasonStop
}
}
response := types.ChatCompletionStreamResponse{
ID: xunfeiResponse.Header.Sid,
chatCompletion := types.ChatCompletionStreamResponse{
ID: xunfeiChatResponse.Header.Sid,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "SparkDesk",
Choices: []types.ChatCompletionStreamChoice{choice},
Model: h.Request.Model,
}
return &response
if xunfeiText.FunctionCall == nil {
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletion)
} else {
choices := choice.ConvertOpenaiStream()
for _, choice := range choices {
chatCompletionCopy := chatCompletion
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletionCopy)
}
}
return nil
}

View File

@ -1,14 +1,18 @@
package zhipu
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/types"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
)
@ -18,12 +22,12 @@ var expSeconds int64 = 24 * 3600
type ZhipuProviderFactory struct{}
// 创建 ZhipuProvider
func (f ZhipuProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f ZhipuProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &ZhipuProvider{
BaseProvider: base.BaseProvider{
BaseURL: "https://open.bigmodel.cn",
ChatCompletions: "/api/paas/v3/model-api",
Context: c,
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
},
}
}
@ -32,6 +36,36 @@ type ZhipuProvider struct {
base.BaseProvider
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://open.bigmodel.cn",
ChatCompletions: "/api/paas/v3/model-api",
}
}
// 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
var zhipuError *ZhipuResponse
err := json.NewDecoder(resp.Body).Decode(zhipuError)
if err != nil {
return nil
}
return errorHandle(zhipuError)
}
// 错误处理
func errorHandle(zhipuError *ZhipuResponse) *types.OpenAIError {
if zhipuError.Success {
return nil
}
return &types.OpenAIError{
Message: zhipuError.Msg,
Type: "zhipu_error",
Code: zhipuError.Code,
}
}
// 获取请求头
func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)

View File

@ -1,38 +1,108 @@
package zhipu
import (
"bufio"
"encoding/json"
"io"
"fmt"
"net/http"
"one-api/common"
"one-api/providers/base"
"one-api/common/requester"
"one-api/types"
"strings"
)
func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if !zhipuResponse.Success {
return &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: zhipuResponse.Msg,
Type: "zhipu_error",
Param: "",
Code: zhipuResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
type zhipuStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
}
func (p *ZhipuProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
zhipuChatResponse := &ZhipuResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, zhipuChatResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
fullTextResponse := types.ChatCompletionResponse{
ID: zhipuResponse.Data.TaskId,
return p.convertToChatOpenai(zhipuChatResponse, request)
}
func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := &zhipuStreamHandler{
Usage: p.Usage,
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
}
func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, request.Model)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
}
// 获取请求头
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
fullRequestURL += "/sse-invoke"
} else {
fullRequestURL += "/invoke"
}
zhipuRequest := convertFromChatOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(zhipuRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
return req, nil
}
func (p *ZhipuProvider) convertToChatOpenai(response *ZhipuResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
error := errorHandle(response)
if error != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *error,
StatusCode: http.StatusBadRequest,
}
return
}
openaiResponse = &types.ChatCompletionResponse{
ID: response.Data.TaskId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Model: zhipuResponse.Model,
Choices: make([]types.ChatCompletionChoice, 0, len(zhipuResponse.Data.Choices)),
Usage: &zhipuResponse.Data.Usage,
Model: request.Model,
Choices: make([]types.ChatCompletionChoice, 0, len(response.Data.Choices)),
Usage: &response.Data.Usage,
}
for i, choice := range zhipuResponse.Data.Choices {
for i, choice := range response.Data.Choices {
openaiChoice := types.ChatCompletionChoice{
Index: i,
Message: types.ChatCompletionMessage{
@ -41,16 +111,18 @@ func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAI
},
FinishReason: "",
}
if i == len(zhipuResponse.Data.Choices)-1 {
if i == len(response.Data.Choices)-1 {
openaiChoice.FinishReason = "stop"
}
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
openaiResponse.Choices = append(openaiResponse.Choices, openaiChoice)
}
return fullTextResponse, nil
*p.Usage = response.Data.Usage
return
}
func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest) *ZhipuRequest {
func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest {
messages := make([]ZhipuMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
@ -77,158 +149,60 @@ func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest)
}
}
func (p *ZhipuProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
fullRequestURL += "/sse-invoke"
} else {
fullRequestURL += "/invoke"
// 转换为OpenAI聊天流式请求体
func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data:") && !strings.HasPrefix(string(*rawLine), "meta:") {
*rawLine = nil
return nil
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
if strings.HasPrefix(string(*rawLine), "meta:") {
*rawLine = (*rawLine)[5:]
var zhipuStreamMetaResponse ZhipuStreamMetaResponse
err := json.Unmarshal(*rawLine, &zhipuStreamMetaResponse)
if err != nil {
return common.ErrorToOpenAIError(err)
}
*isFinished = true
return h.handlerMeta(&zhipuStreamMetaResponse, response)
}
if request.Stream {
errWithCode, usage = p.sendStreamRequest(req, request.Model)
if errWithCode != nil {
return
}
} else {
zhipuResponse := &ZhipuResponse{
Model: request.Model,
}
errWithCode = p.SendRequest(req, zhipuResponse, false)
if errWithCode != nil {
return
}
usage = &zhipuResponse.Data.Usage
}
return
*rawLine = (*rawLine)[5:]
return h.convertToOpenaiStream(string(*rawLine), response)
}
func (p *ZhipuProvider) streamResponseZhipu2OpenAI(zhipuResponse string) *types.ChatCompletionStreamResponse {
func (h *zhipuStreamHandler) convertToOpenaiStream(content string, response *[]types.ChatCompletionStreamResponse) error {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = zhipuResponse
response := types.ChatCompletionStreamResponse{
choice.Delta.Content = content
streamResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "chatglm",
Model: h.Request.Model,
Choices: []types.ChatCompletionStreamChoice{choice},
}
return &response
*response = append(*response, streamResponse)
return nil
}
func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*types.ChatCompletionStreamResponse, *types.Usage) {
func (h *zhipuStreamHandler) handlerMeta(zhipuResponse *ZhipuStreamMetaResponse, response *[]types.ChatCompletionStreamResponse) error {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = ""
choice.FinishReason = &base.StopFinishReason
response := types.ChatCompletionStreamResponse{
choice.FinishReason = types.FinishReasonStop
streamResponse := types.ChatCompletionStreamResponse{
ID: zhipuResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: zhipuResponse.Model,
Model: h.Request.Model,
Choices: []types.ChatCompletionStreamChoice{choice},
}
return &response, &zhipuResponse.Usage
}
func (p *ZhipuProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, *types.Usage) {
defer req.Body.Close()
// 发送请求
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), nil
}
defer resp.Body.Close()
var usage *types.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Contains(string(data), ":") {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
metaChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
lines := strings.Split(data, "\n")
for i, line := range lines {
if len(line) < 5 {
continue
}
if line[:5] == "data:" {
dataChan <- line[5:]
if i != len(lines)-1 {
dataChan <- "\n"
}
} else if line[:5] == "meta:" {
metaChan <- line[5:]
}
}
}
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
response := p.streamResponseZhipu2OpenAI(data)
response.Model = model
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case data := <-metaChan:
var zhipuResponse ZhipuStreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
zhipuResponse.Model = model
response, zhipuUsage := p.streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
usage = zhipuUsage
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, usage
*response = append(*response, streamResponse)
*h.Usage = zhipuResponse.Usage
return nil
}

View File

@ -26,3 +26,8 @@ type AudioResponse struct {
Segments any `json:"segments,omitempty"`
Text string `json:"text"`
}
type AudioResponseWrapper struct {
Headers map[string]string
Body []byte
}

View File

@ -5,16 +5,33 @@ const (
ContentTypeImageURL = "image_url"
)
const (
FinishReasonStop = "stop"
FinishReasonLength = "length"
FinishReasonFunctionCall = "function_call"
FinishReasonToolCalls = "tool_calls"
FinishReasonContentFilter = "content_filter"
FinishReasonNull = "null"
)
const (
ChatMessageRoleSystem = "system"
ChatMessageRoleUser = "user"
ChatMessageRoleAssistant = "assistant"
ChatMessageRoleFunction = "function"
ChatMessageRoleTool = "tool"
)
type ChatCompletionToolCallsFunction struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
Arguments string `json:"arguments"`
}
type ChatCompletionToolCalls struct {
Id string `json:"id"`
Type string `json:"type"`
Function ChatCompletionToolCallsFunction `json:"function"`
Index int `json:"index"`
Id string `json:"id"`
Type string `json:"type"`
Function *ChatCompletionToolCallsFunction `json:"function"`
Index int `json:"index"`
}
type ChatCompletionMessage struct {
@ -123,6 +140,8 @@ type ChatCompletionRequest struct {
Seed *int `json:"seed,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
LogitBias any `json:"logit_bias,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
User string `json:"user,omitempty"`
Functions []*ChatCompletionFunction `json:"functions,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
@ -151,19 +170,89 @@ type ChatCompletionTool struct {
}
type ChatCompletionChoice struct {
Index int `json:"index"`
Message ChatCompletionMessage `json:"message"`
FinishReason any `json:"finish_reason,omitempty"`
Index int `json:"index"`
Message ChatCompletionMessage `json:"message"`
FinishReason any `json:"finish_reason,omitempty"`
ContentFilterResults any `json:"content_filter_results,omitempty"`
FinishDetails any `json:"finish_details,omitempty"`
}
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
PromptFilterResults any `json:"prompt_filter_results,omitempty"`
}
func (c ChatCompletionStreamChoice) ConvertOpenaiStream() []ChatCompletionStreamChoice {
var function *ChatCompletionToolCallsFunction
var functions []*ChatCompletionToolCallsFunction
var choices []ChatCompletionStreamChoice
var stopFinish string
if c.Delta.FunctionCall != nil {
function = c.Delta.FunctionCall
stopFinish = FinishReasonFunctionCall
} else {
function = c.Delta.ToolCalls[0].Function
stopFinish = FinishReasonToolCalls
}
if function.Name == "" {
c.FinishReason = stopFinish
choices = append(choices, c)
return choices
}
functions = append(functions, &ChatCompletionToolCallsFunction{
Name: function.Name,
Arguments: "",
})
if function.Arguments == "" || function.Arguments == "{}" {
functions = append(functions, &ChatCompletionToolCallsFunction{
Arguments: "{}",
})
} else {
functions = append(functions, &ChatCompletionToolCallsFunction{
Arguments: function.Arguments,
})
}
// 循环functions, 生成choices
for _, function := range functions {
choice := ChatCompletionStreamChoice{
Index: 0,
Delta: ChatCompletionStreamChoiceDelta{
Role: c.Delta.Role,
},
}
if stopFinish == FinishReasonFunctionCall {
choice.Delta.FunctionCall = function
} else {
choice.Delta.ToolCalls = []*ChatCompletionToolCalls{
{
Id: c.Delta.ToolCalls[0].Id,
Index: 0,
Type: "function",
Function: function,
},
}
}
choices = append(choices, choice)
}
choices = append(choices, ChatCompletionStreamChoice{
Index: c.Index,
Delta: ChatCompletionStreamChoiceDelta{},
FinishReason: stopFinish,
})
return choices
}
type ChatCompletionStreamChoiceDelta struct {

View File

@ -1,5 +1,7 @@
package types
import "encoding/json"
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
@ -14,6 +16,16 @@ type OpenAIError struct {
InnerError any `json:"innererror,omitempty"`
}
func (e *OpenAIError) Error() string {
response := &OpenAIErrorResponse{
Error: *e,
}
// 转换为JSON
bytes, _ := json.Marshal(response)
return string(bytes)
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`

View File

@ -8,9 +8,9 @@ type EmbeddingRequest struct {
}
type Embedding struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object"`
Embedding any `json:"embedding"`
Index int `json:"index"`
}
type EmbeddingResponse struct {