🎨 调整供应商目录结构,合并文本输出函数

This commit is contained in:
Martial BE 2023-11-29 16:07:09 +08:00
parent 902c2faa2c
commit 544f20cc73
No known key found for this signature in database
GPG Key ID: D06C32DF0EDB9084
51 changed files with 1062 additions and 1146 deletions

View File

@ -6,6 +6,8 @@ import (
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
)
var HttpClient *http.Client
@ -124,3 +126,11 @@ func DecodeString(body io.Reader, output *string) error {
*output = string(b)
return nil
}
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}

View File

@ -70,7 +70,7 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
}
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
_, openAIErrorWithStatusCode := chatProvider.ChatCompleteResponse(&request, isModelMapped, promptTokens)
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, isModelMapped, promptTokens)
if openAIErrorWithStatusCode != nil {
return nil, &openAIErrorWithStatusCode.OpenAIError
}

View File

@ -2,6 +2,7 @@ package controller
import (
"fmt"
"one-api/types"
"github.com/gin-gonic/gin"
)
@ -541,7 +542,7 @@ func RetrieveModel(c *gin.Context) {
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
} else {
openAIError := OpenAIError{
openAIError := types.OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",

View File

@ -1,127 +0,0 @@
package controller
import (
"context"
"errors"
"net/http"
"one-api/common"
"one-api/model"
"one-api/providers"
"one-api/types"
"github.com/gin-gonic/gin"
)
func relayChatHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode {
// 获取请求参数
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
// consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
// 获取 Provider
chatProvider := GetChatProvider(channelType, c)
if chatProvider == nil {
return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented)
}
// 获取请求体
var chatRequest types.ChatCompletionRequest
err := common.UnmarshalBodyReusable(c, &chatRequest)
if err != nil {
return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
// 检查模型映射
isModelMapped := false
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
if err != nil {
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap != nil && modelMap[chatRequest.Model] != "" {
chatRequest.Model = modelMap[chatRequest.Model]
isModelMapped = true
}
// 开始计算Tokens
var promptTokens int
promptTokens = common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
// 计算预付费配额
quotaInfo := &QuotaInfo{
modelName: chatRequest.Model,
promptTokens: promptTokens,
userId: userId,
channelId: channelId,
tokenId: tokenId,
}
quotaInfo.initQuotaInfo(group)
quota_err := quotaInfo.preQuotaConsumption()
if quota_err != nil {
return quota_err
}
usage, openAIErrorWithStatusCode := chatProvider.ChatCompleteResponse(&chatRequest, isModelMapped, promptTokens)
if openAIErrorWithStatusCode != nil {
if quotaInfo.preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
return openAIErrorWithStatusCode
}
tokenName := c.GetString("token_name")
defer func(ctx context.Context) {
go func() {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}()
}(c.Request.Context())
return nil
}
func GetChatProvider(channelType int, c *gin.Context) providers.ChatProviderAction {
switch channelType {
case common.ChannelTypeOpenAI:
return providers.CreateOpenAIProvider(c, "")
case common.ChannelTypeAzure:
return providers.CreateAzureProvider(c)
case common.ChannelTypeAli:
return providers.CreateAliAIProvider(c)
case common.ChannelTypeTencent:
return providers.CreateTencentProvider(c)
case common.ChannelTypeBaidu:
return providers.CreateBaiduProvider(c)
case common.ChannelTypeAnthropic:
return providers.CreateClaudeProvider(c)
case common.ChannelTypePaLM:
return providers.CreatePalmProvider(c)
case common.ChannelTypeZhipu:
return providers.CreateZhipuProvider(c)
case common.ChannelTypeXunfei:
return providers.CreateXunfeiProvider(c)
}
baseURL := common.ChannelBaseURLs[channelType]
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
if baseURL != "" {
return providers.CreateOpenAIProvider(c, baseURL)
}
return nil
}

View File

@ -1,113 +0,0 @@
package controller
import (
"context"
"errors"
"net/http"
"one-api/common"
"one-api/model"
"one-api/providers"
"one-api/types"
"github.com/gin-gonic/gin"
)
func relayCompletionHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode {
// 获取请求参数
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
// consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
// 获取 Provider
completionProvider := GetCompletionProvider(channelType, c)
if completionProvider == nil {
return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented)
}
// 获取请求体
var completionRequest types.CompletionRequest
err := common.UnmarshalBodyReusable(c, &completionRequest)
if err != nil {
return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
// 检查模型映射
isModelMapped := false
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
if err != nil {
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap != nil && modelMap[completionRequest.Model] != "" {
completionRequest.Model = modelMap[completionRequest.Model]
isModelMapped = true
}
// 开始计算Tokens
var promptTokens int
promptTokens = common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
// 计算预付费配额
quotaInfo := &QuotaInfo{
modelName: completionRequest.Model,
promptTokens: promptTokens,
userId: userId,
channelId: channelId,
tokenId: tokenId,
}
quotaInfo.initQuotaInfo(group)
quota_err := quotaInfo.preQuotaConsumption()
if quota_err != nil {
return quota_err
}
usage, openAIErrorWithStatusCode := completionProvider.CompleteResponse(&completionRequest, isModelMapped, promptTokens)
if openAIErrorWithStatusCode != nil {
if quotaInfo.preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
return openAIErrorWithStatusCode
}
tokenName := c.GetString("token_name")
defer func(ctx context.Context) {
go func() {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}()
}(c.Request.Context())
return nil
}
func GetCompletionProvider(channelType int, c *gin.Context) providers.CompletionProviderAction {
switch channelType {
case common.ChannelTypeOpenAI:
return providers.CreateOpenAIProvider(c, "")
case common.ChannelTypeAzure:
return providers.CreateAzureProvider(c)
}
baseURL := common.ChannelBaseURLs[channelType]
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
if baseURL != "" {
return providers.CreateOpenAIProvider(c, baseURL)
}
return nil
}

View File

@ -1,117 +0,0 @@
package controller
import (
"context"
"errors"
"net/http"
"one-api/common"
"one-api/model"
"one-api/providers"
"one-api/types"
"github.com/gin-gonic/gin"
)
func relayEmbeddingsHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode {
// 获取请求参数
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
// consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
// 获取 Provider
embeddingsProvider := GetEmbeddingsProvider(channelType, c)
if embeddingsProvider == nil {
return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented)
}
// 获取请求体
var embeddingsRequest types.EmbeddingRequest
err := common.UnmarshalBodyReusable(c, &embeddingsRequest)
if err != nil {
return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
// 检查模型映射
isModelMapped := false
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
if err != nil {
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
isModelMapped = true
}
// 开始计算Tokens
var promptTokens int
promptTokens = common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
// 计算预付费配额
quotaInfo := &QuotaInfo{
modelName: embeddingsRequest.Model,
promptTokens: promptTokens,
userId: userId,
channelId: channelId,
tokenId: tokenId,
}
quotaInfo.initQuotaInfo(group)
quota_err := quotaInfo.preQuotaConsumption()
if quota_err != nil {
return quota_err
}
usage, openAIErrorWithStatusCode := embeddingsProvider.EmbeddingsResponse(&embeddingsRequest, isModelMapped, promptTokens)
if openAIErrorWithStatusCode != nil {
if quotaInfo.preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
return openAIErrorWithStatusCode
}
tokenName := c.GetString("token_name")
defer func(ctx context.Context) {
go func() {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}()
}(c.Request.Context())
return nil
}
func GetEmbeddingsProvider(channelType int, c *gin.Context) providers.EmbeddingsProviderAction {
switch channelType {
case common.ChannelTypeOpenAI:
return providers.CreateOpenAIProvider(c, "")
case common.ChannelTypeAzure:
return providers.CreateAzureProvider(c)
case common.ChannelTypeAli:
return providers.CreateAliAIProvider(c)
case common.ChannelTypeBaidu:
return providers.CreateBaiduProvider(c)
}
baseURL := common.ChannelBaseURLs[channelType]
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
if baseURL != "" {
return providers.CreateOpenAIProvider(c, baseURL)
}
return nil
}

160
controller/relay-text.go Normal file
View File

@ -0,0 +1,160 @@
package controller
import (
"context"
"errors"
"net/http"
"one-api/common"
"one-api/model"
"one-api/providers"
providers_base "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
func relayTextHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode {
// 获取请求参数
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
// 获取 Provider
provider := providers.GetProvider(channelType, c)
if provider == nil {
return types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
if err != nil {
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
var promptTokens int
quotaInfo := &QuotaInfo{
modelName: "",
promptTokens: promptTokens,
userId: userId,
channelId: channelId,
tokenId: tokenId,
}
var usage *types.Usage
var openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode
switch relayMode {
case RelayModeChatCompletions:
usage, openAIErrorWithStatusCode = handleChatCompletions(c, provider, modelMap, quotaInfo, group)
case RelayModeCompletions:
usage, openAIErrorWithStatusCode = handleCompletions(c, provider, modelMap, quotaInfo, group)
case RelayModeEmbeddings:
usage, openAIErrorWithStatusCode = handleEmbeddings(c, provider, modelMap, quotaInfo, group)
default:
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
}
if openAIErrorWithStatusCode != nil {
if quotaInfo.preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
return openAIErrorWithStatusCode
}
tokenName := c.GetString("token_name")
defer func(ctx context.Context) {
go func() {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}()
}(c.Request.Context())
return nil
}
func handleChatCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
var chatRequest types.ChatCompletionRequest
isModelMapped := false
chatProvider, ok := provider.(providers_base.ChatInterface)
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
err := common.UnmarshalBodyReusable(c, &chatRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if modelMap != nil && modelMap[chatRequest.Model] != "" {
chatRequest.Model = modelMap[chatRequest.Model]
isModelMapped = true
}
promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
quotaInfo.modelName = chatRequest.Model
quotaInfo.initQuotaInfo(group)
quota_err := quotaInfo.preQuotaConsumption()
if quota_err != nil {
return nil, quota_err
}
return chatProvider.ChatAction(&chatRequest, isModelMapped, promptTokens)
}
func handleCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
var completionRequest types.CompletionRequest
isModelMapped := false
completionProvider, ok := provider.(providers_base.CompletionInterface)
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
err := common.UnmarshalBodyReusable(c, &completionRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if modelMap != nil && modelMap[completionRequest.Model] != "" {
completionRequest.Model = modelMap[completionRequest.Model]
isModelMapped = true
}
promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
quotaInfo.modelName = completionRequest.Model
quotaInfo.initQuotaInfo(group)
quota_err := quotaInfo.preQuotaConsumption()
if quota_err != nil {
return nil, quota_err
}
return completionProvider.CompleteAction(&completionRequest, isModelMapped, promptTokens)
}
func handleEmbeddings(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
var embeddingsRequest types.EmbeddingRequest
isModelMapped := false
embeddingsProvider, ok := provider.(providers_base.EmbeddingsInterface)
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
err := common.UnmarshalBodyReusable(c, &embeddingsRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
isModelMapped = true
}
promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
quotaInfo.modelName = embeddingsRequest.Model
quotaInfo.initQuotaInfo(group)
quota_err := quotaInfo.preQuotaConsumption()
if quota_err != nil {
return nil, quota_err
}
return embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens)
}

View File

@ -237,19 +237,19 @@ type CompletionsStreamResponse struct {
func Relay(c *gin.Context) {
var err *types.OpenAIErrorWithStatusCode
// relayMode := RelayModeUnknown
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
err = relayChatHelper(c)
// relayMode = RelayModeChatCompletions
// err = relayChatHelper(c)
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
err = relayCompletionHelper(c)
// relayMode = RelayModeCompletions
// err = relayCompletionHelper(c)
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
err = relayEmbeddingsHelper(c)
// err = relayEmbeddingsHelper(c)
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
relayMode = RelayModeEmbeddings
}
// relayMode = RelayModeEmbeddings
// } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
// relayMode = RelayModeEmbeddings
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
// relayMode = RelayModeModerations
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
@ -263,7 +263,7 @@ func Relay(c *gin.Context) {
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
// relayMode = RelayModeAudioTranslation
// }
// switch relayMode {
switch relayMode {
// case RelayModeImagesGenerations:
// err = relayImageHelper(c, relayMode)
// case RelayModeAudioSpeech:
@ -272,9 +272,9 @@ func Relay(c *gin.Context) {
// fallthrough
// case RelayModeAudioTranscription:
// err = relayAudioHelper(c, relayMode)
// default:
// err = relayTextHelper(c, relayMode)
// }
default:
err = relayTextHelper(c, relayMode)
}
if err != nil {
requestId := c.GetString(common.RequestIdKey)
retryTimesStr := c.Query("retry")

35
providers/ali/base.go Normal file
View File

@ -0,0 +1,35 @@
package ali
import (
"fmt"
"one-api/providers/base"
"github.com/gin-gonic/gin"
)
type AliProvider struct {
base.BaseProvider
}
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
// 创建 AliAIProvider
func CreateAliAIProvider(c *gin.Context) *AliProvider {
return &AliProvider{
BaseProvider: base.BaseProvider{
BaseURL: "https://dashscope.aliyuncs.com",
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
Context: c,
},
}
}
// 获取请求头
func (p *AliProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
return headers
}

View File

@ -1,4 +1,4 @@
package providers
package ali
import (
"bufio"
@ -10,43 +10,10 @@ import (
"strings"
)
type AliMessage struct {
User string `json:"user"`
Bot string `json:"bot"`
}
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
// 阿里云响应处理
func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if aliResponse.Code != "" {
return nil, &types.OpenAIErrorWithStatusCode{
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
@ -55,6 +22,8 @@ func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIR
},
StatusCode: resp.StatusCode,
}
return
}
choice := types.ChatCompletionChoice{
@ -66,7 +35,7 @@ func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIR
FinishReason: aliResponse.Output.FinishReason,
}
fullTextResponse := types.ChatCompletionResponse{
OpenAIResponse = types.ChatCompletionResponse{
ID: aliResponse.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
@ -78,10 +47,11 @@ func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIR
},
}
return fullTextResponse, nil
return
}
func (p *AliAIProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
// 获取聊天请求体
func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
prompt := ""
for i := 0; i < len(request.Messages); i++ {
@ -113,7 +83,8 @@ func (p *AliAIProvider) getChatRequestBody(request *types.ChatCompletionRequest)
}
}
func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
// 聊天
func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
@ -130,8 +101,8 @@ func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionReques
}
if request.Stream {
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
if openAIErrorWithStatusCode != nil {
usage, errWithCode = p.sendStreamRequest(req)
if errWithCode != nil {
return
}
@ -145,8 +116,8 @@ func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionReques
} else {
aliResponse := &AliChatResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, aliResponse)
if openAIErrorWithStatusCode != nil {
errWithCode = p.SendRequest(req, aliResponse)
if errWithCode != nil {
return
}
@ -159,7 +130,8 @@ func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionReques
return
}
func (p *AliAIProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
// 阿里云响应转OpenAI响应
func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
@ -177,16 +149,17 @@ func (p *AliAIProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *
return &response
}
func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) {
// 发送流请求
func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
return nil, types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), nil
return nil, p.HandleErrorResp(resp)
}
defer resp.Body.Close()
@ -220,7 +193,7 @@ func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta
}
stopChan <- true
}()
setEventStreamHeaders(p.Context)
common.SetEventStreamHeaders(p.Context)
lastResponseText := ""
p.Context.Stream(func(w io.Writer) bool {
select {
@ -252,5 +225,5 @@ func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta
}
})
return nil, usage
return
}

View File

@ -1,4 +1,4 @@
package providers
package ali
import (
"net/http"
@ -6,30 +6,8 @@ import (
"one-api/types"
)
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
func (aliResponse *AliEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
// 嵌入请求处理
func (aliResponse *AliEmbeddingResponse) ResponseHandler(resp *http.Response) (any, *types.OpenAIErrorWithStatusCode) {
if aliResponse.Code != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
@ -60,7 +38,8 @@ func (aliResponse *AliEmbeddingResponse) requestHandler(resp *http.Response) (Op
return openAIEmbeddingResponse, nil
}
func (p *AliAIProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest {
// 获取嵌入请求体
func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
@ -71,7 +50,7 @@ func (p *AliAIProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest
}
}
func (p *AliAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *AliProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getEmbeddingsRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
@ -84,8 +63,8 @@ func (p *AliAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isMo
}
aliEmbeddingResponse := &AliEmbeddingResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, aliEmbeddingResponse)
if openAIErrorWithStatusCode != nil {
errWithCode = p.SendRequest(req, aliEmbeddingResponse)
if errWithCode != nil {
return
}
usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens}

70
providers/ali/type.go Normal file
View File

@ -0,0 +1,70 @@
package ali
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type AliMessage struct {
User string `json:"user"`
Bot string `json:"bot"`
}
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}

View File

@ -1,50 +0,0 @@
package providers
import (
"fmt"
"github.com/gin-gonic/gin"
)
type AliAIProvider struct {
ProviderConfig
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
// 创建 AliAIProvider
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
func CreateAliAIProvider(c *gin.Context) *AliAIProvider {
return &AliAIProvider{
ProviderConfig: ProviderConfig{
BaseURL: "https://dashscope.aliyuncs.com",
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
Context: c,
},
}
}
// 获取请求头
func (p *AliAIProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json"
}
return headers
}

18
providers/api2d/base.go Normal file
View File

@ -0,0 +1,18 @@
package api2d
import (
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type Api2dProvider struct {
*openai.OpenAIProvider
}
// 创建 Api2dProvider
func CreateApi2dProvider(c *gin.Context) *Api2dProvider {
return &Api2dProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://oa.api2d.net"),
}
}

View File

@ -1,14 +0,0 @@
package providers
import "github.com/gin-gonic/gin"
type Api2dProvider struct {
*OpenAIProvider
}
// 创建 OpenAIProvider
func CreateApi2dProvider(c *gin.Context) *Api2dProvider {
return &Api2dProvider{
OpenAIProvider: CreateOpenAIProvider(c, "https://oa.api2d.net"),
}
}

31
providers/azure/base.go Normal file
View File

@ -0,0 +1,31 @@
package azure
import (
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type AzureProvider struct {
openai.OpenAIProvider
}
// 创建 OpenAIProvider
func CreateAzureProvider(c *gin.Context) *AzureProvider {
return &AzureProvider{
OpenAIProvider: openai.OpenAIProvider{
BaseProvider: base.BaseProvider{
BaseURL: "",
Completions: "/completions",
ChatCompletions: "/chat/completions",
Embeddings: "/embeddings",
AudioSpeech: "/audio/speech",
AudioTranscriptions: "/audio/transcriptions",
AudioTranslations: "/audio/translations",
Context: c,
},
IsAzure: true,
},
}
}

View File

@ -1,41 +0,0 @@
package providers
import (
"github.com/gin-gonic/gin"
)
type AzureProvider struct {
OpenAIProvider
}
// 创建 OpenAIProvider
func CreateAzureProvider(c *gin.Context) *AzureProvider {
return &AzureProvider{
OpenAIProvider: OpenAIProvider{
ProviderConfig: ProviderConfig{
BaseURL: "",
Completions: "/completions",
ChatCompletions: "/chat/completions",
Embeddings: "/embeddings",
AudioSpeech: "/audio/speech",
AudioTranscriptions: "/audio/transcriptions",
AudioTranslations: "/audio/translations",
Context: c,
},
isAzure: true,
},
}
}
// // 获取完整请求 URL
// func (p *AzureProvider) GetFullRequestURL(requestURL string, modelName string) string {
// apiVersion := p.Context.GetString("api_version")
// requestURL = fmt.Sprintf("/openai/deployments/%s/%s?api-version=%s", modelName, requestURL, apiVersion)
// baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
// if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
// requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
// }
// return fmt.Sprintf("%s%s", baseURL, requestURL)
// }

View File

@ -1,10 +1,11 @@
package providers
package baidu
import (
"encoding/json"
"errors"
"fmt"
"one-api/common"
"one-api/providers/base"
"strings"
"sync"
"time"
@ -15,20 +16,12 @@ import (
var baiduTokenStore sync.Map
type BaiduProvider struct {
ProviderConfig
}
type BaiduAccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
base.BaseProvider
}
func CreateBaiduProvider(c *gin.Context) *BaiduProvider {
return &BaiduProvider{
ProviderConfig: ProviderConfig{
BaseProvider: base.BaseProvider{
BaseURL: "https://aip.baidubce.com",
ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings",
@ -59,12 +52,7 @@ func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) s
// 获取请求头
func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json"
}
p.CommonRequestHeaders(headers)
return headers
}

View File

@ -1,4 +1,4 @@
package providers
package baidu
import (
"bufio"
@ -6,33 +6,12 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/providers/base"
"one-api/types"
"strings"
)
type BaiduMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
}
type BaiduChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage *types.Usage `json:"usage"`
BaiduError
}
func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (baiduResponse *BaiduChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if baiduResponse.ErrorMsg != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
@ -54,7 +33,7 @@ func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (Ope
FinishReason: "stop",
}
fullTextResponse := types.ChatCompletionResponse{
OpenAIResponse = types.ChatCompletionResponse{
ID: baiduResponse.Id,
Object: "chat.completion",
Created: baiduResponse.Created,
@ -62,18 +41,7 @@ func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (Ope
Usage: baiduResponse.Usage,
}
return fullTextResponse, nil
}
type BaiduChatStreamResponse struct {
BaiduChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type BaiduError struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
return
}
func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaiduChatRequest {
@ -101,7 +69,7 @@ func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest)
}
}
func (p *BaiduProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
if fullRequestURL == "" {
@ -120,15 +88,15 @@ func (p *BaiduProvider) ChatCompleteResponse(request *types.ChatCompletionReques
}
if request.Stream {
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
if openAIErrorWithStatusCode != nil {
usage, errWithCode = p.sendStreamRequest(req)
if errWithCode != nil {
return
}
} else {
baiduChatRequest := &BaiduChatResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, baiduChatRequest)
if openAIErrorWithStatusCode != nil {
errWithCode = p.SendRequest(req, baiduChatRequest)
if errWithCode != nil {
return
}
@ -142,7 +110,7 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd {
choice.FinishReason = &stopFinishReason
choice.FinishReason = &base.StopFinishReason
}
response := types.ChatCompletionStreamResponse{
@ -155,16 +123,16 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
return &response
}
func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) {
func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
return nil, types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), nil
return nil, p.HandleErrorResp(resp)
}
defer resp.Body.Close()
@ -195,7 +163,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta
}
stopChan <- true
}()
setEventStreamHeaders(p.Context)
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@ -224,5 +192,5 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta
}
})
return nil, usage
return usage, nil
}

View File

@ -1,4 +1,4 @@
package providers
package baidu
import (
"net/http"
@ -6,32 +6,13 @@ import (
"one-api/types"
)
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage types.Usage `json:"usage"`
BaiduError
}
func (p *BaiduProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *BaiduEmbeddingRequest {
return &BaiduEmbeddingRequest{
Input: request.ParseInput(),
}
}
func (baiduResponse *BaiduEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (baiduResponse *BaiduEmbeddingResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if baiduResponse.ErrorMsg != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
@ -62,7 +43,7 @@ func (baiduResponse *BaiduEmbeddingResponse) requestHandler(resp *http.Response)
return openAIEmbeddingResponse, nil
}
func (p *BaiduProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *BaiduProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getEmbeddingsRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
@ -78,8 +59,8 @@ func (p *BaiduProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isMo
}
baiduEmbeddingResponse := &BaiduEmbeddingResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, baiduEmbeddingResponse)
if openAIErrorWithStatusCode != nil {
errWithCode = p.SendRequest(req, baiduEmbeddingResponse)
if errWithCode != nil {
return
}
usage = &baiduEmbeddingResponse.Usage

66
providers/baidu/type.go Normal file
View File

@ -0,0 +1,66 @@
package baidu
import (
"one-api/types"
"time"
)
type BaiduAccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}
type BaiduMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
}
type BaiduChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage *types.Usage `json:"usage"`
BaiduError
}
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage types.Usage `json:"usage"`
BaiduError
}
type BaiduChatStreamResponse struct {
BaiduChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type BaiduError struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}

View File

@ -1,4 +1,4 @@
package providers
package base
import (
"encoding/json"
@ -6,7 +6,6 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/model"
"one-api/types"
"strconv"
"strings"
@ -14,9 +13,9 @@ import (
"github.com/gin-gonic/gin"
)
var stopFinishReason = "stop"
var StopFinishReason = "stop"
type ProviderConfig struct {
type BaseProvider struct {
BaseURL string
Completions string
ChatCompletions string
@ -28,32 +27,8 @@ type ProviderConfig struct {
Context *gin.Context
}
type BaseProviderAction interface {
GetBaseURL() string
GetFullRequestURL(requestURL string, modelName string) string
GetRequestHeaders() (headers map[string]string)
}
type CompletionProviderAction interface {
BaseProviderAction
CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
}
type ChatProviderAction interface {
BaseProviderAction
ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
}
type EmbeddingsProviderAction interface {
BaseProviderAction
EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
}
type BalanceProviderAction interface {
Balance(channel *model.Channel) (float64, error)
}
func (p *ProviderConfig) GetBaseURL() string {
// 获取基础URL
func (p *BaseProvider) GetBaseURL() string {
if p.Context.GetString("base_url") != "" {
return p.Context.GetString("base_url")
}
@ -61,21 +36,66 @@ func (p *ProviderConfig) GetBaseURL() string {
return p.BaseURL
}
func (p *ProviderConfig) GetFullRequestURL(requestURL string, modelName string) string {
// 获取完整请求URL
func (p *BaseProvider) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
return fmt.Sprintf("%s%s", baseURL, requestURL)
}
func setEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
// 获取请求头
func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) {
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json"
}
}
func (p *ProviderConfig) handleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
// 发送请求
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
defer resp.Body.Close()
// 处理响应
if common.IsFailureStatusCode(resp) {
return p.HandleErrorResp(resp)
}
// 解析响应
err = common.DecodeResponse(resp.Body, response)
if err != nil {
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
openAIResponse, openAIErrorWithStatusCode := response.ResponseHandler(resp)
if openAIErrorWithStatusCode != nil {
return
}
jsonResponse, err := json.Marshal(openAIResponse)
if err != nil {
return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
p.Context.Writer.Header().Set("Content-Type", "application/json")
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err = p.Context.Writer.Write(jsonResponse)
if err != nil {
return types.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
return nil
}
// 处理错误响应
func (p *BaseProvider) HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: types.OpenAIError{
@ -105,46 +125,3 @@ func (p *ProviderConfig) handleErrorResp(resp *http.Response) (openAIErrorWithSt
}
return
}
// 供应商响应处理函数
type ProviderResponseHandler interface {
// 请求处理函数
requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
}
// 发送请求
func (p *ProviderConfig) sendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
defer resp.Body.Close()
// 处理响应
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp)
}
// 解析响应
err = common.DecodeResponse(resp.Body, response)
if err != nil {
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
openAIResponse, openAIErrorWithStatusCode := response.requestHandler(resp)
if openAIErrorWithStatusCode != nil {
return
}
jsonResponse, err := json.Marshal(openAIResponse)
if err != nil {
return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
p.Context.Writer.Header().Set("Content-Type", "application/json")
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err = p.Context.Writer.Write(jsonResponse)
return nil
}

View File

@ -0,0 +1,42 @@
package base
import (
"net/http"
"one-api/model"
"one-api/types"
)
// 基础接口
type ProviderInterface interface {
GetBaseURL() string
GetFullRequestURL(requestURL string, modelName string) string
GetRequestHeaders() (headers map[string]string)
}
// 完成接口
type CompletionInterface interface {
ProviderInterface
CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
}
// 聊天接口
type ChatInterface interface {
ProviderInterface
ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
}
// 嵌入接口
type EmbeddingsInterface interface {
ProviderInterface
EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
}
// 余额接口
type BalanceInterface interface {
BalanceAction(channel *model.Channel) (float64, error)
}
type ProviderResponseHandler interface {
// 响应处理函数
ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode)
}

View File

@ -1,21 +1,18 @@
package providers
package claude
import (
"one-api/providers/base"
"github.com/gin-gonic/gin"
)
type ClaudeProvider struct {
ProviderConfig
}
type ClaudeError struct {
Type string `json:"type"`
Message string `json:"message"`
base.BaseProvider
}
func CreateClaudeProvider(c *gin.Context) *ClaudeProvider {
return &ClaudeProvider{
ProviderConfig: ProviderConfig{
BaseProvider: base.BaseProvider{
BaseURL: "https://api.anthropic.com",
ChatCompletions: "/v1/complete",
Context: c,
@ -26,14 +23,9 @@ func CreateClaudeProvider(c *gin.Context) *ClaudeProvider {
// 获取请求头
func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["x-api-key"] = p.Context.GetString("api_key")
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json"
}
anthropicVersion := p.Context.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"

View File

@ -1,4 +1,4 @@
package providers
package claude
import (
"bufio"
@ -11,31 +11,7 @@ import (
"strings"
)
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ClaudeResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
Usage *types.Usage `json:"usage,omitempty"`
}
func (claudeResponse *ClaudeResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (claudeResponse *ClaudeResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if claudeResponse.Error.Type != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
@ -101,7 +77,7 @@ func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest
return &claudeRequest
}
func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *ClaudeProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
@ -117,8 +93,8 @@ func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionReque
if request.Stream {
var responseText string
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
if openAIErrorWithStatusCode != nil {
errWithCode, responseText = p.sendStreamRequest(req)
if errWithCode != nil {
return
}
@ -132,8 +108,8 @@ func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionReque
PromptTokens: promptTokens,
},
}
openAIErrorWithStatusCode = p.sendRequest(req, claudeResponse)
if openAIErrorWithStatusCode != nil {
errWithCode = p.SendRequest(req, claudeResponse)
if errWithCode != nil {
return
}
@ -165,7 +141,7 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), ""
return p.HandleErrorResp(resp), ""
}
defer resp.Body.Close()
@ -199,7 +175,7 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
}
stopChan <- true
}()
setEventStreamHeaders(p.Context)
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

32
providers/claude/type.go Normal file
View File

@ -0,0 +1,32 @@
package claude
import "one-api/types"
type ClaudeError struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ClaudeResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
Usage *types.Usage `json:"usage,omitempty"`
}

View File

@ -1,31 +1,11 @@
package providers
package closeai
import (
"fmt"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
type CloseaiProxyProvider struct {
*OpenAIProvider
}
type OpenAICreditGrants struct {
Object string `json:"object"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalAvailable float64 `json:"total_available"`
}
// 创建 CloseaiProxyProvider
func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider {
return &CloseaiProxyProvider{
OpenAIProvider: CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"),
}
}
func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)

18
providers/closeai/base.go Normal file
View File

@ -0,0 +1,18 @@
package closeai
import (
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type CloseaiProxyProvider struct {
*openai.OpenAIProvider
}
// 创建 CloseaiProxyProvider
func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider {
return &CloseaiProxyProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"),
}
}

View File

@ -0,0 +1,8 @@
package closeai
type OpenAICreditGrants struct {
Object string `json:"object"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalAvailable float64 `json:"total_available"`
}

View File

@ -1,4 +1,4 @@
package providers
package openai
import (
"bufio"
@ -11,32 +11,25 @@ import (
"one-api/types"
"strings"
"one-api/providers/base"
"github.com/gin-gonic/gin"
)
type OpenAIProvider struct {
ProviderConfig
isAzure bool
}
type OpenAIProviderResponseHandler interface {
// 请求处理函数
requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
}
type OpenAIProviderStreamResponseHandler interface {
// 请求流处理函数
requestStreamHandler() (responseText string)
base.BaseProvider
IsAzure bool
}
// 创建 OpenAIProvider
// https://platform.openai.com/docs/api-reference/introduction
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
if baseURL == "" {
baseURL = "https://api.openai.com"
}
return &OpenAIProvider{
ProviderConfig: ProviderConfig{
BaseProvider: base.BaseProvider{
BaseURL: baseURL,
Completions: "/v1/completions",
ChatCompletions: "/v1/chat/completions",
@ -46,7 +39,7 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
AudioTranslations: "/v1/audio/translations",
Context: c,
},
isAzure: false,
IsAzure: false,
}
}
@ -54,13 +47,13 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
if p.isAzure {
if p.IsAzure {
apiVersion := p.Context.GetString("api_version")
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
}
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
if p.isAzure {
if p.IsAzure {
requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
} else {
requestURL = strings.TrimPrefix(requestURL, "/v1")
@ -73,16 +66,12 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)
// 获取请求头
func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
if p.isAzure {
p.CommonRequestHeaders(headers)
if p.IsAzure {
headers["api-key"] = p.Context.GetString("api_key")
} else {
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
}
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json; charset=utf-8"
}
return headers
}
@ -114,7 +103,7 @@ func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderR
// 处理响应
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp)
return p.HandleErrorResp(resp)
}
// 创建一个 bytes.Buffer 来存储响应体
@ -127,7 +116,7 @@ func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderR
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
openAIErrorWithStatusCode = response.requestHandler(resp)
openAIErrorWithStatusCode = response.responseHandler(resp)
if openAIErrorWithStatusCode != nil {
return
}
@ -145,6 +134,7 @@ func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderR
return nil
}
// 发送流式请求
func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
resp, err := common.HttpClient.Do(req)
@ -153,7 +143,7 @@ func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIPro
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), ""
return p.HandleErrorResp(resp), ""
}
defer resp.Body.Close()
@ -190,12 +180,12 @@ func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIPro
common.SysError("error unmarshalling stream response: " + err.Error())
continue // just ignore the error
}
responseText += response.requestStreamHandler()
responseText += response.responseStreamHandler()
}
}
stopChan <- true
}()
setEventStreamHeaders(p.Context)
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

View File

@ -1,4 +1,4 @@
package providers
package openai
import (
"net/http"
@ -6,19 +6,9 @@ import (
"one-api/types"
)
type OpenAIProviderChatResponse struct {
types.ChatCompletionResponse
types.OpenAIErrorResponse
}
type OpenAIProviderChatStreamResponse struct {
types.ChatCompletionStreamResponse
types.OpenAIErrorResponse
}
func (c *OpenAIProviderChatResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (c *OpenAIProviderChatResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
@ -27,7 +17,7 @@ func (c *OpenAIProviderChatResponse) requestHandler(resp *http.Response) (openAI
return nil
}
func (c *OpenAIProviderChatStreamResponse) requestStreamHandler() (responseText string) {
func (c *OpenAIProviderChatStreamResponse) responseStreamHandler() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Delta.Content
}
@ -35,7 +25,7 @@ func (c *OpenAIProviderChatStreamResponse) requestStreamHandler() (responseText
return
}
func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.getRequestBody(&request, isModelMapped)
if err != nil {
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
@ -56,8 +46,8 @@ func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionReque
if request.Stream {
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
var textResponse string
openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse)
if openAIErrorWithStatusCode != nil {
errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse)
if errWithCode != nil {
return
}
@ -69,8 +59,8 @@ func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionReque
} else {
openAIProviderChatResponse := &OpenAIProviderChatResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderChatResponse)
if openAIErrorWithStatusCode != nil {
errWithCode = p.sendRequest(req, openAIProviderChatResponse)
if errWithCode != nil {
return
}

View File

@ -1,4 +1,4 @@
package providers
package openai
import (
"net/http"
@ -6,14 +6,9 @@ import (
"one-api/types"
)
type OpenAIProviderCompletionResponse struct {
types.CompletionResponse
types.OpenAIErrorResponse
}
func (c *OpenAIProviderCompletionResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (c *OpenAIProviderCompletionResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
@ -22,7 +17,7 @@ func (c *OpenAIProviderCompletionResponse) requestHandler(resp *http.Response) (
return nil
}
func (c *OpenAIProviderCompletionResponse) requestStreamHandler() (responseText string) {
func (c *OpenAIProviderCompletionResponse) responseStreamHandler() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Text
}
@ -30,7 +25,7 @@ func (c *OpenAIProviderCompletionResponse) requestStreamHandler() (responseText
return
}
func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.getRequestBody(&request, isModelMapped)
if err != nil {
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
@ -52,8 +47,8 @@ func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isMo
if request.Stream {
// TODO
var textResponse string
openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse)
if openAIErrorWithStatusCode != nil {
errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse)
if errWithCode != nil {
return
}
@ -64,8 +59,8 @@ func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isMo
}
} else {
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderCompletionResponse)
if openAIErrorWithStatusCode != nil {
errWithCode = p.sendRequest(req, openAIProviderCompletionResponse)
if errWithCode != nil {
return
}

View File

@ -1,4 +1,4 @@
package providers
package openai
import (
"net/http"
@ -6,14 +6,9 @@ import (
"one-api/types"
)
type OpenAIProviderEmbeddingsResponse struct {
types.EmbeddingResponse
types.OpenAIErrorResponse
}
func (c *OpenAIProviderEmbeddingsResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (c *OpenAIProviderEmbeddingsResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
@ -22,7 +17,7 @@ func (c *OpenAIProviderEmbeddingsResponse) requestHandler(resp *http.Response) (
return nil
}
func (p *OpenAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.getRequestBody(&request, isModelMapped)
if err != nil {
@ -39,8 +34,8 @@ func (p *OpenAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isM
}
openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderEmbeddingsResponse)
if openAIErrorWithStatusCode != nil {
errWithCode = p.sendRequest(req, openAIProviderEmbeddingsResponse)
if errWithCode != nil {
return
}

View File

@ -0,0 +1,16 @@
package openai
import (
"net/http"
"one-api/types"
)
type OpenAIProviderResponseHandler interface {
// 请求处理函数
responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode)
}
type OpenAIProviderStreamResponseHandler interface {
// 请求流处理函数
responseStreamHandler() (responseText string)
}

23
providers/openai/type.go Normal file
View File

@ -0,0 +1,23 @@
package openai
import "one-api/types"
type OpenAIProviderChatResponse struct {
types.ChatCompletionResponse
types.OpenAIErrorResponse
}
type OpenAIProviderChatStreamResponse struct {
types.ChatCompletionStreamResponse
types.OpenAIErrorResponse
}
type OpenAIProviderCompletionResponse struct {
types.CompletionResponse
types.OpenAIErrorResponse
}
type OpenAIProviderEmbeddingsResponse struct {
types.EmbeddingResponse
types.OpenAIErrorResponse
}

View File

@ -1,4 +1,4 @@
package providers
package openaisb
import (
"errors"
@ -6,28 +6,8 @@ import (
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-gonic/gin"
)
type OpenaiSBProvider struct {
*OpenAIProvider
}
type OpenAISBUsageResponse struct {
Msg string `json:"msg"`
Data *struct {
Credit string `json:"credit"`
} `json:"data"`
}
// 创建 OpenaiSBProvider
func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider {
return &OpenaiSBProvider{
OpenAIProvider: CreateOpenAIProvider(c, "https://api.openai-sb.com"),
}
}
func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)

View File

@ -0,0 +1,18 @@
package openaisb
import (
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type OpenaiSBProvider struct {
*openai.OpenAIProvider
}
// 创建 OpenaiSBProvider
func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider {
return &OpenaiSBProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.openai-sb.com"),
}
}

View File

@ -0,0 +1,8 @@
package openaisb
type OpenAISBUsageResponse struct {
Msg string `json:"msg"`
Data *struct {
Credit string `json:"credit"`
} `json:"data"`
}

View File

@ -1,20 +1,21 @@
package providers
package palm
import (
"fmt"
"one-api/providers/base"
"strings"
"github.com/gin-gonic/gin"
)
type PalmProvider struct {
ProviderConfig
base.BaseProvider
}
// 创建 PalmProvider
func CreatePalmProvider(c *gin.Context) *PalmProvider {
return &PalmProvider{
ProviderConfig: ProviderConfig{
BaseProvider: base.BaseProvider{
BaseURL: "https://generativelanguage.googleapis.com",
ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage",
Context: c,
@ -25,12 +26,7 @@ func CreatePalmProvider(c *gin.Context) *PalmProvider {
// 获取请求头
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json"
}
p.CommonRequestHeaders(headers)
return headers
}

View File

@ -1,4 +1,4 @@
package providers
package palm
import (
"encoding/json"
@ -6,47 +6,11 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/providers/base"
"one-api/types"
)
type PaLMChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
type PaLMPrompt struct {
Messages []PaLMChatMessage `json:"messages"`
}
type PaLMChatRequest struct {
Prompt PaLMPrompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
}
type PaLMError struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type PaLMChatResponse struct {
Candidates []PaLMChatMessage `json:"candidates"`
Messages []types.ChatCompletionMessage `json:"messages"`
Filters []PaLMFilter `json:"filters"`
Error PaLMError `json:"error"`
Usage *types.Usage `json:"usage,omitempty"`
Model string `json:"model,omitempty"`
}
func (palmResponse *PaLMChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
@ -107,7 +71,7 @@ func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest)
return &palmRequest
}
func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *PalmProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
@ -123,8 +87,8 @@ func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest
if request.Stream {
var responseText string
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
if openAIErrorWithStatusCode != nil {
errWithCode, responseText = p.sendStreamRequest(req)
if errWithCode != nil {
return
}
@ -139,8 +103,8 @@ func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest
PromptTokens: promptTokens,
},
}
openAIErrorWithStatusCode = p.sendRequest(req, palmChatResponse)
if openAIErrorWithStatusCode != nil {
errWithCode = p.SendRequest(req, palmChatResponse)
if errWithCode != nil {
return
}
@ -155,7 +119,7 @@ func (p *PalmProvider) streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse)
if len(palmResponse.Candidates) > 0 {
choice.Delta.Content = palmResponse.Candidates[0].Content
}
choice.FinishReason = &stopFinishReason
choice.FinishReason = &base.StopFinishReason
var response types.ChatCompletionStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "palm2"
@ -171,7 +135,7 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), ""
return p.HandleErrorResp(resp), ""
}
defer resp.Body.Close()
@ -216,7 +180,7 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW
dataChan <- string(jsonResponse)
stopChan <- true
}()
setEventStreamHeaders(p.Context)
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

40
providers/palm/type.go Normal file
View File

@ -0,0 +1,40 @@
package palm
import "one-api/types"
type PaLMChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
type PaLMPrompt struct {
Messages []PaLMChatMessage `json:"messages"`
}
type PaLMChatRequest struct {
Prompt PaLMPrompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
}
type PaLMError struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type PaLMChatResponse struct {
Candidates []PaLMChatMessage `json:"candidates"`
Messages []types.ChatCompletionMessage `json:"messages"`
Filters []PaLMFilter `json:"filters"`
Error PaLMError `json:"error"`
Usage *types.Usage `json:"usage,omitempty"`
Model string `json:"model,omitempty"`
}

50
providers/providers.go Normal file
View File

@ -0,0 +1,50 @@
package providers
import (
"one-api/common"
"one-api/providers/ali"
"one-api/providers/azure"
"one-api/providers/baidu"
"one-api/providers/base"
"one-api/providers/claude"
"one-api/providers/openai"
"one-api/providers/palm"
"one-api/providers/tencent"
"one-api/providers/xunfei"
"one-api/providers/zhipu"
"github.com/gin-gonic/gin"
)
func GetProvider(channelType int, c *gin.Context) base.ProviderInterface {
switch channelType {
case common.ChannelTypeOpenAI:
return openai.CreateOpenAIProvider(c, "")
case common.ChannelTypeAzure:
return azure.CreateAzureProvider(c)
case common.ChannelTypeAli:
return ali.CreateAliAIProvider(c)
case common.ChannelTypeTencent:
return tencent.CreateTencentProvider(c)
case common.ChannelTypeBaidu:
return baidu.CreateBaiduProvider(c)
case common.ChannelTypeAnthropic:
return claude.CreateClaudeProvider(c)
case common.ChannelTypePaLM:
return palm.CreatePalmProvider(c)
case common.ChannelTypeZhipu:
return zhipu.CreateZhipuProvider(c)
case common.ChannelTypeXunfei:
return xunfei.CreateXunfeiProvider(c)
default:
baseURL := common.ChannelBaseURLs[channelType]
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
if baseURL != "" {
return openai.CreateOpenAIProvider(c, baseURL)
}
return nil
}
}

View File

@ -1,4 +1,4 @@
package providers
package tencent
import (
"crypto/hmac"
@ -6,6 +6,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"one-api/providers/base"
"sort"
"strconv"
"strings"
@ -14,18 +15,13 @@ import (
)
type TencentProvider struct {
ProviderConfig
}
type TencentError struct {
Code int `json:"code"`
Message string `json:"message"`
base.BaseProvider
}
// 创建 TencentProvider
func CreateTencentProvider(c *gin.Context) *TencentProvider {
return &TencentProvider{
ProviderConfig: ProviderConfig{
BaseProvider: base.BaseProvider{
BaseURL: "https://hunyuan.cloud.tencent.com",
ChatCompletions: "/hyllm/v1/chat/completions",
Context: c,
@ -36,12 +32,7 @@ func CreateTencentProvider(c *gin.Context) *TencentProvider {
// 获取请求头
func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json"
}
p.CommonRequestHeaders(headers)
return headers
}

View File

@ -1,4 +1,4 @@
package providers
package tencent
import (
"bufio"
@ -7,64 +7,12 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/providers/base"
"one-api/types"
"strings"
)
type TencentMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type TencentChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []TencentMessage `json:"messages"`
}
type TencentUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage *types.Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
func (TencentResponse *TencentChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (TencentResponse *TencentChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if TencentResponse.Error.Code != 0 {
return &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
@ -130,7 +78,7 @@ func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionReques
}
}
func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
sign := p.getTencentSign(*requestBody)
if sign == "" {
@ -152,8 +100,8 @@ func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequ
if request.Stream {
var responseText string
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
if openAIErrorWithStatusCode != nil {
errWithCode, responseText = p.sendStreamRequest(req)
if errWithCode != nil {
return
}
@ -163,8 +111,8 @@ func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequ
} else {
tencentResponse := &TencentChatResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, tencentResponse)
if openAIErrorWithStatusCode != nil {
errWithCode = p.SendRequest(req, tencentResponse)
if errWithCode != nil {
return
}
@ -184,7 +132,7 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &stopFinishReason
choice.FinishReason = &base.StopFinishReason
}
response.Choices = append(response.Choices, choice)
}
@ -199,7 +147,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), ""
return p.HandleErrorResp(resp), ""
}
defer resp.Body.Close()
@ -234,7 +182,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr
}
stopChan <- true
}()
setEventStreamHeaders(p.Context)
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

61
providers/tencent/type.go Normal file
View File

@ -0,0 +1,61 @@
package tencent
import "one-api/types"
type TencentError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type TencentMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type TencentChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []TencentMessage `json:"messages"`
}
type TencentUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage *types.Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}

View File

@ -1,4 +1,4 @@
package providers
package xunfei
import (
"crypto/hmac"
@ -7,6 +7,7 @@ import (
"fmt"
"net/url"
"one-api/common"
"one-api/providers/base"
"strings"
"time"
@ -15,7 +16,7 @@ import (
// https://www.xfyun.cn/doc/spark/Web.html
type XunfeiProvider struct {
ProviderConfig
base.BaseProvider
domain string
apiId string
}
@ -23,7 +24,7 @@ type XunfeiProvider struct {
// 创建 XunfeiProvider
func CreateXunfeiProvider(c *gin.Context) *XunfeiProvider {
return &XunfeiProvider{
ProviderConfig: ProviderConfig{
BaseProvider: base.BaseProvider{
BaseURL: "wss://spark-api.xf-yun.com",
ChatCompletions: "",
Context: c,

View File

@ -1,73 +1,18 @@
package providers
package xunfei
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/providers/base"
"one-api/types"
"time"
"github.com/gorilla/websocket"
)
type XunfeiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type XunfeiChatRequest struct {
Header struct {
AppId string `json:"app_id"`
} `json:"header"`
Parameter struct {
Chat struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`
Payload struct {
Message struct {
Text []XunfeiMessage `json:"text"`
} `json:"message"`
} `json:"payload"`
}
type XunfeiChatResponseTextItem struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
}
type XunfeiChatResponse struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []XunfeiChatResponseTextItem `json:"text"`
} `json:"choices"`
Usage struct {
//Text struct {
// QuestionTokens string `json:"question_tokens"`
// PromptTokens string `json:"prompt_tokens"`
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
Text types.Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
func (p *XunfeiProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *XunfeiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model)
if request.Stream {
@ -77,7 +22,7 @@ func (p *XunfeiProvider) ChatCompleteResponse(request *types.ChatCompletionReque
}
}
func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
if err != nil {
@ -113,13 +58,13 @@ func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authU
return usage, nil
}
func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
if err != nil {
return nil, types.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
}
setEventStreamHeaders(p.Context)
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
@ -185,7 +130,7 @@ func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *ty
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
},
FinishReason: stopFinishReason,
FinishReason: base.StopFinishReason,
}
fullTextResponse := types.ChatCompletionResponse{
Object: "chat.completion",
@ -251,7 +196,7 @@ func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatR
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &stopFinishReason
choice.FinishReason = &base.StopFinishReason
}
response := types.ChatCompletionStreamResponse{
Object: "chat.completion.chunk",

59
providers/xunfei/type.go Normal file
View File

@ -0,0 +1,59 @@
package xunfei
import "one-api/types"
type XunfeiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type XunfeiChatRequest struct {
Header struct {
AppId string `json:"app_id"`
} `json:"header"`
Parameter struct {
Chat struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`
Payload struct {
Message struct {
Text []XunfeiMessage `json:"text"`
} `json:"message"`
} `json:"payload"`
}
type XunfeiChatResponseTextItem struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
}
type XunfeiChatResponse struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []XunfeiChatResponseTextItem `json:"text"`
} `json:"choices"`
Usage struct {
//Text struct {
// QuestionTokens string `json:"question_tokens"`
// PromptTokens string `json:"prompt_tokens"`
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
Text types.Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}

View File

@ -1,8 +1,9 @@
package providers
package zhipu
import (
"fmt"
"one-api/common"
"one-api/providers/base"
"strings"
"sync"
"time"
@ -15,18 +16,13 @@ var zhipuTokens sync.Map
var expSeconds int64 = 24 * 3600
type ZhipuProvider struct {
ProviderConfig
}
type zhipuTokenData struct {
Token string
ExpiryTime time.Time
base.BaseProvider
}
// 创建 ZhipuProvider
func CreateZhipuProvider(c *gin.Context) *ZhipuProvider {
return &ZhipuProvider{
ProviderConfig: ProviderConfig{
BaseProvider: base.BaseProvider{
BaseURL: "https://open.bigmodel.cn",
ChatCompletions: "/api/paas/v3/model-api",
Context: c,
@ -37,13 +33,8 @@ func CreateZhipuProvider(c *gin.Context) *ZhipuProvider {
// 获取请求头
func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["Authorization"] = p.getZhipuToken()
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json"
}
return headers
}

View File

@ -1,4 +1,4 @@
package providers
package zhipu
import (
"bufio"
@ -6,46 +6,12 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/providers/base"
"one-api/types"
"strings"
)
type ZhipuMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ZhipuRequest struct {
Prompt []ZhipuMessage `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
RequestId string `json:"request_id,omitempty"`
Incremental bool `json:"incremental,omitempty"`
}
type ZhipuResponseData struct {
TaskId string `json:"task_id"`
RequestId string `json:"request_id"`
TaskStatus string `json:"task_status"`
Choices []ZhipuMessage `json:"choices"`
types.Usage `json:"usage"`
}
type ZhipuResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Success bool `json:"success"`
Data ZhipuResponseData `json:"data"`
}
type ZhipuStreamMetaResponse struct {
RequestId string `json:"request_id"`
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
types.Usage `json:"usage"`
}
func (zhipuResponse *ZhipuResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if !zhipuResponse.Success {
return &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
@ -110,7 +76,7 @@ func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest)
}
}
func (p *ZhipuProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *ZhipuProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
@ -128,15 +94,15 @@ func (p *ZhipuProvider) ChatCompleteResponse(request *types.ChatCompletionReques
}
if request.Stream {
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
if openAIErrorWithStatusCode != nil {
errWithCode, usage = p.sendStreamRequest(req)
if errWithCode != nil {
return
}
} else {
zhipuResponse := &ZhipuResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, zhipuResponse)
if openAIErrorWithStatusCode != nil {
errWithCode = p.SendRequest(req, zhipuResponse)
if errWithCode != nil {
return
}
@ -161,7 +127,7 @@ func (p *ZhipuProvider) streamResponseZhipu2OpenAI(zhipuResponse string) *types.
func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*types.ChatCompletionStreamResponse, *types.Usage) {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = ""
choice.FinishReason = &stopFinishReason
choice.FinishReason = &base.StopFinishReason
response := types.ChatCompletionStreamResponse{
ID: zhipuResponse.RequestId,
Object: "chat.completion.chunk",
@ -180,7 +146,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), nil
return p.HandleErrorResp(resp), nil
}
defer resp.Body.Close()
@ -222,7 +188,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
}
stopChan <- true
}()
setEventStreamHeaders(p.Context)
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

46
providers/zhipu/type.go Normal file
View File

@ -0,0 +1,46 @@
package zhipu
import (
"one-api/types"
"time"
)
type ZhipuMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ZhipuRequest struct {
Prompt []ZhipuMessage `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
RequestId string `json:"request_id,omitempty"`
Incremental bool `json:"incremental,omitempty"`
}
type ZhipuResponseData struct {
TaskId string `json:"task_id"`
RequestId string `json:"request_id"`
TaskStatus string `json:"task_status"`
Choices []ZhipuMessage `json:"choices"`
types.Usage `json:"usage"`
}
type ZhipuResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Success bool `json:"success"`
Data ZhipuResponseData `json:"data"`
}
type ZhipuStreamMetaResponse struct {
RequestId string `json:"request_id"`
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
types.Usage `json:"usage"`
}
type zhipuTokenData struct {
Token string
ExpiryTime time.Time
}