🎨 调整供应商目录结构,合并文本输出函数
This commit is contained in:
parent
902c2faa2c
commit
544f20cc73
@ -6,6 +6,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var HttpClient *http.Client
|
||||
@ -124,3 +126,11 @@ func DecodeString(body io.Reader, output *string) error {
|
||||
*output = string(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
||||
}
|
||||
|
||||
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
|
||||
_, openAIErrorWithStatusCode := chatProvider.ChatCompleteResponse(&request, isModelMapped, promptTokens)
|
||||
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, isModelMapped, promptTokens)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return nil, &openAIErrorWithStatusCode.OpenAIError
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@ -541,7 +542,7 @@ func RetrieveModel(c *gin.Context) {
|
||||
if model, ok := openAIModelsMap[modelId]; ok {
|
||||
c.JSON(200, model)
|
||||
} else {
|
||||
openAIError := OpenAIError{
|
||||
openAIError := types.OpenAIError{
|
||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||
Type: "invalid_request_error",
|
||||
Param: "model",
|
||||
|
@ -1,127 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func relayChatHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode {
|
||||
|
||||
// 获取请求参数
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
// consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
|
||||
// 获取 Provider
|
||||
chatProvider := GetChatProvider(channelType, c)
|
||||
if chatProvider == nil {
|
||||
return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// 获取请求体
|
||||
var chatRequest types.ChatCompletionRequest
|
||||
err := common.UnmarshalBodyReusable(c, &chatRequest)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// 检查模型映射
|
||||
isModelMapped := false
|
||||
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap != nil && modelMap[chatRequest.Model] != "" {
|
||||
chatRequest.Model = modelMap[chatRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
// 开始计算Tokens
|
||||
var promptTokens int
|
||||
promptTokens = common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
|
||||
|
||||
// 计算预付费配额
|
||||
quotaInfo := &QuotaInfo{
|
||||
modelName: chatRequest.Model,
|
||||
promptTokens: promptTokens,
|
||||
userId: userId,
|
||||
channelId: channelId,
|
||||
tokenId: tokenId,
|
||||
}
|
||||
quotaInfo.initQuotaInfo(group)
|
||||
quota_err := quotaInfo.preQuotaConsumption()
|
||||
if quota_err != nil {
|
||||
return quota_err
|
||||
}
|
||||
|
||||
usage, openAIErrorWithStatusCode := chatProvider.ChatCompleteResponse(&chatRequest, isModelMapped, promptTokens)
|
||||
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
if quotaInfo.preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
return openAIErrorWithStatusCode
|
||||
}
|
||||
|
||||
tokenName := c.GetString("token_name")
|
||||
defer func(ctx context.Context) {
|
||||
go func() {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetChatProvider(channelType int, c *gin.Context) providers.ChatProviderAction {
|
||||
switch channelType {
|
||||
case common.ChannelTypeOpenAI:
|
||||
return providers.CreateOpenAIProvider(c, "")
|
||||
case common.ChannelTypeAzure:
|
||||
return providers.CreateAzureProvider(c)
|
||||
case common.ChannelTypeAli:
|
||||
return providers.CreateAliAIProvider(c)
|
||||
case common.ChannelTypeTencent:
|
||||
return providers.CreateTencentProvider(c)
|
||||
case common.ChannelTypeBaidu:
|
||||
return providers.CreateBaiduProvider(c)
|
||||
case common.ChannelTypeAnthropic:
|
||||
return providers.CreateClaudeProvider(c)
|
||||
case common.ChannelTypePaLM:
|
||||
return providers.CreatePalmProvider(c)
|
||||
case common.ChannelTypeZhipu:
|
||||
return providers.CreateZhipuProvider(c)
|
||||
case common.ChannelTypeXunfei:
|
||||
return providers.CreateXunfeiProvider(c)
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
|
||||
if baseURL != "" {
|
||||
return providers.CreateOpenAIProvider(c, baseURL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,113 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func relayCompletionHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode {
|
||||
|
||||
// 获取请求参数
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
// consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
|
||||
// 获取 Provider
|
||||
completionProvider := GetCompletionProvider(channelType, c)
|
||||
if completionProvider == nil {
|
||||
return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// 获取请求体
|
||||
var completionRequest types.CompletionRequest
|
||||
err := common.UnmarshalBodyReusable(c, &completionRequest)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// 检查模型映射
|
||||
isModelMapped := false
|
||||
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap != nil && modelMap[completionRequest.Model] != "" {
|
||||
completionRequest.Model = modelMap[completionRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
// 开始计算Tokens
|
||||
var promptTokens int
|
||||
promptTokens = common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
|
||||
|
||||
// 计算预付费配额
|
||||
quotaInfo := &QuotaInfo{
|
||||
modelName: completionRequest.Model,
|
||||
promptTokens: promptTokens,
|
||||
userId: userId,
|
||||
channelId: channelId,
|
||||
tokenId: tokenId,
|
||||
}
|
||||
quotaInfo.initQuotaInfo(group)
|
||||
quota_err := quotaInfo.preQuotaConsumption()
|
||||
if quota_err != nil {
|
||||
return quota_err
|
||||
}
|
||||
|
||||
usage, openAIErrorWithStatusCode := completionProvider.CompleteResponse(&completionRequest, isModelMapped, promptTokens)
|
||||
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
if quotaInfo.preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
return openAIErrorWithStatusCode
|
||||
}
|
||||
|
||||
tokenName := c.GetString("token_name")
|
||||
defer func(ctx context.Context) {
|
||||
go func() {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetCompletionProvider(channelType int, c *gin.Context) providers.CompletionProviderAction {
|
||||
switch channelType {
|
||||
case common.ChannelTypeOpenAI:
|
||||
return providers.CreateOpenAIProvider(c, "")
|
||||
case common.ChannelTypeAzure:
|
||||
return providers.CreateAzureProvider(c)
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
|
||||
if baseURL != "" {
|
||||
return providers.CreateOpenAIProvider(c, baseURL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,117 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func relayEmbeddingsHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode {
|
||||
|
||||
// 获取请求参数
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
// consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
|
||||
// 获取 Provider
|
||||
embeddingsProvider := GetEmbeddingsProvider(channelType, c)
|
||||
if embeddingsProvider == nil {
|
||||
return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// 获取请求体
|
||||
var embeddingsRequest types.EmbeddingRequest
|
||||
err := common.UnmarshalBodyReusable(c, &embeddingsRequest)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// 检查模型映射
|
||||
isModelMapped := false
|
||||
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
|
||||
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
// 开始计算Tokens
|
||||
var promptTokens int
|
||||
promptTokens = common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
|
||||
|
||||
// 计算预付费配额
|
||||
quotaInfo := &QuotaInfo{
|
||||
modelName: embeddingsRequest.Model,
|
||||
promptTokens: promptTokens,
|
||||
userId: userId,
|
||||
channelId: channelId,
|
||||
tokenId: tokenId,
|
||||
}
|
||||
quotaInfo.initQuotaInfo(group)
|
||||
quota_err := quotaInfo.preQuotaConsumption()
|
||||
if quota_err != nil {
|
||||
return quota_err
|
||||
}
|
||||
|
||||
usage, openAIErrorWithStatusCode := embeddingsProvider.EmbeddingsResponse(&embeddingsRequest, isModelMapped, promptTokens)
|
||||
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
if quotaInfo.preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
return openAIErrorWithStatusCode
|
||||
}
|
||||
|
||||
tokenName := c.GetString("token_name")
|
||||
defer func(ctx context.Context) {
|
||||
go func() {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetEmbeddingsProvider(channelType int, c *gin.Context) providers.EmbeddingsProviderAction {
|
||||
switch channelType {
|
||||
case common.ChannelTypeOpenAI:
|
||||
return providers.CreateOpenAIProvider(c, "")
|
||||
case common.ChannelTypeAzure:
|
||||
return providers.CreateAzureProvider(c)
|
||||
case common.ChannelTypeAli:
|
||||
return providers.CreateAliAIProvider(c)
|
||||
case common.ChannelTypeBaidu:
|
||||
return providers.CreateBaiduProvider(c)
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
|
||||
if baseURL != "" {
|
||||
return providers.CreateOpenAIProvider(c, baseURL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
160
controller/relay-text.go
Normal file
160
controller/relay-text.go
Normal file
@ -0,0 +1,160 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers"
|
||||
providers_base "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func relayTextHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode {
|
||||
// 获取请求参数
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
|
||||
// 获取 Provider
|
||||
provider := providers.GetProvider(channelType, c)
|
||||
if provider == nil {
|
||||
return types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var promptTokens int
|
||||
quotaInfo := &QuotaInfo{
|
||||
modelName: "",
|
||||
promptTokens: promptTokens,
|
||||
userId: userId,
|
||||
channelId: channelId,
|
||||
tokenId: tokenId,
|
||||
}
|
||||
|
||||
var usage *types.Usage
|
||||
var openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode
|
||||
|
||||
switch relayMode {
|
||||
case RelayModeChatCompletions:
|
||||
usage, openAIErrorWithStatusCode = handleChatCompletions(c, provider, modelMap, quotaInfo, group)
|
||||
case RelayModeCompletions:
|
||||
usage, openAIErrorWithStatusCode = handleCompletions(c, provider, modelMap, quotaInfo, group)
|
||||
case RelayModeEmbeddings:
|
||||
usage, openAIErrorWithStatusCode = handleEmbeddings(c, provider, modelMap, quotaInfo, group)
|
||||
default:
|
||||
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
if quotaInfo.preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
return openAIErrorWithStatusCode
|
||||
}
|
||||
|
||||
tokenName := c.GetString("token_name")
|
||||
defer func(ctx context.Context) {
|
||||
go func() {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleChatCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||
var chatRequest types.ChatCompletionRequest
|
||||
isModelMapped := false
|
||||
chatProvider, ok := provider.(providers_base.ChatInterface)
|
||||
if !ok {
|
||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
err := common.UnmarshalBodyReusable(c, &chatRequest)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
if modelMap != nil && modelMap[chatRequest.Model] != "" {
|
||||
chatRequest.Model = modelMap[chatRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
|
||||
|
||||
quotaInfo.modelName = chatRequest.Model
|
||||
quotaInfo.initQuotaInfo(group)
|
||||
quota_err := quotaInfo.preQuotaConsumption()
|
||||
if quota_err != nil {
|
||||
return nil, quota_err
|
||||
}
|
||||
return chatProvider.ChatAction(&chatRequest, isModelMapped, promptTokens)
|
||||
}
|
||||
|
||||
func handleCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||
var completionRequest types.CompletionRequest
|
||||
isModelMapped := false
|
||||
completionProvider, ok := provider.(providers_base.CompletionInterface)
|
||||
if !ok {
|
||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
err := common.UnmarshalBodyReusable(c, &completionRequest)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
if modelMap != nil && modelMap[completionRequest.Model] != "" {
|
||||
completionRequest.Model = modelMap[completionRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
|
||||
|
||||
quotaInfo.modelName = completionRequest.Model
|
||||
quotaInfo.initQuotaInfo(group)
|
||||
quota_err := quotaInfo.preQuotaConsumption()
|
||||
if quota_err != nil {
|
||||
return nil, quota_err
|
||||
}
|
||||
return completionProvider.CompleteAction(&completionRequest, isModelMapped, promptTokens)
|
||||
}
|
||||
|
||||
func handleEmbeddings(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||
var embeddingsRequest types.EmbeddingRequest
|
||||
isModelMapped := false
|
||||
embeddingsProvider, ok := provider.(providers_base.EmbeddingsInterface)
|
||||
if !ok {
|
||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
err := common.UnmarshalBodyReusable(c, &embeddingsRequest)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
|
||||
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
|
||||
|
||||
quotaInfo.modelName = embeddingsRequest.Model
|
||||
quotaInfo.initQuotaInfo(group)
|
||||
quota_err := quotaInfo.preQuotaConsumption()
|
||||
if quota_err != nil {
|
||||
return nil, quota_err
|
||||
}
|
||||
return embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens)
|
||||
}
|
@ -237,19 +237,19 @@ type CompletionsStreamResponse struct {
|
||||
func Relay(c *gin.Context) {
|
||||
var err *types.OpenAIErrorWithStatusCode
|
||||
|
||||
// relayMode := RelayModeUnknown
|
||||
relayMode := RelayModeUnknown
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
||||
err = relayChatHelper(c)
|
||||
// relayMode = RelayModeChatCompletions
|
||||
// err = relayChatHelper(c)
|
||||
relayMode = RelayModeChatCompletions
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
||||
err = relayCompletionHelper(c)
|
||||
// relayMode = RelayModeCompletions
|
||||
// err = relayCompletionHelper(c)
|
||||
relayMode = RelayModeCompletions
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
||||
err = relayEmbeddingsHelper(c)
|
||||
// err = relayEmbeddingsHelper(c)
|
||||
relayMode = RelayModeEmbeddings
|
||||
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
relayMode = RelayModeEmbeddings
|
||||
}
|
||||
// relayMode = RelayModeEmbeddings
|
||||
// } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
// relayMode = RelayModeEmbeddings
|
||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
// relayMode = RelayModeModerations
|
||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
@ -263,7 +263,7 @@ func Relay(c *gin.Context) {
|
||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
// relayMode = RelayModeAudioTranslation
|
||||
// }
|
||||
// switch relayMode {
|
||||
switch relayMode {
|
||||
// case RelayModeImagesGenerations:
|
||||
// err = relayImageHelper(c, relayMode)
|
||||
// case RelayModeAudioSpeech:
|
||||
@ -272,9 +272,9 @@ func Relay(c *gin.Context) {
|
||||
// fallthrough
|
||||
// case RelayModeAudioTranscription:
|
||||
// err = relayAudioHelper(c, relayMode)
|
||||
// default:
|
||||
// err = relayTextHelper(c, relayMode)
|
||||
// }
|
||||
default:
|
||||
err = relayTextHelper(c, relayMode)
|
||||
}
|
||||
if err != nil {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
retryTimesStr := c.Query("retry")
|
||||
|
35
providers/ali/base.go
Normal file
35
providers/ali/base.go
Normal file
@ -0,0 +1,35 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"one-api/providers/base"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AliProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
|
||||
// 创建 AliAIProvider
|
||||
func CreateAliAIProvider(c *gin.Context) *AliProvider {
|
||||
return &AliProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://dashscope.aliyuncs.com",
|
||||
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
|
||||
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||
Context: c,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *AliProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
||||
|
||||
return headers
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package ali
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@ -10,43 +10,10 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type AliMessage struct {
|
||||
User string `json:"user"`
|
||||
Bot string `json:"bot"`
|
||||
}
|
||||
|
||||
type AliInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
History []AliMessage `json:"history"`
|
||||
}
|
||||
|
||||
type AliParameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
}
|
||||
|
||||
type AliChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input AliInput `json:"input"`
|
||||
Parameters AliParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type AliChatResponse struct {
|
||||
Output AliOutput `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
|
||||
func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
// 阿里云响应处理
|
||||
func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if aliResponse.Code != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
@ -55,6 +22,8 @@ func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIR
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
choice := types.ChatCompletionChoice{
|
||||
@ -66,7 +35,7 @@ func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIR
|
||||
FinishReason: aliResponse.Output.FinishReason,
|
||||
}
|
||||
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
OpenAIResponse = types.ChatCompletionResponse{
|
||||
ID: aliResponse.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
@ -78,10 +47,11 @@ func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIR
|
||||
},
|
||||
}
|
||||
|
||||
return fullTextResponse, nil
|
||||
return
|
||||
}
|
||||
|
||||
func (p *AliAIProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
|
||||
// 获取聊天请求体
|
||||
func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
|
||||
messages := make([]AliMessage, 0, len(request.Messages))
|
||||
prompt := ""
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
@ -113,7 +83,8 @@ func (p *AliAIProvider) getChatRequestBody(request *types.ChatCompletionRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
// 聊天
|
||||
func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
@ -130,8 +101,8 @@ func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionReques
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
usage, errWithCode = p.sendStreamRequest(req)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -145,8 +116,8 @@ func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionReques
|
||||
|
||||
} else {
|
||||
aliResponse := &AliChatResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, aliResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode = p.SendRequest(req, aliResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -159,7 +130,8 @@ func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionReques
|
||||
return
|
||||
}
|
||||
|
||||
func (p *AliAIProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
|
||||
// 阿里云响应转OpenAI响应
|
||||
func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = aliResponse.Output.Text
|
||||
if aliResponse.Output.FinishReason != "null" {
|
||||
@ -177,16 +149,17 @@ func (p *AliAIProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) {
|
||||
// 发送流请求
|
||||
func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
usage = &types.Usage{}
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), nil
|
||||
return nil, p.HandleErrorResp(resp)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
@ -220,7 +193,7 @@ func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
lastResponseText := ""
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
@ -252,5 +225,5 @@ func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta
|
||||
}
|
||||
})
|
||||
|
||||
return nil, usage
|
||||
return
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package ali
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@ -6,30 +6,8 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type AliEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters *struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliEmbedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type AliEmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []AliEmbedding `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
|
||||
func (aliResponse *AliEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
// 嵌入请求处理
|
||||
func (aliResponse *AliEmbeddingResponse) ResponseHandler(resp *http.Response) (any, *types.OpenAIErrorWithStatusCode) {
|
||||
if aliResponse.Code != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
@ -60,7 +38,8 @@ func (aliResponse *AliEmbeddingResponse) requestHandler(resp *http.Response) (Op
|
||||
return openAIEmbeddingResponse, nil
|
||||
}
|
||||
|
||||
func (p *AliAIProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest {
|
||||
// 获取嵌入请求体
|
||||
func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest {
|
||||
return &AliEmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Input: struct {
|
||||
@ -71,7 +50,7 @@ func (p *AliAIProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AliAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *AliProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody := p.getEmbeddingsRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||
@ -84,8 +63,8 @@ func (p *AliAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isMo
|
||||
}
|
||||
|
||||
aliEmbeddingResponse := &AliEmbeddingResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, aliEmbeddingResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode = p.SendRequest(req, aliEmbeddingResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens}
|
70
providers/ali/type.go
Normal file
70
providers/ali/type.go
Normal file
@ -0,0 +1,70 @@
|
||||
package ali
|
||||
|
||||
type AliError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
}
|
||||
|
||||
type AliUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type AliMessage struct {
|
||||
User string `json:"user"`
|
||||
Bot string `json:"bot"`
|
||||
}
|
||||
|
||||
type AliInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
History []AliMessage `json:"history"`
|
||||
}
|
||||
|
||||
type AliParameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
}
|
||||
|
||||
type AliChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input AliInput `json:"input"`
|
||||
Parameters AliParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type AliChatResponse struct {
|
||||
Output AliOutput `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
|
||||
type AliEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters *struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliEmbedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type AliEmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []AliEmbedding `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AliAIProvider struct {
|
||||
ProviderConfig
|
||||
}
|
||||
|
||||
type AliError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
}
|
||||
|
||||
type AliUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// 创建 AliAIProvider
|
||||
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
|
||||
func CreateAliAIProvider(c *gin.Context) *AliAIProvider {
|
||||
return &AliAIProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseURL: "https://dashscope.aliyuncs.com",
|
||||
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
|
||||
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||
Context: c,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *AliAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
||||
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
18
providers/api2d/base.go
Normal file
18
providers/api2d/base.go
Normal file
@ -0,0 +1,18 @@
|
||||
package api2d
|
||||
|
||||
import (
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Api2dProvider struct {
|
||||
*openai.OpenAIProvider
|
||||
}
|
||||
|
||||
// 创建 Api2dProvider
|
||||
func CreateApi2dProvider(c *gin.Context) *Api2dProvider {
|
||||
return &Api2dProvider{
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://oa.api2d.net"),
|
||||
}
|
||||
}
|
@ -1,14 +0,0 @@
|
||||
package providers
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
type Api2dProvider struct {
|
||||
*OpenAIProvider
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
func CreateApi2dProvider(c *gin.Context) *Api2dProvider {
|
||||
return &Api2dProvider{
|
||||
OpenAIProvider: CreateOpenAIProvider(c, "https://oa.api2d.net"),
|
||||
}
|
||||
}
|
31
providers/azure/base.go
Normal file
31
providers/azure/base.go
Normal file
@ -0,0 +1,31 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AzureProvider struct {
|
||||
openai.OpenAIProvider
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
func CreateAzureProvider(c *gin.Context) *AzureProvider {
|
||||
return &AzureProvider{
|
||||
OpenAIProvider: openai.OpenAIProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "",
|
||||
Completions: "/completions",
|
||||
ChatCompletions: "/chat/completions",
|
||||
Embeddings: "/embeddings",
|
||||
AudioSpeech: "/audio/speech",
|
||||
AudioTranscriptions: "/audio/transcriptions",
|
||||
AudioTranslations: "/audio/translations",
|
||||
Context: c,
|
||||
},
|
||||
IsAzure: true,
|
||||
},
|
||||
}
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AzureProvider struct {
|
||||
OpenAIProvider
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
func CreateAzureProvider(c *gin.Context) *AzureProvider {
|
||||
return &AzureProvider{
|
||||
OpenAIProvider: OpenAIProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseURL: "",
|
||||
Completions: "/completions",
|
||||
ChatCompletions: "/chat/completions",
|
||||
Embeddings: "/embeddings",
|
||||
AudioSpeech: "/audio/speech",
|
||||
AudioTranscriptions: "/audio/transcriptions",
|
||||
AudioTranslations: "/audio/translations",
|
||||
Context: c,
|
||||
},
|
||||
isAzure: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// // 获取完整请求 URL
|
||||
// func (p *AzureProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
// apiVersion := p.Context.GetString("api_version")
|
||||
// requestURL = fmt.Sprintf("/openai/deployments/%s/%s?api-version=%s", modelName, requestURL, apiVersion)
|
||||
// baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
// if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
// requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
|
||||
// }
|
||||
|
||||
// return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
// }
|
@ -1,10 +1,11 @@
|
||||
package providers
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -15,20 +16,12 @@ import (
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
type BaiduProvider struct {
|
||||
ProviderConfig
|
||||
}
|
||||
|
||||
type BaiduAccessToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"`
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func CreateBaiduProvider(c *gin.Context) *BaiduProvider {
|
||||
return &BaiduProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://aip.baidubce.com",
|
||||
ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
|
||||
Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings",
|
||||
@ -59,12 +52,7 @@ func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) s
|
||||
// 获取请求头
|
||||
func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
p.CommonRequestHeaders(headers)
|
||||
|
||||
return headers
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@ -6,33 +6,12 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type BaiduMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type BaiduChatRequest struct {
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type BaiduChatResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage *types.Usage `json:"usage"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (baiduResponse *BaiduChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
@ -54,7 +33,7 @@ func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (Ope
|
||||
FinishReason: "stop",
|
||||
}
|
||||
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
OpenAIResponse = types.ChatCompletionResponse{
|
||||
ID: baiduResponse.Id,
|
||||
Object: "chat.completion",
|
||||
Created: baiduResponse.Created,
|
||||
@ -62,18 +41,7 @@ func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (Ope
|
||||
Usage: baiduResponse.Usage,
|
||||
}
|
||||
|
||||
return fullTextResponse, nil
|
||||
}
|
||||
|
||||
type BaiduChatStreamResponse struct {
|
||||
BaiduChatResponse
|
||||
SentenceId int `json:"sentence_id"`
|
||||
IsEnd bool `json:"is_end"`
|
||||
}
|
||||
|
||||
type BaiduError struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
return
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaiduChatRequest {
|
||||
@ -101,7 +69,7 @@ func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
if fullRequestURL == "" {
|
||||
@ -120,15 +88,15 @@ func (p *BaiduProvider) ChatCompleteResponse(request *types.ChatCompletionReques
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
usage, errWithCode = p.sendStreamRequest(req)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
baiduChatRequest := &BaiduChatResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, baiduChatRequest)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode = p.SendRequest(req, baiduChatRequest)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -142,7 +110,7 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = baiduResponse.Result
|
||||
if baiduResponse.IsEnd {
|
||||
choice.FinishReason = &stopFinishReason
|
||||
choice.FinishReason = &base.StopFinishReason
|
||||
}
|
||||
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
@ -155,16 +123,16 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) {
|
||||
func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
usage = &types.Usage{}
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), nil
|
||||
return nil, p.HandleErrorResp(resp)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
@ -195,7 +163,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
@ -224,5 +192,5 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta
|
||||
}
|
||||
})
|
||||
|
||||
return nil, usage
|
||||
return usage, nil
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@ -6,32 +6,13 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type BaiduEmbeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Data []BaiduEmbeddingData `json:"data"`
|
||||
Usage types.Usage `json:"usage"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *BaiduEmbeddingRequest {
|
||||
return &BaiduEmbeddingRequest{
|
||||
Input: request.ParseInput(),
|
||||
}
|
||||
}
|
||||
|
||||
func (baiduResponse *BaiduEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (baiduResponse *BaiduEmbeddingResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
@ -62,7 +43,7 @@ func (baiduResponse *BaiduEmbeddingResponse) requestHandler(resp *http.Response)
|
||||
return openAIEmbeddingResponse, nil
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *BaiduProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody := p.getEmbeddingsRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||
@ -78,8 +59,8 @@ func (p *BaiduProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isMo
|
||||
}
|
||||
|
||||
baiduEmbeddingResponse := &BaiduEmbeddingResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, baiduEmbeddingResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode = p.SendRequest(req, baiduEmbeddingResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
usage = &baiduEmbeddingResponse.Usage
|
66
providers/baidu/type.go
Normal file
66
providers/baidu/type.go
Normal file
@ -0,0 +1,66 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"one-api/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BaiduAccessToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"`
|
||||
}
|
||||
|
||||
type BaiduMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type BaiduChatRequest struct {
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type BaiduChatResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage *types.Usage `json:"usage"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
type BaiduEmbeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Data []BaiduEmbeddingData `json:"data"`
|
||||
Usage types.Usage `json:"usage"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
type BaiduChatStreamResponse struct {
|
||||
BaiduChatResponse
|
||||
SentenceId int `json:"sentence_id"`
|
||||
IsEnd bool `json:"is_end"`
|
||||
}
|
||||
|
||||
type BaiduError struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package base
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@ -6,7 +6,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -14,9 +13,9 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var stopFinishReason = "stop"
|
||||
var StopFinishReason = "stop"
|
||||
|
||||
type ProviderConfig struct {
|
||||
type BaseProvider struct {
|
||||
BaseURL string
|
||||
Completions string
|
||||
ChatCompletions string
|
||||
@ -28,32 +27,8 @@ type ProviderConfig struct {
|
||||
Context *gin.Context
|
||||
}
|
||||
|
||||
type BaseProviderAction interface {
|
||||
GetBaseURL() string
|
||||
GetFullRequestURL(requestURL string, modelName string) string
|
||||
GetRequestHeaders() (headers map[string]string)
|
||||
}
|
||||
|
||||
type CompletionProviderAction interface {
|
||||
BaseProviderAction
|
||||
CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type ChatProviderAction interface {
|
||||
BaseProviderAction
|
||||
ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type EmbeddingsProviderAction interface {
|
||||
BaseProviderAction
|
||||
EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type BalanceProviderAction interface {
|
||||
Balance(channel *model.Channel) (float64, error)
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) GetBaseURL() string {
|
||||
// 获取基础URL
|
||||
func (p *BaseProvider) GetBaseURL() string {
|
||||
if p.Context.GetString("base_url") != "" {
|
||||
return p.Context.GetString("base_url")
|
||||
}
|
||||
@ -61,21 +36,66 @@ func (p *ProviderConfig) GetBaseURL() string {
|
||||
return p.BaseURL
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
// 获取完整请求URL
|
||||
func (p *BaseProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
}
|
||||
|
||||
func setEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
// 获取请求头
|
||||
func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) {
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) handleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
// 发送请求
|
||||
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 处理响应
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.HandleErrorResp(resp)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
err = common.DecodeResponse(resp.Body, response)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIResponse, openAIErrorWithStatusCode := response.ResponseHandler(resp)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(openAIResponse)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
p.Context.Writer.Header().Set("Content-Type", "application/json")
|
||||
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = p.Context.Writer.Write(jsonResponse)
|
||||
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
func (p *BaseProvider) HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: types.OpenAIError{
|
||||
@ -105,46 +125,3 @@ func (p *ProviderConfig) handleErrorResp(resp *http.Response) (openAIErrorWithSt
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 供应商响应处理函数
|
||||
type ProviderResponseHandler interface {
|
||||
// 请求处理函数
|
||||
requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
func (p *ProviderConfig) sendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 处理响应
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
err = common.DecodeResponse(resp.Body, response)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIResponse, openAIErrorWithStatusCode := response.requestHandler(resp)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(openAIResponse)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
p.Context.Writer.Header().Set("Content-Type", "application/json")
|
||||
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = p.Context.Writer.Write(jsonResponse)
|
||||
return nil
|
||||
}
|
42
providers/base/interface.go
Normal file
42
providers/base/interface.go
Normal file
@ -0,0 +1,42 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
// 基础接口
|
||||
type ProviderInterface interface {
|
||||
GetBaseURL() string
|
||||
GetFullRequestURL(requestURL string, modelName string) string
|
||||
GetRequestHeaders() (headers map[string]string)
|
||||
}
|
||||
|
||||
// 完成接口
|
||||
type CompletionInterface interface {
|
||||
ProviderInterface
|
||||
CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 聊天接口
|
||||
type ChatInterface interface {
|
||||
ProviderInterface
|
||||
ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 嵌入接口
|
||||
type EmbeddingsInterface interface {
|
||||
ProviderInterface
|
||||
EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 余额接口
|
||||
type BalanceInterface interface {
|
||||
BalanceAction(channel *model.Channel) (float64, error)
|
||||
}
|
||||
|
||||
type ProviderResponseHandler interface {
|
||||
// 响应处理函数
|
||||
ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
@ -1,21 +1,18 @@
|
||||
package providers
|
||||
package claude
|
||||
|
||||
import (
|
||||
"one-api/providers/base"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ClaudeProvider struct {
|
||||
ProviderConfig
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func CreateClaudeProvider(c *gin.Context) *ClaudeProvider {
|
||||
return &ClaudeProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://api.anthropic.com",
|
||||
ChatCompletions: "/v1/complete",
|
||||
Context: c,
|
||||
@ -26,14 +23,9 @@ func CreateClaudeProvider(c *gin.Context) *ClaudeProvider {
|
||||
// 获取请求头
|
||||
func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
|
||||
headers["x-api-key"] = p.Context.GetString("api_key")
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
|
||||
anthropicVersion := p.Context.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@ -11,31 +11,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeResponse struct {
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Model string `json:"model"`
|
||||
Error ClaudeError `json:"error"`
|
||||
Usage *types.Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
func (claudeResponse *ClaudeResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (claudeResponse *ClaudeResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if claudeResponse.Error.Type != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
@ -101,7 +77,7 @@ func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *ClaudeProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
@ -117,8 +93,8 @@ func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionReque
|
||||
|
||||
if request.Stream {
|
||||
var responseText string
|
||||
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode, responseText = p.sendStreamRequest(req)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -132,8 +108,8 @@ func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionReque
|
||||
PromptTokens: promptTokens,
|
||||
},
|
||||
}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, claudeResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode = p.SendRequest(req, claudeResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -165,7 +141,7 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), ""
|
||||
return p.HandleErrorResp(resp), ""
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
@ -199,7 +175,7 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
32
providers/claude/type.go
Normal file
32
providers/claude/type.go
Normal file
@ -0,0 +1,32 @@
|
||||
package claude
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeResponse struct {
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Model string `json:"model"`
|
||||
Error ClaudeError `json:"error"`
|
||||
Usage *types.Usage `json:"usage,omitempty"`
|
||||
}
|
@ -1,31 +1,11 @@
|
||||
package providers
|
||||
package closeai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type CloseaiProxyProvider struct {
|
||||
*OpenAIProvider
|
||||
}
|
||||
|
||||
type OpenAICreditGrants struct {
|
||||
Object string `json:"object"`
|
||||
TotalGranted float64 `json:"total_granted"`
|
||||
TotalUsed float64 `json:"total_used"`
|
||||
TotalAvailable float64 `json:"total_available"`
|
||||
}
|
||||
|
||||
// 创建 CloseaiProxyProvider
|
||||
func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider {
|
||||
return &CloseaiProxyProvider{
|
||||
OpenAIProvider: CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
|
||||
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
|
18
providers/closeai/base.go
Normal file
18
providers/closeai/base.go
Normal file
@ -0,0 +1,18 @@
|
||||
package closeai
|
||||
|
||||
import (
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type CloseaiProxyProvider struct {
|
||||
*openai.OpenAIProvider
|
||||
}
|
||||
|
||||
// 创建 CloseaiProxyProvider
|
||||
func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider {
|
||||
return &CloseaiProxyProvider{
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"),
|
||||
}
|
||||
}
|
8
providers/closeai/type.go
Normal file
8
providers/closeai/type.go
Normal file
@ -0,0 +1,8 @@
|
||||
package closeai
|
||||
|
||||
type OpenAICreditGrants struct {
|
||||
Object string `json:"object"`
|
||||
TotalGranted float64 `json:"total_granted"`
|
||||
TotalUsed float64 `json:"total_used"`
|
||||
TotalAvailable float64 `json:"total_available"`
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@ -11,32 +11,25 @@ import (
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"one-api/providers/base"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OpenAIProvider struct {
|
||||
ProviderConfig
|
||||
isAzure bool
|
||||
}
|
||||
|
||||
type OpenAIProviderResponseHandler interface {
|
||||
// 请求处理函数
|
||||
requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type OpenAIProviderStreamResponseHandler interface {
|
||||
// 请求流处理函数
|
||||
requestStreamHandler() (responseText string)
|
||||
base.BaseProvider
|
||||
IsAzure bool
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
// https://platform.openai.com/docs/api-reference/introduction
|
||||
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
|
||||
return &OpenAIProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: baseURL,
|
||||
Completions: "/v1/completions",
|
||||
ChatCompletions: "/v1/chat/completions",
|
||||
@ -46,7 +39,7 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
||||
AudioTranslations: "/v1/audio/translations",
|
||||
Context: c,
|
||||
},
|
||||
isAzure: false,
|
||||
IsAzure: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -54,13 +47,13 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
||||
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
if p.isAzure {
|
||||
if p.IsAzure {
|
||||
apiVersion := p.Context.GetString("api_version")
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
if p.isAzure {
|
||||
if p.IsAzure {
|
||||
requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
|
||||
} else {
|
||||
requestURL = strings.TrimPrefix(requestURL, "/v1")
|
||||
@ -73,16 +66,12 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)
|
||||
// 获取请求头
|
||||
func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
if p.isAzure {
|
||||
p.CommonRequestHeaders(headers)
|
||||
if p.IsAzure {
|
||||
headers["api-key"] = p.Context.GetString("api_key")
|
||||
} else {
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
||||
}
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json; charset=utf-8"
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
@ -114,7 +103,7 @@ func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderR
|
||||
|
||||
// 处理响应
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp)
|
||||
return p.HandleErrorResp(resp)
|
||||
}
|
||||
|
||||
// 创建一个 bytes.Buffer 来存储响应体
|
||||
@ -127,7 +116,7 @@ func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderR
|
||||
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIErrorWithStatusCode = response.requestHandler(resp)
|
||||
openAIErrorWithStatusCode = response.responseHandler(resp)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
@ -145,6 +134,7 @@ func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderR
|
||||
return nil
|
||||
}
|
||||
|
||||
// 发送流式请求
|
||||
func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
|
||||
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
@ -153,7 +143,7 @@ func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIPro
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), ""
|
||||
return p.HandleErrorResp(resp), ""
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
@ -190,12 +180,12 @@ func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIPro
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue // just ignore the error
|
||||
}
|
||||
responseText += response.requestStreamHandler()
|
||||
responseText += response.responseStreamHandler()
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@ -6,19 +6,9 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type OpenAIProviderChatResponse struct {
|
||||
types.ChatCompletionResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
type OpenAIProviderChatStreamResponse struct {
|
||||
types.ChatCompletionStreamResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderChatResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (c *OpenAIProviderChatResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
@ -27,7 +17,7 @@ func (c *OpenAIProviderChatResponse) requestHandler(resp *http.Response) (openAI
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderChatStreamResponse) requestStreamHandler() (responseText string) {
|
||||
func (c *OpenAIProviderChatStreamResponse) responseStreamHandler() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Delta.Content
|
||||
}
|
||||
@ -35,7 +25,7 @@ func (c *OpenAIProviderChatStreamResponse) requestStreamHandler() (responseText
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
@ -56,8 +46,8 @@ func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionReque
|
||||
if request.Stream {
|
||||
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
|
||||
var textResponse string
|
||||
openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -69,8 +59,8 @@ func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionReque
|
||||
|
||||
} else {
|
||||
openAIProviderChatResponse := &OpenAIProviderChatResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderChatResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode = p.sendRequest(req, openAIProviderChatResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@ -6,14 +6,9 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type OpenAIProviderCompletionResponse struct {
|
||||
types.CompletionResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderCompletionResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (c *OpenAIProviderCompletionResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
@ -22,7 +17,7 @@ func (c *OpenAIProviderCompletionResponse) requestHandler(resp *http.Response) (
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderCompletionResponse) requestStreamHandler() (responseText string) {
|
||||
func (c *OpenAIProviderCompletionResponse) responseStreamHandler() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Text
|
||||
}
|
||||
@ -30,7 +25,7 @@ func (c *OpenAIProviderCompletionResponse) requestStreamHandler() (responseText
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
@ -52,8 +47,8 @@ func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isMo
|
||||
if request.Stream {
|
||||
// TODO
|
||||
var textResponse string
|
||||
openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -64,8 +59,8 @@ func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isMo
|
||||
}
|
||||
|
||||
} else {
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderCompletionResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode = p.sendRequest(req, openAIProviderCompletionResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@ -6,14 +6,9 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type OpenAIProviderEmbeddingsResponse struct {
|
||||
types.EmbeddingResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderEmbeddingsResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (c *OpenAIProviderEmbeddingsResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
@ -22,7 +17,7 @@ func (c *OpenAIProviderEmbeddingsResponse) requestHandler(resp *http.Response) (
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
@ -39,8 +34,8 @@ func (p *OpenAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isM
|
||||
}
|
||||
|
||||
openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderEmbeddingsResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode = p.sendRequest(req, openAIProviderEmbeddingsResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
16
providers/openai/interface.go
Normal file
16
providers/openai/interface.go
Normal file
@ -0,0 +1,16 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type OpenAIProviderResponseHandler interface {
|
||||
// 请求处理函数
|
||||
responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type OpenAIProviderStreamResponseHandler interface {
|
||||
// 请求流处理函数
|
||||
responseStreamHandler() (responseText string)
|
||||
}
|
23
providers/openai/type.go
Normal file
23
providers/openai/type.go
Normal file
@ -0,0 +1,23 @@
|
||||
package openai
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
type OpenAIProviderChatResponse struct {
|
||||
types.ChatCompletionResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
type OpenAIProviderChatStreamResponse struct {
|
||||
types.ChatCompletionStreamResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
type OpenAIProviderCompletionResponse struct {
|
||||
types.CompletionResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
type OpenAIProviderEmbeddingsResponse struct {
|
||||
types.EmbeddingResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package openaisb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@ -6,28 +6,8 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OpenaiSBProvider struct {
|
||||
*OpenAIProvider
|
||||
}
|
||||
|
||||
type OpenAISBUsageResponse struct {
|
||||
Msg string `json:"msg"`
|
||||
Data *struct {
|
||||
Credit string `json:"credit"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// 创建 OpenaiSBProvider
|
||||
func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider {
|
||||
return &OpenaiSBProvider{
|
||||
OpenAIProvider: CreateOpenAIProvider(c, "https://api.openai-sb.com"),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
|
||||
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
|
18
providers/openaisb/base.go
Normal file
18
providers/openaisb/base.go
Normal file
@ -0,0 +1,18 @@
|
||||
package openaisb
|
||||
|
||||
import (
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OpenaiSBProvider struct {
|
||||
*openai.OpenAIProvider
|
||||
}
|
||||
|
||||
// 创建 OpenaiSBProvider
|
||||
func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider {
|
||||
return &OpenaiSBProvider{
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.openai-sb.com"),
|
||||
}
|
||||
}
|
8
providers/openaisb/type.go
Normal file
8
providers/openaisb/type.go
Normal file
@ -0,0 +1,8 @@
|
||||
package openaisb
|
||||
|
||||
type OpenAISBUsageResponse struct {
|
||||
Msg string `json:"msg"`
|
||||
Data *struct {
|
||||
Credit string `json:"credit"`
|
||||
} `json:"data"`
|
||||
}
|
@ -1,20 +1,21 @@
|
||||
package providers
|
||||
package palm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/providers/base"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type PalmProvider struct {
|
||||
ProviderConfig
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
// 创建 PalmProvider
|
||||
func CreatePalmProvider(c *gin.Context) *PalmProvider {
|
||||
return &PalmProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://generativelanguage.googleapis.com",
|
||||
ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage",
|
||||
Context: c,
|
||||
@ -25,12 +26,7 @@ func CreatePalmProvider(c *gin.Context) *PalmProvider {
|
||||
// 获取请求头
|
||||
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
p.CommonRequestHeaders(headers)
|
||||
|
||||
return headers
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package palm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@ -6,47 +6,11 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type PaLMChatMessage struct {
|
||||
Author string `json:"author"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type PaLMFilter struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type PaLMPrompt struct {
|
||||
Messages []PaLMChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type PaLMChatRequest struct {
|
||||
Prompt PaLMPrompt `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type PaLMError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type PaLMChatResponse struct {
|
||||
Candidates []PaLMChatMessage `json:"candidates"`
|
||||
Messages []types.ChatCompletionMessage `json:"messages"`
|
||||
Filters []PaLMFilter `json:"filters"`
|
||||
Error PaLMError `json:"error"`
|
||||
Usage *types.Usage `json:"usage,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
func (palmResponse *PaLMChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
@ -107,7 +71,7 @@ func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest)
|
||||
return &palmRequest
|
||||
}
|
||||
|
||||
func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *PalmProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
@ -123,8 +87,8 @@ func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest
|
||||
|
||||
if request.Stream {
|
||||
var responseText string
|
||||
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode, responseText = p.sendStreamRequest(req)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -139,8 +103,8 @@ func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest
|
||||
PromptTokens: promptTokens,
|
||||
},
|
||||
}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, palmChatResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode = p.SendRequest(req, palmChatResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -155,7 +119,7 @@ func (p *PalmProvider) streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse)
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||
}
|
||||
choice.FinishReason = &stopFinishReason
|
||||
choice.FinishReason = &base.StopFinishReason
|
||||
var response types.ChatCompletionStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "palm2"
|
||||
@ -171,7 +135,7 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), ""
|
||||
return p.HandleErrorResp(resp), ""
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
@ -216,7 +180,7 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW
|
||||
dataChan <- string(jsonResponse)
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
40
providers/palm/type.go
Normal file
40
providers/palm/type.go
Normal file
@ -0,0 +1,40 @@
|
||||
package palm
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
type PaLMChatMessage struct {
|
||||
Author string `json:"author"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type PaLMFilter struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type PaLMPrompt struct {
|
||||
Messages []PaLMChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type PaLMChatRequest struct {
|
||||
Prompt PaLMPrompt `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type PaLMError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type PaLMChatResponse struct {
|
||||
Candidates []PaLMChatMessage `json:"candidates"`
|
||||
Messages []types.ChatCompletionMessage `json:"messages"`
|
||||
Filters []PaLMFilter `json:"filters"`
|
||||
Error PaLMError `json:"error"`
|
||||
Usage *types.Usage `json:"usage,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
}
|
50
providers/providers.go
Normal file
50
providers/providers.go
Normal file
@ -0,0 +1,50 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/providers/ali"
|
||||
"one-api/providers/azure"
|
||||
"one-api/providers/baidu"
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/claude"
|
||||
"one-api/providers/openai"
|
||||
"one-api/providers/palm"
|
||||
"one-api/providers/tencent"
|
||||
"one-api/providers/xunfei"
|
||||
"one-api/providers/zhipu"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetProvider(channelType int, c *gin.Context) base.ProviderInterface {
|
||||
switch channelType {
|
||||
case common.ChannelTypeOpenAI:
|
||||
return openai.CreateOpenAIProvider(c, "")
|
||||
case common.ChannelTypeAzure:
|
||||
return azure.CreateAzureProvider(c)
|
||||
case common.ChannelTypeAli:
|
||||
return ali.CreateAliAIProvider(c)
|
||||
case common.ChannelTypeTencent:
|
||||
return tencent.CreateTencentProvider(c)
|
||||
case common.ChannelTypeBaidu:
|
||||
return baidu.CreateBaiduProvider(c)
|
||||
case common.ChannelTypeAnthropic:
|
||||
return claude.CreateClaudeProvider(c)
|
||||
case common.ChannelTypePaLM:
|
||||
return palm.CreatePalmProvider(c)
|
||||
case common.ChannelTypeZhipu:
|
||||
return zhipu.CreateZhipuProvider(c)
|
||||
case common.ChannelTypeXunfei:
|
||||
return xunfei.CreateXunfeiProvider(c)
|
||||
default:
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
if baseURL != "" {
|
||||
return openai.CreateOpenAIProvider(c, baseURL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/providers/base"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -14,18 +15,13 @@ import (
|
||||
)
|
||||
|
||||
type TencentProvider struct {
|
||||
ProviderConfig
|
||||
}
|
||||
|
||||
type TencentError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
// 创建 TencentProvider
|
||||
func CreateTencentProvider(c *gin.Context) *TencentProvider {
|
||||
return &TencentProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://hunyuan.cloud.tencent.com",
|
||||
ChatCompletions: "/hyllm/v1/chat/completions",
|
||||
Context: c,
|
||||
@ -36,12 +32,7 @@ func CreateTencentProvider(c *gin.Context) *TencentProvider {
|
||||
// 获取请求头
|
||||
func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
p.CommonRequestHeaders(headers)
|
||||
|
||||
return headers
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@ -7,64 +7,12 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TencentMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type TencentChatRequest struct {
|
||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||
Expired int64 `json:"expired"`
|
||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||
Temperature float64 `json:"temperature"`
|
||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||
TopP float64 `json:"top_p"`
|
||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||
Stream int `json:"stream"`
|
||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||
// 输入 content 总数最大支持 3000 token。
|
||||
Messages []TencentMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type TencentUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type TencentResponseChoices struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
}
|
||||
|
||||
type TencentChatResponse struct {
|
||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"id,omitempty"` // 会话 id
|
||||
Usage *types.Usage `json:"usage,omitempty"` // token 数量
|
||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
||||
|
||||
func (TencentResponse *TencentChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (TencentResponse *TencentChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
return &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
@ -130,7 +78,7 @@ func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionReques
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
sign := p.getTencentSign(*requestBody)
|
||||
if sign == "" {
|
||||
@ -152,8 +100,8 @@ func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequ
|
||||
|
||||
if request.Stream {
|
||||
var responseText string
|
||||
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode, responseText = p.sendStreamRequest(req)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -163,8 +111,8 @@ func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequ
|
||||
|
||||
} else {
|
||||
tencentResponse := &TencentChatResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, tencentResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode = p.SendRequest(req, tencentResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -184,7 +132,7 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||
choice.FinishReason = &stopFinishReason
|
||||
choice.FinishReason = &base.StopFinishReason
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
}
|
||||
@ -199,7 +147,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), ""
|
||||
return p.HandleErrorResp(resp), ""
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
@ -234,7 +182,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
61
providers/tencent/type.go
Normal file
61
providers/tencent/type.go
Normal file
@ -0,0 +1,61 @@
|
||||
package tencent
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
type TencentError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type TencentMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type TencentChatRequest struct {
|
||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||
Expired int64 `json:"expired"`
|
||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||
Temperature float64 `json:"temperature"`
|
||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||
TopP float64 `json:"top_p"`
|
||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||
Stream int `json:"stream"`
|
||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||
// 输入 content 总数最大支持 3000 token。
|
||||
Messages []TencentMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type TencentUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type TencentResponseChoices struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
}
|
||||
|
||||
type TencentChatResponse struct {
|
||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"id,omitempty"` // 会话 id
|
||||
Usage *types.Usage `json:"usage,omitempty"` // token 数量
|
||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package xunfei
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -15,7 +16,7 @@ import (
|
||||
|
||||
// https://www.xfyun.cn/doc/spark/Web.html
|
||||
type XunfeiProvider struct {
|
||||
ProviderConfig
|
||||
base.BaseProvider
|
||||
domain string
|
||||
apiId string
|
||||
}
|
||||
@ -23,7 +24,7 @@ type XunfeiProvider struct {
|
||||
// 创建 XunfeiProvider
|
||||
func CreateXunfeiProvider(c *gin.Context) *XunfeiProvider {
|
||||
return &XunfeiProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "wss://spark-api.xf-yun.com",
|
||||
ChatCompletions: "",
|
||||
Context: c,
|
@ -1,73 +1,18 @@
|
||||
package providers
|
||||
package xunfei
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type XunfeiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type XunfeiChatRequest struct {
|
||||
Header struct {
|
||||
AppId string `json:"app_id"`
|
||||
} `json:"header"`
|
||||
Parameter struct {
|
||||
Chat struct {
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Auditing bool `json:"auditing,omitempty"`
|
||||
} `json:"chat"`
|
||||
} `json:"parameter"`
|
||||
Payload struct {
|
||||
Message struct {
|
||||
Text []XunfeiMessage `json:"text"`
|
||||
} `json:"message"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponseTextItem struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponse struct {
|
||||
Header struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Sid string `json:"sid"`
|
||||
Status int `json:"status"`
|
||||
} `json:"header"`
|
||||
Payload struct {
|
||||
Choices struct {
|
||||
Status int `json:"status"`
|
||||
Seq int `json:"seq"`
|
||||
Text []XunfeiChatResponseTextItem `json:"text"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
//Text struct {
|
||||
// QuestionTokens string `json:"question_tokens"`
|
||||
// PromptTokens string `json:"prompt_tokens"`
|
||||
// CompletionTokens string `json:"completion_tokens"`
|
||||
// TotalTokens string `json:"total_tokens"`
|
||||
//} `json:"text"`
|
||||
Text types.Usage `json:"text"`
|
||||
} `json:"usage"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *XunfeiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
|
||||
if request.Stream {
|
||||
@ -77,7 +22,7 @@ func (p *XunfeiProvider) ChatCompleteResponse(request *types.ChatCompletionReque
|
||||
}
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
usage = &types.Usage{}
|
||||
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
|
||||
if err != nil {
|
||||
@ -113,13 +58,13 @@ func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authU
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
usage = &types.Usage{}
|
||||
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
|
||||
}
|
||||
setEventStreamHeaders(p.Context)
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case xunfeiResponse := <-dataChan:
|
||||
@ -185,7 +130,7 @@ func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *ty
|
||||
Role: "assistant",
|
||||
Content: response.Payload.Choices.Text[0].Content,
|
||||
},
|
||||
FinishReason: stopFinishReason,
|
||||
FinishReason: base.StopFinishReason,
|
||||
}
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
Object: "chat.completion",
|
||||
@ -251,7 +196,7 @@ func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatR
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||
choice.FinishReason = &stopFinishReason
|
||||
choice.FinishReason = &base.StopFinishReason
|
||||
}
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
59
providers/xunfei/type.go
Normal file
59
providers/xunfei/type.go
Normal file
@ -0,0 +1,59 @@
|
||||
package xunfei
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
type XunfeiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type XunfeiChatRequest struct {
|
||||
Header struct {
|
||||
AppId string `json:"app_id"`
|
||||
} `json:"header"`
|
||||
Parameter struct {
|
||||
Chat struct {
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Auditing bool `json:"auditing,omitempty"`
|
||||
} `json:"chat"`
|
||||
} `json:"parameter"`
|
||||
Payload struct {
|
||||
Message struct {
|
||||
Text []XunfeiMessage `json:"text"`
|
||||
} `json:"message"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponseTextItem struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponse struct {
|
||||
Header struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Sid string `json:"sid"`
|
||||
Status int `json:"status"`
|
||||
} `json:"header"`
|
||||
Payload struct {
|
||||
Choices struct {
|
||||
Status int `json:"status"`
|
||||
Seq int `json:"seq"`
|
||||
Text []XunfeiChatResponseTextItem `json:"text"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
//Text struct {
|
||||
// QuestionTokens string `json:"question_tokens"`
|
||||
// PromptTokens string `json:"prompt_tokens"`
|
||||
// CompletionTokens string `json:"completion_tokens"`
|
||||
// TotalTokens string `json:"total_tokens"`
|
||||
//} `json:"text"`
|
||||
Text types.Usage `json:"text"`
|
||||
} `json:"usage"`
|
||||
} `json:"payload"`
|
||||
}
|
@ -1,8 +1,9 @@
|
||||
package providers
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -15,18 +16,13 @@ var zhipuTokens sync.Map
|
||||
var expSeconds int64 = 24 * 3600
|
||||
|
||||
type ZhipuProvider struct {
|
||||
ProviderConfig
|
||||
}
|
||||
|
||||
type zhipuTokenData struct {
|
||||
Token string
|
||||
ExpiryTime time.Time
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
// 创建 ZhipuProvider
|
||||
func CreateZhipuProvider(c *gin.Context) *ZhipuProvider {
|
||||
return &ZhipuProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://open.bigmodel.cn",
|
||||
ChatCompletions: "/api/paas/v3/model-api",
|
||||
Context: c,
|
||||
@ -37,13 +33,8 @@ func CreateZhipuProvider(c *gin.Context) *ZhipuProvider {
|
||||
// 获取请求头
|
||||
func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
|
||||
p.CommonRequestHeaders(headers)
|
||||
headers["Authorization"] = p.getZhipuToken()
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@ -6,46 +6,12 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ZhipuMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ZhipuRequest struct {
|
||||
Prompt []ZhipuMessage `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
RequestId string `json:"request_id,omitempty"`
|
||||
Incremental bool `json:"incremental,omitempty"`
|
||||
}
|
||||
|
||||
type ZhipuResponseData struct {
|
||||
TaskId string `json:"task_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
Choices []ZhipuMessage `json:"choices"`
|
||||
types.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ZhipuResponse struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Data ZhipuResponseData `json:"data"`
|
||||
}
|
||||
|
||||
type ZhipuStreamMetaResponse struct {
|
||||
RequestId string `json:"request_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
types.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
func (zhipuResponse *ZhipuResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if !zhipuResponse.Success {
|
||||
return &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
@ -110,7 +76,7 @@ func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *ZhipuProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
@ -128,15 +94,15 @@ func (p *ZhipuProvider) ChatCompleteResponse(request *types.ChatCompletionReques
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode, usage = p.sendStreamRequest(req)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
zhipuResponse := &ZhipuResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, zhipuResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
errWithCode = p.SendRequest(req, zhipuResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -161,7 +127,7 @@ func (p *ZhipuProvider) streamResponseZhipu2OpenAI(zhipuResponse string) *types.
|
||||
func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*types.ChatCompletionStreamResponse, *types.Usage) {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = ""
|
||||
choice.FinishReason = &stopFinishReason
|
||||
choice.FinishReason = &base.StopFinishReason
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
ID: zhipuResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
@ -180,7 +146,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), nil
|
||||
return p.HandleErrorResp(resp), nil
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
@ -222,7 +188,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
46
providers/zhipu/type.go
Normal file
46
providers/zhipu/type.go
Normal file
@ -0,0 +1,46 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"one-api/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ZhipuMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ZhipuRequest struct {
|
||||
Prompt []ZhipuMessage `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
RequestId string `json:"request_id,omitempty"`
|
||||
Incremental bool `json:"incremental,omitempty"`
|
||||
}
|
||||
|
||||
type ZhipuResponseData struct {
|
||||
TaskId string `json:"task_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
Choices []ZhipuMessage `json:"choices"`
|
||||
types.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ZhipuResponse struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Data ZhipuResponseData `json:"data"`
|
||||
}
|
||||
|
||||
type ZhipuStreamMetaResponse struct {
|
||||
RequestId string `json:"request_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
types.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type zhipuTokenData struct {
|
||||
Token string
|
||||
ExpiryTime time.Time
|
||||
}
|
Loading…
Reference in New Issue
Block a user