From c97c8a0f65d5b4519a52d5467c11b0c563e486af Mon Sep 17 00:00:00 2001 From: MartialBE Date: Sat, 2 Dec 2023 19:54:21 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E4=BD=99=E9=A2=9D=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- common/init.go | 7 +++++ controller/channel-billing.go | 16 +++++++++-- controller/channel-test.go | 46 +++++++++++++------------------ go.mod | 1 + go.sum | 2 ++ providers/aigc2d/balance.go | 30 ++++++++++++++++++++ providers/aigc2d/base.go | 20 ++++++++++++++ providers/api2d/balance.go | 5 ++-- providers/api2gpt/balance.go | 30 ++++++++++++++++++++ providers/api2gpt/base.go | 20 ++++++++++++++ providers/base/interface.go | 2 +- providers/{api2d => base}/type.go | 8 +++--- providers/providers.go | 5 ++++ 14 files changed, 158 insertions(+), 37 deletions(-) create mode 100644 providers/aigc2d/balance.go create mode 100644 providers/aigc2d/base.go create mode 100644 providers/api2gpt/balance.go create mode 100644 providers/api2gpt/base.go rename providers/{api2d => base}/type.go (65%) diff --git a/.gitignore b/.gitignore index 4eaf9868..af989b72 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ build *.db-journal logs data -tmp/ \ No newline at end of file +tmp/ +.env \ No newline at end of file diff --git a/common/init.go b/common/init.go index 1e9c85ce..1e8b4dcd 100644 --- a/common/init.go +++ b/common/init.go @@ -6,6 +6,8 @@ import ( "log" "os" "path/filepath" + + "github.com/joho/godotenv" ) var ( @@ -23,6 +25,11 @@ func printHelp() { } func init() { + // 加载.env文件 + err := godotenv.Load() + if err != nil { + SysLog("failed to load .env file: " + err.Error()) + } flag.Parse() if *PrintVersion { diff --git a/controller/channel-billing.go b/controller/channel-billing.go index a9807417..da80a26d 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -3,6 +3,7 @@ package controller import ( "errors" "net/http" + "net/http/httptest" "one-api/common" "one-api/model" "one-api/providers" @@ -46,7 +47,18 @@ type OpenAIUsageResponse struct { } func updateChannelBalance(channel *model.Channel) (float64, error) { - provider := providers.GetProvider(channel.Type, nil) + req, err := http.NewRequest("POST", "/balance", nil) + if err != nil { + return 0, err + } + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = req + + setChannelToContext(c, channel) + req.Header.Set("Content-Type", "application/json") + + provider := providers.GetProvider(channel.Type, c) if provider == nil { return 0, errors.New("provider not found") } @@ -56,7 +68,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { return 0, errors.New("provider not implemented") } - return balanceProvider.BalanceAction(channel) + return balanceProvider.Balance(channel) } diff --git a/controller/channel-test.go b/controller/channel-test.go index 62ada94e..1f3d5b71 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -30,32 +30,26 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e c.Request = req setChannelToContext(c, channel) - - switch channel.Type { - case common.ChannelTypePaLM: - request.Model = "PaLM-2" - case common.ChannelTypeAnthropic: - request.Model = "claude-2" - case common.ChannelTypeBaidu: - request.Model = "ERNIE-Bot" - case common.ChannelTypeZhipu: - request.Model = "chatglm_lite" - case common.ChannelTypeAli: - request.Model = "qwen-turbo" - case common.ChannelType360: - request.Model = "360GPT_S2_V9" - case common.ChannelTypeXunfei: - request.Model = "SparkDesk" - c.Set("api_version", channel.Other) - case common.ChannelTypeTencent: - request.Model = "hunyuan" - case common.ChannelTypeAzure: - request.Model = "gpt-3.5-turbo" - c.Set("api_version", channel.Other) - default: - request.Model = "gpt-3.5-turbo" + // 创建映射 + channelTypeToModel := map[int]string{ + common.ChannelTypePaLM: "PaLM-2", + common.ChannelTypeAnthropic: "claude-2", + common.ChannelTypeBaidu: "ERNIE-Bot", + common.ChannelTypeZhipu: "chatglm_lite", + common.ChannelTypeAli: "qwen-turbo", + common.ChannelType360: "360GPT_S2_V9", + common.ChannelTypeXunfei: "SparkDesk", + common.ChannelTypeTencent: "hunyuan", + common.ChannelTypeAzure: "gpt-3.5-turbo", } + // 从映射中获取模型名称 + model, ok := channelTypeToModel[channel.Type] + if !ok { + model = "gpt-3.5-turbo" // 默认值 + } + request.Model = model + provider := providers.GetProvider(channel.Type, c) if provider == nil { return errors.New("channel not implemented"), nil @@ -65,18 +59,16 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e return errors.New("channel not implemented"), nil } - isModelMapped := false modelMap, err := parseModelMapping(channel.GetModelMapping()) if err != nil { return err, nil } if modelMap != nil && modelMap[request.Model] != "" { request.Model = modelMap[request.Model] - isModelMapped = true } promptTokens := common.CountTokenMessages(request.Messages, request.Model) - _, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, isModelMapped, promptTokens) + _, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens) if openAIErrorWithStatusCode != nil { return nil, &openAIErrorWithStatusCode.OpenAIError } diff --git a/go.mod b/go.mod index 10b78d68..632e2a6a 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( github.com/jackc/pgx/v5 v5.3.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect + github.com/joho/godotenv v1.5.1 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect diff --git a/go.sum b/go.sum index 4865bcaa..9d3407fe 100644 --- a/go.sum +++ b/go.sum @@ -80,6 +80,8 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= diff --git a/providers/aigc2d/balance.go b/providers/aigc2d/balance.go new file mode 100644 index 00000000..cb1613bb --- /dev/null +++ b/providers/aigc2d/balance.go @@ -0,0 +1,30 @@ +package aigc2d + +import ( + "errors" + "one-api/common" + "one-api/model" + "one-api/providers/base" +) + +func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) { + fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") + headers := p.GetRequestHeaders() + + client := common.NewClient() + req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + if err != nil { + return 0, err + } + + // 发送请求 + var response base.BalanceResponse + _, errWithCode := common.SendRequest(req, &response, false) + if errWithCode != nil { + return 0, errors.New(errWithCode.OpenAIError.Message) + } + + channel.UpdateBalance(response.TotalAvailable) + + return response.TotalAvailable, nil +} diff --git a/providers/aigc2d/base.go b/providers/aigc2d/base.go new file mode 100644 index 00000000..b4654b3b --- /dev/null +++ b/providers/aigc2d/base.go @@ -0,0 +1,20 @@ +package aigc2d + +import ( + "one-api/providers/base" + "one-api/providers/openai" + + "github.com/gin-gonic/gin" +) + +type Aigc2dProviderFactory struct{} + +func (f Aigc2dProviderFactory) Create(c *gin.Context) base.ProviderInterface { + return &Aigc2dProvider{ + OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aigc2d.com"), + } +} + +type Aigc2dProvider struct { + *openai.OpenAIProvider +} diff --git a/providers/api2d/balance.go b/providers/api2d/balance.go index 080a728f..520c04c3 100644 --- a/providers/api2d/balance.go +++ b/providers/api2d/balance.go @@ -4,10 +4,11 @@ import ( "errors" "one-api/common" "one-api/model" + "one-api/providers/base" ) func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) { - fullRequestURL := "https://api.aigc2d.com/dashboard/billing/credit_grants" + fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") headers := p.GetRequestHeaders() client := common.NewClient() @@ -17,7 +18,7 @@ func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) { } // 发送请求 - var response APGC2DGPTUsageResponse + var response base.BalanceResponse _, errWithCode := common.SendRequest(req, &response, false) if errWithCode != nil { return 0, errors.New(errWithCode.OpenAIError.Message) diff --git a/providers/api2gpt/balance.go b/providers/api2gpt/balance.go new file mode 100644 index 00000000..a8872b40 --- /dev/null +++ b/providers/api2gpt/balance.go @@ -0,0 +1,30 @@ +package api2gpt + +import ( + "errors" + "one-api/common" + "one-api/model" + "one-api/providers/base" +) + +func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) { + fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") + headers := p.GetRequestHeaders() + + client := common.NewClient() + req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + if err != nil { + return 0, err + } + + // 发送请求 + var response base.BalanceResponse + _, errWithCode := common.SendRequest(req, &response, false) + if errWithCode != nil { + return 0, errors.New(errWithCode.OpenAIError.Message) + } + + channel.UpdateBalance(response.TotalAvailable) + + return response.TotalRemaining, nil +} diff --git a/providers/api2gpt/base.go b/providers/api2gpt/base.go new file mode 100644 index 00000000..c502108a --- /dev/null +++ b/providers/api2gpt/base.go @@ -0,0 +1,20 @@ +package api2gpt + +import ( + "one-api/providers/base" + "one-api/providers/openai" + + "github.com/gin-gonic/gin" +) + +type Api2gptProviderFactory struct{} + +func (f Api2gptProviderFactory) Create(c *gin.Context) base.ProviderInterface { + return &Api2gptProvider{ + OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.api2gpt.com"), + } +} + +type Api2gptProvider struct { + *openai.OpenAIProvider +} diff --git a/providers/base/interface.go b/providers/base/interface.go index 714bd5c1..5c05b404 100644 --- a/providers/base/interface.go +++ b/providers/base/interface.go @@ -75,7 +75,7 @@ type ImageVariationsInterface interface { // 余额接口 type BalanceInterface interface { - BalanceAction(channel *model.Channel) (float64, error) + Balance(channel *model.Channel) (float64, error) } type ProviderResponseHandler interface { diff --git a/providers/api2d/type.go b/providers/base/type.go similarity index 65% rename from providers/api2d/type.go rename to providers/base/type.go index 988bb7a6..239261d5 100644 --- a/providers/api2d/type.go +++ b/providers/base/type.go @@ -1,9 +1,9 @@ -package api2d +package base -type APGC2DGPTUsageResponse struct { - //Grants interface{} `json:"grants"` +type BalanceResponse struct { Object string `json:"object"` - TotalAvailable float64 `json:"total_available"` TotalGranted float64 `json:"total_granted"` TotalUsed float64 `json:"total_used"` + TotalRemaining float64 `json:"total_remaining"` + TotalAvailable float64 `json:"total_available"` } diff --git a/providers/providers.go b/providers/providers.go index 74a3c385..ceccaeaf 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -2,9 +2,11 @@ package providers import ( "one-api/common" + "one-api/providers/aigc2d" "one-api/providers/aiproxy" "one-api/providers/ali" "one-api/providers/api2d" + "one-api/providers/api2gpt" "one-api/providers/azure" "one-api/providers/baidu" "one-api/providers/base" @@ -43,6 +45,9 @@ func init() { providerFactories[common.ChannelTypeAPI2D] = api2d.Api2dProviderFactory{} providerFactories[common.ChannelTypeCloseAI] = closeai.CloseaiProviderFactory{} providerFactories[common.ChannelTypeOpenAISB] = openaisb.OpenaiSBProviderFactory{} + providerFactories[common.ChannelTypeAIGC2D] = aigc2d.Aigc2dProviderFactory{} + providerFactories[common.ChannelTypeAPI2GPT] = api2gpt.Api2gptProviderFactory{} + } // 获取供应商