From 902c2faa2c813ccc1ce334498ebda2853ffd9636 Mon Sep 17 00:00:00 2001 From: Martial BE Date: Tue, 28 Nov 2023 18:32:26 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20split=20relay?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/client.go | 126 ++++ common/form_builder.go | 65 ++ common/marshaller.go | 15 + common/quota.go | 59 ++ common/request_builder.go | 50 ++ common/token.go | 109 +++ controller/channel-billing.go | 3 +- controller/channel-test.go | 114 ++- controller/relay-aiproxy.go | 220 ------ controller/relay-ali.go | 329 --------- controller/relay-audio.go | 183 ----- controller/relay-baidu.go | 359 ---------- controller/relay-chat.go | 127 ++++ controller/relay-claude.go | 220 ------ controller/relay-completion.go | 113 +++ controller/relay-embeddings.go | 117 ++++ controller/relay-image.go | 206 ------ controller/relay-openai.go | 144 ---- controller/relay-palm.go | 205 ------ controller/relay-text.go | 649 ------------------ controller/relay-utils.go | 282 +++----- controller/relay-zhipu.go | 301 -------- controller/relay.go | 68 +- main.go | 9 +- middleware/distributor.go | 3 +- model/channel.go | 3 +- providers/ali_base.go | 50 ++ providers/ali_chat.go | 256 +++++++ providers/ali_embeddings.go | 94 +++ providers/api2d_base.go | 14 + providers/azure_base.go | 41 ++ providers/baidu_base.go | 136 ++++ providers/baidu_chat.go | 228 ++++++ providers/baidu_embeddings.go | 88 +++ providers/base.go | 150 ++++ providers/claude_base.go | 55 ++ providers/claude_chat.go | 232 +++++++ providers/closeai_proxy_base.go | 50 ++ providers/openai_base.go | 215 ++++++ providers/openai_chat.go | 92 +++ providers/openai_completion.go | 87 +++ providers/openai_embeddings.go | 50 ++ providers/openaisb_base.go | 58 ++ providers/palm_base.go | 43 ++ providers/palm_chat.go | 232 +++++++ providers/tencent_base.go | 94 +++ .../tencent_chat.go | 214 +++--- providers/xunfei_base.go | 96 +++ .../xunfei_chat.go | 289 ++++---- providers/zhipu_base.go | 104 +++ providers/zhipu_chat.go | 260 +++++++ types/assistant.go | 53 ++ types/audio.go | 9 + types/chat.go | 109 +++ types/common.go | 40 ++ types/completion.go | 36 + types/embeddings.go | 40 ++ types/image.go | 23 + 58 files changed, 4248 insertions(+), 3369 deletions(-) create mode 100644 common/client.go create mode 100644 common/form_builder.go create mode 100644 common/marshaller.go create mode 100644 common/quota.go create mode 100644 common/request_builder.go create mode 100644 common/token.go delete mode 100644 controller/relay-aiproxy.go delete mode 100644 controller/relay-ali.go delete mode 100644 controller/relay-audio.go delete mode 100644 controller/relay-baidu.go create mode 100644 controller/relay-chat.go delete mode 100644 controller/relay-claude.go create mode 100644 controller/relay-completion.go create mode 100644 controller/relay-embeddings.go delete mode 100644 controller/relay-image.go delete mode 100644 controller/relay-openai.go delete mode 100644 controller/relay-palm.go delete mode 100644 controller/relay-text.go delete mode 100644 controller/relay-zhipu.go create mode 100644 providers/ali_base.go create mode 100644 providers/ali_chat.go create mode 100644 providers/ali_embeddings.go create mode 100644 providers/api2d_base.go create mode 100644 providers/azure_base.go create mode 100644 providers/baidu_base.go create mode 100644 providers/baidu_chat.go create mode 100644 providers/baidu_embeddings.go create mode 100644 providers/base.go create mode 100644 providers/claude_base.go create mode 100644 providers/claude_chat.go create mode 100644 providers/closeai_proxy_base.go create mode 100644 providers/openai_base.go create mode 100644 providers/openai_chat.go create mode 100644 providers/openai_completion.go create mode 100644 providers/openai_embeddings.go create mode 100644 providers/openaisb_base.go create mode 100644 providers/palm_base.go create mode 100644 providers/palm_chat.go create mode 100644 providers/tencent_base.go rename controller/relay-tencent.go => providers/tencent_chat.go (59%) create mode 100644 providers/xunfei_base.go rename controller/relay-xunfei.go => providers/xunfei_chat.go (56%) create mode 100644 providers/zhipu_base.go create mode 100644 providers/zhipu_chat.go create mode 100644 types/assistant.go create mode 100644 types/audio.go create mode 100644 types/chat.go create mode 100644 types/common.go create mode 100644 types/completion.go create mode 100644 types/embeddings.go create mode 100644 types/image.go diff --git a/common/client.go b/common/client.go new file mode 100644 index 00000000..5fb596b0 --- /dev/null +++ b/common/client.go @@ -0,0 +1,126 @@ +package common + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +var HttpClient *http.Client + +func init() { + if RelayTimeout == 0 { + HttpClient = &http.Client{} + } else { + HttpClient = &http.Client{ + Timeout: time.Duration(RelayTimeout) * time.Second, + } + } +} + +type Client struct { + requestBuilder RequestBuilder + createFormBuilder func(io.Writer) FormBuilder +} + +func NewClient() *Client { + return &Client{ + requestBuilder: NewRequestBuilder(), + createFormBuilder: func(body io.Writer) FormBuilder { + return NewFormBuilder(body) + }, + } +} + +type requestOptions struct { + body any + header http.Header +} + +type requestOption func(*requestOptions) + +func WithBody(body any) requestOption { + return func(args *requestOptions) { + args.body = body + } +} + +func WithHeader(header map[string]string) requestOption { + return func(args *requestOptions) { + for k, v := range header { + args.header.Set(k, v) + } + } +} + +type RequestError struct { + HTTPStatusCode int + Err error +} + +func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) { + // Default Options + args := &requestOptions{ + body: nil, + header: make(http.Header), + } + for _, setter := range setters { + setter(args) + } + req, err := c.requestBuilder.Build(method, url, args.body, args.header) + if err != nil { + return nil, err + } + + return req, nil +} + +func (c *Client) SendRequest(req *http.Request, response any) error { + + // 发送请求 + resp, err := HttpClient.Do(req) + if err != nil { + return err + } + + defer resp.Body.Close() + + // 处理响应 + if IsFailureStatusCode(resp) { + return fmt.Errorf("status code: %d", resp.StatusCode) + } + + // 解析响应 + err = DecodeResponse(resp.Body, response) + if err != nil { + return err + } + + return nil +} + +func IsFailureStatusCode(resp *http.Response) bool { + return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest +} + +func DecodeResponse(body io.Reader, v any) error { + if v == nil { + return nil + } + + if result, ok := v.(*string); ok { + return DecodeString(body, result) + } + return json.NewDecoder(body).Decode(v) +} + +func DecodeString(body io.Reader, output *string) error { + b, err := io.ReadAll(body) + if err != nil { + return err + } + *output = string(b) + return nil +} diff --git a/common/form_builder.go b/common/form_builder.go new file mode 100644 index 00000000..a30e18ff --- /dev/null +++ b/common/form_builder.go @@ -0,0 +1,65 @@ +package common + +import ( + "fmt" + "io" + "mime/multipart" + "os" + "path" +) + +type FormBuilder interface { + CreateFormFile(fieldname string, file *os.File) error + CreateFormFileReader(fieldname string, r io.Reader, filename string) error + WriteField(fieldname, value string) error + Close() error + FormDataContentType() string +} + +type DefaultFormBuilder struct { + writer *multipart.Writer +} + +func NewFormBuilder(body io.Writer) *DefaultFormBuilder { + return &DefaultFormBuilder{ + writer: multipart.NewWriter(body), + } +} + +func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error { + return fb.createFormFile(fieldname, file, file.Name()) +} + +func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { + return fb.createFormFile(fieldname, r, path.Base(filename)) +} + +func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { + if filename == "" { + return fmt.Errorf("filename cannot be empty") + } + + fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename) + if err != nil { + return err + } + + _, err = io.Copy(fieldWriter, r) + if err != nil { + return err + } + + return nil +} + +func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error { + return fb.writer.WriteField(fieldname, value) +} + +func (fb *DefaultFormBuilder) Close() error { + return fb.writer.Close() +} + +func (fb *DefaultFormBuilder) FormDataContentType() string { + return fb.writer.FormDataContentType() +} diff --git a/common/marshaller.go b/common/marshaller.go new file mode 100644 index 00000000..0ef9d5da --- /dev/null +++ b/common/marshaller.go @@ -0,0 +1,15 @@ +package common + +import ( + "encoding/json" +) + +type Marshaller interface { + Marshal(value any) ([]byte, error) +} + +type JSONMarshaller struct{} + +func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) { + return json.Marshal(value) +} diff --git a/common/quota.go b/common/quota.go new file mode 100644 index 00000000..b8c772d9 --- /dev/null +++ b/common/quota.go @@ -0,0 +1,59 @@ +package common + +// type Quota struct { +// ModelName string +// ModelRatio float64 +// GroupRatio float64 +// Ratio float64 +// UserQuota int +// } + +// func CreateQuota(modelName string, userQuota int, group string) *Quota { +// modelRatio := GetModelRatio(modelName) +// groupRatio := GetGroupRatio(group) + +// return &Quota{ +// ModelName: modelName, +// ModelRatio: modelRatio, +// GroupRatio: groupRatio, +// Ratio: modelRatio * groupRatio, +// UserQuota: userQuota, +// } +// } + +// func (q *Quota) getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { +// if ApproximateTokenEnabled { +// return int(float64(len(text)) * 0.38) +// } +// return len(tokenEncoder.Encode(text, nil, nil)) +// } + +// func (q *Quota) CountTokenMessages(messages []Message, model string) int { +// tokenEncoder := q.getTokenEncoder(model) +// // Reference: +// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb +// // https://github.com/pkoukk/tiktoken-go/issues/6 +// // +// // Every message follows <|start|>{role/name}\n{content}<|end|>\n +// var tokensPerMessage int +// var tokensPerName int +// if model == "gpt-3.5-turbo-0301" { +// tokensPerMessage = 4 +// tokensPerName = -1 // If there's a name, the role is omitted +// } else { +// tokensPerMessage = 3 +// tokensPerName = 1 +// } +// tokenNum := 0 +// for _, message := range messages { +// tokenNum += tokensPerMessage +// tokenNum += q.getTokenNum(tokenEncoder, message.StringContent()) +// tokenNum += q.getTokenNum(tokenEncoder, message.Role) +// if message.Name != nil { +// tokenNum += tokensPerName +// tokenNum += q.getTokenNum(tokenEncoder, *message.Name) +// } +// } +// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> +// return tokenNum +// } diff --git a/common/request_builder.go b/common/request_builder.go new file mode 100644 index 00000000..6a97b425 --- /dev/null +++ b/common/request_builder.go @@ -0,0 +1,50 @@ +package common + +import ( + "bytes" + "io" + "net/http" +) + +type RequestBuilder interface { + Build(method, url string, body any, header http.Header) (*http.Request, error) +} + +type HTTPRequestBuilder struct { + marshaller Marshaller +} + +func NewRequestBuilder() *HTTPRequestBuilder { + return &HTTPRequestBuilder{ + marshaller: &JSONMarshaller{}, + } +} + +func (b *HTTPRequestBuilder) Build( + method string, + url string, + body any, + header http.Header, +) (req *http.Request, err error) { + var bodyReader io.Reader + if body != nil { + if v, ok := body.(io.Reader); ok { + bodyReader = v + } else { + var reqBytes []byte + reqBytes, err = b.marshaller.Marshal(body) + if err != nil { + return + } + bodyReader = bytes.NewBuffer(reqBytes) + } + } + req, err = http.NewRequest(method, url, bodyReader) + if err != nil { + return + } + if header != nil { + req.Header = header + } + return +} diff --git a/common/token.go b/common/token.go new file mode 100644 index 00000000..5cac6f20 --- /dev/null +++ b/common/token.go @@ -0,0 +1,109 @@ +package common + +import ( + "fmt" + "strings" + + "one-api/types" + + "github.com/pkoukk/tiktoken-go" +) + +var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} +var defaultTokenEncoder *tiktoken.Tiktoken + +func InitTokenEncoders() { + SysLog("initializing token encoders") + gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") + if err != nil { + FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) + } + defaultTokenEncoder = gpt35TokenEncoder + gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") + if err != nil { + FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) + } + for model, _ := range ModelRatio { + if strings.HasPrefix(model, "gpt-3.5") { + tokenEncoderMap[model] = gpt35TokenEncoder + } else if strings.HasPrefix(model, "gpt-4") { + tokenEncoderMap[model] = gpt4TokenEncoder + } else { + tokenEncoderMap[model] = nil + } + } + SysLog("token encoders initialized") +} + +func getTokenEncoder(model string) *tiktoken.Tiktoken { + tokenEncoder, ok := tokenEncoderMap[model] + if ok && tokenEncoder != nil { + return tokenEncoder + } + if ok { + tokenEncoder, err := tiktoken.EncodingForModel(model) + if err != nil { + SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) + tokenEncoder = defaultTokenEncoder + } + tokenEncoderMap[model] = tokenEncoder + return tokenEncoder + } + return defaultTokenEncoder +} + +func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { + if ApproximateTokenEnabled { + return int(float64(len(text)) * 0.38) + } + return len(tokenEncoder.Encode(text, nil, nil)) +} + +func CountTokenMessages(messages []types.ChatCompletionMessage, model string) int { + tokenEncoder := getTokenEncoder(model) + // Reference: + // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + // https://github.com/pkoukk/tiktoken-go/issues/6 + // + // Every message follows <|start|>{role/name}\n{content}<|end|>\n + var tokensPerMessage int + var tokensPerName int + if model == "gpt-3.5-turbo-0301" { + tokensPerMessage = 4 + tokensPerName = -1 // If there's a name, the role is omitted + } else { + tokensPerMessage = 3 + tokensPerName = 1 + } + tokenNum := 0 + for _, message := range messages { + tokenNum += tokensPerMessage + tokenNum += getTokenNum(tokenEncoder, message.StringContent()) + tokenNum += getTokenNum(tokenEncoder, message.Role) + if message.Name != nil { + tokenNum += tokensPerName + tokenNum += getTokenNum(tokenEncoder, *message.Name) + } + } + tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> + return tokenNum +} + +func CountTokenInput(input any, model string) int { + switch input.(type) { + case string: + return CountTokenText(input.(string), model) + case []string: + text := "" + for _, s := range input.([]string) { + text += s + } + return CountTokenText(text, model) + } + return 0 +} + +func CountTokenText(text string, model string) int { + tokenEncoder := getTokenEncoder(model) + return getTokenNum(tokenEncoder, text) +} diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 6ddad7ea..8f388e6f 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -92,7 +92,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He for k := range headers { req.Header.Add(k, headers.Get(k)) } - res, err := httpClient.Do(req) + res, err := common.HttpClient.Do(req) if err != nil { return nil, err } @@ -204,6 +204,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { if channel.GetBaseURL() == "" { channel.BaseURL = &baseURL } + switch channel.Type { case common.ChannelTypeOpenAI: if channel.GetBaseURL() != "" { diff --git a/controller/channel-test.go b/controller/channel-test.go index 1b0b745a..29e7360a 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -1,14 +1,13 @@ package controller import ( - "bytes" - "encoding/json" "errors" "fmt" - "io" "net/http" + "net/http/httptest" "one-api/common" "one-api/model" + "one-api/types" "strconv" "sync" "time" @@ -16,86 +15,81 @@ import ( "github.com/gin-gonic/gin" ) -func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { +func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (err error, openaiErr *types.OpenAIError) { + // 创建一个 http.Request + req, err := http.NewRequest("POST", "/v1/chat/completions", nil) + if err != nil { + return err, nil + } + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = req + c.Set("channel", channel.Type) + c.Set("channel_id", channel.Id) + c.Set("channel_name", channel.Name) + c.Set("model_mapping", channel.GetModelMapping()) + c.Set("api_key", channel.Key) + c.Set("base_url", channel.GetBaseURL()) + switch channel.Type { case common.ChannelTypePaLM: - fallthrough + request.Model = "PaLM-2" case common.ChannelTypeAnthropic: - fallthrough + request.Model = "claude-2" case common.ChannelTypeBaidu: - fallthrough + request.Model = "ERNIE-Bot" case common.ChannelTypeZhipu: - fallthrough + request.Model = "chatglm_lite" case common.ChannelTypeAli: - fallthrough + request.Model = "qwen-turbo" case common.ChannelType360: - fallthrough + request.Model = "360GPT_S2_V9" case common.ChannelTypeXunfei: - return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil + request.Model = "SparkDesk" + c.Set("api_version", channel.Other) + case common.ChannelTypeTencent: + request.Model = "hunyuan" case common.ChannelTypeAzure: - request.Model = "gpt-35-turbo" - defer func() { - if err != nil { - err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") - } - }() + request.Model = "gpt-3.5-turbo" + c.Set("api_version", channel.Other) default: request.Model = "gpt-3.5-turbo" } - requestURL := common.ChannelBaseURLs[channel.Type] - if channel.Type == common.ChannelTypeAzure { - requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) - } else { - if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { - requestURL = baseURL - } - requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) - } - jsonData, err := json.Marshal(request) + chatProvider := GetChatProvider(channel.Type, c) + isModelMapped := false + modelMap, err := parseModelMapping(c.GetString("model_mapping")) if err != nil { return err, nil } - req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) - if err != nil { - return err, nil + if modelMap != nil && modelMap[request.Model] != "" { + request.Model = modelMap[request.Model] + isModelMapped = true } - if channel.Type == common.ChannelTypeAzure { - req.Header.Set("api-key", channel.Key) - } else { - req.Header.Set("Authorization", "Bearer "+channel.Key) - } - req.Header.Set("Content-Type", "application/json") - resp, err := httpClient.Do(req) - if err != nil { - return err, nil - } - defer resp.Body.Close() - var response TextResponse - body, err := io.ReadAll(resp.Body) - if err != nil { - return err, nil - } - err = json.Unmarshal(body, &response) - if err != nil { - return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil - } - if response.Usage.CompletionTokens == 0 { - return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error + + promptTokens := common.CountTokenMessages(request.Messages, request.Model) + _, openAIErrorWithStatusCode := chatProvider.ChatCompleteResponse(&request, isModelMapped, promptTokens) + if openAIErrorWithStatusCode != nil { + return nil, &openAIErrorWithStatusCode.OpenAIError } + return nil, nil } -func buildTestRequest() *ChatRequest { - testRequest := &ChatRequest{ - Model: "", // this will be set later +func buildTestRequest() *types.ChatCompletionRequest { + testRequest := &types.ChatCompletionRequest{ + Messages: []types.ChatCompletionMessage{ + { + Role: "user", + Content: "You just need to output 'hi' next.", + }, + }, + Model: "", MaxTokens: 1, + Stream: false, } - testMessage := Message{ - Role: "user", - Content: "hi", - } - testRequest.Messages = append(testRequest.Messages, testMessage) return testRequest } diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go deleted file mode 100644 index 543954f7..00000000 --- a/controller/relay-aiproxy.go +++ /dev/null @@ -1,220 +0,0 @@ -package controller - -import ( - "bufio" - "encoding/json" - "fmt" - "github.com/gin-gonic/gin" - "io" - "net/http" - "one-api/common" - "strconv" - "strings" -) - -// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 - -type AIProxyLibraryRequest struct { - Model string `json:"model"` - Query string `json:"query"` - LibraryId string `json:"libraryId"` - Stream bool `json:"stream"` -} - -type AIProxyLibraryError struct { - ErrCode int `json:"errCode"` - Message string `json:"message"` -} - -type AIProxyLibraryDocument struct { - Title string `json:"title"` - URL string `json:"url"` -} - -type AIProxyLibraryResponse struct { - Success bool `json:"success"` - Answer string `json:"answer"` - Documents []AIProxyLibraryDocument `json:"documents"` - AIProxyLibraryError -} - -type AIProxyLibraryStreamResponse struct { - Content string `json:"content"` - Finish bool `json:"finish"` - Model string `json:"model"` - Documents []AIProxyLibraryDocument `json:"documents"` -} - -func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { - query := "" - if len(request.Messages) != 0 { - query = request.Messages[len(request.Messages)-1].StringContent() - } - return &AIProxyLibraryRequest{ - Model: request.Model, - Stream: request.Stream, - Query: query, - } -} - -func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { - if len(documents) == 0 { - return "" - } - content := "\n\n参考文档:\n" - for i, document := range documents { - content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL) - } - return content -} - -func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { - content := response.Answer + aiProxyDocuments2Markdown(response.Documents) - choice := OpenAITextResponseChoice{ - Index: 0, - Message: Message{ - Role: "assistant", - Content: content, - }, - FinishReason: "stop", - } - fullTextResponse := OpenAITextResponse{ - Id: common.GetUUID(), - Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, - } - return &fullTextResponse -} - -func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice - choice.Delta.Content = aiProxyDocuments2Markdown(documents) - choice.FinishReason = &stopFinishReason - return &ChatCompletionsStreamResponse{ - Id: common.GetUUID(), - Object: "chat.completion.chunk", - Created: common.GetTimestamp(), - Model: "", - Choices: []ChatCompletionsStreamResponseChoice{choice}, - } -} - -func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice - choice.Delta.Content = response.Content - return &ChatCompletionsStreamResponse{ - Id: common.GetUUID(), - Object: "chat.completion.chunk", - Created: common.GetTimestamp(), - Model: response.Model, - Choices: []ChatCompletionsStreamResponseChoice{choice}, - } -} - -func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage Usage - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() - setEventStreamHeaders(c) - var documents []AIProxyLibraryDocument - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var AIProxyLibraryResponse AIProxyLibraryStreamResponse - err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if len(AIProxyLibraryResponse.Documents) != 0 { - documents = AIProxyLibraryResponse.Documents - } - response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - response := documentsAIProxyLibrary(documents) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - err := resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - return nil, &usage -} - -func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var AIProxyLibraryResponse AIProxyLibraryResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if AIProxyLibraryResponse.ErrCode != 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: AIProxyLibraryResponse.Message, - Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode), - Code: AIProxyLibraryResponse.ErrCode, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage -} diff --git a/controller/relay-ali.go b/controller/relay-ali.go deleted file mode 100644 index b41ca327..00000000 --- a/controller/relay-ali.go +++ /dev/null @@ -1,329 +0,0 @@ -package controller - -import ( - "bufio" - "encoding/json" - "github.com/gin-gonic/gin" - "io" - "net/http" - "one-api/common" - "strings" -) - -// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r - -type AliMessage struct { - User string `json:"user"` - Bot string `json:"bot"` -} - -type AliInput struct { - Prompt string `json:"prompt"` - History []AliMessage `json:"history"` -} - -type AliParameters struct { - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Seed uint64 `json:"seed,omitempty"` - EnableSearch bool `json:"enable_search,omitempty"` -} - -type AliChatRequest struct { - Model string `json:"model"` - Input AliInput `json:"input"` - Parameters AliParameters `json:"parameters,omitempty"` -} - -type AliEmbeddingRequest struct { - Model string `json:"model"` - Input struct { - Texts []string `json:"texts"` - } `json:"input"` - Parameters *struct { - TextType string `json:"text_type,omitempty"` - } `json:"parameters,omitempty"` -} - -type AliEmbedding struct { - Embedding []float64 `json:"embedding"` - TextIndex int `json:"text_index"` -} - -type AliEmbeddingResponse struct { - Output struct { - Embeddings []AliEmbedding `json:"embeddings"` - } `json:"output"` - Usage AliUsage `json:"usage"` - AliError -} - -type AliError struct { - Code string `json:"code"` - Message string `json:"message"` - RequestId string `json:"request_id"` -} - -type AliUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type AliOutput struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` -} - -type AliChatResponse struct { - Output AliOutput `json:"output"` - Usage AliUsage `json:"usage"` - AliError -} - -func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { - messages := make([]AliMessage, 0, len(request.Messages)) - prompt := "" - for i := 0; i < len(request.Messages); i++ { - message := request.Messages[i] - if message.Role == "system" { - messages = append(messages, AliMessage{ - User: message.StringContent(), - Bot: "Okay", - }) - continue - } else { - if i == len(request.Messages)-1 { - prompt = message.StringContent() - break - } - messages = append(messages, AliMessage{ - User: message.StringContent(), - Bot: request.Messages[i+1].StringContent(), - }) - i++ - } - } - return &AliChatRequest{ - Model: request.Model, - Input: AliInput{ - Prompt: prompt, - History: messages, - }, - //Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's - // TopP: request.TopP, - // TopK: 50, - // //Seed: 0, - // //EnableSearch: false, - //}, - } -} - -func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { - return &AliEmbeddingRequest{ - Model: "text-embedding-v1", - Input: struct { - Texts []string `json:"texts"` - }{ - Texts: request.ParseInput(), - }, - } -} - -func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var aliResponse AliEmbeddingResponse - err := json.NewDecoder(resp.Body).Decode(&aliResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - - if aliResponse.Code != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: aliResponse.Message, - Type: aliResponse.Code, - Param: aliResponse.RequestId, - Code: aliResponse.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - - fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage -} - -func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { - openAIEmbeddingResponse := OpenAIEmbeddingResponse{ - Object: "list", - Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), - Model: "text-embedding-v1", - Usage: Usage{TotalTokens: response.Usage.TotalTokens}, - } - - for _, item := range response.Output.Embeddings { - openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ - Object: `embedding`, - Index: item.TextIndex, - Embedding: item.Embedding, - }) - } - return &openAIEmbeddingResponse -} - -func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { - choice := OpenAITextResponseChoice{ - Index: 0, - Message: Message{ - Role: "assistant", - Content: response.Output.Text, - }, - FinishReason: response.Output.FinishReason, - } - fullTextResponse := OpenAITextResponse{ - Id: response.RequestId, - Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, - Usage: Usage{ - PromptTokens: response.Usage.InputTokens, - CompletionTokens: response.Usage.OutputTokens, - TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, - }, - } - return &fullTextResponse -} - -func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice - choice.Delta.Content = aliResponse.Output.Text - if aliResponse.Output.FinishReason != "null" { - finishReason := aliResponse.Output.FinishReason - choice.FinishReason = &finishReason - } - response := ChatCompletionsStreamResponse{ - Id: aliResponse.RequestId, - Object: "chat.completion.chunk", - Created: common.GetTimestamp(), - Model: "ernie-bot", - Choices: []ChatCompletionsStreamResponseChoice{choice}, - } - return &response -} - -func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage Usage - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() - setEventStreamHeaders(c) - lastResponseText := "" - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var aliResponse AliChatResponse - err := json.Unmarshal([]byte(data), &aliResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if aliResponse.Usage.OutputTokens != 0 { - usage.PromptTokens = aliResponse.Usage.InputTokens - usage.CompletionTokens = aliResponse.Usage.OutputTokens - usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens - } - response := streamResponseAli2OpenAI(&aliResponse) - response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) - lastResponseText = aliResponse.Output.Text - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - err := resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - return nil, &usage -} - -func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var aliResponse AliChatResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &aliResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if aliResponse.Code != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: aliResponse.Message, - Type: aliResponse.Code, - Param: aliResponse.RequestId, - Code: aliResponse.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responseAli2OpenAI(&aliResponse) - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage -} diff --git a/controller/relay-audio.go b/controller/relay-audio.go deleted file mode 100644 index 89a311a0..00000000 --- a/controller/relay-audio.go +++ /dev/null @@ -1,183 +0,0 @@ -package controller - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "github.com/gin-gonic/gin" - "io" - "net/http" - "one-api/common" - "one-api/model" - "strings" -) - -func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - audioModel := "whisper-1" - - tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - userId := c.GetInt("id") - group := c.GetString("group") - tokenName := c.GetString("token_name") - - var ttsRequest TextToSpeechRequest - if relayMode == RelayModeAudioSpeech { - // Read JSON - err := common.UnmarshalBodyReusable(c, &ttsRequest) - // Check if JSON is valid - if err != nil { - return errorWrapper(err, "invalid_json", http.StatusBadRequest) - } - audioModel = ttsRequest.Model - // Check if text is too long 4096 - if len(ttsRequest.Input) > 4096 { - return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) - } - } - - preConsumedTokens := common.PreConsumedQuota - modelRatio := common.GetModelRatio(audioModel) - groupRatio := common.GetGroupRatio(group) - ratio := modelRatio * groupRatio - preConsumedQuota := int(float64(preConsumedTokens) * ratio) - userQuota, err := model.CacheGetUserQuota(userId) - if err != nil { - return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) - } - - quota := 0 - // Check if user quota is enough - if relayMode == RelayModeAudioSpeech { - quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio) - if quota > userQuota { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - } else { - if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - } - if preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) - } - } - } - - // map model name - modelMapping := c.GetString("model_mapping") - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[audioModel] != "" { - audioModel = modelMap[audioModel] - } - } - - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - - fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) - if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") - } - baseURL = c.GetString("base_url") - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) - } - - requestBody := c.Request.Body - - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - - if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - req.Header.Set("api-key", apiKey) - req.ContentLength = c.Request.ContentLength - } else { - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) - } - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - - resp, err := httpClient.Do(req) - if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) - } - - err = req.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - - if relayMode == RelayModeAudioSpeech { - defer func(ctx context.Context) { - go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) - }(c.Request.Context()) - } else { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - var whisperResponse WhisperResponse - err = json.Unmarshal(responseBody, &whisperResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - defer func(ctx context.Context) { - quota := countTokenText(whisperResponse.Text, audioModel) - quotaDelta := quota - preConsumedQuota - go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) - }(c.Request.Context()) - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - } - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - return nil -} diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go deleted file mode 100644 index c75ec09a..00000000 --- a/controller/relay-baidu.go +++ /dev/null @@ -1,359 +0,0 @@ -package controller - -import ( - "bufio" - "encoding/json" - "errors" - "fmt" - "github.com/gin-gonic/gin" - "io" - "net/http" - "one-api/common" - "strings" - "sync" - "time" -) - -// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 - -type BaiduTokenResponse struct { - ExpiresIn int `json:"expires_in"` - AccessToken string `json:"access_token"` -} - -type BaiduMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type BaiduChatRequest struct { - Messages []BaiduMessage `json:"messages"` - Stream bool `json:"stream"` - UserId string `json:"user_id,omitempty"` -} - -type BaiduError struct { - ErrorCode int `json:"error_code"` - ErrorMsg string `json:"error_msg"` -} - -type BaiduChatResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Result string `json:"result"` - IsTruncated bool `json:"is_truncated"` - NeedClearHistory bool `json:"need_clear_history"` - Usage Usage `json:"usage"` - BaiduError -} - -type BaiduChatStreamResponse struct { - BaiduChatResponse - SentenceId int `json:"sentence_id"` - IsEnd bool `json:"is_end"` -} - -type BaiduEmbeddingRequest struct { - Input []string `json:"input"` -} - -type BaiduEmbeddingData struct { - Object string `json:"object"` - Embedding []float64 `json:"embedding"` - Index int `json:"index"` -} - -type BaiduEmbeddingResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Data []BaiduEmbeddingData `json:"data"` - Usage Usage `json:"usage"` - BaiduError -} - -type BaiduAccessToken struct { - AccessToken string `json:"access_token"` - Error string `json:"error,omitempty"` - ErrorDescription string `json:"error_description,omitempty"` - ExpiresIn int64 `json:"expires_in,omitempty"` - ExpiresAt time.Time `json:"-"` -} - -var baiduTokenStore sync.Map - -func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { - messages := make([]BaiduMessage, 0, len(request.Messages)) - for _, message := range request.Messages { - if message.Role == "system" { - messages = append(messages, BaiduMessage{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, BaiduMessage{ - Role: "assistant", - Content: "Okay", - }) - } else { - messages = append(messages, BaiduMessage{ - Role: message.Role, - Content: message.StringContent(), - }) - } - } - return &BaiduChatRequest{ - Messages: messages, - Stream: request.Stream, - } -} - -func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { - choice := OpenAITextResponseChoice{ - Index: 0, - Message: Message{ - Role: "assistant", - Content: response.Result, - }, - FinishReason: "stop", - } - fullTextResponse := OpenAITextResponse{ - Id: response.Id, - Object: "chat.completion", - Created: response.Created, - Choices: []OpenAITextResponseChoice{choice}, - Usage: response.Usage, - } - return &fullTextResponse -} - -func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice - choice.Delta.Content = baiduResponse.Result - if baiduResponse.IsEnd { - choice.FinishReason = &stopFinishReason - } - response := ChatCompletionsStreamResponse{ - Id: baiduResponse.Id, - Object: "chat.completion.chunk", - Created: baiduResponse.Created, - Model: "ernie-bot", - Choices: []ChatCompletionsStreamResponseChoice{choice}, - } - return &response -} - -func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { - return &BaiduEmbeddingRequest{ - Input: request.ParseInput(), - } -} - -func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { - openAIEmbeddingResponse := OpenAIEmbeddingResponse{ - Object: "list", - Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), - Model: "baidu-embedding", - Usage: response.Usage, - } - for _, item := range response.Data { - openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ - Object: item.Object, - Index: item.Index, - Embedding: item.Embedding, - }) - } - return &openAIEmbeddingResponse -} - -func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage Usage - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // ignore blank line or wrong format - continue - } - data = data[6:] - dataChan <- data - } - stopChan <- true - }() - setEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var baiduResponse BaiduChatStreamResponse - err := json.Unmarshal([]byte(data), &baiduResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if baiduResponse.Usage.TotalTokens != 0 { - usage.TotalTokens = baiduResponse.Usage.TotalTokens - usage.PromptTokens = baiduResponse.Usage.PromptTokens - usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens - } - response := streamResponseBaidu2OpenAI(&baiduResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - err := resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - return nil, &usage -} - -func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var baiduResponse BaiduChatResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &baiduResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if baiduResponse.ErrorMsg != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: baiduResponse.ErrorMsg, - Type: "baidu_error", - Param: "", - Code: baiduResponse.ErrorCode, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responseBaidu2OpenAI(&baiduResponse) - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage -} - -func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var baiduResponse BaiduEmbeddingResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &baiduResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if baiduResponse.ErrorMsg != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: baiduResponse.ErrorMsg, - Type: "baidu_error", - Param: "", - Code: baiduResponse.ErrorCode, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage -} - -func getBaiduAccessToken(apiKey string) (string, error) { - if val, ok := baiduTokenStore.Load(apiKey); ok { - var accessToken BaiduAccessToken - if accessToken, ok = val.(BaiduAccessToken); ok { - // soon this will expire - if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { - go func() { - _, _ = getBaiduAccessTokenHelper(apiKey) - }() - } - return accessToken.AccessToken, nil - } - } - accessToken, err := getBaiduAccessTokenHelper(apiKey) - if err != nil { - return "", err - } - if accessToken == nil { - return "", errors.New("getBaiduAccessToken return a nil token") - } - return (*accessToken).AccessToken, nil -} - -func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { - parts := strings.Split(apiKey, "|") - if len(parts) != 2 { - return nil, errors.New("invalid baidu apikey") - } - req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", - parts[0], parts[1]), nil) - if err != nil { - return nil, err - } - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Accept", "application/json") - res, err := impatientHTTPClient.Do(req) - if err != nil { - return nil, err - } - defer res.Body.Close() - - var accessToken BaiduAccessToken - err = json.NewDecoder(res.Body).Decode(&accessToken) - if err != nil { - return nil, err - } - if accessToken.Error != "" { - return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) - } - if accessToken.AccessToken == "" { - return nil, errors.New("getBaiduAccessTokenHelper get empty access token") - } - accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) - baiduTokenStore.Store(apiKey, accessToken) - return &accessToken, nil -} diff --git a/controller/relay-chat.go b/controller/relay-chat.go new file mode 100644 index 00000000..746947e0 --- /dev/null +++ b/controller/relay-chat.go @@ -0,0 +1,127 @@ +package controller + +import ( + "context" + "errors" + "net/http" + "one-api/common" + "one-api/model" + "one-api/providers" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +func relayChatHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode { + + // 获取请求参数 + channelType := c.GetInt("channel") + channelId := c.GetInt("channel_id") + tokenId := c.GetInt("token_id") + userId := c.GetInt("id") + // consumeQuota := c.GetBool("consume_quota") + group := c.GetString("group") + + // 获取 Provider + chatProvider := GetChatProvider(channelType, c) + if chatProvider == nil { + return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented) + } + + // 获取请求体 + var chatRequest types.ChatCompletionRequest + err := common.UnmarshalBodyReusable(c, &chatRequest) + if err != nil { + return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + + // 检查模型映射 + isModelMapped := false + modelMap, err := parseModelMapping(c.GetString("model_mapping")) + if err != nil { + return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap != nil && modelMap[chatRequest.Model] != "" { + chatRequest.Model = modelMap[chatRequest.Model] + isModelMapped = true + } + + // 开始计算Tokens + var promptTokens int + promptTokens = common.CountTokenMessages(chatRequest.Messages, chatRequest.Model) + + // 计算预付费配额 + quotaInfo := &QuotaInfo{ + modelName: chatRequest.Model, + promptTokens: promptTokens, + userId: userId, + channelId: channelId, + tokenId: tokenId, + } + quotaInfo.initQuotaInfo(group) + quota_err := quotaInfo.preQuotaConsumption() + if quota_err != nil { + return quota_err + } + + usage, openAIErrorWithStatusCode := chatProvider.ChatCompleteResponse(&chatRequest, isModelMapped, promptTokens) + + if openAIErrorWithStatusCode != nil { + if quotaInfo.preConsumedQuota != 0 { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) + if err != nil { + common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) + } + }(c.Request.Context()) + } + return openAIErrorWithStatusCode + } + + tokenName := c.GetString("token_name") + defer func(ctx context.Context) { + go func() { + err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) + if err != nil { + common.LogError(ctx, err.Error()) + } + }() + }(c.Request.Context()) + + return nil +} + +func GetChatProvider(channelType int, c *gin.Context) providers.ChatProviderAction { + switch channelType { + case common.ChannelTypeOpenAI: + return providers.CreateOpenAIProvider(c, "") + case common.ChannelTypeAzure: + return providers.CreateAzureProvider(c) + case common.ChannelTypeAli: + return providers.CreateAliAIProvider(c) + case common.ChannelTypeTencent: + return providers.CreateTencentProvider(c) + case common.ChannelTypeBaidu: + return providers.CreateBaiduProvider(c) + case common.ChannelTypeAnthropic: + return providers.CreateClaudeProvider(c) + case common.ChannelTypePaLM: + return providers.CreatePalmProvider(c) + case common.ChannelTypeZhipu: + return providers.CreateZhipuProvider(c) + case common.ChannelTypeXunfei: + return providers.CreateXunfeiProvider(c) + } + + baseURL := common.ChannelBaseURLs[channelType] + if c.GetString("base_url") != "" { + baseURL = c.GetString("base_url") + } + + if baseURL != "" { + return providers.CreateOpenAIProvider(c, baseURL) + } + + return nil +} diff --git a/controller/relay-claude.go b/controller/relay-claude.go deleted file mode 100644 index 1f4a3e7b..00000000 --- a/controller/relay-claude.go +++ /dev/null @@ -1,220 +0,0 @@ -package controller - -import ( - "bufio" - "encoding/json" - "fmt" - "github.com/gin-gonic/gin" - "io" - "net/http" - "one-api/common" - "strings" -) - -type ClaudeMetadata struct { - UserId string `json:"user_id"` -} - -type ClaudeRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - MaxTokensToSample int `json:"max_tokens_to_sample"` - StopSequences []string `json:"stop_sequences,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - //ClaudeMetadata `json:"metadata,omitempty"` - Stream bool `json:"stream,omitempty"` -} - -type ClaudeError struct { - Type string `json:"type"` - Message string `json:"message"` -} - -type ClaudeResponse struct { - Completion string `json:"completion"` - StopReason string `json:"stop_reason"` - Model string `json:"model"` - Error ClaudeError `json:"error"` -} - -func stopReasonClaude2OpenAI(reason string) string { - switch reason { - case "stop_sequence": - return "stop" - case "max_tokens": - return "length" - default: - return reason - } -} - -func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { - claudeRequest := ClaudeRequest{ - Model: textRequest.Model, - Prompt: "", - MaxTokensToSample: textRequest.MaxTokens, - StopSequences: nil, - Temperature: textRequest.Temperature, - TopP: textRequest.TopP, - Stream: textRequest.Stream, - } - if claudeRequest.MaxTokensToSample == 0 { - claudeRequest.MaxTokensToSample = 1000000 - } - prompt := "" - for _, message := range textRequest.Messages { - if message.Role == "user" { - prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) - } else if message.Role == "assistant" { - prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) - } else if message.Role == "system" { - prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) - } - } - prompt += "\n\nAssistant:" - claudeRequest.Prompt = prompt - return &claudeRequest -} - -func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice - choice.Delta.Content = claudeResponse.Completion - finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) - if finishReason != "null" { - choice.FinishReason = &finishReason - } - var response ChatCompletionsStreamResponse - response.Object = "chat.completion.chunk" - response.Model = claudeResponse.Model - response.Choices = []ChatCompletionsStreamResponseChoice{choice} - return &response -} - -func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { - choice := OpenAITextResponseChoice{ - Index: 0, - Message: Message{ - Role: "assistant", - Content: strings.TrimPrefix(claudeResponse.Completion, " "), - Name: nil, - }, - FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), - } - fullTextResponse := OpenAITextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), - Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, - } - return &fullTextResponse -} - -func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { - responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - createdTime := common.GetTimestamp() - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { - return i + 4, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if !strings.HasPrefix(data, "event: completion") { - continue - } - data = strings.TrimPrefix(data, "event: completion\r\ndata: ") - dataChan <- data - } - stopChan <- true - }() - setEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var claudeResponse ClaudeResponse - err := json.Unmarshal([]byte(data), &claudeResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - responseText += claudeResponse.Completion - response := streamResponseClaude2OpenAI(&claudeResponse) - response.Id = responseId - response.Created = createdTime - jsonStr, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - err := resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" - } - return nil, responseText -} - -func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - var claudeResponse ClaudeResponse - err = json.Unmarshal(responseBody, &claudeResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if claudeResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: claudeResponse.Error.Message, - Type: claudeResponse.Error.Type, - Param: "", - Code: claudeResponse.Error.Type, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responseClaude2OpenAI(&claudeResponse) - completionTokens := countTokenText(claudeResponse.Completion, model) - usage := Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, - } - fullTextResponse.Usage = usage - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &usage -} diff --git a/controller/relay-completion.go b/controller/relay-completion.go new file mode 100644 index 00000000..6087adfa --- /dev/null +++ b/controller/relay-completion.go @@ -0,0 +1,113 @@ +package controller + +import ( + "context" + "errors" + "net/http" + "one-api/common" + "one-api/model" + "one-api/providers" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +func relayCompletionHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode { + + // 获取请求参数 + channelType := c.GetInt("channel") + channelId := c.GetInt("channel_id") + tokenId := c.GetInt("token_id") + userId := c.GetInt("id") + // consumeQuota := c.GetBool("consume_quota") + group := c.GetString("group") + + // 获取 Provider + completionProvider := GetCompletionProvider(channelType, c) + if completionProvider == nil { + return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented) + } + + // 获取请求体 + var completionRequest types.CompletionRequest + err := common.UnmarshalBodyReusable(c, &completionRequest) + if err != nil { + return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + + // 检查模型映射 + isModelMapped := false + modelMap, err := parseModelMapping(c.GetString("model_mapping")) + if err != nil { + return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap != nil && modelMap[completionRequest.Model] != "" { + completionRequest.Model = modelMap[completionRequest.Model] + isModelMapped = true + } + + // 开始计算Tokens + var promptTokens int + promptTokens = common.CountTokenInput(completionRequest.Prompt, completionRequest.Model) + + // 计算预付费配额 + quotaInfo := &QuotaInfo{ + modelName: completionRequest.Model, + promptTokens: promptTokens, + userId: userId, + channelId: channelId, + tokenId: tokenId, + } + quotaInfo.initQuotaInfo(group) + quota_err := quotaInfo.preQuotaConsumption() + if quota_err != nil { + return quota_err + } + + usage, openAIErrorWithStatusCode := completionProvider.CompleteResponse(&completionRequest, isModelMapped, promptTokens) + + if openAIErrorWithStatusCode != nil { + if quotaInfo.preConsumedQuota != 0 { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) + if err != nil { + common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) + } + }(c.Request.Context()) + } + return openAIErrorWithStatusCode + } + + tokenName := c.GetString("token_name") + defer func(ctx context.Context) { + go func() { + err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) + if err != nil { + common.LogError(ctx, err.Error()) + } + }() + }(c.Request.Context()) + + return nil +} + +func GetCompletionProvider(channelType int, c *gin.Context) providers.CompletionProviderAction { + switch channelType { + case common.ChannelTypeOpenAI: + return providers.CreateOpenAIProvider(c, "") + case common.ChannelTypeAzure: + return providers.CreateAzureProvider(c) + } + + baseURL := common.ChannelBaseURLs[channelType] + if c.GetString("base_url") != "" { + baseURL = c.GetString("base_url") + } + + if baseURL != "" { + return providers.CreateOpenAIProvider(c, baseURL) + } + + return nil +} diff --git a/controller/relay-embeddings.go b/controller/relay-embeddings.go new file mode 100644 index 00000000..86189ba7 --- /dev/null +++ b/controller/relay-embeddings.go @@ -0,0 +1,117 @@ +package controller + +import ( + "context" + "errors" + "net/http" + "one-api/common" + "one-api/model" + "one-api/providers" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +func relayEmbeddingsHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode { + + // 获取请求参数 + channelType := c.GetInt("channel") + channelId := c.GetInt("channel_id") + tokenId := c.GetInt("token_id") + userId := c.GetInt("id") + // consumeQuota := c.GetBool("consume_quota") + group := c.GetString("group") + + // 获取 Provider + embeddingsProvider := GetEmbeddingsProvider(channelType, c) + if embeddingsProvider == nil { + return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented) + } + + // 获取请求体 + var embeddingsRequest types.EmbeddingRequest + err := common.UnmarshalBodyReusable(c, &embeddingsRequest) + if err != nil { + return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + + // 检查模型映射 + isModelMapped := false + modelMap, err := parseModelMapping(c.GetString("model_mapping")) + if err != nil { + return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap != nil && modelMap[embeddingsRequest.Model] != "" { + embeddingsRequest.Model = modelMap[embeddingsRequest.Model] + isModelMapped = true + } + + // 开始计算Tokens + var promptTokens int + promptTokens = common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model) + + // 计算预付费配额 + quotaInfo := &QuotaInfo{ + modelName: embeddingsRequest.Model, + promptTokens: promptTokens, + userId: userId, + channelId: channelId, + tokenId: tokenId, + } + quotaInfo.initQuotaInfo(group) + quota_err := quotaInfo.preQuotaConsumption() + if quota_err != nil { + return quota_err + } + + usage, openAIErrorWithStatusCode := embeddingsProvider.EmbeddingsResponse(&embeddingsRequest, isModelMapped, promptTokens) + + if openAIErrorWithStatusCode != nil { + if quotaInfo.preConsumedQuota != 0 { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) + if err != nil { + common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) + } + }(c.Request.Context()) + } + return openAIErrorWithStatusCode + } + + tokenName := c.GetString("token_name") + defer func(ctx context.Context) { + go func() { + err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) + if err != nil { + common.LogError(ctx, err.Error()) + } + }() + }(c.Request.Context()) + + return nil +} + +func GetEmbeddingsProvider(channelType int, c *gin.Context) providers.EmbeddingsProviderAction { + switch channelType { + case common.ChannelTypeOpenAI: + return providers.CreateOpenAIProvider(c, "") + case common.ChannelTypeAzure: + return providers.CreateAzureProvider(c) + case common.ChannelTypeAli: + return providers.CreateAliAIProvider(c) + case common.ChannelTypeBaidu: + return providers.CreateBaiduProvider(c) + } + + baseURL := common.ChannelBaseURLs[channelType] + if c.GetString("base_url") != "" { + baseURL = c.GetString("base_url") + } + + if baseURL != "" { + return providers.CreateOpenAIProvider(c, baseURL) + } + + return nil +} diff --git a/controller/relay-image.go b/controller/relay-image.go deleted file mode 100644 index 1d1b71ba..00000000 --- a/controller/relay-image.go +++ /dev/null @@ -1,206 +0,0 @@ -package controller - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "one-api/common" - "one-api/model" - - "github.com/gin-gonic/gin" -) - -func isWithinRange(element string, value int) bool { - if _, ok := common.DalleGenerationImageAmounts[element]; !ok { - return false - } - - min := common.DalleGenerationImageAmounts[element][0] - max := common.DalleGenerationImageAmounts[element][1] - - return value >= min && value <= max -} - -func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - imageModel := "dall-e-2" - imageSize := "1024x1024" - - tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") - group := c.GetString("group") - - var imageRequest ImageRequest - if consumeQuota { - err := common.UnmarshalBodyReusable(c, &imageRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - } - - // Size validation - if imageRequest.Size != "" { - imageSize = imageRequest.Size - } - - // Model validation - if imageRequest.Model != "" { - imageModel = imageRequest.Model - } - - imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] - - // Check if model is supported - if hasValidSize { - if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { - if imageSize == "1024x1024" { - imageCostRatio *= 2 - } else { - imageCostRatio *= 1.5 - } - } - } else { - return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) - } - - // Prompt validation - if imageRequest.Prompt == "" { - return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) - } - - // Check prompt length - if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { - return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) - } - - // Number of generated images validation - if isWithinRange(imageModel, imageRequest.N) == false { - return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) - } - - // map model name - modelMapping := c.GetString("model_mapping") - isModelMapped := false - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[imageModel] != "" { - imageModel = modelMap[imageModel] - isModelMapped = true - } - } - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) - var requestBody io.Reader - if isModelMapped { - jsonStr, err := json.Marshal(imageRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body - } - - modelRatio := common.GetModelRatio(imageModel) - groupRatio := common.GetGroupRatio(group) - ratio := modelRatio * groupRatio - userQuota, err := model.CacheGetUserQuota(userId) - - quota := int(ratio*imageCostRatio*1000) * imageRequest.N - - if consumeQuota && userQuota-quota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) - - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - - resp, err := httpClient.Do(req) - if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) - } - - err = req.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - var textResponse ImageResponse - - defer func(ctx context.Context) { - if consumeQuota { - err := model.PostConsumeTokenQuota(tokenId, quota) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - } - }(c.Request.Context()) - - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - } - - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - return nil -} diff --git a/controller/relay-openai.go b/controller/relay-openai.go deleted file mode 100644 index dcd20115..00000000 --- a/controller/relay-openai.go +++ /dev/null @@ -1,144 +0,0 @@ -package controller - -import ( - "bufio" - "bytes" - "encoding/json" - "github.com/gin-gonic/gin" - "io" - "net/http" - "one-api/common" - "strings" -) - -func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { - responseText := "" - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // ignore blank line or wrong format - continue - } - if data[:6] != "data: " && data[:6] != "[DONE]" { - continue - } - dataChan <- data - data = data[6:] - if !strings.HasPrefix(data, "[DONE]") { - switch relayMode { - case RelayModeChatCompletions: - var streamResponse ChatCompletionsStreamResponse - err := json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - continue // just ignore the error - } - for _, choice := range streamResponse.Choices { - responseText += choice.Delta.Content - } - case RelayModeCompletions: - var streamResponse CompletionsStreamResponse - err := json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - continue - } - for _, choice := range streamResponse.Choices { - responseText += choice.Text - } - } - } - } - stopChan <- true - }() - setEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if strings.HasPrefix(data, "data: [DONE]") { - data = data[:12] - } - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - c.Render(-1, common.CustomEvent{Data: data}) - return true - case <-stopChan: - return false - } - }) - err := resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" - } - return nil, responseText -} - -func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { - var textResponse TextResponse - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if textResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: textResponse.Error, - StatusCode: resp.StatusCode, - }, nil - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - } - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - _, err := io.Copy(c.Writer, resp.Body) - if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - - if textResponse.Usage.TotalTokens == 0 { - completionTokens := 0 - for _, choice := range textResponse.Choices { - completionTokens += countTokenText(choice.Message.StringContent(), model) - } - textResponse.Usage = Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, - } - } - return nil, &textResponse.Usage -} diff --git a/controller/relay-palm.go b/controller/relay-palm.go deleted file mode 100644 index 2bd0bcd8..00000000 --- a/controller/relay-palm.go +++ /dev/null @@ -1,205 +0,0 @@ -package controller - -import ( - "encoding/json" - "fmt" - "github.com/gin-gonic/gin" - "io" - "net/http" - "one-api/common" -) - -// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body -// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body - -type PaLMChatMessage struct { - Author string `json:"author"` - Content string `json:"content"` -} - -type PaLMFilter struct { - Reason string `json:"reason"` - Message string `json:"message"` -} - -type PaLMPrompt struct { - Messages []PaLMChatMessage `json:"messages"` -} - -type PaLMChatRequest struct { - Prompt PaLMPrompt `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` -} - -type PaLMError struct { - Code int `json:"code"` - Message string `json:"message"` - Status string `json:"status"` -} - -type PaLMChatResponse struct { - Candidates []PaLMChatMessage `json:"candidates"` - Messages []Message `json:"messages"` - Filters []PaLMFilter `json:"filters"` - Error PaLMError `json:"error"` -} - -func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { - palmRequest := PaLMChatRequest{ - Prompt: PaLMPrompt{ - Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), - }, - Temperature: textRequest.Temperature, - CandidateCount: textRequest.N, - TopP: textRequest.TopP, - TopK: textRequest.MaxTokens, - } - for _, message := range textRequest.Messages { - palmMessage := PaLMChatMessage{ - Content: message.StringContent(), - } - if message.Role == "user" { - palmMessage.Author = "0" - } else { - palmMessage.Author = "1" - } - palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) - } - return &palmRequest -} - -func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ - Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), - } - for i, candidate := range response.Candidates { - choice := OpenAITextResponseChoice{ - Index: i, - Message: Message{ - Role: "assistant", - Content: candidate.Content, - }, - FinishReason: "stop", - } - fullTextResponse.Choices = append(fullTextResponse.Choices, choice) - } - return &fullTextResponse -} - -func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice - if len(palmResponse.Candidates) > 0 { - choice.Delta.Content = palmResponse.Candidates[0].Content - } - choice.FinishReason = &stopFinishReason - var response ChatCompletionsStreamResponse - response.Object = "chat.completion.chunk" - response.Model = "palm2" - response.Choices = []ChatCompletionsStreamResponseChoice{choice} - return &response -} - -func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { - responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - createdTime := common.GetTimestamp() - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - common.SysError("error reading stream response: " + err.Error()) - stopChan <- true - return - } - err = resp.Body.Close() - if err != nil { - common.SysError("error closing stream response: " + err.Error()) - stopChan <- true - return - } - var palmResponse PaLMChatResponse - err = json.Unmarshal(responseBody, &palmResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - stopChan <- true - return - } - fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) - fullTextResponse.Id = responseId - fullTextResponse.Created = createdTime - if len(palmResponse.Candidates) > 0 { - responseText = palmResponse.Candidates[0].Content - } - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - stopChan <- true - return - } - dataChan <- string(jsonResponse) - stopChan <- true - }() - setEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - c.Render(-1, common.CustomEvent{Data: "data: " + data}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - err := resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" - } - return nil, responseText -} - -func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - var palmResponse PaLMChatResponse - err = json.Unmarshal(responseBody, &palmResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: palmResponse.Error.Message, - Type: palmResponse.Error.Status, - Param: "", - Code: palmResponse.Error.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) - usage := Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, - } - fullTextResponse.Usage = usage - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &usage -} diff --git a/controller/relay-text.go b/controller/relay-text.go deleted file mode 100644 index 018c8d8a..00000000 --- a/controller/relay-text.go +++ /dev/null @@ -1,649 +0,0 @@ -package controller - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "math" - "net/http" - "one-api/common" - "one-api/model" - "strings" - "time" - - "github.com/gin-gonic/gin" -) - -const ( - APITypeOpenAI = iota - APITypeClaude - APITypePaLM - APITypeBaidu - APITypeZhipu - APITypeAli - APITypeXunfei - APITypeAIProxyLibrary - APITypeTencent -) - -var httpClient *http.Client -var impatientHTTPClient *http.Client - -func init() { - if common.RelayTimeout == 0 { - httpClient = &http.Client{} - } else { - httpClient = &http.Client{ - Timeout: time.Duration(common.RelayTimeout) * time.Second, - } - } - - impatientHTTPClient = &http.Client{ - Timeout: 5 * time.Second, - } -} - -func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") - group := c.GetString("group") - var textRequest GeneralOpenAIRequest - if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { - err := common.UnmarshalBodyReusable(c, &textRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - } - if relayMode == RelayModeModerations && textRequest.Model == "" { - textRequest.Model = "text-moderation-latest" - } - if relayMode == RelayModeEmbeddings && textRequest.Model == "" { - textRequest.Model = c.Param("model") - } - // request validation - if textRequest.Model == "" { - return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) - } - switch relayMode { - case RelayModeCompletions: - if textRequest.Prompt == "" { - return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) - } - case RelayModeChatCompletions: - if textRequest.Messages == nil || len(textRequest.Messages) == 0 { - return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) - } - case RelayModeEmbeddings: - case RelayModeModerations: - if textRequest.Input == "" { - return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) - } - case RelayModeEdits: - if textRequest.Instruction == "" { - return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) - } - } - // map model name - modelMapping := c.GetString("model_mapping") - isModelMapped := false - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[textRequest.Model] != "" { - textRequest.Model = modelMap[textRequest.Model] - isModelMapped = true - } - } - apiType := APITypeOpenAI - switch channelType { - case common.ChannelTypeAnthropic: - apiType = APITypeClaude - case common.ChannelTypeBaidu: - apiType = APITypeBaidu - case common.ChannelTypePaLM: - apiType = APITypePaLM - case common.ChannelTypeZhipu: - apiType = APITypeZhipu - case common.ChannelTypeAli: - apiType = APITypeAli - case common.ChannelTypeXunfei: - apiType = APITypeXunfei - case common.ChannelTypeAIProxyLibrary: - apiType = APITypeAIProxyLibrary - case common.ChannelTypeTencent: - apiType = APITypeTencent - } - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) - switch apiType { - case APITypeOpenAI: - if channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") - } - requestURL := strings.Split(requestURL, "?")[0] - requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) - baseURL = c.GetString("base_url") - task := strings.TrimPrefix(requestURL, "/v1/") - model_ := textRequest.Model - model_ = strings.Replace(model_, ".", "", -1) - // https://github.com/songquanpeng/one-api/issues/67 - model_ = strings.TrimSuffix(model_, "-0301") - model_ = strings.TrimSuffix(model_, "-0314") - model_ = strings.TrimSuffix(model_, "-0613") - - requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) - fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) - } - case APITypeClaude: - fullRequestURL = "https://api.anthropic.com/v1/complete" - if baseURL != "" { - fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) - } - case APITypeBaidu: - switch textRequest.Model { - case "ERNIE-Bot": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" - case "ERNIE-Bot-turbo": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" - case "ERNIE-Bot-4": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" - case "BLOOMZ-7B": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" - case "Embedding-V1": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" - } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - var err error - if apiKey, err = getBaiduAccessToken(apiKey); err != nil { - return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) - } - fullRequestURL += "?access_token=" + apiKey - case APITypePaLM: - fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" - if baseURL != "" { - fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) - } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - fullRequestURL += "?key=" + apiKey - case APITypeZhipu: - method := "invoke" - if textRequest.Stream { - method = "sse-invoke" - } - fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) - case APITypeAli: - fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" - if relayMode == RelayModeEmbeddings { - fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" - } - case APITypeTencent: - fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" - case APITypeAIProxyLibrary: - fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) - } - var promptTokens int - var completionTokens int - switch relayMode { - case RelayModeChatCompletions: - promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) - case RelayModeCompletions: - promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) - case RelayModeModerations: - promptTokens = countTokenInput(textRequest.Input, textRequest.Model) - } - preConsumedTokens := common.PreConsumedQuota - if textRequest.MaxTokens != 0 { - preConsumedTokens = promptTokens + textRequest.MaxTokens - } - modelRatio := common.GetModelRatio(textRequest.Model) - groupRatio := common.GetGroupRatio(group) - ratio := modelRatio * groupRatio - preConsumedQuota := int(float64(preConsumedTokens) * ratio) - userQuota, err := model.CacheGetUserQuota(userId) - if err != nil { - return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) - } - if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) - } - if consumeQuota && preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) - } - } - var requestBody io.Reader - if isModelMapped { - jsonStr, err := json.Marshal(textRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body - } - switch apiType { - case APITypeClaude: - claudeRequest := requestOpenAI2Claude(textRequest) - jsonStr, err := json.Marshal(claudeRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeBaidu: - var jsonData []byte - var err error - switch relayMode { - case RelayModeEmbeddings: - baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) - jsonData, err = json.Marshal(baiduEmbeddingRequest) - default: - baiduRequest := requestOpenAI2Baidu(textRequest) - jsonData, err = json.Marshal(baiduRequest) - } - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonData) - case APITypePaLM: - palmRequest := requestOpenAI2PaLM(textRequest) - jsonStr, err := json.Marshal(palmRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeZhipu: - zhipuRequest := requestOpenAI2Zhipu(textRequest) - jsonStr, err := json.Marshal(zhipuRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeAli: - var jsonStr []byte - var err error - switch relayMode { - case RelayModeEmbeddings: - aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) - jsonStr, err = json.Marshal(aliEmbeddingRequest) - default: - aliRequest := requestOpenAI2Ali(textRequest) - jsonStr, err = json.Marshal(aliRequest) - } - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeTencent: - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - appId, secretId, secretKey, err := parseTencentConfig(apiKey) - if err != nil { - return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) - } - tencentRequest := requestOpenAI2Tencent(textRequest) - tencentRequest.AppId = appId - tencentRequest.SecretId = secretId - jsonStr, err := json.Marshal(tencentRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - sign := getTencentSign(*tencentRequest, secretKey) - c.Request.Header.Set("Authorization", sign) - requestBody = bytes.NewBuffer(jsonStr) - case APITypeAIProxyLibrary: - aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) - aiProxyLibraryRequest.LibraryId = c.GetString("library_id") - jsonStr, err := json.Marshal(aiProxyLibraryRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } - - var req *http.Request - var resp *http.Response - isStream := textRequest.Stream - - if apiType != APITypeXunfei { // cause xunfei use websocket - req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - switch apiType { - case APITypeOpenAI: - if channelType == common.ChannelTypeAzure { - req.Header.Set("api-key", apiKey) - } else { - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) - if channelType == common.ChannelTypeOpenRouter { - req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") - req.Header.Set("X-Title", "One API") - } - } - case APITypeClaude: - req.Header.Set("x-api-key", apiKey) - anthropicVersion := c.Request.Header.Get("anthropic-version") - if anthropicVersion == "" { - anthropicVersion = "2023-06-01" - } - req.Header.Set("anthropic-version", anthropicVersion) - case APITypeZhipu: - token := getZhipuToken(apiKey) - req.Header.Set("Authorization", token) - case APITypeAli: - req.Header.Set("Authorization", "Bearer "+apiKey) - if textRequest.Stream { - req.Header.Set("X-DashScope-SSE", "enable") - } - case APITypeTencent: - req.Header.Set("Authorization", apiKey) - case APITypePaLM: - // do not set Authorization header - default: - req.Header.Set("Authorization", "Bearer "+apiKey) - } - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - if isStream && c.Request.Header.Get("Accept") == "" { - req.Header.Set("Accept", "text/event-stream") - } - //req.Header.Set("Connection", c.Request.Header.Get("Connection")) - resp, err = httpClient.Do(req) - if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) - } - err = req.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - - if resp.StatusCode != http.StatusOK { - if preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - return relayErrorHandler(resp) - } - } - - var textResponse TextResponse - tokenName := c.GetString("token_name") - - defer func(ctx context.Context) { - // c.Writer.Flush() - go func() { - if consumeQuota { - quota := 0 - completionRatio := common.GetCompletionRatio(textRequest.Model) - promptTokens = textResponse.Usage.PromptTokens - completionTokens = textResponse.Usage.CompletionTokens - quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) - if ratio != 0 && quota <= 0 { - quota = 1 - } - totalTokens := promptTokens + completionTokens - if totalTokens == 0 { - // in this case, must be some error happened - // we cannot just return, because we may have to return the pre-consumed quota - quota = 0 - } - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } - if quota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) - } - } - }() - }(c.Request.Context()) - switch apiType { - case APITypeOpenAI: - if isStream { - err, responseText := openaiStreamHandler(c, resp, relayMode) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeClaude: - if isStream { - err, responseText := claudeStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeBaidu: - if isStream { - err, usage := baiduStreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - var err *OpenAIErrorWithStatusCode - var usage *Usage - switch relayMode { - case RelayModeEmbeddings: - err, usage = baiduEmbeddingHandler(c, resp) - default: - err, usage = baiduHandler(c, resp) - } - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypePaLM: - if textRequest.Stream { // PaLM2 API does not support stream - err, responseText := palmStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeZhipu: - if isStream { - err, usage := zhipuStreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - // zhipu's API does not return prompt tokens & completion tokens - textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens - return nil - } else { - err, usage := zhipuHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - // zhipu's API does not return prompt tokens & completion tokens - textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens - return nil - } - case APITypeAli: - if isStream { - err, usage := aliStreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - var err *OpenAIErrorWithStatusCode - var usage *Usage - switch relayMode { - case RelayModeEmbeddings: - err, usage = aliEmbeddingHandler(c, resp) - default: - err, usage = aliHandler(c, resp) - } - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeXunfei: - auth := c.Request.Header.Get("Authorization") - auth = strings.TrimPrefix(auth, "Bearer ") - splits := strings.Split(auth, "|") - if len(splits) != 3 { - return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) - } - var err *OpenAIErrorWithStatusCode - var usage *Usage - if isStream { - err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) - } else { - err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) - } - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - case APITypeAIProxyLibrary: - if isStream { - err, usage := aiProxyLibraryStreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - err, usage := aiProxyLibraryHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeTencent: - if isStream { - err, responseText := tencentStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := tencentHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - default: - return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) - } -} diff --git a/controller/relay-utils.go b/controller/relay-utils.go index c7cd4766..cecd9944 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -3,133 +3,16 @@ package controller import ( "context" "encoding/json" + "errors" "fmt" - "io" + "math" "net/http" "one-api/common" "one-api/model" - "strconv" - "strings" - - "github.com/gin-gonic/gin" - "github.com/pkoukk/tiktoken-go" + "one-api/types" ) -var stopFinishReason = "stop" - -// tokenEncoderMap won't grow after initialization -var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} -var defaultTokenEncoder *tiktoken.Tiktoken - -func InitTokenEncoders() { - common.SysLog("initializing token encoders") - gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") - if err != nil { - common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) - } - defaultTokenEncoder = gpt35TokenEncoder - gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") - if err != nil { - common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) - } - for model, _ := range common.ModelRatio { - if strings.HasPrefix(model, "gpt-3.5") { - tokenEncoderMap[model] = gpt35TokenEncoder - } else if strings.HasPrefix(model, "gpt-4") { - tokenEncoderMap[model] = gpt4TokenEncoder - } else { - tokenEncoderMap[model] = nil - } - } - common.SysLog("token encoders initialized") -} - -func getTokenEncoder(model string) *tiktoken.Tiktoken { - tokenEncoder, ok := tokenEncoderMap[model] - if ok && tokenEncoder != nil { - return tokenEncoder - } - if ok { - tokenEncoder, err := tiktoken.EncodingForModel(model) - if err != nil { - common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) - tokenEncoder = defaultTokenEncoder - } - tokenEncoderMap[model] = tokenEncoder - return tokenEncoder - } - return defaultTokenEncoder -} - -func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { - if common.ApproximateTokenEnabled { - return int(float64(len(text)) * 0.38) - } - return len(tokenEncoder.Encode(text, nil, nil)) -} - -func countTokenMessages(messages []Message, model string) int { - tokenEncoder := getTokenEncoder(model) - // Reference: - // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - // https://github.com/pkoukk/tiktoken-go/issues/6 - // - // Every message follows <|start|>{role/name}\n{content}<|end|>\n - var tokensPerMessage int - var tokensPerName int - if model == "gpt-3.5-turbo-0301" { - tokensPerMessage = 4 - tokensPerName = -1 // If there's a name, the role is omitted - } else { - tokensPerMessage = 3 - tokensPerName = 1 - } - tokenNum := 0 - for _, message := range messages { - tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.StringContent()) - tokenNum += getTokenNum(tokenEncoder, message.Role) - if message.Name != nil { - tokenNum += tokensPerName - tokenNum += getTokenNum(tokenEncoder, *message.Name) - } - } - tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> - return tokenNum -} - -func countTokenInput(input any, model string) int { - switch input.(type) { - case string: - return countTokenText(input.(string), model) - case []string: - text := "" - for _, s := range input.([]string) { - text += s - } - return countTokenText(text, model) - } - return 0 -} - -func countTokenText(text string, model string) int { - tokenEncoder := getTokenEncoder(model) - return getTokenNum(tokenEncoder, text) -} - -func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { - openAIError := OpenAIError{ - Message: err.Error(), - Type: "one_api_error", - Code: code, - } - return &OpenAIErrorWithStatusCode{ - OpenAIError: openAIError, - StatusCode: statusCode, - } -} - -func shouldDisableChannel(err *OpenAIError, statusCode int) bool { +func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool { if !common.AutomaticDisableChannelEnabled { return false } @@ -145,56 +28,6 @@ func shouldDisableChannel(err *OpenAIError, statusCode int) bool { return false } -func setEventStreamHeaders(c *gin.Context) { - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") -} - -func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { - openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ - StatusCode: resp.StatusCode, - OpenAIError: OpenAIError{ - Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, - } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return - } - err = resp.Body.Close() - if err != nil { - return - } - var textResponse TextResponse - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return - } - openAIErrorWithStatusCode.OpenAIError = textResponse.Error - return -} - -func getFullRequestURL(baseURL string, requestURL string, channelType int) string { - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { - switch channelType { - case common.ChannelTypeOpenAI: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - case common.ChannelTypeAzure: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) - } - } - - return fullRequestURL -} - func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { err := model.PostConsumeTokenQuota(tokenId, quota) if err != nil { @@ -211,3 +44,110 @@ func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, c model.UpdateChannelUsedQuota(channelId, quota) } } + +func parseModelMapping(modelMapping string) (map[string]string, error) { + if modelMapping == "" || modelMapping == "{}" { + return nil, nil + } + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return nil, err + } + return modelMap, nil +} + +type QuotaInfo struct { + modelName string + promptTokens int + preConsumedTokens int + modelRatio float64 + groupRatio float64 + ratio float64 + preConsumedQuota int + userId int + channelId int + tokenId int +} + +func (q *QuotaInfo) initQuotaInfo(groupName string) { + modelRatio := common.GetModelRatio(q.modelName) + groupRatio := common.GetGroupRatio(groupName) + preConsumedTokens := common.PreConsumedQuota + ratio := modelRatio * groupRatio + preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio) + + q.preConsumedTokens = preConsumedTokens + q.modelRatio = modelRatio + q.groupRatio = groupRatio + q.ratio = ratio + q.preConsumedQuota = preConsumedQuota + + return +} + +func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode { + userQuota, err := model.CacheGetUserQuota(q.userId) + if err != nil { + return types.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + + if userQuota < q.preConsumedQuota { + return types.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + + err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota) + if err != nil { + return types.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + + if userQuota > 100*q.preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + q.preConsumedQuota = 0 + // common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) + } + + if q.preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota) + if err != nil { + return types.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + + return nil +} + +func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error { + quota := 0 + completionRatio := common.GetCompletionRatio(q.modelName) + promptTokens := usage.PromptTokens + completionTokens := usage.CompletionTokens + quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * q.ratio)) + if q.ratio != 0 && quota <= 0 { + quota = 1 + } + totalTokens := promptTokens + completionTokens + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + } + quotaDelta := quota - q.preConsumedQuota + err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta) + if err != nil { + return errors.New("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(q.userId) + if err != nil { + return errors.New("error consuming token remain quota: " + err.Error()) + } + if quota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", q.modelRatio, q.groupRatio) + model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota) + model.UpdateChannelUsedQuota(q.channelId, quota) + } + + return nil +} diff --git a/controller/relay-zhipu.go b/controller/relay-zhipu.go deleted file mode 100644 index 2e345ab5..00000000 --- a/controller/relay-zhipu.go +++ /dev/null @@ -1,301 +0,0 @@ -package controller - -import ( - "bufio" - "encoding/json" - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt" - "io" - "net/http" - "one-api/common" - "strings" - "sync" - "time" -) - -// https://open.bigmodel.cn/doc/api#chatglm_std -// chatglm_std, chatglm_lite -// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke -// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke - -type ZhipuMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type ZhipuRequest struct { - Prompt []ZhipuMessage `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - RequestId string `json:"request_id,omitempty"` - Incremental bool `json:"incremental,omitempty"` -} - -type ZhipuResponseData struct { - TaskId string `json:"task_id"` - RequestId string `json:"request_id"` - TaskStatus string `json:"task_status"` - Choices []ZhipuMessage `json:"choices"` - Usage `json:"usage"` -} - -type ZhipuResponse struct { - Code int `json:"code"` - Msg string `json:"msg"` - Success bool `json:"success"` - Data ZhipuResponseData `json:"data"` -} - -type ZhipuStreamMetaResponse struct { - RequestId string `json:"request_id"` - TaskId string `json:"task_id"` - TaskStatus string `json:"task_status"` - Usage `json:"usage"` -} - -type zhipuTokenData struct { - Token string - ExpiryTime time.Time -} - -var zhipuTokens sync.Map -var expSeconds int64 = 24 * 3600 - -func getZhipuToken(apikey string) string { - data, ok := zhipuTokens.Load(apikey) - if ok { - tokenData := data.(zhipuTokenData) - if time.Now().Before(tokenData.ExpiryTime) { - return tokenData.Token - } - } - - split := strings.Split(apikey, ".") - if len(split) != 2 { - common.SysError("invalid zhipu key: " + apikey) - return "" - } - - id := split[0] - secret := split[1] - - expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6 - expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second) - - timestamp := time.Now().UnixNano() / 1e6 - - payload := jwt.MapClaims{ - "api_key": id, - "exp": expMillis, - "timestamp": timestamp, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) - - token.Header["alg"] = "HS256" - token.Header["sign_type"] = "SIGN" - - tokenString, err := token.SignedString([]byte(secret)) - if err != nil { - return "" - } - - zhipuTokens.Store(apikey, zhipuTokenData{ - Token: tokenString, - ExpiryTime: expiryTime, - }) - - return tokenString -} - -func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { - messages := make([]ZhipuMessage, 0, len(request.Messages)) - for _, message := range request.Messages { - if message.Role == "system" { - messages = append(messages, ZhipuMessage{ - Role: "system", - Content: message.StringContent(), - }) - messages = append(messages, ZhipuMessage{ - Role: "user", - Content: "Okay", - }) - } else { - messages = append(messages, ZhipuMessage{ - Role: message.Role, - Content: message.StringContent(), - }) - } - } - return &ZhipuRequest{ - Prompt: messages, - Temperature: request.Temperature, - TopP: request.TopP, - Incremental: false, - } -} - -func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ - Id: response.Data.TaskId, - Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)), - Usage: response.Data.Usage, - } - for i, choice := range response.Data.Choices { - openaiChoice := OpenAITextResponseChoice{ - Index: i, - Message: Message{ - Role: choice.Role, - Content: strings.Trim(choice.Content, "\""), - }, - FinishReason: "", - } - if i == len(response.Data.Choices)-1 { - openaiChoice.FinishReason = "stop" - } - fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice) - } - return &fullTextResponse -} - -func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice - choice.Delta.Content = zhipuResponse - response := ChatCompletionsStreamResponse{ - Object: "chat.completion.chunk", - Created: common.GetTimestamp(), - Model: "chatglm", - Choices: []ChatCompletionsStreamResponseChoice{choice}, - } - return &response -} - -func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) { - var choice ChatCompletionsStreamResponseChoice - choice.Delta.Content = "" - choice.FinishReason = &stopFinishReason - response := ChatCompletionsStreamResponse{ - Id: zhipuResponse.RequestId, - Object: "chat.completion.chunk", - Created: common.GetTimestamp(), - Model: "chatglm", - Choices: []ChatCompletionsStreamResponseChoice{choice}, - } - return &response, &zhipuResponse.Usage -} - -func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage *Usage - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { - return i + 2, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - metaChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - lines := strings.Split(data, "\n") - for i, line := range lines { - if len(line) < 5 { - continue - } - if line[:5] == "data:" { - dataChan <- line[5:] - if i != len(lines)-1 { - dataChan <- "\n" - } - } else if line[:5] == "meta:" { - metaChan <- line[5:] - } - } - } - stopChan <- true - }() - setEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - response := streamResponseZhipu2OpenAI(data) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case data := <-metaChan: - var zhipuResponse ZhipuStreamMetaResponse - err := json.Unmarshal([]byte(data), &zhipuResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - usage = zhipuUsage - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - err := resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - return nil, usage -} - -func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var zhipuResponse ZhipuResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &zhipuResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if !zhipuResponse.Success { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: zhipuResponse.Msg, - Type: "zhipu_error", - Param: "", - Code: zhipuResponse.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage -} diff --git a/controller/relay.go b/controller/relay.go index f91ba6da..01519269 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/types" "strconv" "strings" @@ -234,41 +235,46 @@ type CompletionsStreamResponse struct { } func Relay(c *gin.Context) { - relayMode := RelayModeUnknown + var err *types.OpenAIErrorWithStatusCode + + // relayMode := RelayModeUnknown if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { - relayMode = RelayModeChatCompletions + err = relayChatHelper(c) + // relayMode = RelayModeChatCompletions } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { - relayMode = RelayModeCompletions + err = relayCompletionHelper(c) + // relayMode = RelayModeCompletions } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - relayMode = RelayModeModerations - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - relayMode = RelayModeImagesGenerations - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { - relayMode = RelayModeEdits - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { - relayMode = RelayModeAudioSpeech - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - relayMode = RelayModeAudioTranscription - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - relayMode = RelayModeAudioTranslation - } - var err *OpenAIErrorWithStatusCode - switch relayMode { - case RelayModeImagesGenerations: - err = relayImageHelper(c, relayMode) - case RelayModeAudioSpeech: - fallthrough - case RelayModeAudioTranslation: - fallthrough - case RelayModeAudioTranscription: - err = relayAudioHelper(c, relayMode) - default: - err = relayTextHelper(c, relayMode) + err = relayEmbeddingsHelper(c) } + // relayMode = RelayModeEmbeddings + // } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + // relayMode = RelayModeEmbeddings + // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + // relayMode = RelayModeModerations + // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + // relayMode = RelayModeImagesGenerations + // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { + // relayMode = RelayModeEdits + // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + // relayMode = RelayModeAudioSpeech + // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + // relayMode = RelayModeAudioTranscription + // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + // relayMode = RelayModeAudioTranslation + // } + // switch relayMode { + // case RelayModeImagesGenerations: + // err = relayImageHelper(c, relayMode) + // case RelayModeAudioSpeech: + // fallthrough + // case RelayModeAudioTranslation: + // fallthrough + // case RelayModeAudioTranscription: + // err = relayAudioHelper(c, relayMode) + // default: + // err = relayTextHelper(c, relayMode) + // } if err != nil { requestId := c.GetString(common.RequestIdKey) retryTimesStr := c.Query("retry") diff --git a/main.go b/main.go index 88938516..4c897d51 100644 --- a/main.go +++ b/main.go @@ -3,9 +3,6 @@ package main import ( "embed" "fmt" - "github.com/gin-contrib/sessions" - "github.com/gin-contrib/sessions/cookie" - "github.com/gin-gonic/gin" "one-api/common" "one-api/controller" "one-api/middleware" @@ -13,6 +10,10 @@ import ( "one-api/router" "os" "strconv" + + "github.com/gin-contrib/sessions" + "github.com/gin-contrib/sessions/cookie" + "github.com/gin-gonic/gin" ) //go:embed web/build @@ -82,7 +83,7 @@ func main() { common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") model.InitBatchUpdater() } - controller.InitTokenEncoders() + common.InitTokenEncoders() // Initialize HTTP server server := gin.New() diff --git a/middleware/distributor.go b/middleware/distributor.go index c4ddc3a0..50ac1c29 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -80,7 +80,8 @@ func Distribute() func(c *gin.Context) { c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("model_mapping", channel.GetModelMapping()) - c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + c.Set("api_key", channel.Key) + // c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) switch channel.Type { case common.ChannelTypeAzure: diff --git a/model/channel.go b/model/channel.go index 7e7b42e6..8bcb8a96 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,8 +1,9 @@ package model import ( - "gorm.io/gorm" "one-api/common" + + "gorm.io/gorm" ) type Channel struct { diff --git a/providers/ali_base.go b/providers/ali_base.go new file mode 100644 index 00000000..df5e4812 --- /dev/null +++ b/providers/ali_base.go @@ -0,0 +1,50 @@ +package providers + +import ( + "fmt" + + "github.com/gin-gonic/gin" +) + +type AliAIProvider struct { + ProviderConfig +} + +type AliError struct { + Code string `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` +} + +type AliUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// 创建 AliAIProvider +// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation +func CreateAliAIProvider(c *gin.Context) *AliAIProvider { + return &AliAIProvider{ + ProviderConfig: ProviderConfig{ + BaseURL: "https://dashscope.aliyuncs.com", + ChatCompletions: "/api/v1/services/aigc/text-generation/generation", + Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding", + Context: c, + }, + } +} + +// 获取请求头 +func (p *AliAIProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key")) + + headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") + headers["Accept"] = p.Context.Request.Header.Get("Accept") + if headers["Content-Type"] == "" { + headers["Content-Type"] = "application/json" + } + + return headers +} diff --git a/providers/ali_chat.go b/providers/ali_chat.go new file mode 100644 index 00000000..12a66313 --- /dev/null +++ b/providers/ali_chat.go @@ -0,0 +1,256 @@ +package providers + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "one-api/common" + "one-api/types" + "strings" +) + +type AliMessage struct { + User string `json:"user"` + Bot string `json:"bot"` +} + +type AliInput struct { + Prompt string `json:"prompt"` + History []AliMessage `json:"history"` +} + +type AliParameters struct { + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` +} + +type AliChatRequest struct { + Model string `json:"model"` + Input AliInput `json:"input"` + Parameters AliParameters `json:"parameters,omitempty"` +} + +type AliOutput struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` +} + +type AliChatResponse struct { + Output AliOutput `json:"output"` + Usage AliUsage `json:"usage"` + AliError +} + +func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + if aliResponse.Code != "" { + return nil, &types.OpenAIErrorWithStatusCode{ + OpenAIError: types.OpenAIError{ + Message: aliResponse.Message, + Type: aliResponse.Code, + Param: aliResponse.RequestId, + Code: aliResponse.Code, + }, + StatusCode: resp.StatusCode, + } + } + + choice := types.ChatCompletionChoice{ + Index: 0, + Message: types.ChatCompletionMessage{ + Role: "assistant", + Content: aliResponse.Output.Text, + }, + FinishReason: aliResponse.Output.FinishReason, + } + + fullTextResponse := types.ChatCompletionResponse{ + ID: aliResponse.RequestId, + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []types.ChatCompletionChoice{choice}, + Usage: &types.Usage{ + PromptTokens: aliResponse.Usage.InputTokens, + CompletionTokens: aliResponse.Usage.OutputTokens, + TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens, + }, + } + + return fullTextResponse, nil +} + +func (p *AliAIProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest { + messages := make([]AliMessage, 0, len(request.Messages)) + prompt := "" + for i := 0; i < len(request.Messages); i++ { + message := request.Messages[i] + if message.Role == "system" { + messages = append(messages, AliMessage{ + User: message.StringContent(), + Bot: "Okay", + }) + continue + } else { + if i == len(request.Messages)-1 { + prompt = message.StringContent() + break + } + messages = append(messages, AliMessage{ + User: message.StringContent(), + Bot: request.Messages[i+1].StringContent(), + }) + i++ + } + } + return &AliChatRequest{ + Model: request.Model, + Input: AliInput{ + Prompt: prompt, + History: messages, + }, + } +} + +func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + + requestBody := p.getChatRequestBody(request) + fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + headers["X-DashScope-SSE"] = "enable" + } + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + if request.Stream { + openAIErrorWithStatusCode, usage = p.sendStreamRequest(req) + if openAIErrorWithStatusCode != nil { + return + } + + if usage == nil { + usage = &types.Usage{ + PromptTokens: 0, + CompletionTokens: 0, + TotalTokens: 0, + } + } + + } else { + aliResponse := &AliChatResponse{} + openAIErrorWithStatusCode = p.sendRequest(req, aliResponse) + if openAIErrorWithStatusCode != nil { + return + } + + usage = &types.Usage{ + PromptTokens: aliResponse.Usage.InputTokens, + CompletionTokens: aliResponse.Usage.OutputTokens, + TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens, + } + } + return +} + +func (p *AliAIProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse { + var choice types.ChatCompletionStreamChoice + choice.Delta.Content = aliResponse.Output.Text + if aliResponse.Output.FinishReason != "null" { + finishReason := aliResponse.Output.FinishReason + choice.FinishReason = &finishReason + } + + response := types.ChatCompletionStreamResponse{ + ID: aliResponse.RequestId, + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "ernie-bot", + Choices: []types.ChatCompletionStreamChoice{choice}, + } + return &response +} + +func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) { + usage = &types.Usage{} + // 发送请求 + resp, err := common.HttpClient.Do(req) + if err != nil { + return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil + } + + if common.IsFailureStatusCode(resp) { + return p.handleErrorResp(resp), nil + } + + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 { // ignore blank line or wrong format + continue + } + if data[:5] != "data:" { + continue + } + data = data[5:] + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(p.Context) + lastResponseText := "" + p.Context.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var aliResponse AliChatResponse + err := json.Unmarshal([]byte(data), &aliResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if aliResponse.Usage.OutputTokens != 0 { + usage.PromptTokens = aliResponse.Usage.InputTokens + usage.CompletionTokens = aliResponse.Usage.OutputTokens + usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens + } + response := p.streamResponseAli2OpenAI(&aliResponse) + response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) + lastResponseText = aliResponse.Output.Text + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + + return nil, usage +} diff --git a/providers/ali_embeddings.go b/providers/ali_embeddings.go new file mode 100644 index 00000000..913c371a --- /dev/null +++ b/providers/ali_embeddings.go @@ -0,0 +1,94 @@ +package providers + +import ( + "net/http" + "one-api/common" + "one-api/types" +) + +type AliEmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type AliEmbedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type AliEmbeddingResponse struct { + Output struct { + Embeddings []AliEmbedding `json:"embeddings"` + } `json:"output"` + Usage AliUsage `json:"usage"` + AliError +} + +func (aliResponse *AliEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + if aliResponse.Code != "" { + return nil, &types.OpenAIErrorWithStatusCode{ + OpenAIError: types.OpenAIError{ + Message: aliResponse.Message, + Type: aliResponse.Code, + Param: aliResponse.RequestId, + Code: aliResponse.Code, + }, + StatusCode: resp.StatusCode, + } + } + + openAIEmbeddingResponse := &types.EmbeddingResponse{ + Object: "list", + Data: make([]types.Embedding, 0, len(aliResponse.Output.Embeddings)), + Model: "text-embedding-v1", + Usage: &types.Usage{TotalTokens: aliResponse.Usage.TotalTokens}, + } + + for _, item := range aliResponse.Output.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{ + Object: `embedding`, + Index: item.TextIndex, + Embedding: item.Embedding, + }) + } + + return openAIEmbeddingResponse, nil +} + +func (p *AliAIProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest { + return &AliEmbeddingRequest{ + Model: "text-embedding-v1", + Input: struct { + Texts []string `json:"texts"` + }{ + Texts: request.ParseInput(), + }, + } +} + +func (p *AliAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + + requestBody := p.getEmbeddingsRequestBody(request) + fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model) + headers := p.GetRequestHeaders() + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + aliEmbeddingResponse := &AliEmbeddingResponse{} + openAIErrorWithStatusCode = p.sendRequest(req, aliEmbeddingResponse) + if openAIErrorWithStatusCode != nil { + return + } + usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens} + + return usage, nil +} diff --git a/providers/api2d_base.go b/providers/api2d_base.go new file mode 100644 index 00000000..e4b4cc28 --- /dev/null +++ b/providers/api2d_base.go @@ -0,0 +1,14 @@ +package providers + +import "github.com/gin-gonic/gin" + +type Api2dProvider struct { + *OpenAIProvider +} + +// 创建 OpenAIProvider +func CreateApi2dProvider(c *gin.Context) *Api2dProvider { + return &Api2dProvider{ + OpenAIProvider: CreateOpenAIProvider(c, "https://oa.api2d.net"), + } +} diff --git a/providers/azure_base.go b/providers/azure_base.go new file mode 100644 index 00000000..0f1aa017 --- /dev/null +++ b/providers/azure_base.go @@ -0,0 +1,41 @@ +package providers + +import ( + "github.com/gin-gonic/gin" +) + +type AzureProvider struct { + OpenAIProvider +} + +// 创建 OpenAIProvider +func CreateAzureProvider(c *gin.Context) *AzureProvider { + return &AzureProvider{ + OpenAIProvider: OpenAIProvider{ + ProviderConfig: ProviderConfig{ + BaseURL: "", + Completions: "/completions", + ChatCompletions: "/chat/completions", + Embeddings: "/embeddings", + AudioSpeech: "/audio/speech", + AudioTranscriptions: "/audio/transcriptions", + AudioTranslations: "/audio/translations", + Context: c, + }, + isAzure: true, + }, + } +} + +// // 获取完整请求 URL +// func (p *AzureProvider) GetFullRequestURL(requestURL string, modelName string) string { +// apiVersion := p.Context.GetString("api_version") +// requestURL = fmt.Sprintf("/openai/deployments/%s/%s?api-version=%s", modelName, requestURL, apiVersion) +// baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") + +// if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { +// requestURL = strings.TrimPrefix(requestURL, "/openai/deployments") +// } + +// return fmt.Sprintf("%s%s", baseURL, requestURL) +// } diff --git a/providers/baidu_base.go b/providers/baidu_base.go new file mode 100644 index 00000000..85cd5ec4 --- /dev/null +++ b/providers/baidu_base.go @@ -0,0 +1,136 @@ +package providers + +import ( + "encoding/json" + "errors" + "fmt" + "one-api/common" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +var baiduTokenStore sync.Map + +type BaiduProvider struct { + ProviderConfig +} + +type BaiduAccessToken struct { + AccessToken string `json:"access_token"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + ExpiresAt time.Time `json:"-"` +} + +func CreateBaiduProvider(c *gin.Context) *BaiduProvider { + return &BaiduProvider{ + ProviderConfig: ProviderConfig{ + BaseURL: "https://aip.baidubce.com", + ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat", + Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings", + Context: c, + }, + } +} + +// 获取完整请求 URL +func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) string { + var modelNameMap = map[string]string{ + "ERNIE-Bot": "completions", + "ERNIE-Bot-turbo": "eb-instant", + "ERNIE-Bot-4": "completions_pro", + "BLOOMZ-7B": "bloomz_7b1", + "Embedding-V1": "embedding-v1", + } + + baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") + apiKey, err := p.getBaiduAccessToken() + if err != nil { + return "" + } + + return fmt.Sprintf("%s%s/%s?access_token=%s", baseURL, requestURL, modelNameMap[modelName], apiKey) +} + +// 获取请求头 +func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + + headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") + headers["Accept"] = p.Context.Request.Header.Get("Accept") + if headers["Content-Type"] == "" { + headers["Content-Type"] = "application/json" + } + + return headers +} + +func (p *BaiduProvider) getBaiduAccessToken() (string, error) { + apiKey := p.Context.GetString("api_key") + if val, ok := baiduTokenStore.Load(apiKey); ok { + var accessToken BaiduAccessToken + if accessToken, ok = val.(BaiduAccessToken); ok { + // soon this will expire + if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { + go func() { + _, _ = p.getBaiduAccessTokenHelper(apiKey) + }() + } + return accessToken.AccessToken, nil + } + } + accessToken, err := p.getBaiduAccessTokenHelper(apiKey) + if err != nil { + return "", err + } + if accessToken == nil { + return "", errors.New("getBaiduAccessToken return a nil token") + } + return (*accessToken).AccessToken, nil +} + +func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { + parts := strings.Split(apiKey, "|") + if len(parts) != 2 { + return nil, errors.New("invalid baidu apikey") + } + + client := common.NewClient() + url := fmt.Sprintf(p.BaseURL+"/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1]) + + var headers = map[string]string{ + "Content-Type": "application/json", + "Accept": "application/json", + } + + req, err := client.NewRequest("POST", url, common.WithHeader(headers)) + if err != nil { + return nil, err + } + + resp, err := common.HttpClient.Do(req) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + var accessToken BaiduAccessToken + err = json.NewDecoder(resp.Body).Decode(&accessToken) + if err != nil { + return nil, err + } + if accessToken.Error != "" { + return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) + } + if accessToken.AccessToken == "" { + return nil, errors.New("getBaiduAccessTokenHelper get empty access token") + } + accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) + baiduTokenStore.Store(apiKey, accessToken) + return &accessToken, nil +} diff --git a/providers/baidu_chat.go b/providers/baidu_chat.go new file mode 100644 index 00000000..26fafa96 --- /dev/null +++ b/providers/baidu_chat.go @@ -0,0 +1,228 @@ +package providers + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "one-api/common" + "one-api/types" + "strings" +) + +type BaiduMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type BaiduChatRequest struct { + Messages []BaiduMessage `json:"messages"` + Stream bool `json:"stream"` + UserId string `json:"user_id,omitempty"` +} + +type BaiduChatResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Result string `json:"result"` + IsTruncated bool `json:"is_truncated"` + NeedClearHistory bool `json:"need_clear_history"` + Usage *types.Usage `json:"usage"` + BaiduError +} + +func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + if baiduResponse.ErrorMsg != "" { + return nil, &types.OpenAIErrorWithStatusCode{ + OpenAIError: types.OpenAIError{ + Message: baiduResponse.ErrorMsg, + Type: "baidu_error", + Param: "", + Code: baiduResponse.ErrorCode, + }, + StatusCode: resp.StatusCode, + } + } + + choice := types.ChatCompletionChoice{ + Index: 0, + Message: types.ChatCompletionMessage{ + Role: "assistant", + Content: baiduResponse.Result, + }, + FinishReason: "stop", + } + + fullTextResponse := types.ChatCompletionResponse{ + ID: baiduResponse.Id, + Object: "chat.completion", + Created: baiduResponse.Created, + Choices: []types.ChatCompletionChoice{choice}, + Usage: baiduResponse.Usage, + } + + return fullTextResponse, nil +} + +type BaiduChatStreamResponse struct { + BaiduChatResponse + SentenceId int `json:"sentence_id"` + IsEnd bool `json:"is_end"` +} + +type BaiduError struct { + ErrorCode int `json:"error_code"` + ErrorMsg string `json:"error_msg"` +} + +func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaiduChatRequest { + messages := make([]BaiduMessage, 0, len(request.Messages)) + for _, message := range request.Messages { + if message.Role == "system" { + messages = append(messages, BaiduMessage{ + Role: "user", + Content: message.StringContent(), + }) + messages = append(messages, BaiduMessage{ + Role: "assistant", + Content: "Okay", + }) + } else { + messages = append(messages, BaiduMessage{ + Role: message.Role, + Content: message.StringContent(), + }) + } + } + return &BaiduChatRequest{ + Messages: messages, + Stream: request.Stream, + } +} + +func (p *BaiduProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + requestBody := p.getChatRequestBody(request) + fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) + if fullRequestURL == "" { + return nil, types.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError) + } + + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + } + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + if request.Stream { + openAIErrorWithStatusCode, usage = p.sendStreamRequest(req) + if openAIErrorWithStatusCode != nil { + return + } + + } else { + baiduChatRequest := &BaiduChatResponse{} + openAIErrorWithStatusCode = p.sendRequest(req, baiduChatRequest) + if openAIErrorWithStatusCode != nil { + return + } + + usage = baiduChatRequest.Usage + } + return + +} + +func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse { + var choice types.ChatCompletionStreamChoice + choice.Delta.Content = baiduResponse.Result + if baiduResponse.IsEnd { + choice.FinishReason = &stopFinishReason + } + + response := types.ChatCompletionStreamResponse{ + ID: baiduResponse.Id, + Object: "chat.completion.chunk", + Created: baiduResponse.Created, + Model: "ernie-bot", + Choices: []types.ChatCompletionStreamChoice{choice}, + } + return &response +} + +func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) { + usage = &types.Usage{} + // 发送请求 + resp, err := common.HttpClient.Do(req) + if err != nil { + return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil + } + + if common.IsFailureStatusCode(resp) { + return p.handleErrorResp(resp), nil + } + + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // ignore blank line or wrong format + continue + } + data = data[6:] + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(p.Context) + p.Context.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var baiduResponse BaiduChatStreamResponse + err := json.Unmarshal([]byte(data), &baiduResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if baiduResponse.Usage.TotalTokens != 0 { + usage.TotalTokens = baiduResponse.Usage.TotalTokens + usage.PromptTokens = baiduResponse.Usage.PromptTokens + usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens + } + response := p.streamResponseBaidu2OpenAI(&baiduResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + + return nil, usage +} diff --git a/providers/baidu_embeddings.go b/providers/baidu_embeddings.go new file mode 100644 index 00000000..9318d3f3 --- /dev/null +++ b/providers/baidu_embeddings.go @@ -0,0 +1,88 @@ +package providers + +import ( + "net/http" + "one-api/common" + "one-api/types" +) + +type BaiduEmbeddingRequest struct { + Input []string `json:"input"` +} + +type BaiduEmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +type BaiduEmbeddingResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Data []BaiduEmbeddingData `json:"data"` + Usage types.Usage `json:"usage"` + BaiduError +} + +func (p *BaiduProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *BaiduEmbeddingRequest { + return &BaiduEmbeddingRequest{ + Input: request.ParseInput(), + } +} + +func (baiduResponse *BaiduEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + if baiduResponse.ErrorMsg != "" { + return nil, &types.OpenAIErrorWithStatusCode{ + OpenAIError: types.OpenAIError{ + Message: baiduResponse.ErrorMsg, + Type: "baidu_error", + Param: "", + Code: baiduResponse.ErrorCode, + }, + StatusCode: resp.StatusCode, + } + } + + openAIEmbeddingResponse := &types.EmbeddingResponse{ + Object: "list", + Data: make([]types.Embedding, 0, len(baiduResponse.Data)), + Model: "text-embedding-v1", + Usage: &baiduResponse.Usage, + } + + for _, item := range baiduResponse.Data { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{ + Object: item.Object, + Index: item.Index, + Embedding: item.Embedding, + }) + } + + return openAIEmbeddingResponse, nil +} + +func (p *BaiduProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + + requestBody := p.getEmbeddingsRequestBody(request) + fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model) + if fullRequestURL == "" { + return nil, types.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError) + } + + headers := p.GetRequestHeaders() + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + baiduEmbeddingResponse := &BaiduEmbeddingResponse{} + openAIErrorWithStatusCode = p.sendRequest(req, baiduEmbeddingResponse) + if openAIErrorWithStatusCode != nil { + return + } + usage = &baiduEmbeddingResponse.Usage + + return usage, nil +} diff --git a/providers/base.go b/providers/base.go new file mode 100644 index 00000000..895b157a --- /dev/null +++ b/providers/base.go @@ -0,0 +1,150 @@ +package providers + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/model" + "one-api/types" + "strconv" + "strings" + + "github.com/gin-gonic/gin" +) + +var stopFinishReason = "stop" + +type ProviderConfig struct { + BaseURL string + Completions string + ChatCompletions string + Embeddings string + AudioSpeech string + AudioTranscriptions string + AudioTranslations string + Proxy string + Context *gin.Context +} + +type BaseProviderAction interface { + GetBaseURL() string + GetFullRequestURL(requestURL string, modelName string) string + GetRequestHeaders() (headers map[string]string) +} + +type CompletionProviderAction interface { + BaseProviderAction + CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) +} + +type ChatProviderAction interface { + BaseProviderAction + ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) +} + +type EmbeddingsProviderAction interface { + BaseProviderAction + EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) +} + +type BalanceProviderAction interface { + Balance(channel *model.Channel) (float64, error) +} + +func (p *ProviderConfig) GetBaseURL() string { + if p.Context.GetString("base_url") != "" { + return p.Context.GetString("base_url") + } + + return p.BaseURL +} + +func (p *ProviderConfig) GetFullRequestURL(requestURL string, modelName string) string { + baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") + + return fmt.Sprintf("%s%s", baseURL, requestURL) +} + +func setEventStreamHeaders(c *gin.Context) { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") +} + +func (p *ProviderConfig) handleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ + StatusCode: resp.StatusCode, + OpenAIError: types.OpenAIError{ + Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + err = resp.Body.Close() + if err != nil { + return + } + var errorResponse types.OpenAIErrorResponse + err = json.Unmarshal(responseBody, &errorResponse) + if err != nil { + return + } + if errorResponse.Error.Type != "" { + openAIErrorWithStatusCode.OpenAIError = errorResponse.Error + } else { + openAIErrorWithStatusCode.OpenAIError.Message = string(responseBody) + } + return +} + +// 供应商响应处理函数 +type ProviderResponseHandler interface { + // 请求处理函数 + requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) +} + +// 发送请求 +func (p *ProviderConfig) sendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + + // 发送请求 + resp, err := common.HttpClient.Do(req) + if err != nil { + return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) + } + + defer resp.Body.Close() + + // 处理响应 + if common.IsFailureStatusCode(resp) { + return p.handleErrorResp(resp) + } + + // 解析响应 + err = common.DecodeResponse(resp.Body, response) + if err != nil { + return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) + } + + openAIResponse, openAIErrorWithStatusCode := response.requestHandler(resp) + if openAIErrorWithStatusCode != nil { + return + } + + jsonResponse, err := json.Marshal(openAIResponse) + if err != nil { + return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) + } + p.Context.Writer.Header().Set("Content-Type", "application/json") + p.Context.Writer.WriteHeader(resp.StatusCode) + _, err = p.Context.Writer.Write(jsonResponse) + return nil +} diff --git a/providers/claude_base.go b/providers/claude_base.go new file mode 100644 index 00000000..6a94e3f2 --- /dev/null +++ b/providers/claude_base.go @@ -0,0 +1,55 @@ +package providers + +import ( + "github.com/gin-gonic/gin" +) + +type ClaudeProvider struct { + ProviderConfig +} + +type ClaudeError struct { + Type string `json:"type"` + Message string `json:"message"` +} + +func CreateClaudeProvider(c *gin.Context) *ClaudeProvider { + return &ClaudeProvider{ + ProviderConfig: ProviderConfig{ + BaseURL: "https://api.anthropic.com", + ChatCompletions: "/v1/complete", + Context: c, + }, + } +} + +// 获取请求头 +func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + + headers["x-api-key"] = p.Context.GetString("api_key") + headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") + headers["Accept"] = p.Context.Request.Header.Get("Accept") + if headers["Content-Type"] == "" { + headers["Content-Type"] = "application/json" + } + + anthropicVersion := p.Context.Request.Header.Get("anthropic-version") + if anthropicVersion == "" { + anthropicVersion = "2023-06-01" + } + headers["anthropic-version"] = anthropicVersion + + return headers +} + +func stopReasonClaude2OpenAI(reason string) string { + switch reason { + case "stop_sequence": + return "stop" + case "max_tokens": + return "length" + default: + return reason + } +} diff --git a/providers/claude_chat.go b/providers/claude_chat.go new file mode 100644 index 00000000..e6c1d11a --- /dev/null +++ b/providers/claude_chat.go @@ -0,0 +1,232 @@ +package providers + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/types" + "strings" +) + +type ClaudeMetadata struct { + UserId string `json:"user_id"` +} + +type ClaudeRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + MaxTokensToSample int `json:"max_tokens_to_sample"` + StopSequences []string `json:"stop_sequences,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + //ClaudeMetadata `json:"metadata,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type ClaudeResponse struct { + Completion string `json:"completion"` + StopReason string `json:"stop_reason"` + Model string `json:"model"` + Error ClaudeError `json:"error"` + Usage *types.Usage `json:"usage,omitempty"` +} + +func (claudeResponse *ClaudeResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + if claudeResponse.Error.Type != "" { + return nil, &types.OpenAIErrorWithStatusCode{ + OpenAIError: types.OpenAIError{ + Message: claudeResponse.Error.Message, + Type: claudeResponse.Error.Type, + Param: "", + Code: claudeResponse.Error.Type, + }, + StatusCode: resp.StatusCode, + } + } + + choice := types.ChatCompletionChoice{ + Index: 0, + Message: types.ChatCompletionMessage{ + Role: "assistant", + Content: strings.TrimPrefix(claudeResponse.Completion, " "), + Name: nil, + }, + FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), + } + fullTextResponse := types.ChatCompletionResponse{ + ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []types.ChatCompletionChoice{choice}, + } + + completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model) + claudeResponse.Usage.CompletionTokens = completionTokens + claudeResponse.Usage.TotalTokens = claudeResponse.Usage.PromptTokens + completionTokens + + fullTextResponse.Usage = claudeResponse.Usage + + return fullTextResponse, nil +} + +func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *ClaudeRequest) { + claudeRequest := ClaudeRequest{ + Model: request.Model, + Prompt: "", + MaxTokensToSample: request.MaxTokens, + StopSequences: nil, + Temperature: request.Temperature, + TopP: request.TopP, + Stream: request.Stream, + } + if claudeRequest.MaxTokensToSample == 0 { + claudeRequest.MaxTokensToSample = 1000000 + } + prompt := "" + for _, message := range request.Messages { + if message.Role == "user" { + prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) + } else if message.Role == "assistant" { + prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) + } else if message.Role == "system" { + prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) + } + } + prompt += "\n\nAssistant:" + claudeRequest.Prompt = prompt + return &claudeRequest +} + +func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + requestBody := p.getChatRequestBody(request) + fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + } + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + if request.Stream { + var responseText string + openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req) + if openAIErrorWithStatusCode != nil { + return + } + + usage.PromptTokens = promptTokens + usage.CompletionTokens = common.CountTokenText(responseText, request.Model) + usage.TotalTokens = promptTokens + usage.CompletionTokens + + } else { + var claudeResponse = &ClaudeResponse{ + Usage: &types.Usage{ + PromptTokens: promptTokens, + }, + } + openAIErrorWithStatusCode = p.sendRequest(req, claudeResponse) + if openAIErrorWithStatusCode != nil { + return + } + + usage = claudeResponse.Usage + } + return + +} + +func (p *ClaudeProvider) streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *types.ChatCompletionStreamResponse { + var choice types.ChatCompletionStreamChoice + choice.Delta.Content = claudeResponse.Completion + finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) + if finishReason != "null" { + choice.FinishReason = &finishReason + } + var response types.ChatCompletionStreamResponse + response.Object = "chat.completion.chunk" + response.Model = claudeResponse.Model + response.Choices = []types.ChatCompletionStreamChoice{choice} + return &response +} + +func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) { + // 发送请求 + resp, err := common.HttpClient.Do(req) + if err != nil { + return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" + } + + if common.IsFailureStatusCode(resp) { + return p.handleErrorResp(resp), "" + } + + defer resp.Body.Close() + + responseText := "" + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { + return i + 4, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if !strings.HasPrefix(data, "event: completion") { + continue + } + data = strings.TrimPrefix(data, "event: completion\r\ndata: ") + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(p.Context) + p.Context.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + var claudeResponse ClaudeResponse + err := json.Unmarshal([]byte(data), &claudeResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + responseText += claudeResponse.Completion + response := p.streamResponseClaude2OpenAI(&claudeResponse) + response.ID = responseId + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + + return nil, responseText +} diff --git a/providers/closeai_proxy_base.go b/providers/closeai_proxy_base.go new file mode 100644 index 00000000..9879bf38 --- /dev/null +++ b/providers/closeai_proxy_base.go @@ -0,0 +1,50 @@ +package providers + +import ( + "fmt" + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +type CloseaiProxyProvider struct { + *OpenAIProvider +} + +type OpenAICreditGrants struct { + Object string `json:"object"` + TotalGranted float64 `json:"total_granted"` + TotalUsed float64 `json:"total_used"` + TotalAvailable float64 `json:"total_available"` +} + +// 创建 CloseaiProxyProvider +func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider { + return &CloseaiProxyProvider{ + OpenAIProvider: CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"), + } +} + +func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) { + fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "") + fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key) + headers := p.GetRequestHeaders() + + client := common.NewClient() + req, err := client.NewRequest("GET", fullRequestURL, common.WithBody(nil), common.WithHeader(headers)) + if err != nil { + return 0, err + } + + // 发送请求 + var response OpenAICreditGrants + err = client.SendRequest(req, &response) + if err != nil { + return 0, err + } + + channel.UpdateBalance(response.TotalAvailable) + + return response.TotalAvailable, nil +} diff --git a/providers/openai_base.go b/providers/openai_base.go new file mode 100644 index 00000000..1f2cd096 --- /dev/null +++ b/providers/openai_base.go @@ -0,0 +1,215 @@ +package providers + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) + +type OpenAIProvider struct { + ProviderConfig + isAzure bool +} + +type OpenAIProviderResponseHandler interface { + // 请求处理函数 + requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) +} + +type OpenAIProviderStreamResponseHandler interface { + // 请求流处理函数 + requestStreamHandler() (responseText string) +} + +// 创建 OpenAIProvider +func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider { + if baseURL == "" { + baseURL = "https://api.openai.com" + } + + return &OpenAIProvider{ + ProviderConfig: ProviderConfig{ + BaseURL: baseURL, + Completions: "/v1/completions", + ChatCompletions: "/v1/chat/completions", + Embeddings: "/v1/embeddings", + AudioSpeech: "/v1/audio/speech", + AudioTranscriptions: "/v1/audio/transcriptions", + AudioTranslations: "/v1/audio/translations", + Context: c, + }, + isAzure: false, + } +} + +// 获取完整请求 URL +func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string { + baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") + + if p.isAzure { + apiVersion := p.Context.GetString("api_version") + requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion) + } + + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + if p.isAzure { + requestURL = strings.TrimPrefix(requestURL, "/openai/deployments") + } else { + requestURL = strings.TrimPrefix(requestURL, "/v1") + } + } + + return fmt.Sprintf("%s%s", baseURL, requestURL) +} + +// 获取请求头 +func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + if p.isAzure { + headers["api-key"] = p.Context.GetString("api_key") + } else { + headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key")) + } + headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") + headers["Accept"] = p.Context.Request.Header.Get("Accept") + if headers["Content-Type"] == "" { + headers["Content-Type"] = "application/json; charset=utf-8" + } + + return headers +} + +// 获取请求体 +func (p *OpenAIProvider) getRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) { + if isModelMapped { + jsonStr, err := json.Marshal(request) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + } else { + requestBody = p.Context.Request.Body + } + return +} + +// 发送请求 +func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + + // 发送请求 + resp, err := common.HttpClient.Do(req) + if err != nil { + return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) + } + + defer resp.Body.Close() + + // 处理响应 + if common.IsFailureStatusCode(resp) { + return p.handleErrorResp(resp) + } + + // 创建一个 bytes.Buffer 来存储响应体 + var buf bytes.Buffer + tee := io.TeeReader(resp.Body, &buf) + + // 解析响应 + err = common.DecodeResponse(tee, response) + if err != nil { + return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) + } + + openAIErrorWithStatusCode = response.requestHandler(resp) + if openAIErrorWithStatusCode != nil { + return + } + + for k, v := range resp.Header { + p.Context.Writer.Header().Set(k, v[0]) + } + + p.Context.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(p.Context.Writer, &buf) + if err != nil { + return types.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + } + + return nil +} + +func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) { + + resp, err := common.HttpClient.Do(req) + if err != nil { + return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" + } + + if common.IsFailureStatusCode(resp) { + return p.handleErrorResp(resp), "" + } + + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // ignore blank line or wrong format + continue + } + if data[:6] != "data: " && data[:6] != "[DONE]" { + continue + } + dataChan <- data + data = data[6:] + if !strings.HasPrefix(data, "[DONE]") { + err := json.Unmarshal([]byte(data), response) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + continue // just ignore the error + } + responseText += response.requestStreamHandler() + } + } + stopChan <- true + }() + setEventStreamHeaders(p.Context) + p.Context.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + if strings.HasPrefix(data, "data: [DONE]") { + data = data[:12] + } + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + p.Context.Render(-1, common.CustomEvent{Data: data}) + return true + case <-stopChan: + return false + } + }) + + return nil, responseText +} diff --git a/providers/openai_chat.go b/providers/openai_chat.go new file mode 100644 index 00000000..6a7247b3 --- /dev/null +++ b/providers/openai_chat.go @@ -0,0 +1,92 @@ +package providers + +import ( + "net/http" + "one-api/common" + "one-api/types" +) + +type OpenAIProviderChatResponse struct { + types.ChatCompletionResponse + types.OpenAIErrorResponse +} + +type OpenAIProviderChatStreamResponse struct { + types.ChatCompletionStreamResponse + types.OpenAIErrorResponse +} + +func (c *OpenAIProviderChatResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + if c.Error.Type != "" { + openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: c.Error, + StatusCode: resp.StatusCode, + } + return + } + return nil +} + +func (c *OpenAIProviderChatStreamResponse) requestStreamHandler() (responseText string) { + for _, choice := range c.Choices { + responseText += choice.Delta.Content + } + + return +} + +func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + requestBody, err := p.getRequestBody(&request, isModelMapped) + if err != nil { + return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) + } + + fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) + headers := p.GetRequestHeaders() + if request.Stream && headers["Accept"] == "" { + headers["Accept"] = "text/event-stream" + } + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + if request.Stream { + openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{} + var textResponse string + openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse) + if openAIErrorWithStatusCode != nil { + return + } + + usage = &types.Usage{ + PromptTokens: promptTokens, + CompletionTokens: common.CountTokenText(textResponse, request.Model), + TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model), + } + + } else { + openAIProviderChatResponse := &OpenAIProviderChatResponse{} + openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderChatResponse) + if openAIErrorWithStatusCode != nil { + return + } + + usage = openAIProviderChatResponse.Usage + + if usage.TotalTokens == 0 { + completionTokens := 0 + for _, choice := range openAIProviderChatResponse.Choices { + completionTokens += common.CountTokenText(choice.Message.StringContent(), openAIProviderChatResponse.Model) + } + usage = &types.Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + } + } + return +} diff --git a/providers/openai_completion.go b/providers/openai_completion.go new file mode 100644 index 00000000..df99903e --- /dev/null +++ b/providers/openai_completion.go @@ -0,0 +1,87 @@ +package providers + +import ( + "net/http" + "one-api/common" + "one-api/types" +) + +type OpenAIProviderCompletionResponse struct { + types.CompletionResponse + types.OpenAIErrorResponse +} + +func (c *OpenAIProviderCompletionResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + if c.Error.Type != "" { + openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: c.Error, + StatusCode: resp.StatusCode, + } + return + } + return nil +} + +func (c *OpenAIProviderCompletionResponse) requestStreamHandler() (responseText string) { + for _, choice := range c.Choices { + responseText += choice.Text + } + + return +} + +func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + requestBody, err := p.getRequestBody(&request, isModelMapped) + if err != nil { + return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) + } + + fullRequestURL := p.GetFullRequestURL(p.Completions, request.Model) + headers := p.GetRequestHeaders() + if request.Stream && headers["Accept"] == "" { + headers["Accept"] = "text/event-stream" + } + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + openAIProviderCompletionResponse := &OpenAIProviderCompletionResponse{} + if request.Stream { + // TODO + var textResponse string + openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse) + if openAIErrorWithStatusCode != nil { + return + } + + usage = &types.Usage{ + PromptTokens: promptTokens, + CompletionTokens: common.CountTokenText(textResponse, request.Model), + TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model), + } + + } else { + openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderCompletionResponse) + if openAIErrorWithStatusCode != nil { + return + } + + usage = openAIProviderCompletionResponse.Usage + + if usage.TotalTokens == 0 { + completionTokens := 0 + for _, choice := range openAIProviderCompletionResponse.Choices { + completionTokens += common.CountTokenText(choice.Text, openAIProviderCompletionResponse.Model) + } + usage = &types.Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + } + } + return +} diff --git a/providers/openai_embeddings.go b/providers/openai_embeddings.go new file mode 100644 index 00000000..00c3cc80 --- /dev/null +++ b/providers/openai_embeddings.go @@ -0,0 +1,50 @@ +package providers + +import ( + "net/http" + "one-api/common" + "one-api/types" +) + +type OpenAIProviderEmbeddingsResponse struct { + types.EmbeddingResponse + types.OpenAIErrorResponse +} + +func (c *OpenAIProviderEmbeddingsResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + if c.Error.Type != "" { + openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: c.Error, + StatusCode: resp.StatusCode, + } + return + } + return nil +} + +func (p *OpenAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + + requestBody, err := p.getRequestBody(&request, isModelMapped) + if err != nil { + return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) + } + + fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model) + headers := p.GetRequestHeaders() + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{} + openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderEmbeddingsResponse) + if openAIErrorWithStatusCode != nil { + return + } + + usage = openAIProviderEmbeddingsResponse.Usage + + return +} diff --git a/providers/openaisb_base.go b/providers/openaisb_base.go new file mode 100644 index 00000000..c7f11fb9 --- /dev/null +++ b/providers/openaisb_base.go @@ -0,0 +1,58 @@ +package providers + +import ( + "errors" + "fmt" + "one-api/common" + "one-api/model" + "strconv" + + "github.com/gin-gonic/gin" +) + +type OpenaiSBProvider struct { + *OpenAIProvider +} + +type OpenAISBUsageResponse struct { + Msg string `json:"msg"` + Data *struct { + Credit string `json:"credit"` + } `json:"data"` +} + +// 创建 OpenaiSBProvider +func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider { + return &OpenaiSBProvider{ + OpenAIProvider: CreateOpenAIProvider(c, "https://api.openai-sb.com"), + } +} + +func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) { + fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "") + fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key) + headers := p.GetRequestHeaders() + + client := common.NewClient() + req, err := client.NewRequest("GET", fullRequestURL, common.WithBody(nil), common.WithHeader(headers)) + if err != nil { + return 0, err + } + + // 发送请求 + var response OpenAISBUsageResponse + err = client.SendRequest(req, &response) + if err != nil { + return 0, err + } + + if response.Data == nil { + return 0, errors.New(response.Msg) + } + balance, err := strconv.ParseFloat(response.Data.Credit, 64) + if err != nil { + return 0, err + } + channel.UpdateBalance(balance) + return balance, nil +} diff --git a/providers/palm_base.go b/providers/palm_base.go new file mode 100644 index 00000000..40d33ac0 --- /dev/null +++ b/providers/palm_base.go @@ -0,0 +1,43 @@ +package providers + +import ( + "fmt" + "strings" + + "github.com/gin-gonic/gin" +) + +type PalmProvider struct { + ProviderConfig +} + +// 创建 PalmProvider +func CreatePalmProvider(c *gin.Context) *PalmProvider { + return &PalmProvider{ + ProviderConfig: ProviderConfig{ + BaseURL: "https://generativelanguage.googleapis.com", + ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage", + Context: c, + }, + } +} + +// 获取请求头 +func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + + headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") + headers["Accept"] = p.Context.Request.Header.Get("Accept") + if headers["Content-Type"] == "" { + headers["Content-Type"] = "application/json" + } + + return headers +} + +// 获取完整请求 URL +func (p *PalmProvider) GetFullRequestURL(requestURL string, modelName string) string { + baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") + + return fmt.Sprintf("%s%s?key=%s", baseURL, requestURL, p.Context.GetString("api_key")) +} diff --git a/providers/palm_chat.go b/providers/palm_chat.go new file mode 100644 index 00000000..37c1fde1 --- /dev/null +++ b/providers/palm_chat.go @@ -0,0 +1,232 @@ +package providers + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/types" +) + +type PaLMChatMessage struct { + Author string `json:"author"` + Content string `json:"content"` +} + +type PaLMFilter struct { + Reason string `json:"reason"` + Message string `json:"message"` +} + +type PaLMPrompt struct { + Messages []PaLMChatMessage `json:"messages"` +} + +type PaLMChatRequest struct { + Prompt PaLMPrompt `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` +} + +type PaLMError struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` +} + +type PaLMChatResponse struct { + Candidates []PaLMChatMessage `json:"candidates"` + Messages []types.ChatCompletionMessage `json:"messages"` + Filters []PaLMFilter `json:"filters"` + Error PaLMError `json:"error"` + Usage *types.Usage `json:"usage,omitempty"` + Model string `json:"model,omitempty"` +} + +func (palmResponse *PaLMChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { + return nil, &types.OpenAIErrorWithStatusCode{ + OpenAIError: types.OpenAIError{ + Message: palmResponse.Error.Message, + Type: palmResponse.Error.Status, + Param: "", + Code: palmResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + } + } + + fullTextResponse := types.ChatCompletionResponse{ + Choices: make([]types.ChatCompletionChoice, 0, len(palmResponse.Candidates)), + } + for i, candidate := range palmResponse.Candidates { + choice := types.ChatCompletionChoice{ + Index: i, + Message: types.ChatCompletionMessage{ + Role: "assistant", + Content: candidate.Content, + }, + FinishReason: "stop", + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + + completionTokens := common.CountTokenText(palmResponse.Candidates[0].Content, palmResponse.Model) + palmResponse.Usage.CompletionTokens = completionTokens + palmResponse.Usage.TotalTokens = palmResponse.Usage.PromptTokens + completionTokens + + fullTextResponse.Usage = palmResponse.Usage + + return fullTextResponse, nil +} + +func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest) *PaLMChatRequest { + palmRequest := PaLMChatRequest{ + Prompt: PaLMPrompt{ + Messages: make([]PaLMChatMessage, 0, len(request.Messages)), + }, + Temperature: request.Temperature, + CandidateCount: request.N, + TopP: request.TopP, + TopK: request.MaxTokens, + } + for _, message := range request.Messages { + palmMessage := PaLMChatMessage{ + Content: message.StringContent(), + } + if message.Role == "user" { + palmMessage.Author = "0" + } else { + palmMessage.Author = "1" + } + palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) + } + return &palmRequest +} + +func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + requestBody := p.getChatRequestBody(request) + fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + } + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + if request.Stream { + var responseText string + openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req) + if openAIErrorWithStatusCode != nil { + return + } + + usage.PromptTokens = promptTokens + usage.CompletionTokens = common.CountTokenText(responseText, request.Model) + usage.TotalTokens = promptTokens + usage.CompletionTokens + + } else { + var palmChatResponse = &PaLMChatResponse{ + Model: request.Model, + Usage: &types.Usage{ + PromptTokens: promptTokens, + }, + } + openAIErrorWithStatusCode = p.sendRequest(req, palmChatResponse) + if openAIErrorWithStatusCode != nil { + return + } + + usage = palmChatResponse.Usage + } + return + +} + +func (p *PalmProvider) streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *types.ChatCompletionStreamResponse { + var choice types.ChatCompletionStreamChoice + if len(palmResponse.Candidates) > 0 { + choice.Delta.Content = palmResponse.Candidates[0].Content + } + choice.FinishReason = &stopFinishReason + var response types.ChatCompletionStreamResponse + response.Object = "chat.completion.chunk" + response.Model = "palm2" + response.Choices = []types.ChatCompletionStreamChoice{choice} + return &response +} + +func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) { + // 发送请求 + resp, err := common.HttpClient.Do(req) + if err != nil { + return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" + } + + if common.IsFailureStatusCode(resp) { + return p.handleErrorResp(resp), "" + } + + defer resp.Body.Close() + + responseText := "" + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysError("error reading stream response: " + err.Error()) + stopChan <- true + return + } + err = resp.Body.Close() + if err != nil { + common.SysError("error closing stream response: " + err.Error()) + stopChan <- true + return + } + var palmResponse PaLMChatResponse + err = json.Unmarshal(responseBody, &palmResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + stopChan <- true + return + } + fullTextResponse := p.streamResponsePaLM2OpenAI(&palmResponse) + fullTextResponse.ID = responseId + fullTextResponse.Created = createdTime + if len(palmResponse.Candidates) > 0 { + responseText = palmResponse.Candidates[0].Content + } + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + stopChan <- true + return + } + dataChan <- string(jsonResponse) + stopChan <- true + }() + setEventStreamHeaders(p.Context) + p.Context.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + p.Context.Render(-1, common.CustomEvent{Data: "data: " + data}) + return true + case <-stopChan: + p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + + return nil, responseText +} diff --git a/providers/tencent_base.go b/providers/tencent_base.go new file mode 100644 index 00000000..0318bb07 --- /dev/null +++ b/providers/tencent_base.go @@ -0,0 +1,94 @@ +package providers + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "errors" + "fmt" + "sort" + "strconv" + "strings" + + "github.com/gin-gonic/gin" +) + +type TencentProvider struct { + ProviderConfig +} + +type TencentError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// 创建 TencentProvider +func CreateTencentProvider(c *gin.Context) *TencentProvider { + return &TencentProvider{ + ProviderConfig: ProviderConfig{ + BaseURL: "https://hunyuan.cloud.tencent.com", + ChatCompletions: "/hyllm/v1/chat/completions", + Context: c, + }, + } +} + +// 获取请求头 +func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + + headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") + headers["Accept"] = p.Context.Request.Header.Get("Accept") + if headers["Content-Type"] == "" { + headers["Content-Type"] = "application/json" + } + + return headers +} + +func (p *TencentProvider) parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { + parts := strings.Split(config, "|") + if len(parts) != 3 { + err = errors.New("invalid tencent config") + return + } + appId, err = strconv.ParseInt(parts[0], 10, 64) + secretId = parts[1] + secretKey = parts[2] + return +} + +func (p *TencentProvider) getTencentSign(req TencentChatRequest) string { + apiKey := p.Context.GetString("api_key") + appId, secretId, secretKey, err := p.parseTencentConfig(apiKey) + if err != nil { + return "" + } + req.AppId = appId + req.SecretId = secretId + + params := make([]string, 0) + params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) + params = append(params, "secret_id="+req.SecretId) + params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) + params = append(params, "query_id="+req.QueryID) + params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) + params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) + params = append(params, "stream="+strconv.Itoa(req.Stream)) + params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) + + var messageStr string + for _, msg := range req.Messages { + messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) + } + messageStr = strings.TrimSuffix(messageStr, ",") + params = append(params, "messages=["+messageStr+"]") + + sort.Sort(sort.StringSlice(params)) + url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") + mac := hmac.New(sha1.New, []byte(secretKey)) + signURL := url + mac.Write([]byte(signURL)) + sign := mac.Sum([]byte(nil)) + return base64.StdEncoding.EncodeToString(sign) +} diff --git a/controller/relay-tencent.go b/providers/tencent_chat.go similarity index 59% rename from controller/relay-tencent.go rename to providers/tencent_chat.go index f66bf38f..b7ee5eb8 100644 --- a/controller/relay-tencent.go +++ b/providers/tencent_chat.go @@ -1,24 +1,16 @@ -package controller +package providers import ( "bufio" - "crypto/hmac" - "crypto/sha1" - "encoding/base64" "encoding/json" "errors" - "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" - "sort" - "strconv" + "one-api/types" "strings" ) -// https://cloud.tencent.com/document/product/1729/97732 - type TencentMessage struct { Role string `json:"role"` Content string `json:"content"` @@ -50,11 +42,6 @@ type TencentChatRequest struct { Messages []TencentMessage `json:"messages"` } -type TencentError struct { - Code int `json:"code"` - Message string `json:"message"` -} - type TencentUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` @@ -71,13 +58,44 @@ type TencentChatResponse struct { Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 Created string `json:"created,omitempty"` // unix 时间戳的字符串 Id string `json:"id,omitempty"` // 会话 id - Usage Usage `json:"usage,omitempty"` // token 数量 + Usage *types.Usage `json:"usage,omitempty"` // token 数量 Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 Note string `json:"note,omitempty"` // 注释 ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 } -func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { +func (TencentResponse *TencentChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + if TencentResponse.Error.Code != 0 { + return &types.OpenAIErrorWithStatusCode{ + OpenAIError: types.OpenAIError{ + Message: TencentResponse.Error.Message, + Code: TencentResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := types.ChatCompletionResponse{ + Object: "chat.completion", + Created: common.GetTimestamp(), + Usage: TencentResponse.Usage, + } + if len(TencentResponse.Choices) > 0 { + choice := types.ChatCompletionChoice{ + Index: 0, + Message: types.ChatCompletionMessage{ + Role: "assistant", + Content: TencentResponse.Choices[0].Messages.Content, + }, + FinishReason: TencentResponse.Choices[0].FinishReason, + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + + return fullTextResponse, nil +} + +func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionRequest) *TencentChatRequest { messages := make([]TencentMessage, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] @@ -112,34 +130,58 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { } } -func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ - Object: "chat.completion", - Created: common.GetTimestamp(), - Usage: response.Usage, +func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + requestBody := p.getChatRequestBody(request) + sign := p.getTencentSign(*requestBody) + if sign == "" { + return nil, types.ErrorWrapper(errors.New("get tencent sign failed"), "get_tencent_sign_failed", http.StatusInternalServerError) } - if len(response.Choices) > 0 { - choice := OpenAITextResponseChoice{ - Index: 0, - Message: Message{ - Role: "assistant", - Content: response.Choices[0].Messages.Content, - }, - FinishReason: response.Choices[0].FinishReason, + + fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) + headers := p.GetRequestHeaders() + headers["Authorization"] = sign + if request.Stream { + headers["Accept"] = "text/event-stream" + } + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + if request.Stream { + var responseText string + openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req) + if openAIErrorWithStatusCode != nil { + return } - fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + + usage.PromptTokens = promptTokens + usage.CompletionTokens = common.CountTokenText(responseText, request.Model) + usage.TotalTokens = promptTokens + usage.CompletionTokens + + } else { + tencentResponse := &TencentChatResponse{} + openAIErrorWithStatusCode = p.sendRequest(req, tencentResponse) + if openAIErrorWithStatusCode != nil { + return + } + + usage = tencentResponse.Usage } - return &fullTextResponse + return + } -func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { - response := ChatCompletionsStreamResponse{ +func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *types.ChatCompletionStreamResponse { + response := types.ChatCompletionStreamResponse{ Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "tencent-hunyuan", } if len(TencentResponse.Choices) > 0 { - var choice ChatCompletionsStreamResponseChoice + var choice types.ChatCompletionStreamChoice choice.Delta.Content = TencentResponse.Choices[0].Delta.Content if TencentResponse.Choices[0].FinishReason == "stop" { choice.FinishReason = &stopFinishReason @@ -149,7 +191,19 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCom return &response } -func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { +func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) { + // 发送请求 + resp, err := common.HttpClient.Do(req) + if err != nil { + return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" + } + + if common.IsFailureStatusCode(resp) { + return p.handleErrorResp(resp), "" + } + + defer resp.Body.Close() + var responseText string scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -180,8 +234,8 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith } stopChan <- true }() - setEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { + setEventStreamHeaders(p.Context) + p.Context.Stream(func(w io.Writer) bool { select { case data := <-dataChan: var TencentResponse TencentChatResponse @@ -190,7 +244,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith common.SysError("error unmarshalling stream response: " + err.Error()) return true } - response := streamResponseTencent2OpenAI(&TencentResponse) + response := p.streamResponseTencent2OpenAI(&TencentResponse) if len(response.Choices) != 0 { responseText += response.Choices[0].Delta.Content } @@ -199,89 +253,13 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith common.SysError("error marshalling stream response: " + err.Error()) return true } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) return true case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) return false } }) - err := resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" - } + return nil, responseText } - -func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var TencentResponse TencentChatResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &TencentResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if TencentResponse.Error.Code != 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: TencentResponse.Error.Message, - Code: TencentResponse.Error.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responseTencent2OpenAI(&TencentResponse) - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage -} - -func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { - parts := strings.Split(config, "|") - if len(parts) != 3 { - err = errors.New("invalid tencent config") - return - } - appId, err = strconv.ParseInt(parts[0], 10, 64) - secretId = parts[1] - secretKey = parts[2] - return -} - -func getTencentSign(req TencentChatRequest, secretKey string) string { - params := make([]string, 0) - params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) - params = append(params, "secret_id="+req.SecretId) - params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) - params = append(params, "query_id="+req.QueryID) - params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) - params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) - params = append(params, "stream="+strconv.Itoa(req.Stream)) - params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) - - var messageStr string - for _, msg := range req.Messages { - messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) - } - messageStr = strings.TrimSuffix(messageStr, ",") - params = append(params, "messages=["+messageStr+"]") - - sort.Sort(sort.StringSlice(params)) - url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") - mac := hmac.New(sha1.New, []byte(secretKey)) - signURL := url - mac.Write([]byte(signURL)) - sign := mac.Sum([]byte(nil)) - return base64.StdEncoding.EncodeToString(sign) -} diff --git a/providers/xunfei_base.go b/providers/xunfei_base.go new file mode 100644 index 00000000..7d2c4083 --- /dev/null +++ b/providers/xunfei_base.go @@ -0,0 +1,96 @@ +package providers + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/url" + "one-api/common" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +// https://www.xfyun.cn/doc/spark/Web.html +type XunfeiProvider struct { + ProviderConfig + domain string + apiId string +} + +// 创建 XunfeiProvider +func CreateXunfeiProvider(c *gin.Context) *XunfeiProvider { + return &XunfeiProvider{ + ProviderConfig: ProviderConfig{ + BaseURL: "wss://spark-api.xf-yun.com", + ChatCompletions: "", + Context: c, + }, + } +} + +// 获取请求头 +func (p *XunfeiProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + return headers +} + +// 获取完整请求 URL +func (p *XunfeiProvider) GetFullRequestURL(requestURL string, modelName string) string { + splits := strings.Split(p.Context.GetString("api_key"), "|") + if len(splits) != 3 { + return "" + } + domain, authUrl := p.getXunfeiAuthUrl(splits[2], splits[1]) + + p.domain = domain + p.apiId = splits[0] + + return authUrl +} + +func (p *XunfeiProvider) getXunfeiAuthUrl(apiKey string, apiSecret string) (string, string) { + query := p.Context.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = p.Context.GetString("api_version") + } + if apiVersion == "" { + apiVersion = "v1.1" + common.SysLog("api_version not found, use default: " + apiVersion) + } + domain := "general" + if apiVersion != "v1.1" { + domain += strings.Split(apiVersion, ".")[0] + } + authUrl := p.buildXunfeiAuthUrl(fmt.Sprintf("%s/%s/chat", p.BaseURL, apiVersion), apiKey, apiSecret) + return domain, authUrl +} + +func (p *XunfeiProvider) buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { + HmacWithShaToBase64 := func(algorithm, data, key string) string { + mac := hmac.New(sha256.New, []byte(key)) + mac.Write([]byte(data)) + encodeData := mac.Sum(nil) + return base64.StdEncoding.EncodeToString(encodeData) + } + ul, err := url.Parse(hostUrl) + if err != nil { + fmt.Println(err) + } + date := time.Now().UTC().Format(time.RFC1123) + signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} + sign := strings.Join(signString, "\n") + sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) + authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, + "hmac-sha256", "host date request-line", sha) + authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) + v := url.Values{} + v.Add("host", ul.Host) + v.Add("date", date) + v.Add("authorization", authorization) + callUrl := hostUrl + "?" + v.Encode() + return callUrl +} diff --git a/controller/relay-xunfei.go b/providers/xunfei_chat.go similarity index 56% rename from controller/relay-xunfei.go rename to providers/xunfei_chat.go index 00ec8981..ffec9097 100644 --- a/controller/relay-xunfei.go +++ b/providers/xunfei_chat.go @@ -1,23 +1,15 @@ -package controller +package providers import ( - "crypto/hmac" - "crypto/sha256" - "encoding/base64" "encoding/json" - "fmt" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "io" "net/http" - "net/url" "one-api/common" - "strings" + "one-api/types" "time" -) -// https://console.xfyun.cn/services/cbm -// https://www.xfyun.cn/doc/spark/Web.html + "github.com/gorilla/websocket" +) type XunfeiMessage struct { Role string `json:"role"` @@ -70,150 +62,28 @@ type XunfeiChatResponse struct { // CompletionTokens string `json:"completion_tokens"` // TotalTokens string `json:"total_tokens"` //} `json:"text"` - Text Usage `json:"text"` + Text types.Usage `json:"text"` } `json:"usage"` } `json:"payload"` } -func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { - messages := make([]XunfeiMessage, 0, len(request.Messages)) - for _, message := range request.Messages { - if message.Role == "system" { - messages = append(messages, XunfeiMessage{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, XunfeiMessage{ - Role: "assistant", - Content: "Okay", - }) - } else { - messages = append(messages, XunfeiMessage{ - Role: message.Role, - Content: message.StringContent(), - }) - } +func (p *XunfeiProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model) + + if request.Stream { + return p.sendStreamRequest(request, authUrl) + } else { + return p.sendRequest(request, authUrl) } - xunfeiRequest := XunfeiChatRequest{} - xunfeiRequest.Header.AppId = xunfeiAppId - xunfeiRequest.Parameter.Chat.Domain = domain - xunfeiRequest.Parameter.Chat.Temperature = request.Temperature - xunfeiRequest.Parameter.Chat.TopK = request.N - xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens - xunfeiRequest.Payload.Message.Text = messages - return &xunfeiRequest } -func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { - if len(response.Payload.Choices.Text) == 0 { - response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ - { - Content: "", - }, - } - } - choice := OpenAITextResponseChoice{ - Index: 0, - Message: Message{ - Role: "assistant", - Content: response.Payload.Choices.Text[0].Content, - }, - FinishReason: stopFinishReason, - } - fullTextResponse := OpenAITextResponse{ - Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, - Usage: response.Payload.Usage.Text, - } - return &fullTextResponse -} - -func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { - if len(xunfeiResponse.Payload.Choices.Text) == 0 { - xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ - { - Content: "", - }, - } - } - var choice ChatCompletionsStreamResponseChoice - choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content - if xunfeiResponse.Payload.Choices.Status == 2 { - choice.FinishReason = &stopFinishReason - } - response := ChatCompletionsStreamResponse{ - Object: "chat.completion.chunk", - Created: common.GetTimestamp(), - Model: "SparkDesk", - Choices: []ChatCompletionsStreamResponseChoice{choice}, - } - return &response -} - -func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { - HmacWithShaToBase64 := func(algorithm, data, key string) string { - mac := hmac.New(sha256.New, []byte(key)) - mac.Write([]byte(data)) - encodeData := mac.Sum(nil) - return base64.StdEncoding.EncodeToString(encodeData) - } - ul, err := url.Parse(hostUrl) +func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + usage = &types.Usage{} + dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl) if err != nil { - fmt.Println(err) + return nil, types.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError) } - date := time.Now().UTC().Format(time.RFC1123) - signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} - sign := strings.Join(signString, "\n") - sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) - authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, - "hmac-sha256", "host date request-line", sha) - authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) - v := url.Values{} - v.Add("host", ul.Host) - v.Add("date", date) - v.Add("authorization", authorization) - callUrl := hostUrl + "?" + v.Encode() - return callUrl -} -func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { - domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) - dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) - if err != nil { - return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil - } - setEventStreamHeaders(c) - var usage Usage - c.Stream(func(w io.Writer) bool { - select { - case xunfeiResponse := <-dataChan: - usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens - usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens - usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens - response := streamResponseXunfei2OpenAI(&xunfeiResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - return nil, &usage -} - -func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { - domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) - dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) - if err != nil { - return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil - } - var usage Usage var content string var xunfeiResponse XunfeiChatResponse stop := false @@ -233,17 +103,100 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin xunfeiResponse.Payload.Choices.Text[0].Content = content - response := responseXunfei2OpenAI(&xunfeiResponse) + response := p.responseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) } - c.Writer.Header().Set("Content-Type", "application/json") - _, _ = c.Writer.Write(jsonResponse) - return nil, &usage + p.Context.Writer.Header().Set("Content-Type", "application/json") + _, _ = p.Context.Writer.Write(jsonResponse) + return usage, nil } -func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { +func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + usage = &types.Usage{} + dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl) + if err != nil { + return nil, types.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError) + } + setEventStreamHeaders(p.Context) + p.Context.Stream(func(w io.Writer) bool { + select { + case xunfeiResponse := <-dataChan: + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + response := p.streamResponseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + return usage, nil +} + +func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionRequest) *XunfeiChatRequest { + messages := make([]XunfeiMessage, 0, len(request.Messages)) + for _, message := range request.Messages { + if message.Role == "system" { + messages = append(messages, XunfeiMessage{ + Role: "user", + Content: message.StringContent(), + }) + messages = append(messages, XunfeiMessage{ + Role: "assistant", + Content: "Okay", + }) + } else { + messages = append(messages, XunfeiMessage{ + Role: message.Role, + Content: message.StringContent(), + }) + } + } + xunfeiRequest := XunfeiChatRequest{} + xunfeiRequest.Header.AppId = p.apiId + xunfeiRequest.Parameter.Chat.Domain = p.domain + xunfeiRequest.Parameter.Chat.Temperature = request.Temperature + xunfeiRequest.Parameter.Chat.TopK = request.N + xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens + xunfeiRequest.Payload.Message.Text = messages + return &xunfeiRequest +} + +func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *types.ChatCompletionResponse { + if len(response.Payload.Choices.Text) == 0 { + response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + { + Content: "", + }, + } + } + choice := types.ChatCompletionChoice{ + Index: 0, + Message: types.ChatCompletionMessage{ + Role: "assistant", + Content: response.Payload.Choices.Text[0].Content, + }, + FinishReason: stopFinishReason, + } + fullTextResponse := types.ChatCompletionResponse{ + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []types.ChatCompletionChoice{choice}, + Usage: &response.Payload.Usage.Text, + } + return &fullTextResponse +} + +func (p *XunfeiProvider) xunfeiMakeRequest(textRequest *types.ChatCompletionRequest, authUrl string) (chan XunfeiChatResponse, chan bool, error) { d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } @@ -251,7 +204,7 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId if err != nil || resp.StatusCode != 101 { return nil, nil, err } - data := requestOpenAI2Xunfei(textRequest, appId, domain) + data := p.requestOpenAI2Xunfei(textRequest) err = conn.WriteJSON(data) if err != nil { return nil, nil, err @@ -287,20 +240,24 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId return dataChan, stopChan, nil } -func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) { - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") +func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *types.ChatCompletionStreamResponse { + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + { + Content: "", + }, + } } - if apiVersion == "" { - apiVersion = "v1.1" - common.SysLog("api_version not found, use default: " + apiVersion) + var choice types.ChatCompletionStreamChoice + choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content + if xunfeiResponse.Payload.Choices.Status == 2 { + choice.FinishReason = &stopFinishReason } - domain := "general" - if apiVersion != "v1.1" { - domain += strings.Split(apiVersion, ".")[0] + response := types.ChatCompletionStreamResponse{ + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "SparkDesk", + Choices: []types.ChatCompletionStreamChoice{choice}, } - authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) - return domain, authUrl + return &response } diff --git a/providers/zhipu_base.go b/providers/zhipu_base.go new file mode 100644 index 00000000..70eb4288 --- /dev/null +++ b/providers/zhipu_base.go @@ -0,0 +1,104 @@ +package providers + +import ( + "fmt" + "one-api/common" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt" +) + +var zhipuTokens sync.Map +var expSeconds int64 = 24 * 3600 + +type ZhipuProvider struct { + ProviderConfig +} + +type zhipuTokenData struct { + Token string + ExpiryTime time.Time +} + +// 创建 ZhipuProvider +func CreateZhipuProvider(c *gin.Context) *ZhipuProvider { + return &ZhipuProvider{ + ProviderConfig: ProviderConfig{ + BaseURL: "https://open.bigmodel.cn", + ChatCompletions: "/api/paas/v3/model-api", + Context: c, + }, + } +} + +// 获取请求头 +func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + + headers["Authorization"] = p.getZhipuToken() + headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") + headers["Accept"] = p.Context.Request.Header.Get("Accept") + if headers["Content-Type"] == "" { + headers["Content-Type"] = "application/json" + } + + return headers +} + +// 获取完整请求 URL +func (p *ZhipuProvider) GetFullRequestURL(requestURL string, modelName string) string { + baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") + + return fmt.Sprintf("%s%s/%s", baseURL, requestURL, modelName) +} + +func (p *ZhipuProvider) getZhipuToken() string { + apikey := p.Context.GetString("api_key") + data, ok := zhipuTokens.Load(apikey) + if ok { + tokenData := data.(zhipuTokenData) + if time.Now().Before(tokenData.ExpiryTime) { + return tokenData.Token + } + } + + split := strings.Split(apikey, ".") + if len(split) != 2 { + common.SysError("invalid zhipu key: " + apikey) + return "" + } + + id := split[0] + secret := split[1] + + expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6 + expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second) + + timestamp := time.Now().UnixNano() / 1e6 + + payload := jwt.MapClaims{ + "api_key": id, + "exp": expMillis, + "timestamp": timestamp, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) + + token.Header["alg"] = "HS256" + token.Header["sign_type"] = "SIGN" + + tokenString, err := token.SignedString([]byte(secret)) + if err != nil { + return "" + } + + zhipuTokens.Store(apikey, zhipuTokenData{ + Token: tokenString, + ExpiryTime: expiryTime, + }) + + return tokenString +} diff --git a/providers/zhipu_chat.go b/providers/zhipu_chat.go new file mode 100644 index 00000000..4e7f1711 --- /dev/null +++ b/providers/zhipu_chat.go @@ -0,0 +1,260 @@ +package providers + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "one-api/common" + "one-api/types" + "strings" +) + +type ZhipuMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ZhipuRequest struct { + Prompt []ZhipuMessage `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + RequestId string `json:"request_id,omitempty"` + Incremental bool `json:"incremental,omitempty"` +} + +type ZhipuResponseData struct { + TaskId string `json:"task_id"` + RequestId string `json:"request_id"` + TaskStatus string `json:"task_status"` + Choices []ZhipuMessage `json:"choices"` + types.Usage `json:"usage"` +} + +type ZhipuResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Success bool `json:"success"` + Data ZhipuResponseData `json:"data"` +} + +type ZhipuStreamMetaResponse struct { + RequestId string `json:"request_id"` + TaskId string `json:"task_id"` + TaskStatus string `json:"task_status"` + types.Usage `json:"usage"` +} + +func (zhipuResponse *ZhipuResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + if !zhipuResponse.Success { + return &types.OpenAIErrorWithStatusCode{ + OpenAIError: types.OpenAIError{ + Message: zhipuResponse.Msg, + Type: "zhipu_error", + Param: "", + Code: zhipuResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := types.ChatCompletionResponse{ + ID: zhipuResponse.Data.TaskId, + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: make([]types.ChatCompletionChoice, 0, len(zhipuResponse.Data.Choices)), + Usage: &zhipuResponse.Data.Usage, + } + for i, choice := range zhipuResponse.Data.Choices { + openaiChoice := types.ChatCompletionChoice{ + Index: i, + Message: types.ChatCompletionMessage{ + Role: choice.Role, + Content: strings.Trim(choice.Content, "\""), + }, + FinishReason: "", + } + if i == len(zhipuResponse.Data.Choices)-1 { + openaiChoice.FinishReason = "stop" + } + fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice) + } + return fullTextResponse, nil + +} + +func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest) *ZhipuRequest { + messages := make([]ZhipuMessage, 0, len(request.Messages)) + for _, message := range request.Messages { + if message.Role == "system" { + messages = append(messages, ZhipuMessage{ + Role: "system", + Content: message.StringContent(), + }) + messages = append(messages, ZhipuMessage{ + Role: "user", + Content: "Okay", + }) + } else { + messages = append(messages, ZhipuMessage{ + Role: message.Role, + Content: message.StringContent(), + }) + } + } + return &ZhipuRequest{ + Prompt: messages, + Temperature: request.Temperature, + TopP: request.TopP, + Incremental: false, + } +} + +func (p *ZhipuProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + requestBody := p.getChatRequestBody(request) + fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + fullRequestURL += "/sse-invoke" + } else { + fullRequestURL += "/invoke" + } + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + if request.Stream { + openAIErrorWithStatusCode, usage = p.sendStreamRequest(req) + if openAIErrorWithStatusCode != nil { + return + } + + } else { + zhipuResponse := &ZhipuResponse{} + openAIErrorWithStatusCode = p.sendRequest(req, zhipuResponse) + if openAIErrorWithStatusCode != nil { + return + } + + usage = &zhipuResponse.Data.Usage + } + return + +} + +func (p *ZhipuProvider) streamResponseZhipu2OpenAI(zhipuResponse string) *types.ChatCompletionStreamResponse { + var choice types.ChatCompletionStreamChoice + choice.Delta.Content = zhipuResponse + response := types.ChatCompletionStreamResponse{ + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "chatglm", + Choices: []types.ChatCompletionStreamChoice{choice}, + } + return &response +} + +func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*types.ChatCompletionStreamResponse, *types.Usage) { + var choice types.ChatCompletionStreamChoice + choice.Delta.Content = "" + choice.FinishReason = &stopFinishReason + response := types.ChatCompletionStreamResponse{ + ID: zhipuResponse.RequestId, + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "chatglm", + Choices: []types.ChatCompletionStreamChoice{choice}, + } + return &response, &zhipuResponse.Usage +} + +func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, *types.Usage) { + // 发送请求 + resp, err := common.HttpClient.Do(req) + if err != nil { + return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil + } + + if common.IsFailureStatusCode(resp) { + return p.handleErrorResp(resp), nil + } + + defer resp.Body.Close() + + var usage *types.Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { + return i + 2, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + metaChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + lines := strings.Split(data, "\n") + for i, line := range lines { + if len(line) < 5 { + continue + } + if line[:5] == "data:" { + dataChan <- line[5:] + if i != len(lines)-1 { + dataChan <- "\n" + } + } else if line[:5] == "meta:" { + metaChan <- line[5:] + } + } + } + stopChan <- true + }() + setEventStreamHeaders(p.Context) + p.Context.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + response := p.streamResponseZhipu2OpenAI(data) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case data := <-metaChan: + var zhipuResponse ZhipuStreamMetaResponse + err := json.Unmarshal([]byte(data), &zhipuResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + response, zhipuUsage := p.streamMetaResponseZhipu2OpenAI(&zhipuResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + usage = zhipuUsage + p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + + return nil, usage +} diff --git a/types/assistant.go b/types/assistant.go new file mode 100644 index 00000000..a95dc829 --- /dev/null +++ b/types/assistant.go @@ -0,0 +1,53 @@ +package types + +type Assistant struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools any `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type AssistantRequest struct { + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools any `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// AssistantsList is a list of assistants. +type AssistantsList struct { + Assistants []Assistant `json:"data"` + LastID *string `json:"last_id"` + FirstID *string `json:"first_id"` + HasMore bool `json:"has_more"` +} + +type AssistantDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` +} + +type AssistantFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + AssistantID string `json:"assistant_id"` +} + +type AssistantFileRequest struct { + FileID string `json:"file_id"` +} + +type AssistantFilesList struct { + AssistantFiles []AssistantFile `json:"data"` +} diff --git a/types/audio.go b/types/audio.go new file mode 100644 index 00000000..075cbf7a --- /dev/null +++ b/types/audio.go @@ -0,0 +1,9 @@ +package types + +type SpeechAudioRequest struct { + Model string `json:"model"` + Input string `json:"input"` + Voice string `json:"voice"` + ResponseFormat string `json:"response_format,omitempty"` + Speed float64 `json:"speed,omitempty"` +} diff --git a/types/chat.go b/types/chat.go new file mode 100644 index 00000000..5902f56a --- /dev/null +++ b/types/chat.go @@ -0,0 +1,109 @@ +package types + +type ChatCompletionMessage struct { + Role string `json:"role"` + Content any `json:"content"` + Name *string `json:"name,omitempty"` + FunctionCall any `json:"function_call,omitempty"` + ToolCalls any `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +func (m ChatCompletionMessage) StringContent() string { + content, ok := m.Content.(string) + if ok { + return content + } + contentList, ok := m.Content.([]any) + if ok { + var contentStr string + for _, contentItem := range contentList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + if contentMap["type"] == "text" { + if subStr, ok := contentMap["text"].(string); ok { + contentStr += subStr + } + } + } + return contentStr + } + return "" +} + +type ChatMessageImageURL struct { + URL string `json:"url,omitempty"` + Detail string `json:"detail,omitempty"` +} + +type ChatMessagePart struct { + Type string `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` +} + +type ChatCompletionResponseFormat struct { + Type string `json:"type,omitempty"` +} + +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + LogitBias any `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` + Functions any `json:"functions,omitempty"` + FunctionCall any `json:"function_call,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` +} + +type ChatCompletionChoice struct { + Index int `json:"index"` + Message ChatCompletionMessage `json:"message"` + FinishReason any `json:"finish_reason,omitempty"` +} + +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` +} + +type ChatCompletionStreamChoiceDelta struct { + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + FunctionCall any `json:"function_call,omitempty"` + ToolCalls any `json:"tool_calls,omitempty"` +} + +type ChatCompletionStreamChoice struct { + Index int `json:"index"` + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + FinishReason any `json:"finish_reason"` + ContentFilterResults any `json:"content_filter_results,omitempty"` +} + +type ChatCompletionStreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionStreamChoice `json:"choices"` + PromptAnnotations any `json:"prompt_annotations,omitempty"` +} diff --git a/types/common.go b/types/common.go new file mode 100644 index 00000000..b955f388 --- /dev/null +++ b/types/common.go @@ -0,0 +1,40 @@ +package types + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type OpenAIError struct { + Code any `json:"code,omitempty"` + Message string `json:"message"` + Param string `json:"param,omitempty"` + Type string `json:"type"` + InnerError any `json:"innererror,omitempty"` +} + +type OpenAIErrorWithStatusCode struct { + OpenAIError + StatusCode int `json:"status_code"` +} + +type OpenAIErrorResponse struct { + Error OpenAIError `json:"error,omitempty"` +} + +func ErrorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { + openAIError := OpenAIError{ + Message: err.Error(), + Type: "one_api_error", + Code: code, + } + return &OpenAIErrorWithStatusCode{ + OpenAIError: openAIError, + StatusCode: statusCode, + } +} + +// type GeneralErrorHandling interface { +// HandleError(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) +// } diff --git a/types/completion.go b/types/completion.go new file mode 100644 index 00000000..0bac8ceb --- /dev/null +++ b/types/completion.go @@ -0,0 +1,36 @@ +package types + +type CompletionRequest struct { + Model string `json:"model"` + Prompt any `json:"prompt,omitempty"` + Suffix string `json:"suffix,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + Echo bool `json:"echo,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + BestOf int `json:"best_of,omitempty"` + LogitBias any `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` +} + +type CompletionChoice struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + LogProbs any `json:"logprobs,omitempty"` +} + +type CompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []CompletionChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} diff --git a/types/embeddings.go b/types/embeddings.go new file mode 100644 index 00000000..05df9241 --- /dev/null +++ b/types/embeddings.go @@ -0,0 +1,40 @@ +package types + +type EmbeddingRequest struct { + Model string `json:"model"` + Input any `json:"input"` + EncodingFormat string `json:"encoding_format,omitempty"` + User string `json:"user,omitempty"` +} + +type Embedding struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +type EmbeddingResponse struct { + Object string `json:"object"` + Data []Embedding `json:"data"` + Model string `json:"model"` + Usage *Usage `json:"usage,omitempty"` +} + +func (r EmbeddingRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input +} diff --git a/types/image.go b/types/image.go new file mode 100644 index 00000000..a3254769 --- /dev/null +++ b/types/image.go @@ -0,0 +1,23 @@ +package types + +type ImageRequest struct { + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Quality string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + Style string `json:"style,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` +} + +type ImageResponse struct { + Created int64 `json:"created,omitempty"` + Data []ImageResponseDataInner `json:"data,omitempty"` +} + +type ImageResponseDataInner struct { + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` +}