🐛 fix: 修复余额的问题
This commit is contained in:
parent
58fc40a744
commit
c97c8a0f65
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,3 +8,4 @@ build
|
||||
logs
|
||||
data
|
||||
tmp/
|
||||
.env
|
@ -6,6 +6,8 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -23,6 +25,11 @@ func printHelp() {
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 加载.env文件
|
||||
err := godotenv.Load()
|
||||
if err != nil {
|
||||
SysLog("failed to load .env file: " + err.Error())
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
if *PrintVersion {
|
||||
|
@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers"
|
||||
@ -46,7 +47,18 @@ type OpenAIUsageResponse struct {
|
||||
}
|
||||
|
||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
provider := providers.GetProvider(channel.Type, nil)
|
||||
req, err := http.NewRequest("POST", "/balance", nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
setChannelToContext(c, channel)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
provider := providers.GetProvider(channel.Type, c)
|
||||
if provider == nil {
|
||||
return 0, errors.New("provider not found")
|
||||
}
|
||||
@ -56,7 +68,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
return 0, errors.New("provider not implemented")
|
||||
}
|
||||
|
||||
return balanceProvider.BalanceAction(channel)
|
||||
return balanceProvider.Balance(channel)
|
||||
|
||||
}
|
||||
|
||||
|
@ -30,32 +30,26 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
||||
c.Request = req
|
||||
|
||||
setChannelToContext(c, channel)
|
||||
|
||||
switch channel.Type {
|
||||
case common.ChannelTypePaLM:
|
||||
request.Model = "PaLM-2"
|
||||
case common.ChannelTypeAnthropic:
|
||||
request.Model = "claude-2"
|
||||
case common.ChannelTypeBaidu:
|
||||
request.Model = "ERNIE-Bot"
|
||||
case common.ChannelTypeZhipu:
|
||||
request.Model = "chatglm_lite"
|
||||
case common.ChannelTypeAli:
|
||||
request.Model = "qwen-turbo"
|
||||
case common.ChannelType360:
|
||||
request.Model = "360GPT_S2_V9"
|
||||
case common.ChannelTypeXunfei:
|
||||
request.Model = "SparkDesk"
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeTencent:
|
||||
request.Model = "hunyuan"
|
||||
case common.ChannelTypeAzure:
|
||||
request.Model = "gpt-3.5-turbo"
|
||||
c.Set("api_version", channel.Other)
|
||||
default:
|
||||
request.Model = "gpt-3.5-turbo"
|
||||
// 创建映射
|
||||
channelTypeToModel := map[int]string{
|
||||
common.ChannelTypePaLM: "PaLM-2",
|
||||
common.ChannelTypeAnthropic: "claude-2",
|
||||
common.ChannelTypeBaidu: "ERNIE-Bot",
|
||||
common.ChannelTypeZhipu: "chatglm_lite",
|
||||
common.ChannelTypeAli: "qwen-turbo",
|
||||
common.ChannelType360: "360GPT_S2_V9",
|
||||
common.ChannelTypeXunfei: "SparkDesk",
|
||||
common.ChannelTypeTencent: "hunyuan",
|
||||
common.ChannelTypeAzure: "gpt-3.5-turbo",
|
||||
}
|
||||
|
||||
// 从映射中获取模型名称
|
||||
model, ok := channelTypeToModel[channel.Type]
|
||||
if !ok {
|
||||
model = "gpt-3.5-turbo" // 默认值
|
||||
}
|
||||
request.Model = model
|
||||
|
||||
provider := providers.GetProvider(channel.Type, c)
|
||||
if provider == nil {
|
||||
return errors.New("channel not implemented"), nil
|
||||
@ -65,18 +59,16 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
||||
return errors.New("channel not implemented"), nil
|
||||
}
|
||||
|
||||
isModelMapped := false
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if modelMap != nil && modelMap[request.Model] != "" {
|
||||
request.Model = modelMap[request.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
|
||||
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, isModelMapped, promptTokens)
|
||||
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return nil, &openAIErrorWithStatusCode.OpenAIError
|
||||
}
|
||||
|
1
go.mod
1
go.mod
@ -42,6 +42,7 @@ require (
|
||||
github.com/jackc/pgx/v5 v5.3.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/joho/godotenv v1.5.1 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
|
2
go.sum
2
go.sum
@ -80,6 +80,8 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
|
30
providers/aigc2d/balance.go
Normal file
30
providers/aigc2d/balance.go
Normal file
@ -0,0 +1,30 @@
|
||||
package aigc2d
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
)
|
||||
|
||||
func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
var response base.BalanceResponse
|
||||
_, errWithCode := common.SendRequest(req, &response, false)
|
||||
if errWithCode != nil {
|
||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||
}
|
||||
|
||||
channel.UpdateBalance(response.TotalAvailable)
|
||||
|
||||
return response.TotalAvailable, nil
|
||||
}
|
20
providers/aigc2d/base.go
Normal file
20
providers/aigc2d/base.go
Normal file
@ -0,0 +1,20 @@
|
||||
package aigc2d
|
||||
|
||||
import (
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Aigc2dProviderFactory struct{}
|
||||
|
||||
func (f Aigc2dProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
return &Aigc2dProvider{
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aigc2d.com"),
|
||||
}
|
||||
}
|
||||
|
||||
type Aigc2dProvider struct {
|
||||
*openai.OpenAIProvider
|
||||
}
|
@ -4,10 +4,11 @@ import (
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
)
|
||||
|
||||
func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
fullRequestURL := "https://api.aigc2d.com/dashboard/billing/credit_grants"
|
||||
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
@ -17,7 +18,7 @@ func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
var response APGC2DGPTUsageResponse
|
||||
var response base.BalanceResponse
|
||||
_, errWithCode := common.SendRequest(req, &response, false)
|
||||
if errWithCode != nil {
|
||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||
|
30
providers/api2gpt/balance.go
Normal file
30
providers/api2gpt/balance.go
Normal file
@ -0,0 +1,30 @@
|
||||
package api2gpt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
)
|
||||
|
||||
func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
var response base.BalanceResponse
|
||||
_, errWithCode := common.SendRequest(req, &response, false)
|
||||
if errWithCode != nil {
|
||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||
}
|
||||
|
||||
channel.UpdateBalance(response.TotalAvailable)
|
||||
|
||||
return response.TotalRemaining, nil
|
||||
}
|
20
providers/api2gpt/base.go
Normal file
20
providers/api2gpt/base.go
Normal file
@ -0,0 +1,20 @@
|
||||
package api2gpt
|
||||
|
||||
import (
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Api2gptProviderFactory struct{}
|
||||
|
||||
func (f Api2gptProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
return &Api2gptProvider{
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.api2gpt.com"),
|
||||
}
|
||||
}
|
||||
|
||||
type Api2gptProvider struct {
|
||||
*openai.OpenAIProvider
|
||||
}
|
@ -75,7 +75,7 @@ type ImageVariationsInterface interface {
|
||||
|
||||
// 余额接口
|
||||
type BalanceInterface interface {
|
||||
BalanceAction(channel *model.Channel) (float64, error)
|
||||
Balance(channel *model.Channel) (float64, error)
|
||||
}
|
||||
|
||||
type ProviderResponseHandler interface {
|
||||
|
@ -1,9 +1,9 @@
|
||||
package api2d
|
||||
package base
|
||||
|
||||
type APGC2DGPTUsageResponse struct {
|
||||
//Grants interface{} `json:"grants"`
|
||||
type BalanceResponse struct {
|
||||
Object string `json:"object"`
|
||||
TotalAvailable float64 `json:"total_available"`
|
||||
TotalGranted float64 `json:"total_granted"`
|
||||
TotalUsed float64 `json:"total_used"`
|
||||
TotalRemaining float64 `json:"total_remaining"`
|
||||
TotalAvailable float64 `json:"total_available"`
|
||||
}
|
@ -2,9 +2,11 @@ package providers
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/providers/aigc2d"
|
||||
"one-api/providers/aiproxy"
|
||||
"one-api/providers/ali"
|
||||
"one-api/providers/api2d"
|
||||
"one-api/providers/api2gpt"
|
||||
"one-api/providers/azure"
|
||||
"one-api/providers/baidu"
|
||||
"one-api/providers/base"
|
||||
@ -43,6 +45,9 @@ func init() {
|
||||
providerFactories[common.ChannelTypeAPI2D] = api2d.Api2dProviderFactory{}
|
||||
providerFactories[common.ChannelTypeCloseAI] = closeai.CloseaiProviderFactory{}
|
||||
providerFactories[common.ChannelTypeOpenAISB] = openaisb.OpenaiSBProviderFactory{}
|
||||
providerFactories[common.ChannelTypeAIGC2D] = aigc2d.Aigc2dProviderFactory{}
|
||||
providerFactories[common.ChannelTypeAPI2GPT] = api2gpt.Api2gptProviderFactory{}
|
||||
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
|
Loading…
Reference in New Issue
Block a user