🎨 调整供应商目录结构,合并文本输出函数
This commit is contained in:
parent
902c2faa2c
commit
544f20cc73
@ -6,6 +6,8 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
var HttpClient *http.Client
|
var HttpClient *http.Client
|
||||||
@ -124,3 +126,11 @@ func DecodeString(body io.Reader, output *string) error {
|
|||||||
*output = string(b)
|
*output = string(b)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetEventStreamHeaders(c *gin.Context) {
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
}
|
||||||
|
@ -70,7 +70,7 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
|||||||
}
|
}
|
||||||
|
|
||||||
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
|
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
|
||||||
_, openAIErrorWithStatusCode := chatProvider.ChatCompleteResponse(&request, isModelMapped, promptTokens)
|
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, isModelMapped, promptTokens)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if openAIErrorWithStatusCode != nil {
|
||||||
return nil, &openAIErrorWithStatusCode.OpenAIError
|
return nil, &openAIErrorWithStatusCode.OpenAIError
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@ -541,7 +542,7 @@ func RetrieveModel(c *gin.Context) {
|
|||||||
if model, ok := openAIModelsMap[modelId]; ok {
|
if model, ok := openAIModelsMap[modelId]; ok {
|
||||||
c.JSON(200, model)
|
c.JSON(200, model)
|
||||||
} else {
|
} else {
|
||||||
openAIError := OpenAIError{
|
openAIError := types.OpenAIError{
|
||||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||||
Type: "invalid_request_error",
|
Type: "invalid_request_error",
|
||||||
Param: "model",
|
Param: "model",
|
||||||
|
@ -1,127 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"one-api/providers"
|
|
||||||
"one-api/types"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func relayChatHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode {
|
|
||||||
|
|
||||||
// 获取请求参数
|
|
||||||
channelType := c.GetInt("channel")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
// consumeQuota := c.GetBool("consume_quota")
|
|
||||||
group := c.GetString("group")
|
|
||||||
|
|
||||||
// 获取 Provider
|
|
||||||
chatProvider := GetChatProvider(channelType, c)
|
|
||||||
if chatProvider == nil {
|
|
||||||
return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取请求体
|
|
||||||
var chatRequest types.ChatCompletionRequest
|
|
||||||
err := common.UnmarshalBodyReusable(c, &chatRequest)
|
|
||||||
if err != nil {
|
|
||||||
return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查模型映射
|
|
||||||
isModelMapped := false
|
|
||||||
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
|
||||||
if err != nil {
|
|
||||||
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap != nil && modelMap[chatRequest.Model] != "" {
|
|
||||||
chatRequest.Model = modelMap[chatRequest.Model]
|
|
||||||
isModelMapped = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// 开始计算Tokens
|
|
||||||
var promptTokens int
|
|
||||||
promptTokens = common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
|
|
||||||
|
|
||||||
// 计算预付费配额
|
|
||||||
quotaInfo := &QuotaInfo{
|
|
||||||
modelName: chatRequest.Model,
|
|
||||||
promptTokens: promptTokens,
|
|
||||||
userId: userId,
|
|
||||||
channelId: channelId,
|
|
||||||
tokenId: tokenId,
|
|
||||||
}
|
|
||||||
quotaInfo.initQuotaInfo(group)
|
|
||||||
quota_err := quotaInfo.preQuotaConsumption()
|
|
||||||
if quota_err != nil {
|
|
||||||
return quota_err
|
|
||||||
}
|
|
||||||
|
|
||||||
usage, openAIErrorWithStatusCode := chatProvider.ChatCompleteResponse(&chatRequest, isModelMapped, promptTokens)
|
|
||||||
|
|
||||||
if openAIErrorWithStatusCode != nil {
|
|
||||||
if quotaInfo.preConsumedQuota != 0 {
|
|
||||||
go func(ctx context.Context) {
|
|
||||||
// return pre-consumed quota
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
|
||||||
}
|
|
||||||
}(c.Request.Context())
|
|
||||||
}
|
|
||||||
return openAIErrorWithStatusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
go func() {
|
|
||||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, err.Error())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}(c.Request.Context())
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetChatProvider(channelType int, c *gin.Context) providers.ChatProviderAction {
|
|
||||||
switch channelType {
|
|
||||||
case common.ChannelTypeOpenAI:
|
|
||||||
return providers.CreateOpenAIProvider(c, "")
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
return providers.CreateAzureProvider(c)
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
return providers.CreateAliAIProvider(c)
|
|
||||||
case common.ChannelTypeTencent:
|
|
||||||
return providers.CreateTencentProvider(c)
|
|
||||||
case common.ChannelTypeBaidu:
|
|
||||||
return providers.CreateBaiduProvider(c)
|
|
||||||
case common.ChannelTypeAnthropic:
|
|
||||||
return providers.CreateClaudeProvider(c)
|
|
||||||
case common.ChannelTypePaLM:
|
|
||||||
return providers.CreatePalmProvider(c)
|
|
||||||
case common.ChannelTypeZhipu:
|
|
||||||
return providers.CreateZhipuProvider(c)
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
return providers.CreateXunfeiProvider(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
|
||||||
if c.GetString("base_url") != "" {
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
|
||||||
|
|
||||||
if baseURL != "" {
|
|
||||||
return providers.CreateOpenAIProvider(c, baseURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,113 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"one-api/providers"
|
|
||||||
"one-api/types"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func relayCompletionHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode {
|
|
||||||
|
|
||||||
// 获取请求参数
|
|
||||||
channelType := c.GetInt("channel")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
// consumeQuota := c.GetBool("consume_quota")
|
|
||||||
group := c.GetString("group")
|
|
||||||
|
|
||||||
// 获取 Provider
|
|
||||||
completionProvider := GetCompletionProvider(channelType, c)
|
|
||||||
if completionProvider == nil {
|
|
||||||
return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取请求体
|
|
||||||
var completionRequest types.CompletionRequest
|
|
||||||
err := common.UnmarshalBodyReusable(c, &completionRequest)
|
|
||||||
if err != nil {
|
|
||||||
return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查模型映射
|
|
||||||
isModelMapped := false
|
|
||||||
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
|
||||||
if err != nil {
|
|
||||||
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap != nil && modelMap[completionRequest.Model] != "" {
|
|
||||||
completionRequest.Model = modelMap[completionRequest.Model]
|
|
||||||
isModelMapped = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// 开始计算Tokens
|
|
||||||
var promptTokens int
|
|
||||||
promptTokens = common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
|
|
||||||
|
|
||||||
// 计算预付费配额
|
|
||||||
quotaInfo := &QuotaInfo{
|
|
||||||
modelName: completionRequest.Model,
|
|
||||||
promptTokens: promptTokens,
|
|
||||||
userId: userId,
|
|
||||||
channelId: channelId,
|
|
||||||
tokenId: tokenId,
|
|
||||||
}
|
|
||||||
quotaInfo.initQuotaInfo(group)
|
|
||||||
quota_err := quotaInfo.preQuotaConsumption()
|
|
||||||
if quota_err != nil {
|
|
||||||
return quota_err
|
|
||||||
}
|
|
||||||
|
|
||||||
usage, openAIErrorWithStatusCode := completionProvider.CompleteResponse(&completionRequest, isModelMapped, promptTokens)
|
|
||||||
|
|
||||||
if openAIErrorWithStatusCode != nil {
|
|
||||||
if quotaInfo.preConsumedQuota != 0 {
|
|
||||||
go func(ctx context.Context) {
|
|
||||||
// return pre-consumed quota
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
|
||||||
}
|
|
||||||
}(c.Request.Context())
|
|
||||||
}
|
|
||||||
return openAIErrorWithStatusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
go func() {
|
|
||||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, err.Error())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}(c.Request.Context())
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetCompletionProvider(channelType int, c *gin.Context) providers.CompletionProviderAction {
|
|
||||||
switch channelType {
|
|
||||||
case common.ChannelTypeOpenAI:
|
|
||||||
return providers.CreateOpenAIProvider(c, "")
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
return providers.CreateAzureProvider(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
|
||||||
if c.GetString("base_url") != "" {
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
|
||||||
|
|
||||||
if baseURL != "" {
|
|
||||||
return providers.CreateOpenAIProvider(c, baseURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,117 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"one-api/providers"
|
|
||||||
"one-api/types"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func relayEmbeddingsHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode {
|
|
||||||
|
|
||||||
// 获取请求参数
|
|
||||||
channelType := c.GetInt("channel")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
// consumeQuota := c.GetBool("consume_quota")
|
|
||||||
group := c.GetString("group")
|
|
||||||
|
|
||||||
// 获取 Provider
|
|
||||||
embeddingsProvider := GetEmbeddingsProvider(channelType, c)
|
|
||||||
if embeddingsProvider == nil {
|
|
||||||
return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取请求体
|
|
||||||
var embeddingsRequest types.EmbeddingRequest
|
|
||||||
err := common.UnmarshalBodyReusable(c, &embeddingsRequest)
|
|
||||||
if err != nil {
|
|
||||||
return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查模型映射
|
|
||||||
isModelMapped := false
|
|
||||||
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
|
||||||
if err != nil {
|
|
||||||
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
|
|
||||||
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
|
|
||||||
isModelMapped = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// 开始计算Tokens
|
|
||||||
var promptTokens int
|
|
||||||
promptTokens = common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
|
|
||||||
|
|
||||||
// 计算预付费配额
|
|
||||||
quotaInfo := &QuotaInfo{
|
|
||||||
modelName: embeddingsRequest.Model,
|
|
||||||
promptTokens: promptTokens,
|
|
||||||
userId: userId,
|
|
||||||
channelId: channelId,
|
|
||||||
tokenId: tokenId,
|
|
||||||
}
|
|
||||||
quotaInfo.initQuotaInfo(group)
|
|
||||||
quota_err := quotaInfo.preQuotaConsumption()
|
|
||||||
if quota_err != nil {
|
|
||||||
return quota_err
|
|
||||||
}
|
|
||||||
|
|
||||||
usage, openAIErrorWithStatusCode := embeddingsProvider.EmbeddingsResponse(&embeddingsRequest, isModelMapped, promptTokens)
|
|
||||||
|
|
||||||
if openAIErrorWithStatusCode != nil {
|
|
||||||
if quotaInfo.preConsumedQuota != 0 {
|
|
||||||
go func(ctx context.Context) {
|
|
||||||
// return pre-consumed quota
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
|
||||||
}
|
|
||||||
}(c.Request.Context())
|
|
||||||
}
|
|
||||||
return openAIErrorWithStatusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
go func() {
|
|
||||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, err.Error())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}(c.Request.Context())
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetEmbeddingsProvider(channelType int, c *gin.Context) providers.EmbeddingsProviderAction {
|
|
||||||
switch channelType {
|
|
||||||
case common.ChannelTypeOpenAI:
|
|
||||||
return providers.CreateOpenAIProvider(c, "")
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
return providers.CreateAzureProvider(c)
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
return providers.CreateAliAIProvider(c)
|
|
||||||
case common.ChannelTypeBaidu:
|
|
||||||
return providers.CreateBaiduProvider(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
|
||||||
if c.GetString("base_url") != "" {
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
|
||||||
|
|
||||||
if baseURL != "" {
|
|
||||||
return providers.CreateOpenAIProvider(c, baseURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
160
controller/relay-text.go
Normal file
160
controller/relay-text.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/providers"
|
||||||
|
providers_base "one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func relayTextHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode {
|
||||||
|
// 获取请求参数
|
||||||
|
channelType := c.GetInt("channel")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
group := c.GetString("group")
|
||||||
|
|
||||||
|
// 获取 Provider
|
||||||
|
provider := providers.GetProvider(channelType, c)
|
||||||
|
if provider == nil {
|
||||||
|
return types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
var promptTokens int
|
||||||
|
quotaInfo := &QuotaInfo{
|
||||||
|
modelName: "",
|
||||||
|
promptTokens: promptTokens,
|
||||||
|
userId: userId,
|
||||||
|
channelId: channelId,
|
||||||
|
tokenId: tokenId,
|
||||||
|
}
|
||||||
|
|
||||||
|
var usage *types.Usage
|
||||||
|
var openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode
|
||||||
|
|
||||||
|
switch relayMode {
|
||||||
|
case RelayModeChatCompletions:
|
||||||
|
usage, openAIErrorWithStatusCode = handleChatCompletions(c, provider, modelMap, quotaInfo, group)
|
||||||
|
case RelayModeCompletions:
|
||||||
|
usage, openAIErrorWithStatusCode = handleCompletions(c, provider, modelMap, quotaInfo, group)
|
||||||
|
case RelayModeEmbeddings:
|
||||||
|
usage, openAIErrorWithStatusCode = handleEmbeddings(c, provider, modelMap, quotaInfo, group)
|
||||||
|
default:
|
||||||
|
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
if quotaInfo.preConsumedQuota != 0 {
|
||||||
|
go func(ctx context.Context) {
|
||||||
|
// return pre-consumed quota
|
||||||
|
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||||
|
}
|
||||||
|
}(c.Request.Context())
|
||||||
|
}
|
||||||
|
return openAIErrorWithStatusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
defer func(ctx context.Context) {
|
||||||
|
go func() {
|
||||||
|
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, err.Error())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}(c.Request.Context())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleChatCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||||
|
var chatRequest types.ChatCompletionRequest
|
||||||
|
isModelMapped := false
|
||||||
|
chatProvider, ok := provider.(providers_base.ChatInterface)
|
||||||
|
if !ok {
|
||||||
|
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
err := common.UnmarshalBodyReusable(c, &chatRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if modelMap != nil && modelMap[chatRequest.Model] != "" {
|
||||||
|
chatRequest.Model = modelMap[chatRequest.Model]
|
||||||
|
isModelMapped = true
|
||||||
|
}
|
||||||
|
promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
|
||||||
|
|
||||||
|
quotaInfo.modelName = chatRequest.Model
|
||||||
|
quotaInfo.initQuotaInfo(group)
|
||||||
|
quota_err := quotaInfo.preQuotaConsumption()
|
||||||
|
if quota_err != nil {
|
||||||
|
return nil, quota_err
|
||||||
|
}
|
||||||
|
return chatProvider.ChatAction(&chatRequest, isModelMapped, promptTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||||
|
var completionRequest types.CompletionRequest
|
||||||
|
isModelMapped := false
|
||||||
|
completionProvider, ok := provider.(providers_base.CompletionInterface)
|
||||||
|
if !ok {
|
||||||
|
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
err := common.UnmarshalBodyReusable(c, &completionRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if modelMap != nil && modelMap[completionRequest.Model] != "" {
|
||||||
|
completionRequest.Model = modelMap[completionRequest.Model]
|
||||||
|
isModelMapped = true
|
||||||
|
}
|
||||||
|
promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
|
||||||
|
|
||||||
|
quotaInfo.modelName = completionRequest.Model
|
||||||
|
quotaInfo.initQuotaInfo(group)
|
||||||
|
quota_err := quotaInfo.preQuotaConsumption()
|
||||||
|
if quota_err != nil {
|
||||||
|
return nil, quota_err
|
||||||
|
}
|
||||||
|
return completionProvider.CompleteAction(&completionRequest, isModelMapped, promptTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleEmbeddings(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||||
|
var embeddingsRequest types.EmbeddingRequest
|
||||||
|
isModelMapped := false
|
||||||
|
embeddingsProvider, ok := provider.(providers_base.EmbeddingsInterface)
|
||||||
|
if !ok {
|
||||||
|
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
err := common.UnmarshalBodyReusable(c, &embeddingsRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
|
||||||
|
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
|
||||||
|
isModelMapped = true
|
||||||
|
}
|
||||||
|
promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
|
||||||
|
|
||||||
|
quotaInfo.modelName = embeddingsRequest.Model
|
||||||
|
quotaInfo.initQuotaInfo(group)
|
||||||
|
quota_err := quotaInfo.preQuotaConsumption()
|
||||||
|
if quota_err != nil {
|
||||||
|
return nil, quota_err
|
||||||
|
}
|
||||||
|
return embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens)
|
||||||
|
}
|
@ -237,19 +237,19 @@ type CompletionsStreamResponse struct {
|
|||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
var err *types.OpenAIErrorWithStatusCode
|
var err *types.OpenAIErrorWithStatusCode
|
||||||
|
|
||||||
// relayMode := RelayModeUnknown
|
relayMode := RelayModeUnknown
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
||||||
err = relayChatHelper(c)
|
// err = relayChatHelper(c)
|
||||||
// relayMode = RelayModeChatCompletions
|
relayMode = RelayModeChatCompletions
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
||||||
err = relayCompletionHelper(c)
|
// err = relayCompletionHelper(c)
|
||||||
// relayMode = RelayModeCompletions
|
relayMode = RelayModeCompletions
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
||||||
err = relayEmbeddingsHelper(c)
|
// err = relayEmbeddingsHelper(c)
|
||||||
|
relayMode = RelayModeEmbeddings
|
||||||
|
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||||
|
relayMode = RelayModeEmbeddings
|
||||||
}
|
}
|
||||||
// relayMode = RelayModeEmbeddings
|
|
||||||
// } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
|
||||||
// relayMode = RelayModeEmbeddings
|
|
||||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
// relayMode = RelayModeModerations
|
// relayMode = RelayModeModerations
|
||||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
@ -263,7 +263,7 @@ func Relay(c *gin.Context) {
|
|||||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||||
// relayMode = RelayModeAudioTranslation
|
// relayMode = RelayModeAudioTranslation
|
||||||
// }
|
// }
|
||||||
// switch relayMode {
|
switch relayMode {
|
||||||
// case RelayModeImagesGenerations:
|
// case RelayModeImagesGenerations:
|
||||||
// err = relayImageHelper(c, relayMode)
|
// err = relayImageHelper(c, relayMode)
|
||||||
// case RelayModeAudioSpeech:
|
// case RelayModeAudioSpeech:
|
||||||
@ -272,9 +272,9 @@ func Relay(c *gin.Context) {
|
|||||||
// fallthrough
|
// fallthrough
|
||||||
// case RelayModeAudioTranscription:
|
// case RelayModeAudioTranscription:
|
||||||
// err = relayAudioHelper(c, relayMode)
|
// err = relayAudioHelper(c, relayMode)
|
||||||
// default:
|
default:
|
||||||
// err = relayTextHelper(c, relayMode)
|
err = relayTextHelper(c, relayMode)
|
||||||
// }
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
retryTimesStr := c.Query("retry")
|
retryTimesStr := c.Query("retry")
|
||||||
|
35
providers/ali/base.go
Normal file
35
providers/ali/base.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package ali
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"one-api/providers/base"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AliProvider struct {
|
||||||
|
base.BaseProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
|
||||||
|
// 创建 AliAIProvider
|
||||||
|
func CreateAliAIProvider(c *gin.Context) *AliProvider {
|
||||||
|
return &AliProvider{
|
||||||
|
BaseProvider: base.BaseProvider{
|
||||||
|
BaseURL: "https://dashscope.aliyuncs.com",
|
||||||
|
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
|
||||||
|
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||||
|
Context: c,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求头
|
||||||
|
func (p *AliProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
|
headers = make(map[string]string)
|
||||||
|
p.CommonRequestHeaders(headers)
|
||||||
|
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package ali
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -10,43 +10,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AliMessage struct {
|
// 阿里云响应处理
|
||||||
User string `json:"user"`
|
func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
Bot string `json:"bot"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliInput struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
History []AliMessage `json:"history"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliParameters struct {
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
TopK int `json:"top_k,omitempty"`
|
|
||||||
Seed uint64 `json:"seed,omitempty"`
|
|
||||||
EnableSearch bool `json:"enable_search,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliChatRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input AliInput `json:"input"`
|
|
||||||
Parameters AliParameters `json:"parameters,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliOutput struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
FinishReason string `json:"finish_reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliChatResponse struct {
|
|
||||||
Output AliOutput `json:"output"`
|
|
||||||
Usage AliUsage `json:"usage"`
|
|
||||||
AliError
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
|
||||||
if aliResponse.Code != "" {
|
if aliResponse.Code != "" {
|
||||||
return nil, &types.OpenAIErrorWithStatusCode{
|
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: types.OpenAIError{
|
OpenAIError: types.OpenAIError{
|
||||||
Message: aliResponse.Message,
|
Message: aliResponse.Message,
|
||||||
Type: aliResponse.Code,
|
Type: aliResponse.Code,
|
||||||
@ -55,6 +22,8 @@ func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIR
|
|||||||
},
|
},
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
choice := types.ChatCompletionChoice{
|
choice := types.ChatCompletionChoice{
|
||||||
@ -66,7 +35,7 @@ func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIR
|
|||||||
FinishReason: aliResponse.Output.FinishReason,
|
FinishReason: aliResponse.Output.FinishReason,
|
||||||
}
|
}
|
||||||
|
|
||||||
fullTextResponse := types.ChatCompletionResponse{
|
OpenAIResponse = types.ChatCompletionResponse{
|
||||||
ID: aliResponse.RequestId,
|
ID: aliResponse.RequestId,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
@ -78,10 +47,11 @@ func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIR
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return fullTextResponse, nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AliAIProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
|
// 获取聊天请求体
|
||||||
|
func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
|
||||||
messages := make([]AliMessage, 0, len(request.Messages))
|
messages := make([]AliMessage, 0, len(request.Messages))
|
||||||
prompt := ""
|
prompt := ""
|
||||||
for i := 0; i < len(request.Messages); i++ {
|
for i := 0; i < len(request.Messages); i++ {
|
||||||
@ -113,7 +83,8 @@ func (p *AliAIProvider) getChatRequestBody(request *types.ChatCompletionRequest)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
// 聊天
|
||||||
|
func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
requestBody := p.getChatRequestBody(request)
|
requestBody := p.getChatRequestBody(request)
|
||||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
@ -130,8 +101,8 @@ func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionReques
|
|||||||
}
|
}
|
||||||
|
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
|
usage, errWithCode = p.sendStreamRequest(req)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,8 +116,8 @@ func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionReques
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
aliResponse := &AliChatResponse{}
|
aliResponse := &AliChatResponse{}
|
||||||
openAIErrorWithStatusCode = p.sendRequest(req, aliResponse)
|
errWithCode = p.SendRequest(req, aliResponse)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -159,7 +130,8 @@ func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionReques
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AliAIProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
|
// 阿里云响应转OpenAI响应
|
||||||
|
func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
|
||||||
var choice types.ChatCompletionStreamChoice
|
var choice types.ChatCompletionStreamChoice
|
||||||
choice.Delta.Content = aliResponse.Output.Text
|
choice.Delta.Content = aliResponse.Output.Text
|
||||||
if aliResponse.Output.FinishReason != "null" {
|
if aliResponse.Output.FinishReason != "null" {
|
||||||
@ -177,16 +149,17 @@ func (p *AliAIProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) {
|
// 发送流请求
|
||||||
|
func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
usage = &types.Usage{}
|
usage = &types.Usage{}
|
||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := common.HttpClient.Do(req)
|
resp, err := common.HttpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
|
return nil, types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return p.handleErrorResp(resp), nil
|
return nil, p.HandleErrorResp(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@ -220,7 +193,7 @@ func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(p.Context)
|
common.SetEventStreamHeaders(p.Context)
|
||||||
lastResponseText := ""
|
lastResponseText := ""
|
||||||
p.Context.Stream(func(w io.Writer) bool {
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
@ -252,5 +225,5 @@ func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil, usage
|
return
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package ali
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -6,30 +6,8 @@ import (
|
|||||||
"one-api/types"
|
"one-api/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AliEmbeddingRequest struct {
|
// 嵌入请求处理
|
||||||
Model string `json:"model"`
|
func (aliResponse *AliEmbeddingResponse) ResponseHandler(resp *http.Response) (any, *types.OpenAIErrorWithStatusCode) {
|
||||||
Input struct {
|
|
||||||
Texts []string `json:"texts"`
|
|
||||||
} `json:"input"`
|
|
||||||
Parameters *struct {
|
|
||||||
TextType string `json:"text_type,omitempty"`
|
|
||||||
} `json:"parameters,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliEmbedding struct {
|
|
||||||
Embedding []float64 `json:"embedding"`
|
|
||||||
TextIndex int `json:"text_index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliEmbeddingResponse struct {
|
|
||||||
Output struct {
|
|
||||||
Embeddings []AliEmbedding `json:"embeddings"`
|
|
||||||
} `json:"output"`
|
|
||||||
Usage AliUsage `json:"usage"`
|
|
||||||
AliError
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aliResponse *AliEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
|
||||||
if aliResponse.Code != "" {
|
if aliResponse.Code != "" {
|
||||||
return nil, &types.OpenAIErrorWithStatusCode{
|
return nil, &types.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: types.OpenAIError{
|
OpenAIError: types.OpenAIError{
|
||||||
@ -60,7 +38,8 @@ func (aliResponse *AliEmbeddingResponse) requestHandler(resp *http.Response) (Op
|
|||||||
return openAIEmbeddingResponse, nil
|
return openAIEmbeddingResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AliAIProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest {
|
// 获取嵌入请求体
|
||||||
|
func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest {
|
||||||
return &AliEmbeddingRequest{
|
return &AliEmbeddingRequest{
|
||||||
Model: "text-embedding-v1",
|
Model: "text-embedding-v1",
|
||||||
Input: struct {
|
Input: struct {
|
||||||
@ -71,7 +50,7 @@ func (p *AliAIProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AliAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (p *AliProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
requestBody := p.getEmbeddingsRequestBody(request)
|
requestBody := p.getEmbeddingsRequestBody(request)
|
||||||
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||||
@ -84,8 +63,8 @@ func (p *AliAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isMo
|
|||||||
}
|
}
|
||||||
|
|
||||||
aliEmbeddingResponse := &AliEmbeddingResponse{}
|
aliEmbeddingResponse := &AliEmbeddingResponse{}
|
||||||
openAIErrorWithStatusCode = p.sendRequest(req, aliEmbeddingResponse)
|
errWithCode = p.SendRequest(req, aliEmbeddingResponse)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens}
|
usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens}
|
70
providers/ali/type.go
Normal file
70
providers/ali/type.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package ali
|
||||||
|
|
||||||
|
type AliError struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliUsage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliMessage struct {
|
||||||
|
User string `json:"user"`
|
||||||
|
Bot string `json:"bot"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliInput struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
History []AliMessage `json:"history"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliParameters struct {
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
Seed uint64 `json:"seed,omitempty"`
|
||||||
|
EnableSearch bool `json:"enable_search,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliChatRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input AliInput `json:"input"`
|
||||||
|
Parameters AliParameters `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliOutput struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliChatResponse struct {
|
||||||
|
Output AliOutput `json:"output"`
|
||||||
|
Usage AliUsage `json:"usage"`
|
||||||
|
AliError
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliEmbeddingRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input struct {
|
||||||
|
Texts []string `json:"texts"`
|
||||||
|
} `json:"input"`
|
||||||
|
Parameters *struct {
|
||||||
|
TextType string `json:"text_type,omitempty"`
|
||||||
|
} `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliEmbedding struct {
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
TextIndex int `json:"text_index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliEmbeddingResponse struct {
|
||||||
|
Output struct {
|
||||||
|
Embeddings []AliEmbedding `json:"embeddings"`
|
||||||
|
} `json:"output"`
|
||||||
|
Usage AliUsage `json:"usage"`
|
||||||
|
AliError
|
||||||
|
}
|
@ -1,50 +0,0 @@
|
|||||||
package providers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AliAIProvider struct {
|
|
||||||
ProviderConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliError struct {
|
|
||||||
Code string `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
RequestId string `json:"request_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliUsage struct {
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建 AliAIProvider
|
|
||||||
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
|
|
||||||
func CreateAliAIProvider(c *gin.Context) *AliAIProvider {
|
|
||||||
return &AliAIProvider{
|
|
||||||
ProviderConfig: ProviderConfig{
|
|
||||||
BaseURL: "https://dashscope.aliyuncs.com",
|
|
||||||
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
|
|
||||||
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
|
|
||||||
Context: c,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取请求头
|
|
||||||
func (p *AliAIProvider) GetRequestHeaders() (headers map[string]string) {
|
|
||||||
headers = make(map[string]string)
|
|
||||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
|
||||||
|
|
||||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
|
||||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
|
||||||
if headers["Content-Type"] == "" {
|
|
||||||
headers["Content-Type"] = "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
return headers
|
|
||||||
}
|
|
18
providers/api2d/base.go
Normal file
18
providers/api2d/base.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package api2d
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/providers/openai"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Api2dProvider struct {
|
||||||
|
*openai.OpenAIProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 Api2dProvider
|
||||||
|
func CreateApi2dProvider(c *gin.Context) *Api2dProvider {
|
||||||
|
return &Api2dProvider{
|
||||||
|
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://oa.api2d.net"),
|
||||||
|
}
|
||||||
|
}
|
@ -1,14 +0,0 @@
|
|||||||
package providers
|
|
||||||
|
|
||||||
import "github.com/gin-gonic/gin"
|
|
||||||
|
|
||||||
type Api2dProvider struct {
|
|
||||||
*OpenAIProvider
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建 OpenAIProvider
|
|
||||||
func CreateApi2dProvider(c *gin.Context) *Api2dProvider {
|
|
||||||
return &Api2dProvider{
|
|
||||||
OpenAIProvider: CreateOpenAIProvider(c, "https://oa.api2d.net"),
|
|
||||||
}
|
|
||||||
}
|
|
31
providers/azure/base.go
Normal file
31
providers/azure/base.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package azure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/providers/base"
|
||||||
|
"one-api/providers/openai"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AzureProvider struct {
|
||||||
|
openai.OpenAIProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 OpenAIProvider
|
||||||
|
func CreateAzureProvider(c *gin.Context) *AzureProvider {
|
||||||
|
return &AzureProvider{
|
||||||
|
OpenAIProvider: openai.OpenAIProvider{
|
||||||
|
BaseProvider: base.BaseProvider{
|
||||||
|
BaseURL: "",
|
||||||
|
Completions: "/completions",
|
||||||
|
ChatCompletions: "/chat/completions",
|
||||||
|
Embeddings: "/embeddings",
|
||||||
|
AudioSpeech: "/audio/speech",
|
||||||
|
AudioTranscriptions: "/audio/transcriptions",
|
||||||
|
AudioTranslations: "/audio/translations",
|
||||||
|
Context: c,
|
||||||
|
},
|
||||||
|
IsAzure: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
@ -1,41 +0,0 @@
|
|||||||
package providers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AzureProvider struct {
|
|
||||||
OpenAIProvider
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建 OpenAIProvider
|
|
||||||
func CreateAzureProvider(c *gin.Context) *AzureProvider {
|
|
||||||
return &AzureProvider{
|
|
||||||
OpenAIProvider: OpenAIProvider{
|
|
||||||
ProviderConfig: ProviderConfig{
|
|
||||||
BaseURL: "",
|
|
||||||
Completions: "/completions",
|
|
||||||
ChatCompletions: "/chat/completions",
|
|
||||||
Embeddings: "/embeddings",
|
|
||||||
AudioSpeech: "/audio/speech",
|
|
||||||
AudioTranscriptions: "/audio/transcriptions",
|
|
||||||
AudioTranslations: "/audio/translations",
|
|
||||||
Context: c,
|
|
||||||
},
|
|
||||||
isAzure: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// // 获取完整请求 URL
|
|
||||||
// func (p *AzureProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
|
||||||
// apiVersion := p.Context.GetString("api_version")
|
|
||||||
// requestURL = fmt.Sprintf("/openai/deployments/%s/%s?api-version=%s", modelName, requestURL, apiVersion)
|
|
||||||
// baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
|
||||||
|
|
||||||
// if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
|
||||||
// requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
||||||
// }
|
|
@ -1,10 +1,11 @@
|
|||||||
package providers
|
package baidu
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/providers/base"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -15,20 +16,12 @@ import (
|
|||||||
var baiduTokenStore sync.Map
|
var baiduTokenStore sync.Map
|
||||||
|
|
||||||
type BaiduProvider struct {
|
type BaiduProvider struct {
|
||||||
ProviderConfig
|
base.BaseProvider
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduAccessToken struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
ErrorDescription string `json:"error_description,omitempty"`
|
|
||||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
|
||||||
ExpiresAt time.Time `json:"-"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateBaiduProvider(c *gin.Context) *BaiduProvider {
|
func CreateBaiduProvider(c *gin.Context) *BaiduProvider {
|
||||||
return &BaiduProvider{
|
return &BaiduProvider{
|
||||||
ProviderConfig: ProviderConfig{
|
BaseProvider: base.BaseProvider{
|
||||||
BaseURL: "https://aip.baidubce.com",
|
BaseURL: "https://aip.baidubce.com",
|
||||||
ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
|
ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
|
||||||
Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings",
|
Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings",
|
||||||
@ -59,12 +52,7 @@ func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) s
|
|||||||
// 获取请求头
|
// 获取请求头
|
||||||
func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
|
p.CommonRequestHeaders(headers)
|
||||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
|
||||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
|
||||||
if headers["Content-Type"] == "" {
|
|
||||||
headers["Content-Type"] = "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package baidu
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -6,33 +6,12 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/providers/base"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type BaiduMessage struct {
|
func (baiduResponse *BaiduChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduChatRequest struct {
|
|
||||||
Messages []BaiduMessage `json:"messages"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
UserId string `json:"user_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduChatResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Result string `json:"result"`
|
|
||||||
IsTruncated bool `json:"is_truncated"`
|
|
||||||
NeedClearHistory bool `json:"need_clear_history"`
|
|
||||||
Usage *types.Usage `json:"usage"`
|
|
||||||
BaiduError
|
|
||||||
}
|
|
||||||
|
|
||||||
func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
|
||||||
if baiduResponse.ErrorMsg != "" {
|
if baiduResponse.ErrorMsg != "" {
|
||||||
return nil, &types.OpenAIErrorWithStatusCode{
|
return nil, &types.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: types.OpenAIError{
|
OpenAIError: types.OpenAIError{
|
||||||
@ -54,7 +33,7 @@ func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (Ope
|
|||||||
FinishReason: "stop",
|
FinishReason: "stop",
|
||||||
}
|
}
|
||||||
|
|
||||||
fullTextResponse := types.ChatCompletionResponse{
|
OpenAIResponse = types.ChatCompletionResponse{
|
||||||
ID: baiduResponse.Id,
|
ID: baiduResponse.Id,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: baiduResponse.Created,
|
Created: baiduResponse.Created,
|
||||||
@ -62,18 +41,7 @@ func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (Ope
|
|||||||
Usage: baiduResponse.Usage,
|
Usage: baiduResponse.Usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
return fullTextResponse, nil
|
return
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduChatStreamResponse struct {
|
|
||||||
BaiduChatResponse
|
|
||||||
SentenceId int `json:"sentence_id"`
|
|
||||||
IsEnd bool `json:"is_end"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduError struct {
|
|
||||||
ErrorCode int `json:"error_code"`
|
|
||||||
ErrorMsg string `json:"error_msg"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaiduChatRequest {
|
func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaiduChatRequest {
|
||||||
@ -101,7 +69,7 @@ func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BaiduProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
requestBody := p.getChatRequestBody(request)
|
requestBody := p.getChatRequestBody(request)
|
||||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
if fullRequestURL == "" {
|
if fullRequestURL == "" {
|
||||||
@ -120,15 +88,15 @@ func (p *BaiduProvider) ChatCompleteResponse(request *types.ChatCompletionReques
|
|||||||
}
|
}
|
||||||
|
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
|
usage, errWithCode = p.sendStreamRequest(req)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
baiduChatRequest := &BaiduChatResponse{}
|
baiduChatRequest := &BaiduChatResponse{}
|
||||||
openAIErrorWithStatusCode = p.sendRequest(req, baiduChatRequest)
|
errWithCode = p.SendRequest(req, baiduChatRequest)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,7 +110,7 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
|
|||||||
var choice types.ChatCompletionStreamChoice
|
var choice types.ChatCompletionStreamChoice
|
||||||
choice.Delta.Content = baiduResponse.Result
|
choice.Delta.Content = baiduResponse.Result
|
||||||
if baiduResponse.IsEnd {
|
if baiduResponse.IsEnd {
|
||||||
choice.FinishReason = &stopFinishReason
|
choice.FinishReason = &base.StopFinishReason
|
||||||
}
|
}
|
||||||
|
|
||||||
response := types.ChatCompletionStreamResponse{
|
response := types.ChatCompletionStreamResponse{
|
||||||
@ -155,16 +123,16 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) {
|
func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
usage = &types.Usage{}
|
usage = &types.Usage{}
|
||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := common.HttpClient.Do(req)
|
resp, err := common.HttpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
|
return nil, types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return p.handleErrorResp(resp), nil
|
return nil, p.HandleErrorResp(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@ -195,7 +163,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(p.Context)
|
common.SetEventStreamHeaders(p.Context)
|
||||||
p.Context.Stream(func(w io.Writer) bool {
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
@ -224,5 +192,5 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithSta
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil, usage
|
return usage, nil
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package baidu
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -6,32 +6,13 @@ import (
|
|||||||
"one-api/types"
|
"one-api/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type BaiduEmbeddingRequest struct {
|
|
||||||
Input []string `json:"input"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduEmbeddingData struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Embedding []float64 `json:"embedding"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduEmbeddingResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Data []BaiduEmbeddingData `json:"data"`
|
|
||||||
Usage types.Usage `json:"usage"`
|
|
||||||
BaiduError
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *BaiduProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *BaiduEmbeddingRequest {
|
func (p *BaiduProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *BaiduEmbeddingRequest {
|
||||||
return &BaiduEmbeddingRequest{
|
return &BaiduEmbeddingRequest{
|
||||||
Input: request.ParseInput(),
|
Input: request.ParseInput(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (baiduResponse *BaiduEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (baiduResponse *BaiduEmbeddingResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
if baiduResponse.ErrorMsg != "" {
|
if baiduResponse.ErrorMsg != "" {
|
||||||
return nil, &types.OpenAIErrorWithStatusCode{
|
return nil, &types.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: types.OpenAIError{
|
OpenAIError: types.OpenAIError{
|
||||||
@ -62,7 +43,7 @@ func (baiduResponse *BaiduEmbeddingResponse) requestHandler(resp *http.Response)
|
|||||||
return openAIEmbeddingResponse, nil
|
return openAIEmbeddingResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BaiduProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (p *BaiduProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
requestBody := p.getEmbeddingsRequestBody(request)
|
requestBody := p.getEmbeddingsRequestBody(request)
|
||||||
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||||
@ -78,8 +59,8 @@ func (p *BaiduProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isMo
|
|||||||
}
|
}
|
||||||
|
|
||||||
baiduEmbeddingResponse := &BaiduEmbeddingResponse{}
|
baiduEmbeddingResponse := &BaiduEmbeddingResponse{}
|
||||||
openAIErrorWithStatusCode = p.sendRequest(req, baiduEmbeddingResponse)
|
errWithCode = p.SendRequest(req, baiduEmbeddingResponse)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
usage = &baiduEmbeddingResponse.Usage
|
usage = &baiduEmbeddingResponse.Usage
|
66
providers/baidu/type.go
Normal file
66
providers/baidu/type.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
package baidu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/types"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BaiduAccessToken struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
ErrorDescription string `json:"error_description,omitempty"`
|
||||||
|
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||||
|
ExpiresAt time.Time `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatRequest struct {
|
||||||
|
Messages []BaiduMessage `json:"messages"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
UserId string `json:"user_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Result string `json:"result"`
|
||||||
|
IsTruncated bool `json:"is_truncated"`
|
||||||
|
NeedClearHistory bool `json:"need_clear_history"`
|
||||||
|
Usage *types.Usage `json:"usage"`
|
||||||
|
BaiduError
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduEmbeddingRequest struct {
|
||||||
|
Input []string `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduEmbeddingData struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduEmbeddingResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Data []BaiduEmbeddingData `json:"data"`
|
||||||
|
Usage types.Usage `json:"usage"`
|
||||||
|
BaiduError
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatStreamResponse struct {
|
||||||
|
BaiduChatResponse
|
||||||
|
SentenceId int `json:"sentence_id"`
|
||||||
|
IsEnd bool `json:"is_end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduError struct {
|
||||||
|
ErrorCode int `json:"error_code"`
|
||||||
|
ErrorMsg string `json:"error_msg"`
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package base
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -6,7 +6,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -14,9 +13,9 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
var stopFinishReason = "stop"
|
var StopFinishReason = "stop"
|
||||||
|
|
||||||
type ProviderConfig struct {
|
type BaseProvider struct {
|
||||||
BaseURL string
|
BaseURL string
|
||||||
Completions string
|
Completions string
|
||||||
ChatCompletions string
|
ChatCompletions string
|
||||||
@ -28,32 +27,8 @@ type ProviderConfig struct {
|
|||||||
Context *gin.Context
|
Context *gin.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
type BaseProviderAction interface {
|
// 获取基础URL
|
||||||
GetBaseURL() string
|
func (p *BaseProvider) GetBaseURL() string {
|
||||||
GetFullRequestURL(requestURL string, modelName string) string
|
|
||||||
GetRequestHeaders() (headers map[string]string)
|
|
||||||
}
|
|
||||||
|
|
||||||
type CompletionProviderAction interface {
|
|
||||||
BaseProviderAction
|
|
||||||
CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatProviderAction interface {
|
|
||||||
BaseProviderAction
|
|
||||||
ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
type EmbeddingsProviderAction interface {
|
|
||||||
BaseProviderAction
|
|
||||||
EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
type BalanceProviderAction interface {
|
|
||||||
Balance(channel *model.Channel) (float64, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ProviderConfig) GetBaseURL() string {
|
|
||||||
if p.Context.GetString("base_url") != "" {
|
if p.Context.GetString("base_url") != "" {
|
||||||
return p.Context.GetString("base_url")
|
return p.Context.GetString("base_url")
|
||||||
}
|
}
|
||||||
@ -61,21 +36,66 @@ func (p *ProviderConfig) GetBaseURL() string {
|
|||||||
return p.BaseURL
|
return p.BaseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProviderConfig) GetFullRequestURL(requestURL string, modelName string) string {
|
// 获取完整请求URL
|
||||||
|
func (p *BaseProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
|
||||||
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func setEventStreamHeaders(c *gin.Context) {
|
// 获取请求头
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) {
|
||||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||||
c.Writer.Header().Set("Connection", "keep-alive")
|
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
if headers["Content-Type"] == "" {
|
||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
headers["Content-Type"] = "application/json"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProviderConfig) handleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
// 发送请求
|
||||||
|
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
resp, err := common.HttpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 处理响应
|
||||||
|
if common.IsFailureStatusCode(resp) {
|
||||||
|
return p.HandleErrorResp(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析响应
|
||||||
|
err = common.DecodeResponse(resp.Body, response)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIResponse, openAIErrorWithStatusCode := response.ResponseHandler(resp)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResponse, err := json.Marshal(openAIResponse)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
p.Context.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = p.Context.Writer.Write(jsonResponse)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理错误响应
|
||||||
|
func (p *BaseProvider) HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
OpenAIError: types.OpenAIError{
|
OpenAIError: types.OpenAIError{
|
||||||
@ -105,46 +125,3 @@ func (p *ProviderConfig) handleErrorResp(resp *http.Response) (openAIErrorWithSt
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 供应商响应处理函数
|
|
||||||
type ProviderResponseHandler interface {
|
|
||||||
// 请求处理函数
|
|
||||||
requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送请求
|
|
||||||
func (p *ProviderConfig) sendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
|
||||||
|
|
||||||
// 发送请求
|
|
||||||
resp, err := common.HttpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// 处理响应
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
|
||||||
return p.handleErrorResp(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析响应
|
|
||||||
err = common.DecodeResponse(resp.Body, response)
|
|
||||||
if err != nil {
|
|
||||||
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
openAIResponse, openAIErrorWithStatusCode := response.requestHandler(resp)
|
|
||||||
if openAIErrorWithStatusCode != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResponse, err := json.Marshal(openAIResponse)
|
|
||||||
if err != nil {
|
|
||||||
return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
p.Context.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
p.Context.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = p.Context.Writer.Write(jsonResponse)
|
|
||||||
return nil
|
|
||||||
}
|
|
42
providers/base/interface.go
Normal file
42
providers/base/interface.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package base
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 基础接口
|
||||||
|
type ProviderInterface interface {
|
||||||
|
GetBaseURL() string
|
||||||
|
GetFullRequestURL(requestURL string, modelName string) string
|
||||||
|
GetRequestHeaders() (headers map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 完成接口
|
||||||
|
type CompletionInterface interface {
|
||||||
|
ProviderInterface
|
||||||
|
CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 聊天接口
|
||||||
|
type ChatInterface interface {
|
||||||
|
ProviderInterface
|
||||||
|
ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 嵌入接口
|
||||||
|
type EmbeddingsInterface interface {
|
||||||
|
ProviderInterface
|
||||||
|
EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 余额接口
|
||||||
|
type BalanceInterface interface {
|
||||||
|
BalanceAction(channel *model.Channel) (float64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProviderResponseHandler interface {
|
||||||
|
// 响应处理函数
|
||||||
|
ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||||
|
}
|
@ -1,21 +1,18 @@
|
|||||||
package providers
|
package claude
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"one-api/providers/base"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ClaudeProvider struct {
|
type ClaudeProvider struct {
|
||||||
ProviderConfig
|
base.BaseProvider
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeError struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateClaudeProvider(c *gin.Context) *ClaudeProvider {
|
func CreateClaudeProvider(c *gin.Context) *ClaudeProvider {
|
||||||
return &ClaudeProvider{
|
return &ClaudeProvider{
|
||||||
ProviderConfig: ProviderConfig{
|
BaseProvider: base.BaseProvider{
|
||||||
BaseURL: "https://api.anthropic.com",
|
BaseURL: "https://api.anthropic.com",
|
||||||
ChatCompletions: "/v1/complete",
|
ChatCompletions: "/v1/complete",
|
||||||
Context: c,
|
Context: c,
|
||||||
@ -26,14 +23,9 @@ func CreateClaudeProvider(c *gin.Context) *ClaudeProvider {
|
|||||||
// 获取请求头
|
// 获取请求头
|
||||||
func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
|
p.CommonRequestHeaders(headers)
|
||||||
|
|
||||||
headers["x-api-key"] = p.Context.GetString("api_key")
|
headers["x-api-key"] = p.Context.GetString("api_key")
|
||||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
|
||||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
|
||||||
if headers["Content-Type"] == "" {
|
|
||||||
headers["Content-Type"] = "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
anthropicVersion := p.Context.Request.Header.Get("anthropic-version")
|
anthropicVersion := p.Context.Request.Header.Get("anthropic-version")
|
||||||
if anthropicVersion == "" {
|
if anthropicVersion == "" {
|
||||||
anthropicVersion = "2023-06-01"
|
anthropicVersion = "2023-06-01"
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package claude
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -11,31 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ClaudeMetadata struct {
|
func (claudeResponse *ClaudeResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
UserId string `json:"user_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
TopK int `json:"top_k,omitempty"`
|
|
||||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeResponse struct {
|
|
||||||
Completion string `json:"completion"`
|
|
||||||
StopReason string `json:"stop_reason"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Error ClaudeError `json:"error"`
|
|
||||||
Usage *types.Usage `json:"usage,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (claudeResponse *ClaudeResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
|
||||||
if claudeResponse.Error.Type != "" {
|
if claudeResponse.Error.Type != "" {
|
||||||
return nil, &types.OpenAIErrorWithStatusCode{
|
return nil, &types.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: types.OpenAIError{
|
OpenAIError: types.OpenAIError{
|
||||||
@ -101,7 +77,7 @@ func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest
|
|||||||
return &claudeRequest
|
return &claudeRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (p *ClaudeProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
requestBody := p.getChatRequestBody(request)
|
requestBody := p.getChatRequestBody(request)
|
||||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
headers := p.GetRequestHeaders()
|
headers := p.GetRequestHeaders()
|
||||||
@ -117,8 +93,8 @@ func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionReque
|
|||||||
|
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
var responseText string
|
var responseText string
|
||||||
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
|
errWithCode, responseText = p.sendStreamRequest(req)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,8 +108,8 @@ func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionReque
|
|||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
openAIErrorWithStatusCode = p.sendRequest(req, claudeResponse)
|
errWithCode = p.SendRequest(req, claudeResponse)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -165,7 +141,7 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
|
|||||||
}
|
}
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return p.handleErrorResp(resp), ""
|
return p.HandleErrorResp(resp), ""
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@ -199,7 +175,7 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(p.Context)
|
common.SetEventStreamHeaders(p.Context)
|
||||||
p.Context.Stream(func(w io.Writer) bool {
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
32
providers/claude/type.go
Normal file
32
providers/claude/type.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import "one-api/types"
|
||||||
|
|
||||||
|
type ClaudeError struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeMetadata struct {
|
||||||
|
UserId string `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
||||||
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeResponse struct {
|
||||||
|
Completion string `json:"completion"`
|
||||||
|
StopReason string `json:"stop_reason"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Error ClaudeError `json:"error"`
|
||||||
|
Usage *types.Usage `json:"usage,omitempty"`
|
||||||
|
}
|
@ -1,31 +1,11 @@
|
|||||||
package providers
|
package closeai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type CloseaiProxyProvider struct {
|
|
||||||
*OpenAIProvider
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAICreditGrants struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
TotalGranted float64 `json:"total_granted"`
|
|
||||||
TotalUsed float64 `json:"total_used"`
|
|
||||||
TotalAvailable float64 `json:"total_available"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建 CloseaiProxyProvider
|
|
||||||
func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider {
|
|
||||||
return &CloseaiProxyProvider{
|
|
||||||
OpenAIProvider: CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) {
|
func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) {
|
||||||
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
|
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
|
||||||
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
|
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
|
18
providers/closeai/base.go
Normal file
18
providers/closeai/base.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package closeai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/providers/openai"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CloseaiProxyProvider struct {
|
||||||
|
*openai.OpenAIProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 CloseaiProxyProvider
|
||||||
|
func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider {
|
||||||
|
return &CloseaiProxyProvider{
|
||||||
|
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"),
|
||||||
|
}
|
||||||
|
}
|
8
providers/closeai/type.go
Normal file
8
providers/closeai/type.go
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
package closeai
|
||||||
|
|
||||||
|
type OpenAICreditGrants struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
TotalGranted float64 `json:"total_granted"`
|
||||||
|
TotalUsed float64 `json:"total_used"`
|
||||||
|
TotalAvailable float64 `json:"total_available"`
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -11,32 +11,25 @@ import (
|
|||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"one-api/providers/base"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIProvider struct {
|
type OpenAIProvider struct {
|
||||||
ProviderConfig
|
base.BaseProvider
|
||||||
isAzure bool
|
IsAzure bool
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIProviderResponseHandler interface {
|
|
||||||
// 请求处理函数
|
|
||||||
requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIProviderStreamResponseHandler interface {
|
|
||||||
// 请求流处理函数
|
|
||||||
requestStreamHandler() (responseText string)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建 OpenAIProvider
|
// 创建 OpenAIProvider
|
||||||
|
// https://platform.openai.com/docs/api-reference/introduction
|
||||||
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://api.openai.com"
|
baseURL = "https://api.openai.com"
|
||||||
}
|
}
|
||||||
|
|
||||||
return &OpenAIProvider{
|
return &OpenAIProvider{
|
||||||
ProviderConfig: ProviderConfig{
|
BaseProvider: base.BaseProvider{
|
||||||
BaseURL: baseURL,
|
BaseURL: baseURL,
|
||||||
Completions: "/v1/completions",
|
Completions: "/v1/completions",
|
||||||
ChatCompletions: "/v1/chat/completions",
|
ChatCompletions: "/v1/chat/completions",
|
||||||
@ -46,7 +39,7 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
|||||||
AudioTranslations: "/v1/audio/translations",
|
AudioTranslations: "/v1/audio/translations",
|
||||||
Context: c,
|
Context: c,
|
||||||
},
|
},
|
||||||
isAzure: false,
|
IsAzure: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,13 +47,13 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
|||||||
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
|
||||||
if p.isAzure {
|
if p.IsAzure {
|
||||||
apiVersion := p.Context.GetString("api_version")
|
apiVersion := p.Context.GetString("api_version")
|
||||||
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
|
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||||
if p.isAzure {
|
if p.IsAzure {
|
||||||
requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
|
requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
|
||||||
} else {
|
} else {
|
||||||
requestURL = strings.TrimPrefix(requestURL, "/v1")
|
requestURL = strings.TrimPrefix(requestURL, "/v1")
|
||||||
@ -73,16 +66,12 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)
|
|||||||
// 获取请求头
|
// 获取请求头
|
||||||
func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
if p.isAzure {
|
p.CommonRequestHeaders(headers)
|
||||||
|
if p.IsAzure {
|
||||||
headers["api-key"] = p.Context.GetString("api_key")
|
headers["api-key"] = p.Context.GetString("api_key")
|
||||||
} else {
|
} else {
|
||||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
||||||
}
|
}
|
||||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
|
||||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
|
||||||
if headers["Content-Type"] == "" {
|
|
||||||
headers["Content-Type"] = "application/json; charset=utf-8"
|
|
||||||
}
|
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
}
|
}
|
||||||
@ -114,7 +103,7 @@ func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderR
|
|||||||
|
|
||||||
// 处理响应
|
// 处理响应
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return p.handleErrorResp(resp)
|
return p.HandleErrorResp(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建一个 bytes.Buffer 来存储响应体
|
// 创建一个 bytes.Buffer 来存储响应体
|
||||||
@ -127,7 +116,7 @@ func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderR
|
|||||||
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
openAIErrorWithStatusCode = response.requestHandler(resp)
|
openAIErrorWithStatusCode = response.responseHandler(resp)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if openAIErrorWithStatusCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -145,6 +134,7 @@ func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderR
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 发送流式请求
|
||||||
func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
|
func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
|
||||||
|
|
||||||
resp, err := common.HttpClient.Do(req)
|
resp, err := common.HttpClient.Do(req)
|
||||||
@ -153,7 +143,7 @@ func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIPro
|
|||||||
}
|
}
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return p.handleErrorResp(resp), ""
|
return p.HandleErrorResp(resp), ""
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@ -190,12 +180,12 @@ func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIPro
|
|||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
continue // just ignore the error
|
continue // just ignore the error
|
||||||
}
|
}
|
||||||
responseText += response.requestStreamHandler()
|
responseText += response.responseStreamHandler()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(p.Context)
|
common.SetEventStreamHeaders(p.Context)
|
||||||
p.Context.Stream(func(w io.Writer) bool {
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -6,19 +6,9 @@ import (
|
|||||||
"one-api/types"
|
"one-api/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIProviderChatResponse struct {
|
func (c *OpenAIProviderChatResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
types.ChatCompletionResponse
|
|
||||||
types.OpenAIErrorResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIProviderChatStreamResponse struct {
|
|
||||||
types.ChatCompletionStreamResponse
|
|
||||||
types.OpenAIErrorResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *OpenAIProviderChatResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
|
||||||
if c.Error.Type != "" {
|
if c.Error.Type != "" {
|
||||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: c.Error,
|
OpenAIError: c.Error,
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}
|
}
|
||||||
@ -27,7 +17,7 @@ func (c *OpenAIProviderChatResponse) requestHandler(resp *http.Response) (openAI
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OpenAIProviderChatStreamResponse) requestStreamHandler() (responseText string) {
|
func (c *OpenAIProviderChatStreamResponse) responseStreamHandler() (responseText string) {
|
||||||
for _, choice := range c.Choices {
|
for _, choice := range c.Choices {
|
||||||
responseText += choice.Delta.Content
|
responseText += choice.Delta.Content
|
||||||
}
|
}
|
||||||
@ -35,7 +25,7 @@ func (c *OpenAIProviderChatStreamResponse) requestStreamHandler() (responseText
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
@ -56,8 +46,8 @@ func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionReque
|
|||||||
if request.Stream {
|
if request.Stream {
|
||||||
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
|
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
|
||||||
var textResponse string
|
var textResponse string
|
||||||
openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse)
|
errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,8 +59,8 @@ func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionReque
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
openAIProviderChatResponse := &OpenAIProviderChatResponse{}
|
openAIProviderChatResponse := &OpenAIProviderChatResponse{}
|
||||||
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderChatResponse)
|
errWithCode = p.sendRequest(req, openAIProviderChatResponse)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -6,14 +6,9 @@ import (
|
|||||||
"one-api/types"
|
"one-api/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIProviderCompletionResponse struct {
|
func (c *OpenAIProviderCompletionResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
types.CompletionResponse
|
|
||||||
types.OpenAIErrorResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *OpenAIProviderCompletionResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
|
||||||
if c.Error.Type != "" {
|
if c.Error.Type != "" {
|
||||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: c.Error,
|
OpenAIError: c.Error,
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}
|
}
|
||||||
@ -22,7 +17,7 @@ func (c *OpenAIProviderCompletionResponse) requestHandler(resp *http.Response) (
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OpenAIProviderCompletionResponse) requestStreamHandler() (responseText string) {
|
func (c *OpenAIProviderCompletionResponse) responseStreamHandler() (responseText string) {
|
||||||
for _, choice := range c.Choices {
|
for _, choice := range c.Choices {
|
||||||
responseText += choice.Text
|
responseText += choice.Text
|
||||||
}
|
}
|
||||||
@ -30,7 +25,7 @@ func (c *OpenAIProviderCompletionResponse) requestStreamHandler() (responseText
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
@ -52,8 +47,8 @@ func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isMo
|
|||||||
if request.Stream {
|
if request.Stream {
|
||||||
// TODO
|
// TODO
|
||||||
var textResponse string
|
var textResponse string
|
||||||
openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse)
|
errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,8 +59,8 @@ func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isMo
|
|||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderCompletionResponse)
|
errWithCode = p.sendRequest(req, openAIProviderCompletionResponse)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -6,14 +6,9 @@ import (
|
|||||||
"one-api/types"
|
"one-api/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIProviderEmbeddingsResponse struct {
|
func (c *OpenAIProviderEmbeddingsResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
types.EmbeddingResponse
|
|
||||||
types.OpenAIErrorResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *OpenAIProviderEmbeddingsResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
|
||||||
if c.Error.Type != "" {
|
if c.Error.Type != "" {
|
||||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: c.Error,
|
OpenAIError: c.Error,
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}
|
}
|
||||||
@ -22,7 +17,7 @@ func (c *OpenAIProviderEmbeddingsResponse) requestHandler(resp *http.Response) (
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OpenAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -39,8 +34,8 @@ func (p *OpenAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isM
|
|||||||
}
|
}
|
||||||
|
|
||||||
openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{}
|
openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{}
|
||||||
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderEmbeddingsResponse)
|
errWithCode = p.sendRequest(req, openAIProviderEmbeddingsResponse)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
16
providers/openai/interface.go
Normal file
16
providers/openai/interface.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenAIProviderResponseHandler interface {
|
||||||
|
// 请求处理函数
|
||||||
|
responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIProviderStreamResponseHandler interface {
|
||||||
|
// 请求流处理函数
|
||||||
|
responseStreamHandler() (responseText string)
|
||||||
|
}
|
23
providers/openai/type.go
Normal file
23
providers/openai/type.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import "one-api/types"
|
||||||
|
|
||||||
|
type OpenAIProviderChatResponse struct {
|
||||||
|
types.ChatCompletionResponse
|
||||||
|
types.OpenAIErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIProviderChatStreamResponse struct {
|
||||||
|
types.ChatCompletionStreamResponse
|
||||||
|
types.OpenAIErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIProviderCompletionResponse struct {
|
||||||
|
types.CompletionResponse
|
||||||
|
types.OpenAIErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIProviderEmbeddingsResponse struct {
|
||||||
|
types.EmbeddingResponse
|
||||||
|
types.OpenAIErrorResponse
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package openaisb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
@ -6,28 +6,8 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenaiSBProvider struct {
|
|
||||||
*OpenAIProvider
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAISBUsageResponse struct {
|
|
||||||
Msg string `json:"msg"`
|
|
||||||
Data *struct {
|
|
||||||
Credit string `json:"credit"`
|
|
||||||
} `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建 OpenaiSBProvider
|
|
||||||
func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider {
|
|
||||||
return &OpenaiSBProvider{
|
|
||||||
OpenAIProvider: CreateOpenAIProvider(c, "https://api.openai-sb.com"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
|
func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
|
||||||
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
|
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
|
||||||
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
|
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
|
18
providers/openaisb/base.go
Normal file
18
providers/openaisb/base.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package openaisb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/providers/openai"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenaiSBProvider struct {
|
||||||
|
*openai.OpenAIProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 OpenaiSBProvider
|
||||||
|
func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider {
|
||||||
|
return &OpenaiSBProvider{
|
||||||
|
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.openai-sb.com"),
|
||||||
|
}
|
||||||
|
}
|
8
providers/openaisb/type.go
Normal file
8
providers/openaisb/type.go
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
package openaisb
|
||||||
|
|
||||||
|
type OpenAISBUsageResponse struct {
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Data *struct {
|
||||||
|
Credit string `json:"credit"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
@ -1,20 +1,21 @@
|
|||||||
package providers
|
package palm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"one-api/providers/base"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PalmProvider struct {
|
type PalmProvider struct {
|
||||||
ProviderConfig
|
base.BaseProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建 PalmProvider
|
// 创建 PalmProvider
|
||||||
func CreatePalmProvider(c *gin.Context) *PalmProvider {
|
func CreatePalmProvider(c *gin.Context) *PalmProvider {
|
||||||
return &PalmProvider{
|
return &PalmProvider{
|
||||||
ProviderConfig: ProviderConfig{
|
BaseProvider: base.BaseProvider{
|
||||||
BaseURL: "https://generativelanguage.googleapis.com",
|
BaseURL: "https://generativelanguage.googleapis.com",
|
||||||
ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage",
|
ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage",
|
||||||
Context: c,
|
Context: c,
|
||||||
@ -25,12 +26,7 @@ func CreatePalmProvider(c *gin.Context) *PalmProvider {
|
|||||||
// 获取请求头
|
// 获取请求头
|
||||||
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
|
p.CommonRequestHeaders(headers)
|
||||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
|
||||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
|
||||||
if headers["Content-Type"] == "" {
|
|
||||||
headers["Content-Type"] = "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package palm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -6,47 +6,11 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/providers/base"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PaLMChatMessage struct {
|
func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
Author string `json:"author"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMFilter struct {
|
|
||||||
Reason string `json:"reason"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMPrompt struct {
|
|
||||||
Messages []PaLMChatMessage `json:"messages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMChatRequest struct {
|
|
||||||
Prompt PaLMPrompt `json:"prompt"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
CandidateCount int `json:"candidateCount,omitempty"`
|
|
||||||
TopP float64 `json:"topP,omitempty"`
|
|
||||||
TopK int `json:"topK,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMError struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMChatResponse struct {
|
|
||||||
Candidates []PaLMChatMessage `json:"candidates"`
|
|
||||||
Messages []types.ChatCompletionMessage `json:"messages"`
|
|
||||||
Filters []PaLMFilter `json:"filters"`
|
|
||||||
Error PaLMError `json:"error"`
|
|
||||||
Usage *types.Usage `json:"usage,omitempty"`
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (palmResponse *PaLMChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
|
||||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||||
return nil, &types.OpenAIErrorWithStatusCode{
|
return nil, &types.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: types.OpenAIError{
|
OpenAIError: types.OpenAIError{
|
||||||
@ -107,7 +71,7 @@ func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest)
|
|||||||
return &palmRequest
|
return &palmRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (p *PalmProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
requestBody := p.getChatRequestBody(request)
|
requestBody := p.getChatRequestBody(request)
|
||||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
headers := p.GetRequestHeaders()
|
headers := p.GetRequestHeaders()
|
||||||
@ -123,8 +87,8 @@ func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest
|
|||||||
|
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
var responseText string
|
var responseText string
|
||||||
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
|
errWithCode, responseText = p.sendStreamRequest(req)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,8 +103,8 @@ func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest
|
|||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
openAIErrorWithStatusCode = p.sendRequest(req, palmChatResponse)
|
errWithCode = p.SendRequest(req, palmChatResponse)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,7 +119,7 @@ func (p *PalmProvider) streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse)
|
|||||||
if len(palmResponse.Candidates) > 0 {
|
if len(palmResponse.Candidates) > 0 {
|
||||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||||
}
|
}
|
||||||
choice.FinishReason = &stopFinishReason
|
choice.FinishReason = &base.StopFinishReason
|
||||||
var response types.ChatCompletionStreamResponse
|
var response types.ChatCompletionStreamResponse
|
||||||
response.Object = "chat.completion.chunk"
|
response.Object = "chat.completion.chunk"
|
||||||
response.Model = "palm2"
|
response.Model = "palm2"
|
||||||
@ -171,7 +135,7 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW
|
|||||||
}
|
}
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return p.handleErrorResp(resp), ""
|
return p.HandleErrorResp(resp), ""
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@ -216,7 +180,7 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW
|
|||||||
dataChan <- string(jsonResponse)
|
dataChan <- string(jsonResponse)
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(p.Context)
|
common.SetEventStreamHeaders(p.Context)
|
||||||
p.Context.Stream(func(w io.Writer) bool {
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
40
providers/palm/type.go
Normal file
40
providers/palm/type.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package palm
|
||||||
|
|
||||||
|
import "one-api/types"
|
||||||
|
|
||||||
|
type PaLMChatMessage struct {
|
||||||
|
Author string `json:"author"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMFilter struct {
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMPrompt struct {
|
||||||
|
Messages []PaLMChatMessage `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMChatRequest struct {
|
||||||
|
Prompt PaLMPrompt `json:"prompt"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
CandidateCount int `json:"candidateCount,omitempty"`
|
||||||
|
TopP float64 `json:"topP,omitempty"`
|
||||||
|
TopK int `json:"topK,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMChatResponse struct {
|
||||||
|
Candidates []PaLMChatMessage `json:"candidates"`
|
||||||
|
Messages []types.ChatCompletionMessage `json:"messages"`
|
||||||
|
Filters []PaLMFilter `json:"filters"`
|
||||||
|
Error PaLMError `json:"error"`
|
||||||
|
Usage *types.Usage `json:"usage,omitempty"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
}
|
50
providers/providers.go
Normal file
50
providers/providers.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/providers/ali"
|
||||||
|
"one-api/providers/azure"
|
||||||
|
"one-api/providers/baidu"
|
||||||
|
"one-api/providers/base"
|
||||||
|
"one-api/providers/claude"
|
||||||
|
"one-api/providers/openai"
|
||||||
|
"one-api/providers/palm"
|
||||||
|
"one-api/providers/tencent"
|
||||||
|
"one-api/providers/xunfei"
|
||||||
|
"one-api/providers/zhipu"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetProvider(channelType int, c *gin.Context) base.ProviderInterface {
|
||||||
|
switch channelType {
|
||||||
|
case common.ChannelTypeOpenAI:
|
||||||
|
return openai.CreateOpenAIProvider(c, "")
|
||||||
|
case common.ChannelTypeAzure:
|
||||||
|
return azure.CreateAzureProvider(c)
|
||||||
|
case common.ChannelTypeAli:
|
||||||
|
return ali.CreateAliAIProvider(c)
|
||||||
|
case common.ChannelTypeTencent:
|
||||||
|
return tencent.CreateTencentProvider(c)
|
||||||
|
case common.ChannelTypeBaidu:
|
||||||
|
return baidu.CreateBaiduProvider(c)
|
||||||
|
case common.ChannelTypeAnthropic:
|
||||||
|
return claude.CreateClaudeProvider(c)
|
||||||
|
case common.ChannelTypePaLM:
|
||||||
|
return palm.CreatePalmProvider(c)
|
||||||
|
case common.ChannelTypeZhipu:
|
||||||
|
return zhipu.CreateZhipuProvider(c)
|
||||||
|
case common.ChannelTypeXunfei:
|
||||||
|
return xunfei.CreateXunfeiProvider(c)
|
||||||
|
default:
|
||||||
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
|
if c.GetString("base_url") != "" {
|
||||||
|
baseURL = c.GetString("base_url")
|
||||||
|
}
|
||||||
|
if baseURL != "" {
|
||||||
|
return openai.CreateOpenAIProvider(c, baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package tencent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"one-api/providers/base"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -14,18 +15,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TencentProvider struct {
|
type TencentProvider struct {
|
||||||
ProviderConfig
|
base.BaseProvider
|
||||||
}
|
|
||||||
|
|
||||||
type TencentError struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建 TencentProvider
|
// 创建 TencentProvider
|
||||||
func CreateTencentProvider(c *gin.Context) *TencentProvider {
|
func CreateTencentProvider(c *gin.Context) *TencentProvider {
|
||||||
return &TencentProvider{
|
return &TencentProvider{
|
||||||
ProviderConfig: ProviderConfig{
|
BaseProvider: base.BaseProvider{
|
||||||
BaseURL: "https://hunyuan.cloud.tencent.com",
|
BaseURL: "https://hunyuan.cloud.tencent.com",
|
||||||
ChatCompletions: "/hyllm/v1/chat/completions",
|
ChatCompletions: "/hyllm/v1/chat/completions",
|
||||||
Context: c,
|
Context: c,
|
||||||
@ -36,12 +32,7 @@ func CreateTencentProvider(c *gin.Context) *TencentProvider {
|
|||||||
// 获取请求头
|
// 获取请求头
|
||||||
func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
|
p.CommonRequestHeaders(headers)
|
||||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
|
||||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
|
||||||
if headers["Content-Type"] == "" {
|
|
||||||
headers["Content-Type"] = "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package tencent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -7,64 +7,12 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/providers/base"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TencentMessage struct {
|
func (TencentResponse *TencentChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentChatRequest struct {
|
|
||||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
|
||||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
|
||||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
|
||||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
|
||||||
Timestamp int64 `json:"timestamp"`
|
|
||||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
|
||||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
|
||||||
Expired int64 `json:"expired"`
|
|
||||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
|
||||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
|
||||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
|
||||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
|
||||||
Temperature float64 `json:"temperature"`
|
|
||||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
|
||||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
|
||||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
|
||||||
TopP float64 `json:"top_p"`
|
|
||||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
|
||||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
|
||||||
Stream int `json:"stream"`
|
|
||||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
|
||||||
// 输入 content 总数最大支持 3000 token。
|
|
||||||
Messages []TencentMessage `json:"messages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentUsage struct {
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentResponseChoices struct {
|
|
||||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
|
||||||
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
|
||||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentChatResponse struct {
|
|
||||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
|
||||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
|
||||||
Id string `json:"id,omitempty"` // 会话 id
|
|
||||||
Usage *types.Usage `json:"usage,omitempty"` // token 数量
|
|
||||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
|
||||||
Note string `json:"note,omitempty"` // 注释
|
|
||||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
|
||||||
}
|
|
||||||
|
|
||||||
func (TencentResponse *TencentChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
|
||||||
if TencentResponse.Error.Code != 0 {
|
if TencentResponse.Error.Code != 0 {
|
||||||
return &types.OpenAIErrorWithStatusCode{
|
return &types.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: types.OpenAIError{
|
OpenAIError: types.OpenAIError{
|
||||||
@ -130,7 +78,7 @@ func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionReques
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
requestBody := p.getChatRequestBody(request)
|
requestBody := p.getChatRequestBody(request)
|
||||||
sign := p.getTencentSign(*requestBody)
|
sign := p.getTencentSign(*requestBody)
|
||||||
if sign == "" {
|
if sign == "" {
|
||||||
@ -152,8 +100,8 @@ func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequ
|
|||||||
|
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
var responseText string
|
var responseText string
|
||||||
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
|
errWithCode, responseText = p.sendStreamRequest(req)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,8 +111,8 @@ func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequ
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
tencentResponse := &TencentChatResponse{}
|
tencentResponse := &TencentChatResponse{}
|
||||||
openAIErrorWithStatusCode = p.sendRequest(req, tencentResponse)
|
errWithCode = p.SendRequest(req, tencentResponse)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,7 +132,7 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC
|
|||||||
var choice types.ChatCompletionStreamChoice
|
var choice types.ChatCompletionStreamChoice
|
||||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||||
choice.FinishReason = &stopFinishReason
|
choice.FinishReason = &base.StopFinishReason
|
||||||
}
|
}
|
||||||
response.Choices = append(response.Choices, choice)
|
response.Choices = append(response.Choices, choice)
|
||||||
}
|
}
|
||||||
@ -199,7 +147,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr
|
|||||||
}
|
}
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return p.handleErrorResp(resp), ""
|
return p.HandleErrorResp(resp), ""
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@ -234,7 +182,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(p.Context)
|
common.SetEventStreamHeaders(p.Context)
|
||||||
p.Context.Stream(func(w io.Writer) bool {
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
61
providers/tencent/type.go
Normal file
61
providers/tencent/type.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package tencent
|
||||||
|
|
||||||
|
import "one-api/types"
|
||||||
|
|
||||||
|
type TencentError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentChatRequest struct {
|
||||||
|
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||||
|
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||||
|
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||||
|
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||||
|
Timestamp int64 `json:"timestamp"`
|
||||||
|
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||||
|
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||||
|
Expired int64 `json:"expired"`
|
||||||
|
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||||
|
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||||
|
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||||
|
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||||
|
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||||
|
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||||
|
TopP float64 `json:"top_p"`
|
||||||
|
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||||
|
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||||
|
Stream int `json:"stream"`
|
||||||
|
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||||
|
// 输入 content 总数最大支持 3000 token。
|
||||||
|
Messages []TencentMessage `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentUsage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentResponseChoices struct {
|
||||||
|
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||||
|
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||||
|
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentChatResponse struct {
|
||||||
|
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||||
|
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||||
|
Id string `json:"id,omitempty"` // 会话 id
|
||||||
|
Usage *types.Usage `json:"usage,omitempty"` // token 数量
|
||||||
|
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||||
|
Note string `json:"note,omitempty"` // 注释
|
||||||
|
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package xunfei
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/providers/base"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -15,7 +16,7 @@ import (
|
|||||||
|
|
||||||
// https://www.xfyun.cn/doc/spark/Web.html
|
// https://www.xfyun.cn/doc/spark/Web.html
|
||||||
type XunfeiProvider struct {
|
type XunfeiProvider struct {
|
||||||
ProviderConfig
|
base.BaseProvider
|
||||||
domain string
|
domain string
|
||||||
apiId string
|
apiId string
|
||||||
}
|
}
|
||||||
@ -23,7 +24,7 @@ type XunfeiProvider struct {
|
|||||||
// 创建 XunfeiProvider
|
// 创建 XunfeiProvider
|
||||||
func CreateXunfeiProvider(c *gin.Context) *XunfeiProvider {
|
func CreateXunfeiProvider(c *gin.Context) *XunfeiProvider {
|
||||||
return &XunfeiProvider{
|
return &XunfeiProvider{
|
||||||
ProviderConfig: ProviderConfig{
|
BaseProvider: base.BaseProvider{
|
||||||
BaseURL: "wss://spark-api.xf-yun.com",
|
BaseURL: "wss://spark-api.xf-yun.com",
|
||||||
ChatCompletions: "",
|
ChatCompletions: "",
|
||||||
Context: c,
|
Context: c,
|
@ -1,73 +1,18 @@
|
|||||||
package providers
|
package xunfei
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/providers/base"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
type XunfeiMessage struct {
|
func (p *XunfeiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XunfeiChatRequest struct {
|
|
||||||
Header struct {
|
|
||||||
AppId string `json:"app_id"`
|
|
||||||
} `json:"header"`
|
|
||||||
Parameter struct {
|
|
||||||
Chat struct {
|
|
||||||
Domain string `json:"domain,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopK int `json:"top_k,omitempty"`
|
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
|
||||||
Auditing bool `json:"auditing,omitempty"`
|
|
||||||
} `json:"chat"`
|
|
||||||
} `json:"parameter"`
|
|
||||||
Payload struct {
|
|
||||||
Message struct {
|
|
||||||
Text []XunfeiMessage `json:"text"`
|
|
||||||
} `json:"message"`
|
|
||||||
} `json:"payload"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XunfeiChatResponseTextItem struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
Role string `json:"role"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XunfeiChatResponse struct {
|
|
||||||
Header struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Sid string `json:"sid"`
|
|
||||||
Status int `json:"status"`
|
|
||||||
} `json:"header"`
|
|
||||||
Payload struct {
|
|
||||||
Choices struct {
|
|
||||||
Status int `json:"status"`
|
|
||||||
Seq int `json:"seq"`
|
|
||||||
Text []XunfeiChatResponseTextItem `json:"text"`
|
|
||||||
} `json:"choices"`
|
|
||||||
Usage struct {
|
|
||||||
//Text struct {
|
|
||||||
// QuestionTokens string `json:"question_tokens"`
|
|
||||||
// PromptTokens string `json:"prompt_tokens"`
|
|
||||||
// CompletionTokens string `json:"completion_tokens"`
|
|
||||||
// TotalTokens string `json:"total_tokens"`
|
|
||||||
//} `json:"text"`
|
|
||||||
Text types.Usage `json:"text"`
|
|
||||||
} `json:"usage"`
|
|
||||||
} `json:"payload"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *XunfeiProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
|
||||||
authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
|
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
@ -77,7 +22,7 @@ func (p *XunfeiProvider) ChatCompleteResponse(request *types.ChatCompletionReque
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
usage = &types.Usage{}
|
usage = &types.Usage{}
|
||||||
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
|
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -113,13 +58,13 @@ func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authU
|
|||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
usage = &types.Usage{}
|
usage = &types.Usage{}
|
||||||
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
|
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
|
return nil, types.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
setEventStreamHeaders(p.Context)
|
common.SetEventStreamHeaders(p.Context)
|
||||||
p.Context.Stream(func(w io.Writer) bool {
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case xunfeiResponse := <-dataChan:
|
case xunfeiResponse := <-dataChan:
|
||||||
@ -185,7 +130,7 @@ func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *ty
|
|||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: response.Payload.Choices.Text[0].Content,
|
Content: response.Payload.Choices.Text[0].Content,
|
||||||
},
|
},
|
||||||
FinishReason: stopFinishReason,
|
FinishReason: base.StopFinishReason,
|
||||||
}
|
}
|
||||||
fullTextResponse := types.ChatCompletionResponse{
|
fullTextResponse := types.ChatCompletionResponse{
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
@ -251,7 +196,7 @@ func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatR
|
|||||||
var choice types.ChatCompletionStreamChoice
|
var choice types.ChatCompletionStreamChoice
|
||||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||||
choice.FinishReason = &stopFinishReason
|
choice.FinishReason = &base.StopFinishReason
|
||||||
}
|
}
|
||||||
response := types.ChatCompletionStreamResponse{
|
response := types.ChatCompletionStreamResponse{
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
59
providers/xunfei/type.go
Normal file
59
providers/xunfei/type.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package xunfei
|
||||||
|
|
||||||
|
import "one-api/types"
|
||||||
|
|
||||||
|
type XunfeiMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XunfeiChatRequest struct {
|
||||||
|
Header struct {
|
||||||
|
AppId string `json:"app_id"`
|
||||||
|
} `json:"header"`
|
||||||
|
Parameter struct {
|
||||||
|
Chat struct {
|
||||||
|
Domain string `json:"domain,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Auditing bool `json:"auditing,omitempty"`
|
||||||
|
} `json:"chat"`
|
||||||
|
} `json:"parameter"`
|
||||||
|
Payload struct {
|
||||||
|
Message struct {
|
||||||
|
Text []XunfeiMessage `json:"text"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"payload"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XunfeiChatResponseTextItem struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XunfeiChatResponse struct {
|
||||||
|
Header struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Sid string `json:"sid"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
} `json:"header"`
|
||||||
|
Payload struct {
|
||||||
|
Choices struct {
|
||||||
|
Status int `json:"status"`
|
||||||
|
Seq int `json:"seq"`
|
||||||
|
Text []XunfeiChatResponseTextItem `json:"text"`
|
||||||
|
} `json:"choices"`
|
||||||
|
Usage struct {
|
||||||
|
//Text struct {
|
||||||
|
// QuestionTokens string `json:"question_tokens"`
|
||||||
|
// PromptTokens string `json:"prompt_tokens"`
|
||||||
|
// CompletionTokens string `json:"completion_tokens"`
|
||||||
|
// TotalTokens string `json:"total_tokens"`
|
||||||
|
//} `json:"text"`
|
||||||
|
Text types.Usage `json:"text"`
|
||||||
|
} `json:"usage"`
|
||||||
|
} `json:"payload"`
|
||||||
|
}
|
@ -1,8 +1,9 @@
|
|||||||
package providers
|
package zhipu
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/providers/base"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -15,18 +16,13 @@ var zhipuTokens sync.Map
|
|||||||
var expSeconds int64 = 24 * 3600
|
var expSeconds int64 = 24 * 3600
|
||||||
|
|
||||||
type ZhipuProvider struct {
|
type ZhipuProvider struct {
|
||||||
ProviderConfig
|
base.BaseProvider
|
||||||
}
|
|
||||||
|
|
||||||
type zhipuTokenData struct {
|
|
||||||
Token string
|
|
||||||
ExpiryTime time.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建 ZhipuProvider
|
// 创建 ZhipuProvider
|
||||||
func CreateZhipuProvider(c *gin.Context) *ZhipuProvider {
|
func CreateZhipuProvider(c *gin.Context) *ZhipuProvider {
|
||||||
return &ZhipuProvider{
|
return &ZhipuProvider{
|
||||||
ProviderConfig: ProviderConfig{
|
BaseProvider: base.BaseProvider{
|
||||||
BaseURL: "https://open.bigmodel.cn",
|
BaseURL: "https://open.bigmodel.cn",
|
||||||
ChatCompletions: "/api/paas/v3/model-api",
|
ChatCompletions: "/api/paas/v3/model-api",
|
||||||
Context: c,
|
Context: c,
|
||||||
@ -37,13 +33,8 @@ func CreateZhipuProvider(c *gin.Context) *ZhipuProvider {
|
|||||||
// 获取请求头
|
// 获取请求头
|
||||||
func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
|
p.CommonRequestHeaders(headers)
|
||||||
headers["Authorization"] = p.getZhipuToken()
|
headers["Authorization"] = p.getZhipuToken()
|
||||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
|
||||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
|
||||||
if headers["Content-Type"] == "" {
|
|
||||||
headers["Content-Type"] = "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package zhipu
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -6,46 +6,12 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/providers/base"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ZhipuMessage struct {
|
func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuRequest struct {
|
|
||||||
Prompt []ZhipuMessage `json:"prompt"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
RequestId string `json:"request_id,omitempty"`
|
|
||||||
Incremental bool `json:"incremental,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuResponseData struct {
|
|
||||||
TaskId string `json:"task_id"`
|
|
||||||
RequestId string `json:"request_id"`
|
|
||||||
TaskStatus string `json:"task_status"`
|
|
||||||
Choices []ZhipuMessage `json:"choices"`
|
|
||||||
types.Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuResponse struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Msg string `json:"msg"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Data ZhipuResponseData `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuStreamMetaResponse struct {
|
|
||||||
RequestId string `json:"request_id"`
|
|
||||||
TaskId string `json:"task_id"`
|
|
||||||
TaskStatus string `json:"task_status"`
|
|
||||||
types.Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (zhipuResponse *ZhipuResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
|
||||||
if !zhipuResponse.Success {
|
if !zhipuResponse.Success {
|
||||||
return &types.OpenAIErrorWithStatusCode{
|
return &types.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: types.OpenAIError{
|
OpenAIError: types.OpenAIError{
|
||||||
@ -110,7 +76,7 @@ func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ZhipuProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (p *ZhipuProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
requestBody := p.getChatRequestBody(request)
|
requestBody := p.getChatRequestBody(request)
|
||||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
headers := p.GetRequestHeaders()
|
headers := p.GetRequestHeaders()
|
||||||
@ -128,15 +94,15 @@ func (p *ZhipuProvider) ChatCompleteResponse(request *types.ChatCompletionReques
|
|||||||
}
|
}
|
||||||
|
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
|
errWithCode, usage = p.sendStreamRequest(req)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
zhipuResponse := &ZhipuResponse{}
|
zhipuResponse := &ZhipuResponse{}
|
||||||
openAIErrorWithStatusCode = p.sendRequest(req, zhipuResponse)
|
errWithCode = p.SendRequest(req, zhipuResponse)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,7 +127,7 @@ func (p *ZhipuProvider) streamResponseZhipu2OpenAI(zhipuResponse string) *types.
|
|||||||
func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*types.ChatCompletionStreamResponse, *types.Usage) {
|
func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*types.ChatCompletionStreamResponse, *types.Usage) {
|
||||||
var choice types.ChatCompletionStreamChoice
|
var choice types.ChatCompletionStreamChoice
|
||||||
choice.Delta.Content = ""
|
choice.Delta.Content = ""
|
||||||
choice.FinishReason = &stopFinishReason
|
choice.FinishReason = &base.StopFinishReason
|
||||||
response := types.ChatCompletionStreamResponse{
|
response := types.ChatCompletionStreamResponse{
|
||||||
ID: zhipuResponse.RequestId,
|
ID: zhipuResponse.RequestId,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
@ -180,7 +146,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
|
|||||||
}
|
}
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return p.handleErrorResp(resp), nil
|
return p.HandleErrorResp(resp), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@ -222,7 +188,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(p.Context)
|
common.SetEventStreamHeaders(p.Context)
|
||||||
p.Context.Stream(func(w io.Writer) bool {
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
46
providers/zhipu/type.go
Normal file
46
providers/zhipu/type.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
package zhipu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/types"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ZhipuMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuRequest struct {
|
||||||
|
Prompt []ZhipuMessage `json:"prompt"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
RequestId string `json:"request_id,omitempty"`
|
||||||
|
Incremental bool `json:"incremental,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuResponseData struct {
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
TaskStatus string `json:"task_status"`
|
||||||
|
Choices []ZhipuMessage `json:"choices"`
|
||||||
|
types.Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data ZhipuResponseData `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuStreamMetaResponse struct {
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
TaskStatus string `json:"task_status"`
|
||||||
|
types.Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type zhipuTokenData struct {
|
||||||
|
Token string
|
||||||
|
ExpiryTime time.Time
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user