diff --git a/common/client.go b/common/client.go index 5fb596b0..48fd0ebc 100644 --- a/common/client.go +++ b/common/client.go @@ -6,6 +6,8 @@ import ( "io" "net/http" "time" + + "github.com/gin-gonic/gin" ) var HttpClient *http.Client @@ -124,3 +126,11 @@ func DecodeString(body io.Reader, output *string) error { *output = string(b) return nil } + +func SetEventStreamHeaders(c *gin.Context) { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") +} diff --git a/controller/channel-test.go b/controller/channel-test.go index 29e7360a..a1656470 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -70,7 +70,7 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e } promptTokens := common.CountTokenMessages(request.Messages, request.Model) - _, openAIErrorWithStatusCode := chatProvider.ChatCompleteResponse(&request, isModelMapped, promptTokens) + _, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, isModelMapped, promptTokens) if openAIErrorWithStatusCode != nil { return nil, &openAIErrorWithStatusCode.OpenAIError } diff --git a/controller/model.go b/controller/model.go index 59ea22e8..1bae0f74 100644 --- a/controller/model.go +++ b/controller/model.go @@ -2,6 +2,7 @@ package controller import ( "fmt" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -541,7 +542,7 @@ func RetrieveModel(c *gin.Context) { if model, ok := openAIModelsMap[modelId]; ok { c.JSON(200, model) } else { - openAIError := OpenAIError{ + openAIError := types.OpenAIError{ Message: fmt.Sprintf("The model '%s' does not exist", modelId), Type: "invalid_request_error", Param: "model", diff --git a/controller/relay-chat.go b/controller/relay-chat.go deleted file mode 100644 index 746947e0..00000000 --- a/controller/relay-chat.go +++ /dev/null @@ -1,127 +0,0 @@ -package controller - -import ( - "context" - "errors" - "net/http" - "one-api/common" - "one-api/model" - "one-api/providers" - "one-api/types" - - "github.com/gin-gonic/gin" -) - -func relayChatHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode { - - // 获取请求参数 - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - // consumeQuota := c.GetBool("consume_quota") - group := c.GetString("group") - - // 获取 Provider - chatProvider := GetChatProvider(channelType, c) - if chatProvider == nil { - return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented) - } - - // 获取请求体 - var chatRequest types.ChatCompletionRequest - err := common.UnmarshalBodyReusable(c, &chatRequest) - if err != nil { - return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - // 检查模型映射 - isModelMapped := false - modelMap, err := parseModelMapping(c.GetString("model_mapping")) - if err != nil { - return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap != nil && modelMap[chatRequest.Model] != "" { - chatRequest.Model = modelMap[chatRequest.Model] - isModelMapped = true - } - - // 开始计算Tokens - var promptTokens int - promptTokens = common.CountTokenMessages(chatRequest.Messages, chatRequest.Model) - - // 计算预付费配额 - quotaInfo := &QuotaInfo{ - modelName: chatRequest.Model, - promptTokens: promptTokens, - userId: userId, - channelId: channelId, - tokenId: tokenId, - } - quotaInfo.initQuotaInfo(group) - quota_err := quotaInfo.preQuotaConsumption() - if quota_err != nil { - return quota_err - } - - usage, openAIErrorWithStatusCode := chatProvider.ChatCompleteResponse(&chatRequest, isModelMapped, promptTokens) - - if openAIErrorWithStatusCode != nil { - if quotaInfo.preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - return openAIErrorWithStatusCode - } - - tokenName := c.GetString("token_name") - defer func(ctx context.Context) { - go func() { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }() - }(c.Request.Context()) - - return nil -} - -func GetChatProvider(channelType int, c *gin.Context) providers.ChatProviderAction { - switch channelType { - case common.ChannelTypeOpenAI: - return providers.CreateOpenAIProvider(c, "") - case common.ChannelTypeAzure: - return providers.CreateAzureProvider(c) - case common.ChannelTypeAli: - return providers.CreateAliAIProvider(c) - case common.ChannelTypeTencent: - return providers.CreateTencentProvider(c) - case common.ChannelTypeBaidu: - return providers.CreateBaiduProvider(c) - case common.ChannelTypeAnthropic: - return providers.CreateClaudeProvider(c) - case common.ChannelTypePaLM: - return providers.CreatePalmProvider(c) - case common.ChannelTypeZhipu: - return providers.CreateZhipuProvider(c) - case common.ChannelTypeXunfei: - return providers.CreateXunfeiProvider(c) - } - - baseURL := common.ChannelBaseURLs[channelType] - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - - if baseURL != "" { - return providers.CreateOpenAIProvider(c, baseURL) - } - - return nil -} diff --git a/controller/relay-completion.go b/controller/relay-completion.go deleted file mode 100644 index 6087adfa..00000000 --- a/controller/relay-completion.go +++ /dev/null @@ -1,113 +0,0 @@ -package controller - -import ( - "context" - "errors" - "net/http" - "one-api/common" - "one-api/model" - "one-api/providers" - "one-api/types" - - "github.com/gin-gonic/gin" -) - -func relayCompletionHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode { - - // 获取请求参数 - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - // consumeQuota := c.GetBool("consume_quota") - group := c.GetString("group") - - // 获取 Provider - completionProvider := GetCompletionProvider(channelType, c) - if completionProvider == nil { - return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented) - } - - // 获取请求体 - var completionRequest types.CompletionRequest - err := common.UnmarshalBodyReusable(c, &completionRequest) - if err != nil { - return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - // 检查模型映射 - isModelMapped := false - modelMap, err := parseModelMapping(c.GetString("model_mapping")) - if err != nil { - return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap != nil && modelMap[completionRequest.Model] != "" { - completionRequest.Model = modelMap[completionRequest.Model] - isModelMapped = true - } - - // 开始计算Tokens - var promptTokens int - promptTokens = common.CountTokenInput(completionRequest.Prompt, completionRequest.Model) - - // 计算预付费配额 - quotaInfo := &QuotaInfo{ - modelName: completionRequest.Model, - promptTokens: promptTokens, - userId: userId, - channelId: channelId, - tokenId: tokenId, - } - quotaInfo.initQuotaInfo(group) - quota_err := quotaInfo.preQuotaConsumption() - if quota_err != nil { - return quota_err - } - - usage, openAIErrorWithStatusCode := completionProvider.CompleteResponse(&completionRequest, isModelMapped, promptTokens) - - if openAIErrorWithStatusCode != nil { - if quotaInfo.preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - return openAIErrorWithStatusCode - } - - tokenName := c.GetString("token_name") - defer func(ctx context.Context) { - go func() { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }() - }(c.Request.Context()) - - return nil -} - -func GetCompletionProvider(channelType int, c *gin.Context) providers.CompletionProviderAction { - switch channelType { - case common.ChannelTypeOpenAI: - return providers.CreateOpenAIProvider(c, "") - case common.ChannelTypeAzure: - return providers.CreateAzureProvider(c) - } - - baseURL := common.ChannelBaseURLs[channelType] - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - - if baseURL != "" { - return providers.CreateOpenAIProvider(c, baseURL) - } - - return nil -} diff --git a/controller/relay-embeddings.go b/controller/relay-embeddings.go deleted file mode 100644 index 86189ba7..00000000 --- a/controller/relay-embeddings.go +++ /dev/null @@ -1,117 +0,0 @@ -package controller - -import ( - "context" - "errors" - "net/http" - "one-api/common" - "one-api/model" - "one-api/providers" - "one-api/types" - - "github.com/gin-gonic/gin" -) - -func relayEmbeddingsHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode { - - // 获取请求参数 - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - // consumeQuota := c.GetBool("consume_quota") - group := c.GetString("group") - - // 获取 Provider - embeddingsProvider := GetEmbeddingsProvider(channelType, c) - if embeddingsProvider == nil { - return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented) - } - - // 获取请求体 - var embeddingsRequest types.EmbeddingRequest - err := common.UnmarshalBodyReusable(c, &embeddingsRequest) - if err != nil { - return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - // 检查模型映射 - isModelMapped := false - modelMap, err := parseModelMapping(c.GetString("model_mapping")) - if err != nil { - return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap != nil && modelMap[embeddingsRequest.Model] != "" { - embeddingsRequest.Model = modelMap[embeddingsRequest.Model] - isModelMapped = true - } - - // 开始计算Tokens - var promptTokens int - promptTokens = common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model) - - // 计算预付费配额 - quotaInfo := &QuotaInfo{ - modelName: embeddingsRequest.Model, - promptTokens: promptTokens, - userId: userId, - channelId: channelId, - tokenId: tokenId, - } - quotaInfo.initQuotaInfo(group) - quota_err := quotaInfo.preQuotaConsumption() - if quota_err != nil { - return quota_err - } - - usage, openAIErrorWithStatusCode := embeddingsProvider.EmbeddingsResponse(&embeddingsRequest, isModelMapped, promptTokens) - - if openAIErrorWithStatusCode != nil { - if quotaInfo.preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - return openAIErrorWithStatusCode - } - - tokenName := c.GetString("token_name") - defer func(ctx context.Context) { - go func() { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }() - }(c.Request.Context()) - - return nil -} - -func GetEmbeddingsProvider(channelType int, c *gin.Context) providers.EmbeddingsProviderAction { - switch channelType { - case common.ChannelTypeOpenAI: - return providers.CreateOpenAIProvider(c, "") - case common.ChannelTypeAzure: - return providers.CreateAzureProvider(c) - case common.ChannelTypeAli: - return providers.CreateAliAIProvider(c) - case common.ChannelTypeBaidu: - return providers.CreateBaiduProvider(c) - } - - baseURL := common.ChannelBaseURLs[channelType] - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - - if baseURL != "" { - return providers.CreateOpenAIProvider(c, baseURL) - } - - return nil -} diff --git a/controller/relay-text.go b/controller/relay-text.go new file mode 100644 index 00000000..528b0690 --- /dev/null +++ b/controller/relay-text.go @@ -0,0 +1,160 @@ +package controller + +import ( + "context" + "errors" + "net/http" + "one-api/common" + "one-api/model" + "one-api/providers" + providers_base "one-api/providers/base" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +func relayTextHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode { + // 获取请求参数 + channelType := c.GetInt("channel") + channelId := c.GetInt("channel_id") + tokenId := c.GetInt("token_id") + userId := c.GetInt("id") + group := c.GetString("group") + + // 获取 Provider + provider := providers.GetProvider(channelType, c) + if provider == nil { + return types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) + } + + modelMap, err := parseModelMapping(c.GetString("model_mapping")) + if err != nil { + return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + + var promptTokens int + quotaInfo := &QuotaInfo{ + modelName: "", + promptTokens: promptTokens, + userId: userId, + channelId: channelId, + tokenId: tokenId, + } + + var usage *types.Usage + var openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode + + switch relayMode { + case RelayModeChatCompletions: + usage, openAIErrorWithStatusCode = handleChatCompletions(c, provider, modelMap, quotaInfo, group) + case RelayModeCompletions: + usage, openAIErrorWithStatusCode = handleCompletions(c, provider, modelMap, quotaInfo, group) + case RelayModeEmbeddings: + usage, openAIErrorWithStatusCode = handleEmbeddings(c, provider, modelMap, quotaInfo, group) + default: + return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest) + } + + if openAIErrorWithStatusCode != nil { + if quotaInfo.preConsumedQuota != 0 { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) + if err != nil { + common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) + } + }(c.Request.Context()) + } + return openAIErrorWithStatusCode + } + + tokenName := c.GetString("token_name") + defer func(ctx context.Context) { + go func() { + err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) + if err != nil { + common.LogError(ctx, err.Error()) + } + }() + }(c.Request.Context()) + + return nil +} + +func handleChatCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { + var chatRequest types.ChatCompletionRequest + isModelMapped := false + chatProvider, ok := provider.(providers_base.ChatInterface) + if !ok { + return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) + } + err := common.UnmarshalBodyReusable(c, &chatRequest) + if err != nil { + return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + if modelMap != nil && modelMap[chatRequest.Model] != "" { + chatRequest.Model = modelMap[chatRequest.Model] + isModelMapped = true + } + promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model) + + quotaInfo.modelName = chatRequest.Model + quotaInfo.initQuotaInfo(group) + quota_err := quotaInfo.preQuotaConsumption() + if quota_err != nil { + return nil, quota_err + } + return chatProvider.ChatAction(&chatRequest, isModelMapped, promptTokens) +} + +func handleCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { + var completionRequest types.CompletionRequest + isModelMapped := false + completionProvider, ok := provider.(providers_base.CompletionInterface) + if !ok { + return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) + } + err := common.UnmarshalBodyReusable(c, &completionRequest) + if err != nil { + return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + if modelMap != nil && modelMap[completionRequest.Model] != "" { + completionRequest.Model = modelMap[completionRequest.Model] + isModelMapped = true + } + promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model) + + quotaInfo.modelName = completionRequest.Model + quotaInfo.initQuotaInfo(group) + quota_err := quotaInfo.preQuotaConsumption() + if quota_err != nil { + return nil, quota_err + } + return completionProvider.CompleteAction(&completionRequest, isModelMapped, promptTokens) +} + +func handleEmbeddings(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { + var embeddingsRequest types.EmbeddingRequest + isModelMapped := false + embeddingsProvider, ok := provider.(providers_base.EmbeddingsInterface) + if !ok { + return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) + } + err := common.UnmarshalBodyReusable(c, &embeddingsRequest) + if err != nil { + return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + if modelMap != nil && modelMap[embeddingsRequest.Model] != "" { + embeddingsRequest.Model = modelMap[embeddingsRequest.Model] + isModelMapped = true + } + promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model) + + quotaInfo.modelName = embeddingsRequest.Model + quotaInfo.initQuotaInfo(group) + quota_err := quotaInfo.preQuotaConsumption() + if quota_err != nil { + return nil, quota_err + } + return embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens) +} diff --git a/controller/relay.go b/controller/relay.go index 01519269..dfa185e8 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -237,19 +237,19 @@ type CompletionsStreamResponse struct { func Relay(c *gin.Context) { var err *types.OpenAIErrorWithStatusCode - // relayMode := RelayModeUnknown + relayMode := RelayModeUnknown if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { - err = relayChatHelper(c) - // relayMode = RelayModeChatCompletions + // err = relayChatHelper(c) + relayMode = RelayModeChatCompletions } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { - err = relayCompletionHelper(c) - // relayMode = RelayModeCompletions + // err = relayCompletionHelper(c) + relayMode = RelayModeCompletions } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { - err = relayEmbeddingsHelper(c) + // err = relayEmbeddingsHelper(c) + relayMode = RelayModeEmbeddings + } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + relayMode = RelayModeEmbeddings } - // relayMode = RelayModeEmbeddings - // } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - // relayMode = RelayModeEmbeddings // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { // relayMode = RelayModeModerations // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { @@ -263,7 +263,7 @@ func Relay(c *gin.Context) { // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { // relayMode = RelayModeAudioTranslation // } - // switch relayMode { + switch relayMode { // case RelayModeImagesGenerations: // err = relayImageHelper(c, relayMode) // case RelayModeAudioSpeech: @@ -272,9 +272,9 @@ func Relay(c *gin.Context) { // fallthrough // case RelayModeAudioTranscription: // err = relayAudioHelper(c, relayMode) - // default: - // err = relayTextHelper(c, relayMode) - // } + default: + err = relayTextHelper(c, relayMode) + } if err != nil { requestId := c.GetString(common.RequestIdKey) retryTimesStr := c.Query("retry") diff --git a/providers/ali/base.go b/providers/ali/base.go new file mode 100644 index 00000000..249abf05 --- /dev/null +++ b/providers/ali/base.go @@ -0,0 +1,35 @@ +package ali + +import ( + "fmt" + + "one-api/providers/base" + + "github.com/gin-gonic/gin" +) + +type AliProvider struct { + base.BaseProvider +} + +// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation +// 创建 AliAIProvider +func CreateAliAIProvider(c *gin.Context) *AliProvider { + return &AliProvider{ + BaseProvider: base.BaseProvider{ + BaseURL: "https://dashscope.aliyuncs.com", + ChatCompletions: "/api/v1/services/aigc/text-generation/generation", + Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding", + Context: c, + }, + } +} + +// 获取请求头 +func (p *AliProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + p.CommonRequestHeaders(headers) + headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key")) + + return headers +} diff --git a/providers/ali_chat.go b/providers/ali/chat.go similarity index 71% rename from providers/ali_chat.go rename to providers/ali/chat.go index 12a66313..6a8a41ba 100644 --- a/providers/ali_chat.go +++ b/providers/ali/chat.go @@ -1,4 +1,4 @@ -package providers +package ali import ( "bufio" @@ -10,43 +10,10 @@ import ( "strings" ) -type AliMessage struct { - User string `json:"user"` - Bot string `json:"bot"` -} - -type AliInput struct { - Prompt string `json:"prompt"` - History []AliMessage `json:"history"` -} - -type AliParameters struct { - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Seed uint64 `json:"seed,omitempty"` - EnableSearch bool `json:"enable_search,omitempty"` -} - -type AliChatRequest struct { - Model string `json:"model"` - Input AliInput `json:"input"` - Parameters AliParameters `json:"parameters,omitempty"` -} - -type AliOutput struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` -} - -type AliChatResponse struct { - Output AliOutput `json:"output"` - Usage AliUsage `json:"usage"` - AliError -} - -func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +// 阿里云响应处理 +func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { if aliResponse.Code != "" { - return nil, &types.OpenAIErrorWithStatusCode{ + errWithCode = &types.OpenAIErrorWithStatusCode{ OpenAIError: types.OpenAIError{ Message: aliResponse.Message, Type: aliResponse.Code, @@ -55,6 +22,8 @@ func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIR }, StatusCode: resp.StatusCode, } + + return } choice := types.ChatCompletionChoice{ @@ -66,7 +35,7 @@ func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIR FinishReason: aliResponse.Output.FinishReason, } - fullTextResponse := types.ChatCompletionResponse{ + OpenAIResponse = types.ChatCompletionResponse{ ID: aliResponse.RequestId, Object: "chat.completion", Created: common.GetTimestamp(), @@ -78,10 +47,11 @@ func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIR }, } - return fullTextResponse, nil + return } -func (p *AliAIProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest { +// 获取聊天请求体 +func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest { messages := make([]AliMessage, 0, len(request.Messages)) prompt := "" for i := 0; i < len(request.Messages); i++ { @@ -113,7 +83,8 @@ func (p *AliAIProvider) getChatRequestBody(request *types.ChatCompletionRequest) } } -func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +// 聊天 +func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { requestBody := p.getChatRequestBody(request) fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) @@ -130,8 +101,8 @@ func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionReques } if request.Stream { - openAIErrorWithStatusCode, usage = p.sendStreamRequest(req) - if openAIErrorWithStatusCode != nil { + usage, errWithCode = p.sendStreamRequest(req) + if errWithCode != nil { return } @@ -145,8 +116,8 @@ func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionReques } else { aliResponse := &AliChatResponse{} - openAIErrorWithStatusCode = p.sendRequest(req, aliResponse) - if openAIErrorWithStatusCode != nil { + errWithCode = p.SendRequest(req, aliResponse) + if errWithCode != nil { return } @@ -159,7 +130,8 @@ func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionReques return } -func (p *AliAIProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse { +// 阿里云响应转OpenAI响应 +func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse { var choice types.ChatCompletionStreamChoice choice.Delta.Content = aliResponse.Output.Text if aliResponse.Output.FinishReason != "null" { @@ -177,16 +149,17 @@ func (p *AliAIProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) * return &response } -func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) { +// 发送流请求 +func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { usage = &types.Usage{} // 发送请求 resp, err := common.HttpClient.Do(req) if err != nil { - return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil + return nil, types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) } if common.IsFailureStatusCode(resp) { - return p.handleErrorResp(resp), nil + return nil, p.HandleErrorResp(resp) } defer resp.Body.Close() @@ -220,7 +193,7 @@ func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta } stopChan <- true }() - setEventStreamHeaders(p.Context) + common.SetEventStreamHeaders(p.Context) lastResponseText := "" p.Context.Stream(func(w io.Writer) bool { select { @@ -252,5 +225,5 @@ func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta } }) - return nil, usage + return } diff --git a/providers/ali_embeddings.go b/providers/ali/embeddings.go similarity index 59% rename from providers/ali_embeddings.go rename to providers/ali/embeddings.go index 913c371a..b3ce200f 100644 --- a/providers/ali_embeddings.go +++ b/providers/ali/embeddings.go @@ -1,4 +1,4 @@ -package providers +package ali import ( "net/http" @@ -6,30 +6,8 @@ import ( "one-api/types" ) -type AliEmbeddingRequest struct { - Model string `json:"model"` - Input struct { - Texts []string `json:"texts"` - } `json:"input"` - Parameters *struct { - TextType string `json:"text_type,omitempty"` - } `json:"parameters,omitempty"` -} - -type AliEmbedding struct { - Embedding []float64 `json:"embedding"` - TextIndex int `json:"text_index"` -} - -type AliEmbeddingResponse struct { - Output struct { - Embeddings []AliEmbedding `json:"embeddings"` - } `json:"output"` - Usage AliUsage `json:"usage"` - AliError -} - -func (aliResponse *AliEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +// 嵌入请求处理 +func (aliResponse *AliEmbeddingResponse) ResponseHandler(resp *http.Response) (any, *types.OpenAIErrorWithStatusCode) { if aliResponse.Code != "" { return nil, &types.OpenAIErrorWithStatusCode{ OpenAIError: types.OpenAIError{ @@ -60,7 +38,8 @@ func (aliResponse *AliEmbeddingResponse) requestHandler(resp *http.Response) (Op return openAIEmbeddingResponse, nil } -func (p *AliAIProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest { +// 获取嵌入请求体 +func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest { return &AliEmbeddingRequest{ Model: "text-embedding-v1", Input: struct { @@ -71,7 +50,7 @@ func (p *AliAIProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest } } -func (p *AliAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *AliProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { requestBody := p.getEmbeddingsRequestBody(request) fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model) @@ -84,8 +63,8 @@ func (p *AliAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isMo } aliEmbeddingResponse := &AliEmbeddingResponse{} - openAIErrorWithStatusCode = p.sendRequest(req, aliEmbeddingResponse) - if openAIErrorWithStatusCode != nil { + errWithCode = p.SendRequest(req, aliEmbeddingResponse) + if errWithCode != nil { return } usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens} diff --git a/providers/ali/type.go b/providers/ali/type.go new file mode 100644 index 00000000..e4c5d3d2 --- /dev/null +++ b/providers/ali/type.go @@ -0,0 +1,70 @@ +package ali + +type AliError struct { + Code string `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` +} + +type AliUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type AliMessage struct { + User string `json:"user"` + Bot string `json:"bot"` +} + +type AliInput struct { + Prompt string `json:"prompt"` + History []AliMessage `json:"history"` +} + +type AliParameters struct { + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` +} + +type AliChatRequest struct { + Model string `json:"model"` + Input AliInput `json:"input"` + Parameters AliParameters `json:"parameters,omitempty"` +} + +type AliOutput struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` +} + +type AliChatResponse struct { + Output AliOutput `json:"output"` + Usage AliUsage `json:"usage"` + AliError +} + +type AliEmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type AliEmbedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type AliEmbeddingResponse struct { + Output struct { + Embeddings []AliEmbedding `json:"embeddings"` + } `json:"output"` + Usage AliUsage `json:"usage"` + AliError +} diff --git a/providers/ali_base.go b/providers/ali_base.go deleted file mode 100644 index df5e4812..00000000 --- a/providers/ali_base.go +++ /dev/null @@ -1,50 +0,0 @@ -package providers - -import ( - "fmt" - - "github.com/gin-gonic/gin" -) - -type AliAIProvider struct { - ProviderConfig -} - -type AliError struct { - Code string `json:"code"` - Message string `json:"message"` - RequestId string `json:"request_id"` -} - -type AliUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` -} - -// 创建 AliAIProvider -// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation -func CreateAliAIProvider(c *gin.Context) *AliAIProvider { - return &AliAIProvider{ - ProviderConfig: ProviderConfig{ - BaseURL: "https://dashscope.aliyuncs.com", - ChatCompletions: "/api/v1/services/aigc/text-generation/generation", - Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding", - Context: c, - }, - } -} - -// 获取请求头 -func (p *AliAIProvider) GetRequestHeaders() (headers map[string]string) { - headers = make(map[string]string) - headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key")) - - headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") - headers["Accept"] = p.Context.Request.Header.Get("Accept") - if headers["Content-Type"] == "" { - headers["Content-Type"] = "application/json" - } - - return headers -} diff --git a/providers/api2d/base.go b/providers/api2d/base.go new file mode 100644 index 00000000..b81d2371 --- /dev/null +++ b/providers/api2d/base.go @@ -0,0 +1,18 @@ +package api2d + +import ( + "one-api/providers/openai" + + "github.com/gin-gonic/gin" +) + +type Api2dProvider struct { + *openai.OpenAIProvider +} + +// 创建 Api2dProvider +func CreateApi2dProvider(c *gin.Context) *Api2dProvider { + return &Api2dProvider{ + OpenAIProvider: openai.CreateOpenAIProvider(c, "https://oa.api2d.net"), + } +} diff --git a/providers/api2d_base.go b/providers/api2d_base.go deleted file mode 100644 index e4b4cc28..00000000 --- a/providers/api2d_base.go +++ /dev/null @@ -1,14 +0,0 @@ -package providers - -import "github.com/gin-gonic/gin" - -type Api2dProvider struct { - *OpenAIProvider -} - -// 创建 OpenAIProvider -func CreateApi2dProvider(c *gin.Context) *Api2dProvider { - return &Api2dProvider{ - OpenAIProvider: CreateOpenAIProvider(c, "https://oa.api2d.net"), - } -} diff --git a/providers/azure/base.go b/providers/azure/base.go new file mode 100644 index 00000000..8a3def4e --- /dev/null +++ b/providers/azure/base.go @@ -0,0 +1,31 @@ +package azure + +import ( + "one-api/providers/base" + "one-api/providers/openai" + + "github.com/gin-gonic/gin" +) + +type AzureProvider struct { + openai.OpenAIProvider +} + +// 创建 OpenAIProvider +func CreateAzureProvider(c *gin.Context) *AzureProvider { + return &AzureProvider{ + OpenAIProvider: openai.OpenAIProvider{ + BaseProvider: base.BaseProvider{ + BaseURL: "", + Completions: "/completions", + ChatCompletions: "/chat/completions", + Embeddings: "/embeddings", + AudioSpeech: "/audio/speech", + AudioTranscriptions: "/audio/transcriptions", + AudioTranslations: "/audio/translations", + Context: c, + }, + IsAzure: true, + }, + } +} diff --git a/providers/azure_base.go b/providers/azure_base.go deleted file mode 100644 index 0f1aa017..00000000 --- a/providers/azure_base.go +++ /dev/null @@ -1,41 +0,0 @@ -package providers - -import ( - "github.com/gin-gonic/gin" -) - -type AzureProvider struct { - OpenAIProvider -} - -// 创建 OpenAIProvider -func CreateAzureProvider(c *gin.Context) *AzureProvider { - return &AzureProvider{ - OpenAIProvider: OpenAIProvider{ - ProviderConfig: ProviderConfig{ - BaseURL: "", - Completions: "/completions", - ChatCompletions: "/chat/completions", - Embeddings: "/embeddings", - AudioSpeech: "/audio/speech", - AudioTranscriptions: "/audio/transcriptions", - AudioTranslations: "/audio/translations", - Context: c, - }, - isAzure: true, - }, - } -} - -// // 获取完整请求 URL -// func (p *AzureProvider) GetFullRequestURL(requestURL string, modelName string) string { -// apiVersion := p.Context.GetString("api_version") -// requestURL = fmt.Sprintf("/openai/deployments/%s/%s?api-version=%s", modelName, requestURL, apiVersion) -// baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") - -// if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { -// requestURL = strings.TrimPrefix(requestURL, "/openai/deployments") -// } - -// return fmt.Sprintf("%s%s", baseURL, requestURL) -// } diff --git a/providers/baidu_base.go b/providers/baidu/base.go similarity index 84% rename from providers/baidu_base.go rename to providers/baidu/base.go index 85cd5ec4..1a5005fc 100644 --- a/providers/baidu_base.go +++ b/providers/baidu/base.go @@ -1,10 +1,11 @@ -package providers +package baidu import ( "encoding/json" "errors" "fmt" "one-api/common" + "one-api/providers/base" "strings" "sync" "time" @@ -15,20 +16,12 @@ import ( var baiduTokenStore sync.Map type BaiduProvider struct { - ProviderConfig -} - -type BaiduAccessToken struct { - AccessToken string `json:"access_token"` - Error string `json:"error,omitempty"` - ErrorDescription string `json:"error_description,omitempty"` - ExpiresIn int64 `json:"expires_in,omitempty"` - ExpiresAt time.Time `json:"-"` + base.BaseProvider } func CreateBaiduProvider(c *gin.Context) *BaiduProvider { return &BaiduProvider{ - ProviderConfig: ProviderConfig{ + BaseProvider: base.BaseProvider{ BaseURL: "https://aip.baidubce.com", ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat", Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings", @@ -59,12 +52,7 @@ func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) s // 获取请求头 func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) - - headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") - headers["Accept"] = p.Context.Request.Header.Get("Accept") - if headers["Content-Type"] == "" { - headers["Content-Type"] = "application/json" - } + p.CommonRequestHeaders(headers) return headers } diff --git a/providers/baidu_chat.go b/providers/baidu/chat.go similarity index 69% rename from providers/baidu_chat.go rename to providers/baidu/chat.go index 26fafa96..0cf19751 100644 --- a/providers/baidu_chat.go +++ b/providers/baidu/chat.go @@ -1,4 +1,4 @@ -package providers +package baidu import ( "bufio" @@ -6,33 +6,12 @@ import ( "io" "net/http" "one-api/common" + "one-api/providers/base" "one-api/types" "strings" ) -type BaiduMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type BaiduChatRequest struct { - Messages []BaiduMessage `json:"messages"` - Stream bool `json:"stream"` - UserId string `json:"user_id,omitempty"` -} - -type BaiduChatResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Result string `json:"result"` - IsTruncated bool `json:"is_truncated"` - NeedClearHistory bool `json:"need_clear_history"` - Usage *types.Usage `json:"usage"` - BaiduError -} - -func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (baiduResponse *BaiduChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { if baiduResponse.ErrorMsg != "" { return nil, &types.OpenAIErrorWithStatusCode{ OpenAIError: types.OpenAIError{ @@ -54,7 +33,7 @@ func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (Ope FinishReason: "stop", } - fullTextResponse := types.ChatCompletionResponse{ + OpenAIResponse = types.ChatCompletionResponse{ ID: baiduResponse.Id, Object: "chat.completion", Created: baiduResponse.Created, @@ -62,18 +41,7 @@ func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (Ope Usage: baiduResponse.Usage, } - return fullTextResponse, nil -} - -type BaiduChatStreamResponse struct { - BaiduChatResponse - SentenceId int `json:"sentence_id"` - IsEnd bool `json:"is_end"` -} - -type BaiduError struct { - ErrorCode int `json:"error_code"` - ErrorMsg string `json:"error_msg"` + return } func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaiduChatRequest { @@ -101,7 +69,7 @@ func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) } } -func (p *BaiduProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { requestBody := p.getChatRequestBody(request) fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) if fullRequestURL == "" { @@ -120,15 +88,15 @@ func (p *BaiduProvider) ChatCompleteResponse(request *types.ChatCompletionReques } if request.Stream { - openAIErrorWithStatusCode, usage = p.sendStreamRequest(req) - if openAIErrorWithStatusCode != nil { + usage, errWithCode = p.sendStreamRequest(req) + if errWithCode != nil { return } } else { baiduChatRequest := &BaiduChatResponse{} - openAIErrorWithStatusCode = p.sendRequest(req, baiduChatRequest) - if openAIErrorWithStatusCode != nil { + errWithCode = p.SendRequest(req, baiduChatRequest) + if errWithCode != nil { return } @@ -142,7 +110,7 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea var choice types.ChatCompletionStreamChoice choice.Delta.Content = baiduResponse.Result if baiduResponse.IsEnd { - choice.FinishReason = &stopFinishReason + choice.FinishReason = &base.StopFinishReason } response := types.ChatCompletionStreamResponse{ @@ -155,16 +123,16 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea return &response } -func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) { +func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { usage = &types.Usage{} // 发送请求 resp, err := common.HttpClient.Do(req) if err != nil { - return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil + return nil, types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) } if common.IsFailureStatusCode(resp) { - return p.handleErrorResp(resp), nil + return nil, p.HandleErrorResp(resp) } defer resp.Body.Close() @@ -195,7 +163,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta } stopChan <- true }() - setEventStreamHeaders(p.Context) + common.SetEventStreamHeaders(p.Context) p.Context.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -224,5 +192,5 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta } }) - return nil, usage + return usage, nil } diff --git a/providers/baidu_embeddings.go b/providers/baidu/embeddings.go similarity index 62% rename from providers/baidu_embeddings.go rename to providers/baidu/embeddings.go index 9318d3f3..9a26e1a5 100644 --- a/providers/baidu_embeddings.go +++ b/providers/baidu/embeddings.go @@ -1,4 +1,4 @@ -package providers +package baidu import ( "net/http" @@ -6,32 +6,13 @@ import ( "one-api/types" ) -type BaiduEmbeddingRequest struct { - Input []string `json:"input"` -} - -type BaiduEmbeddingData struct { - Object string `json:"object"` - Embedding []float64 `json:"embedding"` - Index int `json:"index"` -} - -type BaiduEmbeddingResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Data []BaiduEmbeddingData `json:"data"` - Usage types.Usage `json:"usage"` - BaiduError -} - func (p *BaiduProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *BaiduEmbeddingRequest { return &BaiduEmbeddingRequest{ Input: request.ParseInput(), } } -func (baiduResponse *BaiduEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (baiduResponse *BaiduEmbeddingResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { if baiduResponse.ErrorMsg != "" { return nil, &types.OpenAIErrorWithStatusCode{ OpenAIError: types.OpenAIError{ @@ -62,7 +43,7 @@ func (baiduResponse *BaiduEmbeddingResponse) requestHandler(resp *http.Response) return openAIEmbeddingResponse, nil } -func (p *BaiduProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *BaiduProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { requestBody := p.getEmbeddingsRequestBody(request) fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model) @@ -78,8 +59,8 @@ func (p *BaiduProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isMo } baiduEmbeddingResponse := &BaiduEmbeddingResponse{} - openAIErrorWithStatusCode = p.sendRequest(req, baiduEmbeddingResponse) - if openAIErrorWithStatusCode != nil { + errWithCode = p.SendRequest(req, baiduEmbeddingResponse) + if errWithCode != nil { return } usage = &baiduEmbeddingResponse.Usage diff --git a/providers/baidu/type.go b/providers/baidu/type.go new file mode 100644 index 00000000..30e6f4c4 --- /dev/null +++ b/providers/baidu/type.go @@ -0,0 +1,66 @@ +package baidu + +import ( + "one-api/types" + "time" +) + +type BaiduAccessToken struct { + AccessToken string `json:"access_token"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + ExpiresAt time.Time `json:"-"` +} + +type BaiduMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type BaiduChatRequest struct { + Messages []BaiduMessage `json:"messages"` + Stream bool `json:"stream"` + UserId string `json:"user_id,omitempty"` +} + +type BaiduChatResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Result string `json:"result"` + IsTruncated bool `json:"is_truncated"` + NeedClearHistory bool `json:"need_clear_history"` + Usage *types.Usage `json:"usage"` + BaiduError +} + +type BaiduEmbeddingRequest struct { + Input []string `json:"input"` +} + +type BaiduEmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +type BaiduEmbeddingResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Data []BaiduEmbeddingData `json:"data"` + Usage types.Usage `json:"usage"` + BaiduError +} + +type BaiduChatStreamResponse struct { + BaiduChatResponse + SentenceId int `json:"sentence_id"` + IsEnd bool `json:"is_end"` +} + +type BaiduError struct { + ErrorCode int `json:"error_code"` + ErrorMsg string `json:"error_msg"` +} diff --git a/providers/base.go b/providers/base/common.go similarity index 51% rename from providers/base.go rename to providers/base/common.go index 895b157a..c02a2f38 100644 --- a/providers/base.go +++ b/providers/base/common.go @@ -1,4 +1,4 @@ -package providers +package base import ( "encoding/json" @@ -6,7 +6,6 @@ import ( "io" "net/http" "one-api/common" - "one-api/model" "one-api/types" "strconv" "strings" @@ -14,9 +13,9 @@ import ( "github.com/gin-gonic/gin" ) -var stopFinishReason = "stop" +var StopFinishReason = "stop" -type ProviderConfig struct { +type BaseProvider struct { BaseURL string Completions string ChatCompletions string @@ -28,32 +27,8 @@ type ProviderConfig struct { Context *gin.Context } -type BaseProviderAction interface { - GetBaseURL() string - GetFullRequestURL(requestURL string, modelName string) string - GetRequestHeaders() (headers map[string]string) -} - -type CompletionProviderAction interface { - BaseProviderAction - CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) -} - -type ChatProviderAction interface { - BaseProviderAction - ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) -} - -type EmbeddingsProviderAction interface { - BaseProviderAction - EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) -} - -type BalanceProviderAction interface { - Balance(channel *model.Channel) (float64, error) -} - -func (p *ProviderConfig) GetBaseURL() string { +// 获取基础URL +func (p *BaseProvider) GetBaseURL() string { if p.Context.GetString("base_url") != "" { return p.Context.GetString("base_url") } @@ -61,21 +36,66 @@ func (p *ProviderConfig) GetBaseURL() string { return p.BaseURL } -func (p *ProviderConfig) GetFullRequestURL(requestURL string, modelName string) string { +// 获取完整请求URL +func (p *BaseProvider) GetFullRequestURL(requestURL string, modelName string) string { baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") return fmt.Sprintf("%s%s", baseURL, requestURL) } -func setEventStreamHeaders(c *gin.Context) { - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") +// 获取请求头 +func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) { + headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") + headers["Accept"] = p.Context.Request.Header.Get("Accept") + if headers["Content-Type"] == "" { + headers["Content-Type"] = "application/json" + } } -func (p *ProviderConfig) handleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +// 发送请求 +func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + + // 发送请求 + resp, err := common.HttpClient.Do(req) + if err != nil { + return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) + } + + defer resp.Body.Close() + + // 处理响应 + if common.IsFailureStatusCode(resp) { + return p.HandleErrorResp(resp) + } + + // 解析响应 + err = common.DecodeResponse(resp.Body, response) + if err != nil { + return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) + } + + openAIResponse, openAIErrorWithStatusCode := response.ResponseHandler(resp) + if openAIErrorWithStatusCode != nil { + return + } + + jsonResponse, err := json.Marshal(openAIResponse) + if err != nil { + return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) + } + p.Context.Writer.Header().Set("Content-Type", "application/json") + p.Context.Writer.WriteHeader(resp.StatusCode) + _, err = p.Context.Writer.Write(jsonResponse) + + if err != nil { + return types.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError) + } + + return nil +} + +// 处理错误响应 +func (p *BaseProvider) HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, OpenAIError: types.OpenAIError{ @@ -105,46 +125,3 @@ func (p *ProviderConfig) handleErrorResp(resp *http.Response) (openAIErrorWithSt } return } - -// 供应商响应处理函数 -type ProviderResponseHandler interface { - // 请求处理函数 - requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) -} - -// 发送请求 -func (p *ProviderConfig) sendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { - - // 发送请求 - resp, err := common.HttpClient.Do(req) - if err != nil { - return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) - } - - defer resp.Body.Close() - - // 处理响应 - if common.IsFailureStatusCode(resp) { - return p.handleErrorResp(resp) - } - - // 解析响应 - err = common.DecodeResponse(resp.Body, response) - if err != nil { - return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) - } - - openAIResponse, openAIErrorWithStatusCode := response.requestHandler(resp) - if openAIErrorWithStatusCode != nil { - return - } - - jsonResponse, err := json.Marshal(openAIResponse) - if err != nil { - return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) - } - p.Context.Writer.Header().Set("Content-Type", "application/json") - p.Context.Writer.WriteHeader(resp.StatusCode) - _, err = p.Context.Writer.Write(jsonResponse) - return nil -} diff --git a/providers/base/interface.go b/providers/base/interface.go new file mode 100644 index 00000000..f0f8e4a0 --- /dev/null +++ b/providers/base/interface.go @@ -0,0 +1,42 @@ +package base + +import ( + "net/http" + "one-api/model" + "one-api/types" +) + +// 基础接口 +type ProviderInterface interface { + GetBaseURL() string + GetFullRequestURL(requestURL string, modelName string) string + GetRequestHeaders() (headers map[string]string) +} + +// 完成接口 +type CompletionInterface interface { + ProviderInterface + CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) +} + +// 聊天接口 +type ChatInterface interface { + ProviderInterface + ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) +} + +// 嵌入接口 +type EmbeddingsInterface interface { + ProviderInterface + EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) +} + +// 余额接口 +type BalanceInterface interface { + BalanceAction(channel *model.Channel) (float64, error) +} + +type ProviderResponseHandler interface { + // 响应处理函数 + ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) +} diff --git a/providers/claude_base.go b/providers/claude/base.go similarity index 69% rename from providers/claude_base.go rename to providers/claude/base.go index 6a94e3f2..f77a5c2f 100644 --- a/providers/claude_base.go +++ b/providers/claude/base.go @@ -1,21 +1,18 @@ -package providers +package claude import ( + "one-api/providers/base" + "github.com/gin-gonic/gin" ) type ClaudeProvider struct { - ProviderConfig -} - -type ClaudeError struct { - Type string `json:"type"` - Message string `json:"message"` + base.BaseProvider } func CreateClaudeProvider(c *gin.Context) *ClaudeProvider { return &ClaudeProvider{ - ProviderConfig: ProviderConfig{ + BaseProvider: base.BaseProvider{ BaseURL: "https://api.anthropic.com", ChatCompletions: "/v1/complete", Context: c, @@ -26,14 +23,9 @@ func CreateClaudeProvider(c *gin.Context) *ClaudeProvider { // 获取请求头 func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) + p.CommonRequestHeaders(headers) headers["x-api-key"] = p.Context.GetString("api_key") - headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") - headers["Accept"] = p.Context.Request.Header.Get("Accept") - if headers["Content-Type"] == "" { - headers["Content-Type"] = "application/json" - } - anthropicVersion := p.Context.Request.Header.Get("anthropic-version") if anthropicVersion == "" { anthropicVersion = "2023-06-01" diff --git a/providers/claude_chat.go b/providers/claude/chat.go similarity index 79% rename from providers/claude_chat.go rename to providers/claude/chat.go index e6c1d11a..660db19a 100644 --- a/providers/claude_chat.go +++ b/providers/claude/chat.go @@ -1,4 +1,4 @@ -package providers +package claude import ( "bufio" @@ -11,31 +11,7 @@ import ( "strings" ) -type ClaudeMetadata struct { - UserId string `json:"user_id"` -} - -type ClaudeRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - MaxTokensToSample int `json:"max_tokens_to_sample"` - StopSequences []string `json:"stop_sequences,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - //ClaudeMetadata `json:"metadata,omitempty"` - Stream bool `json:"stream,omitempty"` -} - -type ClaudeResponse struct { - Completion string `json:"completion"` - StopReason string `json:"stop_reason"` - Model string `json:"model"` - Error ClaudeError `json:"error"` - Usage *types.Usage `json:"usage,omitempty"` -} - -func (claudeResponse *ClaudeResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (claudeResponse *ClaudeResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { if claudeResponse.Error.Type != "" { return nil, &types.OpenAIErrorWithStatusCode{ OpenAIError: types.OpenAIError{ @@ -101,7 +77,7 @@ func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest return &claudeRequest } -func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *ClaudeProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { requestBody := p.getChatRequestBody(request) fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) headers := p.GetRequestHeaders() @@ -117,8 +93,8 @@ func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionReque if request.Stream { var responseText string - openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req) - if openAIErrorWithStatusCode != nil { + errWithCode, responseText = p.sendStreamRequest(req) + if errWithCode != nil { return } @@ -132,8 +108,8 @@ func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionReque PromptTokens: promptTokens, }, } - openAIErrorWithStatusCode = p.sendRequest(req, claudeResponse) - if openAIErrorWithStatusCode != nil { + errWithCode = p.SendRequest(req, claudeResponse) + if errWithCode != nil { return } @@ -165,7 +141,7 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro } if common.IsFailureStatusCode(resp) { - return p.handleErrorResp(resp), "" + return p.HandleErrorResp(resp), "" } defer resp.Body.Close() @@ -199,7 +175,7 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro } stopChan <- true }() - setEventStreamHeaders(p.Context) + common.SetEventStreamHeaders(p.Context) p.Context.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/providers/claude/type.go b/providers/claude/type.go new file mode 100644 index 00000000..8a920c73 --- /dev/null +++ b/providers/claude/type.go @@ -0,0 +1,32 @@ +package claude + +import "one-api/types" + +type ClaudeError struct { + Type string `json:"type"` + Message string `json:"message"` +} + +type ClaudeMetadata struct { + UserId string `json:"user_id"` +} + +type ClaudeRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + MaxTokensToSample int `json:"max_tokens_to_sample"` + StopSequences []string `json:"stop_sequences,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + //ClaudeMetadata `json:"metadata,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type ClaudeResponse struct { + Completion string `json:"completion"` + StopReason string `json:"stop_reason"` + Model string `json:"model"` + Error ClaudeError `json:"error"` + Usage *types.Usage `json:"usage,omitempty"` +} diff --git a/providers/closeai_proxy_base.go b/providers/closeai/balance.go similarity index 56% rename from providers/closeai_proxy_base.go rename to providers/closeai/balance.go index 9879bf38..85b22c02 100644 --- a/providers/closeai_proxy_base.go +++ b/providers/closeai/balance.go @@ -1,31 +1,11 @@ -package providers +package closeai import ( "fmt" "one-api/common" "one-api/model" - - "github.com/gin-gonic/gin" ) -type CloseaiProxyProvider struct { - *OpenAIProvider -} - -type OpenAICreditGrants struct { - Object string `json:"object"` - TotalGranted float64 `json:"total_granted"` - TotalUsed float64 `json:"total_used"` - TotalAvailable float64 `json:"total_available"` -} - -// 创建 CloseaiProxyProvider -func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider { - return &CloseaiProxyProvider{ - OpenAIProvider: CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"), - } -} - func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) { fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "") fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key) diff --git a/providers/closeai/base.go b/providers/closeai/base.go new file mode 100644 index 00000000..0d64d6e7 --- /dev/null +++ b/providers/closeai/base.go @@ -0,0 +1,18 @@ +package closeai + +import ( + "one-api/providers/openai" + + "github.com/gin-gonic/gin" +) + +type CloseaiProxyProvider struct { + *openai.OpenAIProvider +} + +// 创建 CloseaiProxyProvider +func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider { + return &CloseaiProxyProvider{ + OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"), + } +} diff --git a/providers/closeai/type.go b/providers/closeai/type.go new file mode 100644 index 00000000..81f64b8c --- /dev/null +++ b/providers/closeai/type.go @@ -0,0 +1,8 @@ +package closeai + +type OpenAICreditGrants struct { + Object string `json:"object"` + TotalGranted float64 `json:"total_granted"` + TotalUsed float64 `json:"total_used"` + TotalAvailable float64 `json:"total_available"` +} diff --git a/providers/openai_base.go b/providers/openai/base.go similarity index 84% rename from providers/openai_base.go rename to providers/openai/base.go index 1f2cd096..bd4e9cd0 100644 --- a/providers/openai_base.go +++ b/providers/openai/base.go @@ -1,4 +1,4 @@ -package providers +package openai import ( "bufio" @@ -11,32 +11,25 @@ import ( "one-api/types" "strings" + "one-api/providers/base" + "github.com/gin-gonic/gin" ) type OpenAIProvider struct { - ProviderConfig - isAzure bool -} - -type OpenAIProviderResponseHandler interface { - // 请求处理函数 - requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) -} - -type OpenAIProviderStreamResponseHandler interface { - // 请求流处理函数 - requestStreamHandler() (responseText string) + base.BaseProvider + IsAzure bool } // 创建 OpenAIProvider +// https://platform.openai.com/docs/api-reference/introduction func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider { if baseURL == "" { baseURL = "https://api.openai.com" } return &OpenAIProvider{ - ProviderConfig: ProviderConfig{ + BaseProvider: base.BaseProvider{ BaseURL: baseURL, Completions: "/v1/completions", ChatCompletions: "/v1/chat/completions", @@ -46,7 +39,7 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider { AudioTranslations: "/v1/audio/translations", Context: c, }, - isAzure: false, + IsAzure: false, } } @@ -54,13 +47,13 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider { func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string { baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") - if p.isAzure { + if p.IsAzure { apiVersion := p.Context.GetString("api_version") requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion) } if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { - if p.isAzure { + if p.IsAzure { requestURL = strings.TrimPrefix(requestURL, "/openai/deployments") } else { requestURL = strings.TrimPrefix(requestURL, "/v1") @@ -73,16 +66,12 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) // 获取请求头 func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) - if p.isAzure { + p.CommonRequestHeaders(headers) + if p.IsAzure { headers["api-key"] = p.Context.GetString("api_key") } else { headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key")) } - headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") - headers["Accept"] = p.Context.Request.Header.Get("Accept") - if headers["Content-Type"] == "" { - headers["Content-Type"] = "application/json; charset=utf-8" - } return headers } @@ -114,7 +103,7 @@ func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderR // 处理响应 if common.IsFailureStatusCode(resp) { - return p.handleErrorResp(resp) + return p.HandleErrorResp(resp) } // 创建一个 bytes.Buffer 来存储响应体 @@ -127,7 +116,7 @@ func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderR return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) } - openAIErrorWithStatusCode = response.requestHandler(resp) + openAIErrorWithStatusCode = response.responseHandler(resp) if openAIErrorWithStatusCode != nil { return } @@ -145,6 +134,7 @@ func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderR return nil } +// 发送流式请求 func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) { resp, err := common.HttpClient.Do(req) @@ -153,7 +143,7 @@ func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIPro } if common.IsFailureStatusCode(resp) { - return p.handleErrorResp(resp), "" + return p.HandleErrorResp(resp), "" } defer resp.Body.Close() @@ -190,12 +180,12 @@ func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIPro common.SysError("error unmarshalling stream response: " + err.Error()) continue // just ignore the error } - responseText += response.requestStreamHandler() + responseText += response.responseStreamHandler() } } stopChan <- true }() - setEventStreamHeaders(p.Context) + common.SetEventStreamHeaders(p.Context) p.Context.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/providers/openai_chat.go b/providers/openai/chat.go similarity index 64% rename from providers/openai_chat.go rename to providers/openai/chat.go index 6a7247b3..a937d004 100644 --- a/providers/openai_chat.go +++ b/providers/openai/chat.go @@ -1,4 +1,4 @@ -package providers +package openai import ( "net/http" @@ -6,19 +6,9 @@ import ( "one-api/types" ) -type OpenAIProviderChatResponse struct { - types.ChatCompletionResponse - types.OpenAIErrorResponse -} - -type OpenAIProviderChatStreamResponse struct { - types.ChatCompletionStreamResponse - types.OpenAIErrorResponse -} - -func (c *OpenAIProviderChatResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (c *OpenAIProviderChatResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) { if c.Error.Type != "" { - openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ + errWithCode = &types.OpenAIErrorWithStatusCode{ OpenAIError: c.Error, StatusCode: resp.StatusCode, } @@ -27,7 +17,7 @@ func (c *OpenAIProviderChatResponse) requestHandler(resp *http.Response) (openAI return nil } -func (c *OpenAIProviderChatStreamResponse) requestStreamHandler() (responseText string) { +func (c *OpenAIProviderChatStreamResponse) responseStreamHandler() (responseText string) { for _, choice := range c.Choices { responseText += choice.Delta.Content } @@ -35,7 +25,7 @@ func (c *OpenAIProviderChatStreamResponse) requestStreamHandler() (responseText return } -func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { requestBody, err := p.getRequestBody(&request, isModelMapped) if err != nil { return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) @@ -56,8 +46,8 @@ func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionReque if request.Stream { openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{} var textResponse string - openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse) - if openAIErrorWithStatusCode != nil { + errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse) + if errWithCode != nil { return } @@ -69,8 +59,8 @@ func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionReque } else { openAIProviderChatResponse := &OpenAIProviderChatResponse{} - openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderChatResponse) - if openAIErrorWithStatusCode != nil { + errWithCode = p.sendRequest(req, openAIProviderChatResponse) + if errWithCode != nil { return } diff --git a/providers/openai_completion.go b/providers/openai/completion.go similarity index 66% rename from providers/openai_completion.go rename to providers/openai/completion.go index df99903e..e446c93b 100644 --- a/providers/openai_completion.go +++ b/providers/openai/completion.go @@ -1,4 +1,4 @@ -package providers +package openai import ( "net/http" @@ -6,14 +6,9 @@ import ( "one-api/types" ) -type OpenAIProviderCompletionResponse struct { - types.CompletionResponse - types.OpenAIErrorResponse -} - -func (c *OpenAIProviderCompletionResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (c *OpenAIProviderCompletionResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) { if c.Error.Type != "" { - openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ + errWithCode = &types.OpenAIErrorWithStatusCode{ OpenAIError: c.Error, StatusCode: resp.StatusCode, } @@ -22,7 +17,7 @@ func (c *OpenAIProviderCompletionResponse) requestHandler(resp *http.Response) ( return nil } -func (c *OpenAIProviderCompletionResponse) requestStreamHandler() (responseText string) { +func (c *OpenAIProviderCompletionResponse) responseStreamHandler() (responseText string) { for _, choice := range c.Choices { responseText += choice.Text } @@ -30,7 +25,7 @@ func (c *OpenAIProviderCompletionResponse) requestStreamHandler() (responseText return } -func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { requestBody, err := p.getRequestBody(&request, isModelMapped) if err != nil { return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) @@ -52,8 +47,8 @@ func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isMo if request.Stream { // TODO var textResponse string - openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse) - if openAIErrorWithStatusCode != nil { + errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse) + if errWithCode != nil { return } @@ -64,8 +59,8 @@ func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isMo } } else { - openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderCompletionResponse) - if openAIErrorWithStatusCode != nil { + errWithCode = p.sendRequest(req, openAIProviderCompletionResponse) + if errWithCode != nil { return } diff --git a/providers/openai_embeddings.go b/providers/openai/embeddings.go similarity index 57% rename from providers/openai_embeddings.go rename to providers/openai/embeddings.go index 00c3cc80..641caa49 100644 --- a/providers/openai_embeddings.go +++ b/providers/openai/embeddings.go @@ -1,4 +1,4 @@ -package providers +package openai import ( "net/http" @@ -6,14 +6,9 @@ import ( "one-api/types" ) -type OpenAIProviderEmbeddingsResponse struct { - types.EmbeddingResponse - types.OpenAIErrorResponse -} - -func (c *OpenAIProviderEmbeddingsResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (c *OpenAIProviderEmbeddingsResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) { if c.Error.Type != "" { - openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ + errWithCode = &types.OpenAIErrorWithStatusCode{ OpenAIError: c.Error, StatusCode: resp.StatusCode, } @@ -22,7 +17,7 @@ func (c *OpenAIProviderEmbeddingsResponse) requestHandler(resp *http.Response) ( return nil } -func (p *OpenAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { requestBody, err := p.getRequestBody(&request, isModelMapped) if err != nil { @@ -39,8 +34,8 @@ func (p *OpenAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isM } openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{} - openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderEmbeddingsResponse) - if openAIErrorWithStatusCode != nil { + errWithCode = p.sendRequest(req, openAIProviderEmbeddingsResponse) + if errWithCode != nil { return } diff --git a/providers/openai/interface.go b/providers/openai/interface.go new file mode 100644 index 00000000..1695be8c --- /dev/null +++ b/providers/openai/interface.go @@ -0,0 +1,16 @@ +package openai + +import ( + "net/http" + "one-api/types" +) + +type OpenAIProviderResponseHandler interface { + // 请求处理函数 + responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) +} + +type OpenAIProviderStreamResponseHandler interface { + // 请求流处理函数 + responseStreamHandler() (responseText string) +} diff --git a/providers/openai/type.go b/providers/openai/type.go new file mode 100644 index 00000000..f8fee787 --- /dev/null +++ b/providers/openai/type.go @@ -0,0 +1,23 @@ +package openai + +import "one-api/types" + +type OpenAIProviderChatResponse struct { + types.ChatCompletionResponse + types.OpenAIErrorResponse +} + +type OpenAIProviderChatStreamResponse struct { + types.ChatCompletionStreamResponse + types.OpenAIErrorResponse +} + +type OpenAIProviderCompletionResponse struct { + types.CompletionResponse + types.OpenAIErrorResponse +} + +type OpenAIProviderEmbeddingsResponse struct { + types.EmbeddingResponse + types.OpenAIErrorResponse +} diff --git a/providers/openaisb_base.go b/providers/openaisb/balance.go similarity index 66% rename from providers/openaisb_base.go rename to providers/openaisb/balance.go index c7f11fb9..72ea530e 100644 --- a/providers/openaisb_base.go +++ b/providers/openaisb/balance.go @@ -1,4 +1,4 @@ -package providers +package openaisb import ( "errors" @@ -6,28 +6,8 @@ import ( "one-api/common" "one-api/model" "strconv" - - "github.com/gin-gonic/gin" ) -type OpenaiSBProvider struct { - *OpenAIProvider -} - -type OpenAISBUsageResponse struct { - Msg string `json:"msg"` - Data *struct { - Credit string `json:"credit"` - } `json:"data"` -} - -// 创建 OpenaiSBProvider -func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider { - return &OpenaiSBProvider{ - OpenAIProvider: CreateOpenAIProvider(c, "https://api.openai-sb.com"), - } -} - func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) { fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "") fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key) diff --git a/providers/openaisb/base.go b/providers/openaisb/base.go new file mode 100644 index 00000000..f5f46dfb --- /dev/null +++ b/providers/openaisb/base.go @@ -0,0 +1,18 @@ +package openaisb + +import ( + "one-api/providers/openai" + + "github.com/gin-gonic/gin" +) + +type OpenaiSBProvider struct { + *openai.OpenAIProvider +} + +// 创建 OpenaiSBProvider +func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider { + return &OpenaiSBProvider{ + OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.openai-sb.com"), + } +} diff --git a/providers/openaisb/type.go b/providers/openaisb/type.go new file mode 100644 index 00000000..d4f10ead --- /dev/null +++ b/providers/openaisb/type.go @@ -0,0 +1,8 @@ +package openaisb + +type OpenAISBUsageResponse struct { + Msg string `json:"msg"` + Data *struct { + Credit string `json:"credit"` + } `json:"data"` +} diff --git a/providers/palm_base.go b/providers/palm/base.go similarity index 72% rename from providers/palm_base.go rename to providers/palm/base.go index 40d33ac0..ee10709e 100644 --- a/providers/palm_base.go +++ b/providers/palm/base.go @@ -1,20 +1,21 @@ -package providers +package palm import ( "fmt" + "one-api/providers/base" "strings" "github.com/gin-gonic/gin" ) type PalmProvider struct { - ProviderConfig + base.BaseProvider } // 创建 PalmProvider func CreatePalmProvider(c *gin.Context) *PalmProvider { return &PalmProvider{ - ProviderConfig: ProviderConfig{ + BaseProvider: base.BaseProvider{ BaseURL: "https://generativelanguage.googleapis.com", ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage", Context: c, @@ -25,12 +26,7 @@ func CreatePalmProvider(c *gin.Context) *PalmProvider { // 获取请求头 func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) - - headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") - headers["Accept"] = p.Context.Request.Header.Get("Accept") - if headers["Content-Type"] == "" { - headers["Content-Type"] = "application/json" - } + p.CommonRequestHeaders(headers) return headers } diff --git a/providers/palm_chat.go b/providers/palm/chat.go similarity index 74% rename from providers/palm_chat.go rename to providers/palm/chat.go index 37c1fde1..3159bdf0 100644 --- a/providers/palm_chat.go +++ b/providers/palm/chat.go @@ -1,4 +1,4 @@ -package providers +package palm import ( "encoding/json" @@ -6,47 +6,11 @@ import ( "io" "net/http" "one-api/common" + "one-api/providers/base" "one-api/types" ) -type PaLMChatMessage struct { - Author string `json:"author"` - Content string `json:"content"` -} - -type PaLMFilter struct { - Reason string `json:"reason"` - Message string `json:"message"` -} - -type PaLMPrompt struct { - Messages []PaLMChatMessage `json:"messages"` -} - -type PaLMChatRequest struct { - Prompt PaLMPrompt `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` -} - -type PaLMError struct { - Code int `json:"code"` - Message string `json:"message"` - Status string `json:"status"` -} - -type PaLMChatResponse struct { - Candidates []PaLMChatMessage `json:"candidates"` - Messages []types.ChatCompletionMessage `json:"messages"` - Filters []PaLMFilter `json:"filters"` - Error PaLMError `json:"error"` - Usage *types.Usage `json:"usage,omitempty"` - Model string `json:"model,omitempty"` -} - -func (palmResponse *PaLMChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { return nil, &types.OpenAIErrorWithStatusCode{ OpenAIError: types.OpenAIError{ @@ -107,7 +71,7 @@ func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest) return &palmRequest } -func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *PalmProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { requestBody := p.getChatRequestBody(request) fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) headers := p.GetRequestHeaders() @@ -123,8 +87,8 @@ func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest if request.Stream { var responseText string - openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req) - if openAIErrorWithStatusCode != nil { + errWithCode, responseText = p.sendStreamRequest(req) + if errWithCode != nil { return } @@ -139,8 +103,8 @@ func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest PromptTokens: promptTokens, }, } - openAIErrorWithStatusCode = p.sendRequest(req, palmChatResponse) - if openAIErrorWithStatusCode != nil { + errWithCode = p.SendRequest(req, palmChatResponse) + if errWithCode != nil { return } @@ -155,7 +119,7 @@ func (p *PalmProvider) streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) if len(palmResponse.Candidates) > 0 { choice.Delta.Content = palmResponse.Candidates[0].Content } - choice.FinishReason = &stopFinishReason + choice.FinishReason = &base.StopFinishReason var response types.ChatCompletionStreamResponse response.Object = "chat.completion.chunk" response.Model = "palm2" @@ -171,7 +135,7 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW } if common.IsFailureStatusCode(resp) { - return p.handleErrorResp(resp), "" + return p.HandleErrorResp(resp), "" } defer resp.Body.Close() @@ -216,7 +180,7 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW dataChan <- string(jsonResponse) stopChan <- true }() - setEventStreamHeaders(p.Context) + common.SetEventStreamHeaders(p.Context) p.Context.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/providers/palm/type.go b/providers/palm/type.go new file mode 100644 index 00000000..76eadded --- /dev/null +++ b/providers/palm/type.go @@ -0,0 +1,40 @@ +package palm + +import "one-api/types" + +type PaLMChatMessage struct { + Author string `json:"author"` + Content string `json:"content"` +} + +type PaLMFilter struct { + Reason string `json:"reason"` + Message string `json:"message"` +} + +type PaLMPrompt struct { + Messages []PaLMChatMessage `json:"messages"` +} + +type PaLMChatRequest struct { + Prompt PaLMPrompt `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` +} + +type PaLMError struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` +} + +type PaLMChatResponse struct { + Candidates []PaLMChatMessage `json:"candidates"` + Messages []types.ChatCompletionMessage `json:"messages"` + Filters []PaLMFilter `json:"filters"` + Error PaLMError `json:"error"` + Usage *types.Usage `json:"usage,omitempty"` + Model string `json:"model,omitempty"` +} diff --git a/providers/providers.go b/providers/providers.go new file mode 100644 index 00000000..9eeaa34e --- /dev/null +++ b/providers/providers.go @@ -0,0 +1,50 @@ +package providers + +import ( + "one-api/common" + "one-api/providers/ali" + "one-api/providers/azure" + "one-api/providers/baidu" + "one-api/providers/base" + "one-api/providers/claude" + "one-api/providers/openai" + "one-api/providers/palm" + "one-api/providers/tencent" + "one-api/providers/xunfei" + "one-api/providers/zhipu" + + "github.com/gin-gonic/gin" +) + +func GetProvider(channelType int, c *gin.Context) base.ProviderInterface { + switch channelType { + case common.ChannelTypeOpenAI: + return openai.CreateOpenAIProvider(c, "") + case common.ChannelTypeAzure: + return azure.CreateAzureProvider(c) + case common.ChannelTypeAli: + return ali.CreateAliAIProvider(c) + case common.ChannelTypeTencent: + return tencent.CreateTencentProvider(c) + case common.ChannelTypeBaidu: + return baidu.CreateBaiduProvider(c) + case common.ChannelTypeAnthropic: + return claude.CreateClaudeProvider(c) + case common.ChannelTypePaLM: + return palm.CreatePalmProvider(c) + case common.ChannelTypeZhipu: + return zhipu.CreateZhipuProvider(c) + case common.ChannelTypeXunfei: + return xunfei.CreateXunfeiProvider(c) + default: + baseURL := common.ChannelBaseURLs[channelType] + if c.GetString("base_url") != "" { + baseURL = c.GetString("base_url") + } + if baseURL != "" { + return openai.CreateOpenAIProvider(c, baseURL) + } + + return nil + } +} diff --git a/providers/tencent_base.go b/providers/tencent/base.go similarity index 85% rename from providers/tencent_base.go rename to providers/tencent/base.go index 0318bb07..ab7b2b49 100644 --- a/providers/tencent_base.go +++ b/providers/tencent/base.go @@ -1,4 +1,4 @@ -package providers +package tencent import ( "crypto/hmac" @@ -6,6 +6,7 @@ import ( "encoding/base64" "errors" "fmt" + "one-api/providers/base" "sort" "strconv" "strings" @@ -14,18 +15,13 @@ import ( ) type TencentProvider struct { - ProviderConfig -} - -type TencentError struct { - Code int `json:"code"` - Message string `json:"message"` + base.BaseProvider } // 创建 TencentProvider func CreateTencentProvider(c *gin.Context) *TencentProvider { return &TencentProvider{ - ProviderConfig: ProviderConfig{ + BaseProvider: base.BaseProvider{ BaseURL: "https://hunyuan.cloud.tencent.com", ChatCompletions: "/hyllm/v1/chat/completions", Context: c, @@ -36,12 +32,7 @@ func CreateTencentProvider(c *gin.Context) *TencentProvider { // 获取请求头 func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) - - headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") - headers["Accept"] = p.Context.Request.Header.Get("Accept") - if headers["Content-Type"] == "" { - headers["Content-Type"] = "application/json" - } + p.CommonRequestHeaders(headers) return headers } diff --git a/providers/tencent_chat.go b/providers/tencent/chat.go similarity index 58% rename from providers/tencent_chat.go rename to providers/tencent/chat.go index b7ee5eb8..52608630 100644 --- a/providers/tencent_chat.go +++ b/providers/tencent/chat.go @@ -1,4 +1,4 @@ -package providers +package tencent import ( "bufio" @@ -7,64 +7,12 @@ import ( "io" "net/http" "one-api/common" + "one-api/providers/base" "one-api/types" "strings" ) -type TencentMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type TencentChatRequest struct { - AppId int64 `json:"app_id"` // 腾讯云账号的 APPID - SecretId string `json:"secret_id"` // 官网 SecretId - // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 - // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 - Timestamp int64 `json:"timestamp"` - // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, - // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 - Expired int64 `json:"expired"` - QueryID string `json:"query_id"` //请求 Id,用于问题排查 - // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 - // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 - // 建议该参数和 top_p 只设置1个,不要同时更改 top_p - Temperature float64 `json:"temperature"` - // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 - // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 - // 建议该参数和 temperature 只设置1个,不要同时更改 - TopP float64 `json:"top_p"` - // Stream 0:同步,1:流式 (默认,协议:SSE) - // 同步请求超时:60s,如果内容较长建议使用流式 - Stream int `json:"stream"` - // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 - // 输入 content 总数最大支持 3000 token。 - Messages []TencentMessage `json:"messages"` -} - -type TencentUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type TencentResponseChoices struct { - FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 - Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 - Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 -} - -type TencentChatResponse struct { - Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 - Created string `json:"created,omitempty"` // unix 时间戳的字符串 - Id string `json:"id,omitempty"` // 会话 id - Usage *types.Usage `json:"usage,omitempty"` // token 数量 - Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 - Note string `json:"note,omitempty"` // 注释 - ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 -} - -func (TencentResponse *TencentChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (TencentResponse *TencentChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { if TencentResponse.Error.Code != 0 { return &types.OpenAIErrorWithStatusCode{ OpenAIError: types.OpenAIError{ @@ -130,7 +78,7 @@ func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionReques } } -func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { requestBody := p.getChatRequestBody(request) sign := p.getTencentSign(*requestBody) if sign == "" { @@ -152,8 +100,8 @@ func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequ if request.Stream { var responseText string - openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req) - if openAIErrorWithStatusCode != nil { + errWithCode, responseText = p.sendStreamRequest(req) + if errWithCode != nil { return } @@ -163,8 +111,8 @@ func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequ } else { tencentResponse := &TencentChatResponse{} - openAIErrorWithStatusCode = p.sendRequest(req, tencentResponse) - if openAIErrorWithStatusCode != nil { + errWithCode = p.SendRequest(req, tencentResponse) + if errWithCode != nil { return } @@ -184,7 +132,7 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC var choice types.ChatCompletionStreamChoice choice.Delta.Content = TencentResponse.Choices[0].Delta.Content if TencentResponse.Choices[0].FinishReason == "stop" { - choice.FinishReason = &stopFinishReason + choice.FinishReason = &base.StopFinishReason } response.Choices = append(response.Choices, choice) } @@ -199,7 +147,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr } if common.IsFailureStatusCode(resp) { - return p.handleErrorResp(resp), "" + return p.HandleErrorResp(resp), "" } defer resp.Body.Close() @@ -234,7 +182,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr } stopChan <- true }() - setEventStreamHeaders(p.Context) + common.SetEventStreamHeaders(p.Context) p.Context.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/providers/tencent/type.go b/providers/tencent/type.go new file mode 100644 index 00000000..300ba3af --- /dev/null +++ b/providers/tencent/type.go @@ -0,0 +1,61 @@ +package tencent + +import "one-api/types" + +type TencentError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type TencentMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type TencentChatRequest struct { + AppId int64 `json:"app_id"` // 腾讯云账号的 APPID + SecretId string `json:"secret_id"` // 官网 SecretId + // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 + // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 + Timestamp int64 `json:"timestamp"` + // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, + // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 + Expired int64 `json:"expired"` + QueryID string `json:"query_id"` //请求 Id,用于问题排查 + // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 + // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 + // 建议该参数和 top_p 只设置1个,不要同时更改 top_p + Temperature float64 `json:"temperature"` + // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 + // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 + // 建议该参数和 temperature 只设置1个,不要同时更改 + TopP float64 `json:"top_p"` + // Stream 0:同步,1:流式 (默认,协议:SSE) + // 同步请求超时:60s,如果内容较长建议使用流式 + Stream int `json:"stream"` + // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 + // 输入 content 总数最大支持 3000 token。 + Messages []TencentMessage `json:"messages"` +} + +type TencentUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type TencentResponseChoices struct { + FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 + Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 + Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 +} + +type TencentChatResponse struct { + Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 + Created string `json:"created,omitempty"` // unix 时间戳的字符串 + Id string `json:"id,omitempty"` // 会话 id + Usage *types.Usage `json:"usage,omitempty"` // token 数量 + Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 + Note string `json:"note,omitempty"` // 注释 + ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 +} diff --git a/providers/xunfei_base.go b/providers/xunfei/base.go similarity index 96% rename from providers/xunfei_base.go rename to providers/xunfei/base.go index 7d2c4083..c8b37a94 100644 --- a/providers/xunfei_base.go +++ b/providers/xunfei/base.go @@ -1,4 +1,4 @@ -package providers +package xunfei import ( "crypto/hmac" @@ -7,6 +7,7 @@ import ( "fmt" "net/url" "one-api/common" + "one-api/providers/base" "strings" "time" @@ -15,7 +16,7 @@ import ( // https://www.xfyun.cn/doc/spark/Web.html type XunfeiProvider struct { - ProviderConfig + base.BaseProvider domain string apiId string } @@ -23,7 +24,7 @@ type XunfeiProvider struct { // 创建 XunfeiProvider func CreateXunfeiProvider(c *gin.Context) *XunfeiProvider { return &XunfeiProvider{ - ProviderConfig: ProviderConfig{ + BaseProvider: base.BaseProvider{ BaseURL: "wss://spark-api.xf-yun.com", ChatCompletions: "", Context: c, diff --git a/providers/xunfei_chat.go b/providers/xunfei/chat.go similarity index 73% rename from providers/xunfei_chat.go rename to providers/xunfei/chat.go index ffec9097..445bcc71 100644 --- a/providers/xunfei_chat.go +++ b/providers/xunfei/chat.go @@ -1,73 +1,18 @@ -package providers +package xunfei import ( "encoding/json" "io" "net/http" "one-api/common" + "one-api/providers/base" "one-api/types" "time" "github.com/gorilla/websocket" ) -type XunfeiMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type XunfeiChatRequest struct { - Header struct { - AppId string `json:"app_id"` - } `json:"header"` - Parameter struct { - Chat struct { - Domain string `json:"domain,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Auditing bool `json:"auditing,omitempty"` - } `json:"chat"` - } `json:"parameter"` - Payload struct { - Message struct { - Text []XunfeiMessage `json:"text"` - } `json:"message"` - } `json:"payload"` -} - -type XunfeiChatResponseTextItem struct { - Content string `json:"content"` - Role string `json:"role"` - Index int `json:"index"` -} - -type XunfeiChatResponse struct { - Header struct { - Code int `json:"code"` - Message string `json:"message"` - Sid string `json:"sid"` - Status int `json:"status"` - } `json:"header"` - Payload struct { - Choices struct { - Status int `json:"status"` - Seq int `json:"seq"` - Text []XunfeiChatResponseTextItem `json:"text"` - } `json:"choices"` - Usage struct { - //Text struct { - // QuestionTokens string `json:"question_tokens"` - // PromptTokens string `json:"prompt_tokens"` - // CompletionTokens string `json:"completion_tokens"` - // TotalTokens string `json:"total_tokens"` - //} `json:"text"` - Text types.Usage `json:"text"` - } `json:"usage"` - } `json:"payload"` -} - -func (p *XunfeiProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *XunfeiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model) if request.Stream { @@ -77,7 +22,7 @@ func (p *XunfeiProvider) ChatCompleteResponse(request *types.ChatCompletionReque } } -func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { usage = &types.Usage{} dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl) if err != nil { @@ -113,13 +58,13 @@ func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authU return usage, nil } -func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { usage = &types.Usage{} dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl) if err != nil { return nil, types.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError) } - setEventStreamHeaders(p.Context) + common.SetEventStreamHeaders(p.Context) p.Context.Stream(func(w io.Writer) bool { select { case xunfeiResponse := <-dataChan: @@ -185,7 +130,7 @@ func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *ty Role: "assistant", Content: response.Payload.Choices.Text[0].Content, }, - FinishReason: stopFinishReason, + FinishReason: base.StopFinishReason, } fullTextResponse := types.ChatCompletionResponse{ Object: "chat.completion", @@ -251,7 +196,7 @@ func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatR var choice types.ChatCompletionStreamChoice choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content if xunfeiResponse.Payload.Choices.Status == 2 { - choice.FinishReason = &stopFinishReason + choice.FinishReason = &base.StopFinishReason } response := types.ChatCompletionStreamResponse{ Object: "chat.completion.chunk", diff --git a/providers/xunfei/type.go b/providers/xunfei/type.go new file mode 100644 index 00000000..23fdef6a --- /dev/null +++ b/providers/xunfei/type.go @@ -0,0 +1,59 @@ +package xunfei + +import "one-api/types" + +type XunfeiMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type XunfeiChatRequest struct { + Header struct { + AppId string `json:"app_id"` + } `json:"header"` + Parameter struct { + Chat struct { + Domain string `json:"domain,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Auditing bool `json:"auditing,omitempty"` + } `json:"chat"` + } `json:"parameter"` + Payload struct { + Message struct { + Text []XunfeiMessage `json:"text"` + } `json:"message"` + } `json:"payload"` +} + +type XunfeiChatResponseTextItem struct { + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` +} + +type XunfeiChatResponse struct { + Header struct { + Code int `json:"code"` + Message string `json:"message"` + Sid string `json:"sid"` + Status int `json:"status"` + } `json:"header"` + Payload struct { + Choices struct { + Status int `json:"status"` + Seq int `json:"seq"` + Text []XunfeiChatResponseTextItem `json:"text"` + } `json:"choices"` + Usage struct { + //Text struct { + // QuestionTokens string `json:"question_tokens"` + // PromptTokens string `json:"prompt_tokens"` + // CompletionTokens string `json:"completion_tokens"` + // TotalTokens string `json:"total_tokens"` + //} `json:"text"` + Text types.Usage `json:"text"` + } `json:"usage"` + } `json:"payload"` +} diff --git a/providers/zhipu_base.go b/providers/zhipu/base.go similarity index 84% rename from providers/zhipu_base.go rename to providers/zhipu/base.go index 70eb4288..1b59e1a3 100644 --- a/providers/zhipu_base.go +++ b/providers/zhipu/base.go @@ -1,8 +1,9 @@ -package providers +package zhipu import ( "fmt" "one-api/common" + "one-api/providers/base" "strings" "sync" "time" @@ -15,18 +16,13 @@ var zhipuTokens sync.Map var expSeconds int64 = 24 * 3600 type ZhipuProvider struct { - ProviderConfig -} - -type zhipuTokenData struct { - Token string - ExpiryTime time.Time + base.BaseProvider } // 创建 ZhipuProvider func CreateZhipuProvider(c *gin.Context) *ZhipuProvider { return &ZhipuProvider{ - ProviderConfig: ProviderConfig{ + BaseProvider: base.BaseProvider{ BaseURL: "https://open.bigmodel.cn", ChatCompletions: "/api/paas/v3/model-api", Context: c, @@ -37,13 +33,8 @@ func CreateZhipuProvider(c *gin.Context) *ZhipuProvider { // 获取请求头 func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) - + p.CommonRequestHeaders(headers) headers["Authorization"] = p.getZhipuToken() - headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") - headers["Accept"] = p.Context.Request.Header.Get("Accept") - if headers["Content-Type"] == "" { - headers["Content-Type"] = "application/json" - } return headers } diff --git a/providers/zhipu_chat.go b/providers/zhipu/chat.go similarity index 77% rename from providers/zhipu_chat.go rename to providers/zhipu/chat.go index 4e7f1711..c4d24509 100644 --- a/providers/zhipu_chat.go +++ b/providers/zhipu/chat.go @@ -1,4 +1,4 @@ -package providers +package zhipu import ( "bufio" @@ -6,46 +6,12 @@ import ( "io" "net/http" "one-api/common" + "one-api/providers/base" "one-api/types" "strings" ) -type ZhipuMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type ZhipuRequest struct { - Prompt []ZhipuMessage `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - RequestId string `json:"request_id,omitempty"` - Incremental bool `json:"incremental,omitempty"` -} - -type ZhipuResponseData struct { - TaskId string `json:"task_id"` - RequestId string `json:"request_id"` - TaskStatus string `json:"task_status"` - Choices []ZhipuMessage `json:"choices"` - types.Usage `json:"usage"` -} - -type ZhipuResponse struct { - Code int `json:"code"` - Msg string `json:"msg"` - Success bool `json:"success"` - Data ZhipuResponseData `json:"data"` -} - -type ZhipuStreamMetaResponse struct { - RequestId string `json:"request_id"` - TaskId string `json:"task_id"` - TaskStatus string `json:"task_status"` - types.Usage `json:"usage"` -} - -func (zhipuResponse *ZhipuResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { if !zhipuResponse.Success { return &types.OpenAIErrorWithStatusCode{ OpenAIError: types.OpenAIError{ @@ -110,7 +76,7 @@ func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest) } } -func (p *ZhipuProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { +func (p *ZhipuProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { requestBody := p.getChatRequestBody(request) fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) headers := p.GetRequestHeaders() @@ -128,15 +94,15 @@ func (p *ZhipuProvider) ChatCompleteResponse(request *types.ChatCompletionReques } if request.Stream { - openAIErrorWithStatusCode, usage = p.sendStreamRequest(req) - if openAIErrorWithStatusCode != nil { + errWithCode, usage = p.sendStreamRequest(req) + if errWithCode != nil { return } } else { zhipuResponse := &ZhipuResponse{} - openAIErrorWithStatusCode = p.sendRequest(req, zhipuResponse) - if openAIErrorWithStatusCode != nil { + errWithCode = p.SendRequest(req, zhipuResponse) + if errWithCode != nil { return } @@ -161,7 +127,7 @@ func (p *ZhipuProvider) streamResponseZhipu2OpenAI(zhipuResponse string) *types. func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*types.ChatCompletionStreamResponse, *types.Usage) { var choice types.ChatCompletionStreamChoice choice.Delta.Content = "" - choice.FinishReason = &stopFinishReason + choice.FinishReason = &base.StopFinishReason response := types.ChatCompletionStreamResponse{ ID: zhipuResponse.RequestId, Object: "chat.completion.chunk", @@ -180,7 +146,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError } if common.IsFailureStatusCode(resp) { - return p.handleErrorResp(resp), nil + return p.HandleErrorResp(resp), nil } defer resp.Body.Close() @@ -222,7 +188,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError } stopChan <- true }() - setEventStreamHeaders(p.Context) + common.SetEventStreamHeaders(p.Context) p.Context.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/providers/zhipu/type.go b/providers/zhipu/type.go new file mode 100644 index 00000000..5a5942e7 --- /dev/null +++ b/providers/zhipu/type.go @@ -0,0 +1,46 @@ +package zhipu + +import ( + "one-api/types" + "time" +) + +type ZhipuMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ZhipuRequest struct { + Prompt []ZhipuMessage `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + RequestId string `json:"request_id,omitempty"` + Incremental bool `json:"incremental,omitempty"` +} + +type ZhipuResponseData struct { + TaskId string `json:"task_id"` + RequestId string `json:"request_id"` + TaskStatus string `json:"task_status"` + Choices []ZhipuMessage `json:"choices"` + types.Usage `json:"usage"` +} + +type ZhipuResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Success bool `json:"success"` + Data ZhipuResponseData `json:"data"` +} + +type ZhipuStreamMetaResponse struct { + RequestId string `json:"request_id"` + TaskId string `json:"task_id"` + TaskStatus string `json:"task_status"` + types.Usage `json:"usage"` +} + +type zhipuTokenData struct { + Token string + ExpiryTime time.Time +}