🐛 fix: 修复余额的问题

This commit is contained in:
MartialBE 2023-12-02 19:54:21 +08:00
parent 58fc40a744
commit c97c8a0f65
14 changed files with 158 additions and 37 deletions

3
.gitignore vendored
View File

@ -7,4 +7,5 @@ build
*.db-journal
logs
data
tmp/
tmp/
.env

View File

@ -6,6 +6,8 @@ import (
"log"
"os"
"path/filepath"
"github.com/joho/godotenv"
)
var (
@ -23,6 +25,11 @@ func printHelp() {
}
func init() {
// 加载.env文件
err := godotenv.Load()
if err != nil {
SysLog("failed to load .env file: " + err.Error())
}
flag.Parse()
if *PrintVersion {

View File

@ -3,6 +3,7 @@ package controller
import (
"errors"
"net/http"
"net/http/httptest"
"one-api/common"
"one-api/model"
"one-api/providers"
@ -46,7 +47,18 @@ type OpenAIUsageResponse struct {
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
provider := providers.GetProvider(channel.Type, nil)
req, err := http.NewRequest("POST", "/balance", nil)
if err != nil {
return 0, err
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
setChannelToContext(c, channel)
req.Header.Set("Content-Type", "application/json")
provider := providers.GetProvider(channel.Type, c)
if provider == nil {
return 0, errors.New("provider not found")
}
@ -56,7 +68,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return 0, errors.New("provider not implemented")
}
return balanceProvider.BalanceAction(channel)
return balanceProvider.Balance(channel)
}

View File

@ -30,32 +30,26 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
c.Request = req
setChannelToContext(c, channel)
switch channel.Type {
case common.ChannelTypePaLM:
request.Model = "PaLM-2"
case common.ChannelTypeAnthropic:
request.Model = "claude-2"
case common.ChannelTypeBaidu:
request.Model = "ERNIE-Bot"
case common.ChannelTypeZhipu:
request.Model = "chatglm_lite"
case common.ChannelTypeAli:
request.Model = "qwen-turbo"
case common.ChannelType360:
request.Model = "360GPT_S2_V9"
case common.ChannelTypeXunfei:
request.Model = "SparkDesk"
c.Set("api_version", channel.Other)
case common.ChannelTypeTencent:
request.Model = "hunyuan"
case common.ChannelTypeAzure:
request.Model = "gpt-3.5-turbo"
c.Set("api_version", channel.Other)
default:
request.Model = "gpt-3.5-turbo"
// 创建映射
channelTypeToModel := map[int]string{
common.ChannelTypePaLM: "PaLM-2",
common.ChannelTypeAnthropic: "claude-2",
common.ChannelTypeBaidu: "ERNIE-Bot",
common.ChannelTypeZhipu: "chatglm_lite",
common.ChannelTypeAli: "qwen-turbo",
common.ChannelType360: "360GPT_S2_V9",
common.ChannelTypeXunfei: "SparkDesk",
common.ChannelTypeTencent: "hunyuan",
common.ChannelTypeAzure: "gpt-3.5-turbo",
}
// 从映射中获取模型名称
model, ok := channelTypeToModel[channel.Type]
if !ok {
model = "gpt-3.5-turbo" // 默认值
}
request.Model = model
provider := providers.GetProvider(channel.Type, c)
if provider == nil {
return errors.New("channel not implemented"), nil
@ -65,18 +59,16 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
return errors.New("channel not implemented"), nil
}
isModelMapped := false
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
return err, nil
}
if modelMap != nil && modelMap[request.Model] != "" {
request.Model = modelMap[request.Model]
isModelMapped = true
}
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, isModelMapped, promptTokens)
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
if openAIErrorWithStatusCode != nil {
return nil, &openAIErrorWithStatusCode.OpenAIError
}

1
go.mod
View File

@ -42,6 +42,7 @@ require (
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/joho/godotenv v1.5.1 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect

2
go.sum
View File

@ -80,6 +80,8 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=

View File

@ -0,0 +1,30 @@
package aigc2d
import (
"errors"
"one-api/common"
"one-api/model"
"one-api/providers/base"
)
func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response base.BalanceResponse
_, errWithCode := common.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
channel.UpdateBalance(response.TotalAvailable)
return response.TotalAvailable, nil
}

20
providers/aigc2d/base.go Normal file
View File

@ -0,0 +1,20 @@
package aigc2d
import (
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type Aigc2dProviderFactory struct{}
func (f Aigc2dProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &Aigc2dProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aigc2d.com"),
}
}
type Aigc2dProvider struct {
*openai.OpenAIProvider
}

View File

@ -4,10 +4,11 @@ import (
"errors"
"one-api/common"
"one-api/model"
"one-api/providers/base"
)
func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := "https://api.aigc2d.com/dashboard/billing/credit_grants"
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
headers := p.GetRequestHeaders()
client := common.NewClient()
@ -17,7 +18,7 @@ func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
}
// 发送请求
var response APGC2DGPTUsageResponse
var response base.BalanceResponse
_, errWithCode := common.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)

View File

@ -0,0 +1,30 @@
package api2gpt
import (
"errors"
"one-api/common"
"one-api/model"
"one-api/providers/base"
)
func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response base.BalanceResponse
_, errWithCode := common.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
channel.UpdateBalance(response.TotalAvailable)
return response.TotalRemaining, nil
}

20
providers/api2gpt/base.go Normal file
View File

@ -0,0 +1,20 @@
package api2gpt
import (
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type Api2gptProviderFactory struct{}
func (f Api2gptProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &Api2gptProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.api2gpt.com"),
}
}
type Api2gptProvider struct {
*openai.OpenAIProvider
}

View File

@ -75,7 +75,7 @@ type ImageVariationsInterface interface {
// 余额接口
type BalanceInterface interface {
BalanceAction(channel *model.Channel) (float64, error)
Balance(channel *model.Channel) (float64, error)
}
type ProviderResponseHandler interface {

View File

@ -1,9 +1,9 @@
package api2d
package base
type APGC2DGPTUsageResponse struct {
//Grants interface{} `json:"grants"`
type BalanceResponse struct {
Object string `json:"object"`
TotalAvailable float64 `json:"total_available"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalRemaining float64 `json:"total_remaining"`
TotalAvailable float64 `json:"total_available"`
}

View File

@ -2,9 +2,11 @@ package providers
import (
"one-api/common"
"one-api/providers/aigc2d"
"one-api/providers/aiproxy"
"one-api/providers/ali"
"one-api/providers/api2d"
"one-api/providers/api2gpt"
"one-api/providers/azure"
"one-api/providers/baidu"
"one-api/providers/base"
@ -43,6 +45,9 @@ func init() {
providerFactories[common.ChannelTypeAPI2D] = api2d.Api2dProviderFactory{}
providerFactories[common.ChannelTypeCloseAI] = closeai.CloseaiProviderFactory{}
providerFactories[common.ChannelTypeOpenAISB] = openaisb.OpenaiSBProviderFactory{}
providerFactories[common.ChannelTypeAIGC2D] = aigc2d.Aigc2dProviderFactory{}
providerFactories[common.ChannelTypeAPI2GPT] = api2gpt.Api2gptProviderFactory{}
}
// 获取供应商