账单作为可选开关

This commit is contained in:
yu.deng 2024-06-13 09:27:35 +08:00
parent ed717211aa
commit f719606948
15 changed files with 484 additions and 201 deletions

3
.gitignore vendored
View File

@ -8,4 +8,5 @@ build
logs
data
/web/node_modules
cmd.md
cmd.md
vendor/*

View File

@ -1,13 +1,14 @@
package config
import (
"github.com/songquanpeng/one-api/common/env"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/songquanpeng/one-api/common/env"
"github.com/google/uuid"
)
@ -135,6 +136,7 @@ var (
var RateLimitKeyExpirationDuration = 20 * time.Minute
var EnableBilling = env.Bool("ENABLE_BILLING", true)
var EnableMetric = env.Bool("ENABLE_METRIC", false)
var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)

View File

@ -23,24 +23,49 @@ import (
// https://platform.openai.com/docs/api-reference/chat
func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
type Options struct {
Debug bool
EnableMonitor bool
EnableBilling bool
}
type RelayController struct {
opts Options
controller.RelayInstance
monitor.MonitorInstance
}
func NewRelayController(opts Options) *RelayController {
ctrl := &RelayController{
opts: opts,
}
ctrl.RelayInstance = controller.NewRelayInstance(controller.Options{
EnableBilling: opts.EnableBilling,
})
if opts.EnableMonitor {
ctrl.MonitorInstance = monitor.NewMonitorInstance()
}
return ctrl
}
func (ctrl *RelayController) relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
var err *model.ErrorWithStatusCode
switch relayMode {
case relaymode.ImagesGenerations:
err = controller.RelayImageHelper(c, relayMode)
err = ctrl.RelayImageHelper(c, relayMode)
case relaymode.AudioSpeech:
fallthrough
case relaymode.AudioTranslation:
fallthrough
case relaymode.AudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
err = ctrl.RelayAudioHelper(c, relayMode)
default:
err = controller.RelayTextHelper(c)
err = ctrl.RelayTextHelper(c)
}
return err
}
func Relay(c *gin.Context) {
func (ctrl *RelayController) Relay(c *gin.Context) {
ctx := c.Request.Context()
relayMode := relaymode.GetByPath(c.Request.URL.Path)
if config.DebugEnabled {
@ -48,17 +73,19 @@ func Relay(c *gin.Context) {
logger.Debugf(ctx, "request body: %s", string(requestBody))
}
channelId := c.GetInt(ctxkey.ChannelId)
userId := c.GetInt("id")
bizErr := relayHelper(c, relayMode)
bizErr := ctrl.relayHelper(c, relayMode)
if bizErr == nil {
monitor.Emit(channelId, true)
if ctrl.MonitorInstance != nil {
ctrl.Emit(channelId, true)
}
return
}
lastFailedChannelId := channelId
channelName := c.GetString(ctxkey.ChannelName)
group := c.GetString(ctxkey.Group)
originalModel := c.GetString(ctxkey.OriginalModel)
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
userId := c.GetInt(ctxkey.Id)
go ctrl.processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
requestId := c.GetString(helper.RequestIdKey)
retryTimes := config.RetryTimes
if !shouldRetry(c, bizErr.StatusCode) {
@ -77,15 +104,19 @@ func Relay(c *gin.Context) {
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
if err != nil {
logger.Errorf(ctx, "GetRequestBody failed: %+v", err)
break
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
bizErr = relayHelper(c, relayMode)
bizErr = ctrl.relayHelper(c, relayMode)
if bizErr == nil {
return
}
channelId := c.GetInt(ctxkey.ChannelId)
lastFailedChannelId = channelId
channelName := c.GetString(ctxkey.ChannelName)
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
go ctrl.processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
}
if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests {
@ -117,13 +148,16 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
return true
}
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
func (ctrl *RelayController) processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
if ctrl.MonitorInstance == nil {
return
}
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
monitor.DisableChannel(channelId, channelName, err.Message)
if ctrl.ShouldDisableChannel(&err.Error, err.StatusCode) {
ctrl.DisableChannel(channelId, channelName, err.Message)
} else {
monitor.Emit(channelId, false)
ctrl.Emit(channelId, false)
}
}

49
go.sum
View File

@ -1,40 +1,25 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo=
github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
github.com/aws/aws-sdk-go-v2/credentials v1.17.15 h1:YDexlvDRCA8ems2T5IP1xkMtOZ1uLJOCJdTr0igs5zo=
github.com/aws/aws-sdk-go-v2/credentials v1.17.15/go.mod h1:vxHggqW6hFNaeNC0WyXS3VdyjcV0a4KMUY4dKJ96buU=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 h1:Fihjyd6DeNjcawBEGLH9dkIEUi6AdhucDKPE9nJ4QiY=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3/go.mod h1:opvUj3ismqSCxYc+m4WIjPL0ewZGtvp0ess7cKvBPOQ=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/bytedance/sonic v1.11.5 h1:G00FYjjqll5iQ1PYXynbg/hyzqBqavH8Mo9/oTopd9k=
github.com/bytedance/sonic v1.11.5/go.mod h1:X2PC2giUdj/Cv2lliWFLk6c/DUQok5rViJSemeB0wDw=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.0/go.mod h1:UmRT+IRTGKz/DAkzcEGzyVqQFJ7H9BqwBO3pm9H/+HY=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cloudwego/base64x v0.1.3 h1:b5J/l8xolB7dyDTTmhJP2oTs5LdrjyrUFuNxdfq5hAg=
github.com/cloudwego/base64x v0.1.3/go.mod h1:1+1K5BUHIQzyapgpF7LwvOGAEDicKtt1umPV+aN8pi8=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
@ -51,26 +36,16 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/cors v1.7.1 h1:s9SIppU/rk8enVvkzwiC2VK3UZ/0NNGsWfUKvV55rqs=
github.com/gin-contrib/cors v1.7.1/go.mod h1:n/Zj7B4xyrgk/cX1WCX2dkzFfaNm/xJb6oIUk7WTtps=
github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw=
github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
github.com/gin-contrib/gzip v1.0.0 h1:UKN586Po/92IDX6ie5CWLgMI81obiIp5nSP85T3wlTk=
github.com/gin-contrib/gzip v1.0.0/go.mod h1:CtG7tQrPB3vIBo6Gat9FVUsis+1emjvQqd66ME5TdnE=
github.com/gin-contrib/gzip v1.0.1 h1:HQ8ENHODeLY7a4g1Au/46Z92bdGFl74OhxcZble9WJE=
github.com/gin-contrib/gzip v1.0.1/go.mod h1:njt428fdUNRvjuJf16tZMYZ2Yl+WQB53X5wmhDwXvC4=
github.com/gin-contrib/sessions v1.0.0 h1:r5GLta4Oy5xo9rAwMHx8B4wLpeRGHMdz9NafzJAdP8Y=
github.com/gin-contrib/sessions v1.0.0/go.mod h1:DN0f4bvpqMQElDdi+gNGScrP2QEI04IErRyMFyorUOI=
github.com/gin-contrib/sessions v1.0.1 h1:3hsJyNs7v7N8OtelFmYXFrulAf6zSR7nW/putcPEHxI=
github.com/gin-contrib/sessions v1.0.1/go.mod h1:ouxSFM24/OgIud5MJYQJLpy6AwxQ5EYO9yLhbtObGkM=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-contrib/static v1.1.1 h1:XEvBd4DDLG1HBlyPBQU1XO8NlTpw6mgdqcPteetYA5k=
github.com/gin-contrib/static v1.1.1/go.mod h1:yRGmar7+JYvbMLRPIi4H5TVVSBwULfT9vetnVD0IO74=
github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0NglqmlZ4=
github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
@ -78,8 +53,6 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4=
github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
@ -87,8 +60,6 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
@ -147,14 +118,10 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg=
github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@ -181,37 +148,23 @@ github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc=
golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw=
golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs=
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@ -228,8 +181,6 @@ gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4c
gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E=
gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE=
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.9 h1:wct0gxZIELDk8+ZqF/MVnHLkA1rvYlBWUMv2EdsK1g8=
gorm.io/gorm v1.25.9/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=
gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=

32
monitor/monitor.go Normal file
View File

@ -0,0 +1,32 @@
package monitor
import "github.com/songquanpeng/one-api/relay/model"
type MonitorInstance interface {
Emit(ChannelId int, success bool)
ShouldDisableChannel(err *model.Error, statusCode int) bool
DisableChannel(channelId int, channelName string, reason string)
}
type defaultMonitor struct {
}
func NewMonitorInstance() MonitorInstance {
return &defaultMonitor{}
}
func (m *defaultMonitor) Emit(channelId int, success bool) {
if success {
metricSuccessChan <- channelId
} else {
metricFailChan <- channelId
}
}
func (m *defaultMonitor) ShouldDisableChannel(err *model.Error, statusCode int) bool {
return ShouldDisableChannel(err, statusCode)
}
func (m *defaultMonitor) DisableChannel(channelId int, channelName string, reason string) {
DisableChannel(channelId, channelName, reason)
}

View File

@ -0,0 +1,160 @@
package billing
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
// Bookkeeper 记账员逻辑,用于处理用户的配额消费
// 预扣配额检测逻辑:
//
// 开始请求前根据不同的请求类型预先计算需要消费的配额根据请求用户和token的配额余量来判断是否有足够的配额来满足这个请求
// 如果余量配额不能满足这个请求,直接返回错误, 如果余量配额可以满足这个请求,那么预先消费这个配额,然后开始请求, 如果余量远远超过这个请求,那么不需要预先消费配额
// 由于预先计算的配额不是实际消费的配额所以需要在请求结束后根据实际消费的配额来更新用户和token的配额退费或者扣费。
type Bookkeeper interface {
// 获取模型的费率
ModelRatio(model string) float64
// 获取组的费率
GroupRation(group string) float64
// 获取模型的补全费率
ModelCompletionRatio(model string) float64
// 根据消费记录扣除用户token 的配额
Consume(ctx context.Context, consumeLog *ConsumeLog)
// 预消费配额, 当用户配额不足时,预消费配额, 预消费成功返回预消费的配额,失败返回错误, 如果预消费的配额为0表示用户有足够的配额
PreConsumeQuota(ctx context.Context, preConsumedQuota int64, userId, tokenId int) (int64, *relaymodel.ErrorWithStatusCode)
// 退回预消费的配额, 这通常在调用上游api失败的时候执行
RefundQuota(ctx context.Context, preConsumedQuota int64, tokenId int)
// 检测用户是否有足够的配额
// UserHasEnoughQuota(ctx context.Context, userID int, quota int64) bool
// 检测用户是否有远远超过需求的配额, 如果用户的配额远远超过需求,那么不需要预消费配额
// UserHasMuchMoreQuota(ctx context.Context, userID int, quota int64) bool
}
type defaultBookkeeper struct {
}
func NewBookkeeper() Bookkeeper {
return &defaultBookkeeper{}
}
func (b *defaultBookkeeper) ModelRatio(model string) float64 {
return billingratio.GetModelRatio(model)
}
func (b *defaultBookkeeper) GroupRation(group string) float64 {
return billingratio.GetGroupRatio(group)
}
func (b *defaultBookkeeper) ModelCompletionRatio(model string) float64 {
return billingratio.GetCompletionRatio(model)
}
func (b *defaultBookkeeper) Ratio(group, model string) float64 {
modelRatio := billingratio.GetModelRatio(model)
groupRatio := billingratio.GetGroupRatio(group)
return modelRatio * groupRatio
}
// ConsumeLog 消费记录实体
type ConsumeLog struct {
UserId int
ChannelId int
PromptTokens int
CompletionTokens int
ModelName string
TokenId int
TokenName string
Quota int64
Content string
PreConsumedQuota int64
}
func (b *defaultBookkeeper) UserHasEnoughQuota(ctx context.Context, userID int, quota int64) bool {
userQuota, err := model.CacheGetUserQuota(ctx, userID)
if err != nil {
return false
}
return userQuota >= quota
}
func (b *defaultBookkeeper) UserHasMuchMoreQuota(ctx context.Context, userID int, quota int64) bool {
userQuota, err := model.CacheGetUserQuota(ctx, userID)
if err != nil {
return false
}
return userQuota > 100*quota
}
func (b *defaultBookkeeper) Consume(ctx context.Context, consumeLog *ConsumeLog) {
// 更新 access_token 的配额
quotaDelta := consumeLog.Quota - consumeLog.PreConsumedQuota
err := model.PostConsumeTokenQuota(consumeLog.TokenId, quotaDelta)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(ctx, consumeLog.UserId)
if err != nil {
logger.SysError("error update user quota cache: " + err.Error())
}
// 更新用户的配额
model.UpdateUserUsedQuotaAndRequestCount(consumeLog.UserId, consumeLog.Quota)
// 更新渠道的配额
model.UpdateChannelUsedQuota(consumeLog.ChannelId, consumeLog.Quota)
// 记录消费日志
model.RecordConsumeLog(
ctx,
consumeLog.UserId,
consumeLog.ChannelId,
consumeLog.PromptTokens,
consumeLog.CompletionTokens,
consumeLog.ModelName,
consumeLog.TokenName,
consumeLog.Quota,
consumeLog.Content,
)
}
func (b *defaultBookkeeper) PreConsumeQuota(ctx context.Context, preConsumedQuota int64, userId, tokenId int) (int64, *relaymodel.ErrorWithStatusCode) {
userQuota, err := model.CacheGetUserQuota(ctx, userId)
if err != nil {
return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return preConsumedQuota, openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return preConsumedQuota, openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
logger.Info(ctx, fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return preConsumedQuota, openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
return preConsumedQuota, nil
}
func (b *defaultBookkeeper) RefundQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
if preConsumedQuota != 0 {
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
}
}
}

View File

@ -3,10 +3,12 @@ package billing
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
)
// ReturnPreConsumedQuota 在请求失败的时候,退回预消费的配额
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
if preConsumedQuota != 0 {
go func(ctx context.Context) {

View File

@ -7,36 +7,33 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
)
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
func (rl *defaultRelay) RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
meta := meta.GetByContext(c)
audioModel := "whisper-1"
tokenId := c.GetInt(ctxkey.TokenId)
channelType := c.GetInt(ctxkey.Channel)
channelId := c.GetInt(ctxkey.ChannelId)
// channelId := c.GetInt(ctxkey.ChannelId)
userId := c.GetInt(ctxkey.Id)
group := c.GetString(ctxkey.Group)
tokenName := c.GetString(ctxkey.TokenName)
// tokenName := c.GetString(ctxkey.TokenName)
var ttsRequest openai.TextToSpeechRequest
if relayMode == relaymode.AudioSpeech {
@ -53,58 +50,45 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}
modelRatio := billingratio.GetModelRatio(audioModel)
groupRatio := billingratio.GetGroupRatio(group)
ratio := modelRatio * groupRatio
var quota int64
var preConsumedQuota int64
switch relayMode {
case relaymode.AudioSpeech:
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
default:
preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio)
}
userQuota, err := model.CacheGetUserQuota(ctx, userId)
if err != nil {
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
var (
modelRatio float64
groupRatio float64
ratio float64
quota int64
preConsumeQuota int64
preConsumedQuota int64
bizErr *relaymodel.ErrorWithStatusCode
)
if rl.Bookkeeper != nil {
modelRatio = rl.ModelRatio(audioModel)
groupRatio = rl.GroupRation(group)
ratio = modelRatio * groupRatio
}
// Check if user quota is enough
if userQuota-preConsumedQuota < 0 {
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
switch relayMode {
// speech 类型,消费的配额直接根据输入的文本长度计算
case relaymode.AudioSpeech:
preConsumeQuota = int64(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumeQuota
// 其他类型,假设消费的配额是预设的配额的 ratio 倍
default:
preConsumeQuota = int64(float64(config.PreConsumedQuota) * ratio)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
}
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
if rl.Bookkeeper != nil {
preConsumedQuota, bizErr = rl.PreConsumeQuota(c, preConsumeQuota, userId, tokenId)
if bizErr != nil {
return bizErr
}
}
succeed := false
defer func() {
if succeed {
return
}
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota
defer func(ctx context.Context) {
go func() {
// negative means add quota back for token & user
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
}
}()
}(c.Request.Context())
if rl.Bookkeeper != nil {
rl.Bookkeeper.RefundQuota(c.Request.Context(), preConsumedQuota, tokenId)
}
}()
@ -140,8 +124,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
requestBody := &bytes.Buffer{}
_, err = io.Copy(requestBody, c.Request.Body)
if err != nil {
if _, err := io.Copy(requestBody, c.Request.Body); err != nil {
return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
@ -220,9 +203,28 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return RelayErrorHandler(resp)
}
succeed = true
quotaDelta := quota - preConsumedQuota
// quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) {
go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
// go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
// post-consume quota
if rl.Bookkeeper != nil {
// go postConsumeQuota(c, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
consumeLog := &billing.ConsumeLog{
UserId: meta.UserId,
ChannelId: meta.ChannelId,
ModelName: audioModel,
TokenName: c.GetString(ctxkey.TokenName),
TokenId: meta.TokenId,
Quota: quota,
Content: logContent,
PromptTokens: int(preConsumeQuota),
CompletionTokens: 0,
PreConsumedQuota: preConsumedQuota,
}
rl.Bookkeeper.Consume(c, consumeLog)
}
}(c.Request.Context())
for k, v := range resp.Header {

View File

@ -0,0 +1,30 @@
package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/billing"
"github.com/songquanpeng/one-api/relay/model"
)
type Options struct {
EnableBilling bool
}
// RelayInstance is the interface for relay controller
type RelayInstance interface {
RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode
RelayImageHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode
RelayAudioHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode
}
type defaultRelay struct {
billing.Bookkeeper
}
func NewRelayInstance(opts Options) RelayInstance {
relay := &defaultRelay{}
if opts.EnableBilling {
relay.Bookkeeper = billing.NewBookkeeper()
}
return relay
}

View File

@ -4,6 +4,10 @@ import (
"context"
"errors"
"fmt"
"math"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
@ -16,9 +20,6 @@ import (
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"math"
"net/http"
"strings"
)
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
@ -208,10 +209,7 @@ func getMappedModelName(modelName string, mapping map[string]string) (string, bo
func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
if resp == nil {
if meta.ChannelType == channeltype.AwsClaude {
return false
}
return true
return meta.ChannelType != channeltype.Azure
}
if resp.StatusCode != http.StatusOK {
return true

View File

@ -4,20 +4,20 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
func isWithinRange(element string, value int) bool {
@ -29,33 +29,40 @@ func isWithinRange(element string, value int) bool {
return value >= min && value <= max
}
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
func (rl *defaultRelay) RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
meta := meta.GetByContext(c)
imageRequest, err := getImageRequest(c, meta.Mode)
if err != nil {
logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
logger.Errorf(c, "getImageRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
}
// map model name
var isModelMapped bool
var (
isModelMapped bool
preConsumeQuota int64
preConsumedQuota int64
imageCostRatio float64
bizErr *relaymodel.ErrorWithStatusCode
)
meta.OriginModelName = imageRequest.Model
imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
meta.ActualModelName = imageRequest.Model
// model validation
bizErr := validateImageRequest(imageRequest, meta)
bizErr = validateImageRequest(imageRequest, meta)
if bizErr != nil {
return bizErr
}
imageCostRatio, err := getImageCostRatio(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
if rl.Bookkeeper != nil {
imageCostRatio, err = getImageCostRatio(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
}
}
imageModel := imageRequest.Model
originModel := imageRequest.Model
// Convert the original image model
imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName)
c.Set("response_format", imageRequest.ResponseFormat)
@ -94,21 +101,28 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = bytes.NewBuffer(jsonStr)
}
modelRatio := billingratio.GetModelRatio(imageModel)
groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
if userQuota-quota < 0 {
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
if rl.Bookkeeper != nil {
modelRatio := rl.ModelRatio(originModel)
groupRatio := rl.GroupRation(meta.Group)
ratio := modelRatio * groupRatio
preConsumeQuota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
preConsumedQuota, bizErr = rl.PreConsumeQuota(c, preConsumeQuota, meta.UserId, meta.TokenId)
if bizErr != nil {
logger.Warnf(c, "preConsumeQuota failed: %+v", *bizErr)
return bizErr
}
}
refund := func() {
if rl.Bookkeeper != nil && preConsumedQuota > 0 {
rl.RefundQuota(c, preConsumedQuota, meta.TokenId)
}
}
// do request
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
logger.Errorf(c, "DoRequest failed: %s", err.Error())
refund()
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
@ -116,29 +130,37 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
if resp != nil && resp.StatusCode != http.StatusOK {
return
}
if rl.Bookkeeper == nil {
return
}
modelRatio := rl.ModelRatio(originModel)
groupRatio := rl.GroupRation(meta.Group)
ratio := modelRatio * groupRatio
consumedQuota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
err := model.PostConsumeTokenQuota(meta.TokenId, quota)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(ctx, meta.UserId)
if err != nil {
logger.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
if consumedQuota != 0 {
tokenName := c.GetString(ctxkey.TokenName)
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
channelId := c.GetInt(ctxkey.ChannelId)
model.UpdateChannelUsedQuota(channelId, quota)
consumeLog := &billing.ConsumeLog{
UserId: meta.UserId,
ChannelId: meta.ChannelId,
ModelName: imageRequest.Model,
TokenName: tokenName,
TokenId: meta.TokenId,
Quota: consumedQuota,
Content: logContent,
PromptTokens: 0,
CompletionTokens: 0,
PreConsumedQuota: preConsumedQuota,
}
rl.Bookkeeper.Consume(c, consumeLog)
}
}(c.Request.Context())
}(c)
// do response
_, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
logger.Errorf(c, "respErr is not nil: %+v", respErr)
return respErr
}

View File

@ -4,27 +4,27 @@ import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
ctx := c.Request.Context()
func (rl *defaultRelay) RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
meta := meta.GetByContext(c)
// get & validate textRequest
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
if err != nil {
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
logger.Errorf(c, "getAndValidateTextRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
}
meta.IsStream = textRequest.Stream
@ -35,18 +35,26 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model)
groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
// pre-consume quota
promptTokens := getPromptTokens(textRequest, meta.Mode)
meta.PromptTokens = promptTokens
preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta)
if bizErr != nil {
logger.Warnf(ctx, "preConsumeQuota failed: %+v", *bizErr)
return bizErr
var (
preConsumedQuota int64
modelRatio float64
groupRatio float64
ratio float64
)
if rl.Bookkeeper != nil {
modelRatio = rl.ModelRatio(textRequest.Model)
groupRatio = rl.GroupRation(meta.Group)
ratio = modelRatio * groupRatio
// pre-consume quota
meta.PromptTokens = getPromptTokens(textRequest, meta.Mode)
preConsumeQuota := getPreConsumedQuota(textRequest, meta.PromptTokens, ratio)
consumedQuota, bizErr := rl.PreConsumeQuota(c, preConsumeQuota, meta.UserId, meta.TokenId)
if bizErr != nil {
logger.Warnf(c, "preConsumeQuota failed: %+v", *bizErr)
return bizErr
}
preConsumedQuota = consumedQuota
}
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
@ -76,29 +84,51 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
logger.Debugf(c, "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
}
// do request
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
logger.Errorf(c, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
refund := func() {
if rl.Bookkeeper != nil && preConsumedQuota > 0 {
rl.RefundQuota(c, preConsumedQuota, meta.TokenId)
}
}
if isErrorHappened(meta, resp) {
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
refund()
return RelayErrorHandler(resp)
}
// do response
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
logger.Errorf(c, "respErr is not nil: %+v", respErr)
refund()
return respErr
}
// post-consume quota
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
if rl.Bookkeeper != nil {
// go postConsumeQuota(c, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
completionRatio := rl.ModelCompletionRatio(textRequest.Model)
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
consumeLog := &billing.ConsumeLog{
UserId: meta.UserId,
ChannelId: meta.ChannelId,
ModelName: textRequest.Model,
TokenName: c.GetString(ctxkey.TokenName),
TokenId: meta.TokenId,
Quota: usage.Quota(completionRatio, ratio),
Content: logContent,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
PreConsumedQuota: preConsumedQuota,
}
rl.Bookkeeper.Consume(c, consumeLog)
}
return nil
}

View File

@ -1,12 +1,13 @@
package meta
import (
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/relaymode"
"strings"
)
type Meta struct {

View File

@ -1,11 +1,21 @@
package model
import "math"
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
func (u *Usage) Quota(completionRatio, finalRatio float64) int64 {
quota := int64(math.Ceil((float64(u.PromptTokens) + float64(u.CompletionTokens)*completionRatio) * finalRatio))
if finalRatio != 0 && quota <= 0 {
quota = 1
}
return quota
}
type Error struct {
Message string `json:"message"`
Type string `json:"type"`

View File

@ -1,6 +1,7 @@
package router
import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/middleware"
@ -17,19 +18,26 @@ func SetRelayRouter(router *gin.Engine) {
modelsRouter.GET("/:model", controller.RetrieveModel)
}
relayV1Router := router.Group("/v1")
opt := controller.Options{
EnableMonitor: config.EnableMetric,
EnableBilling: config.EnableBilling,
Debug: config.DebugEnabled,
}
ctrl := controller.NewRelayController(opt)
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
{
relayV1Router.POST("/completions", controller.Relay)
relayV1Router.POST("/chat/completions", controller.Relay)
relayV1Router.POST("/edits", controller.Relay)
relayV1Router.POST("/images/generations", controller.Relay)
relayV1Router.POST("/completions", ctrl.Relay)
relayV1Router.POST("/chat/completions", ctrl.Relay)
relayV1Router.POST("/edits", ctrl.Relay)
relayV1Router.POST("/images/generations", ctrl.Relay)
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
relayV1Router.POST("/embeddings", controller.Relay)
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.Relay)
relayV1Router.POST("/audio/translations", controller.Relay)
relayV1Router.POST("/audio/speech", controller.Relay)
relayV1Router.POST("/embeddings", ctrl.Relay)
relayV1Router.POST("/engines/:model/embeddings", ctrl.Relay)
relayV1Router.POST("/audio/transcriptions", ctrl.Relay)
relayV1Router.POST("/audio/translations", ctrl.Relay)
relayV1Router.POST("/audio/speech", ctrl.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
@ -41,7 +49,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented)
relayV1Router.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented)
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.Relay)
relayV1Router.POST("/moderations", ctrl.Relay)
relayV1Router.POST("/assistants", controller.RelayNotImplemented)
relayV1Router.GET("/assistants/:id", controller.RelayNotImplemented)
relayV1Router.POST("/assistants/:id", controller.RelayNotImplemented)