diff --git a/common/constants.go b/common/constants.go
index 7a1694c5..643c2b34 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -15,6 +15,8 @@ var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
+var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
+var DisplayInCurrencyEnabled = false
var UsingSQLite = false
diff --git a/common/logger.go b/common/logger.go
index 0b8b2cfb..3658dbdb 100644
--- a/common/logger.go
+++ b/common/logger.go
@@ -42,3 +42,11 @@ func FatalLog(v ...any) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}
+
+func LogQuota(quota int) string {
+ if DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
+ } else {
+ return fmt.Sprintf("%d 点额度", quota)
+ }
+}
diff --git a/common/redis.go b/common/redis.go
index 56db2b40..8b34a083 100644
--- a/common/redis.go
+++ b/common/redis.go
@@ -37,3 +37,18 @@ func ParseRedisOption() *redis.Options {
}
return opt
}
+
+func RedisSet(key string, value string, expiration time.Duration) error {
+ ctx := context.Background()
+ return RDB.Set(ctx, key, value, expiration).Err()
+}
+
+func RedisGet(key string) (string, error) {
+ ctx := context.Background()
+ return RDB.Get(ctx, key).Result()
+}
+
+func RedisDel(key string) error {
+ ctx := context.Background()
+ return RDB.Del(ctx, key).Err()
+}
diff --git a/controller/billing.go b/controller/billing.go
index 2f0d90fe..18a34dc9 100644
--- a/controller/billing.go
+++ b/controller/billing.go
@@ -2,6 +2,7 @@ package controller
import (
"github.com/gin-gonic/gin"
+ "one-api/common"
"one-api/model"
)
@@ -18,23 +19,38 @@ func GetSubscription(c *gin.Context) {
})
return
}
+ amount := float64(quota)
+ if common.DisplayInCurrencyEnabled {
+ amount /= common.QuotaPerUnit
+ }
subscription := OpenAISubscriptionResponse{
Object: "billing_subscription",
HasPaymentMethod: true,
- SoftLimitUSD: float64(quota),
- HardLimitUSD: float64(quota),
- SystemHardLimitUSD: float64(quota),
+ SoftLimitUSD: amount,
+ HardLimitUSD: amount,
+ SystemHardLimitUSD: amount,
}
c.JSON(200, subscription)
return
}
func GetUsage(c *gin.Context) {
- //userId := c.GetInt("id")
- // TODO: get usage from database
+ userId := c.GetInt("id")
+ quota, err := model.GetUserUsedQuota(userId)
+ if err != nil {
+ openAIError := OpenAIError{
+ Message: err.Error(),
+ Type: "one_api_error",
+ }
+ c.JSON(200, gin.H{
+ "error": openAIError,
+ })
+ return
+ }
+ amount := float64(quota)
usage := OpenAIUsageResponse{
Object: "list",
- TotalUsage: 0,
+ TotalUsage: amount,
}
c.JSON(200, usage)
return
diff --git a/controller/misc.go b/controller/misc.go
index 10f1f99e..755ccbd4 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -14,21 +14,23 @@ func GetStatus(c *gin.Context) {
"success": true,
"message": "",
"data": gin.H{
- "version": common.Version,
- "start_time": common.StartTime,
- "email_verification": common.EmailVerificationEnabled,
- "github_oauth": common.GitHubOAuthEnabled,
- "github_client_id": common.GitHubClientId,
- "system_name": common.SystemName,
- "logo": common.Logo,
- "footer_html": common.Footer,
- "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
- "wechat_login": common.WeChatAuthEnabled,
- "server_address": common.ServerAddress,
- "turnstile_check": common.TurnstileCheckEnabled,
- "turnstile_site_key": common.TurnstileSiteKey,
- "top_up_link": common.TopUpLink,
- "chat_link": common.ChatLink,
+ "version": common.Version,
+ "start_time": common.StartTime,
+ "email_verification": common.EmailVerificationEnabled,
+ "github_oauth": common.GitHubOAuthEnabled,
+ "github_client_id": common.GitHubClientId,
+ "system_name": common.SystemName,
+ "logo": common.Logo,
+ "footer_html": common.Footer,
+ "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
+ "wechat_login": common.WeChatAuthEnabled,
+ "server_address": common.ServerAddress,
+ "turnstile_check": common.TurnstileCheckEnabled,
+ "turnstile_site_key": common.TurnstileSiteKey,
+ "top_up_link": common.TopUpLink,
+ "chat_link": common.ChatLink,
+ "quota_per_unit": common.QuotaPerUnit,
+ "display_in_currency": common.DisplayInCurrencyEnabled,
},
})
return
diff --git a/controller/user.go b/controller/user.go
index c861651a..3060edd4 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -384,7 +384,7 @@ func UpdateUser(c *gin.Context) {
return
}
if originUser.Quota != updatedUser.Quota {
- model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %d 点修改为 %d 点", originUser.Quota, updatedUser.Quota))
+ model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
}
c.JSON(http.StatusOK, gin.H{
"success": true,
diff --git a/main.go b/main.go
index c8656c7a..fb6dc1b5 100644
--- a/main.go
+++ b/main.go
@@ -47,12 +47,18 @@ func main() {
// Initialize options
model.InitOptionMap()
+ if common.RedisEnabled {
+ model.InitChannelCache()
+ }
if os.Getenv("SYNC_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY"))
if err != nil {
common.FatalLog(err)
}
go model.SyncOptions(frequency)
+ if common.RedisEnabled {
+ go model.SyncChannelCache(frequency)
+ }
}
// Initialize HTTP server
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 14658758..5a07683c 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -17,7 +17,7 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
userId := c.GetInt("id")
- userGroup, _ := model.GetUserGroup(userId)
+ userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
var channel *model.Channel
channelId, ok := c.Get("channelId")
@@ -76,7 +76,7 @@ func Distribute() func(c *gin.Context) {
if strings.HasPrefix(modelRequest.Model, "gpt-35-turbo") {
modelRequest.Model = strings.Replace(modelRequest.Model, "gpt-35-turbo", "gpt-3.5-turbo", 1)
}
- channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model)
+ channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
c.JSON(200, gin.H{
"error": gin.H{
diff --git a/model/cache.go b/model/cache.go
new file mode 100644
index 00000000..4ca60496
--- /dev/null
+++ b/model/cache.go
@@ -0,0 +1,99 @@
+package model
+
+import (
+ "encoding/json"
+ "fmt"
+ "one-api/common"
+ "sync"
+ "time"
+)
+
+const (
+ TokenCacheSeconds = 60 * 60
+ UserId2GroupCacheSeconds = 60 * 60
+)
+
+func CacheGetTokenByKey(key string) (*Token, error) {
+ var token Token
+ if !common.RedisEnabled {
+ err := DB.Where("`key` = ?", key).First(&token).Error
+ return &token, err
+ }
+ tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
+ if err != nil {
+ err := DB.Where("`key` = ?", key).First(&token).Error
+ if err != nil {
+ return nil, err
+ }
+ jsonBytes, err := json.Marshal(token)
+ if err != nil {
+ return nil, err
+ }
+ err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), TokenCacheSeconds*time.Second)
+ if err != nil {
+ common.SysError("Redis set token error: " + err.Error())
+ }
+ }
+ err = json.Unmarshal([]byte(tokenObjectString), &token)
+ return &token, err
+}
+
+func CacheGetUserGroup(id int) (group string, err error) {
+ if !common.RedisEnabled {
+ return GetUserGroup(id)
+ }
+ group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
+ if err != nil {
+ group, err = GetUserGroup(id)
+ if err != nil {
+ return "", err
+ }
+ err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, UserId2GroupCacheSeconds*time.Second)
+ if err != nil {
+ common.SysError("Redis set user group error: " + err.Error())
+ }
+ }
+ return group, err
+}
+
+var channelId2channel map[int]*Channel
+var channelSyncLock sync.RWMutex
+var group2model2channels map[string]map[string][]*Channel
+
+func InitChannelCache() {
+ channelSyncLock.Lock()
+ defer channelSyncLock.Unlock()
+ channelId2channel = make(map[int]*Channel)
+ var channels []*Channel
+ DB.Find(&channels)
+ for _, channel := range channels {
+ channelId2channel[channel.Id] = channel
+ }
+ var abilities []*Ability
+ DB.Find(&abilities)
+ groups := make(map[string]bool)
+ for _, ability := range abilities {
+ groups[ability.Group] = true
+ }
+ group2model2channels = make(map[string]map[string][]*Channel)
+ for group := range groups {
+ group2model2channels[group] = make(map[string][]*Channel)
+ // TODO: implement this
+ }
+}
+
+func SyncChannelCache(frequency int) {
+ for {
+ time.Sleep(time.Duration(frequency) * time.Second)
+ common.SysLog("Syncing channels from database")
+ InitChannelCache()
+ }
+}
+
+func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
+ if !common.RedisEnabled {
+ return GetRandomSatisfiedChannel(group, model)
+ }
+ // TODO: implement this
+ return nil, nil
+}
diff --git a/model/option.go b/model/option.go
index b53b172d..12b1ce70 100644
--- a/model/option.go
+++ b/model/option.go
@@ -35,6 +35,7 @@ func InitOptionMap() {
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
+ common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = ""
@@ -64,6 +65,7 @@ func InitOptionMap() {
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
+ common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@@ -140,6 +142,8 @@ func updateOptionMap(key string, value string) (err error) {
common.AutomaticDisableChannelEnabled = boolValue
case "LogConsumeEnabled":
common.LogConsumeEnabled = boolValue
+ case "DisplayInCurrencyEnabled":
+ common.DisplayInCurrencyEnabled = boolValue
}
}
switch key {
@@ -196,6 +200,8 @@ func updateOptionMap(key string, value string) (err error) {
common.ChatLink = value
case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
+ case "QuotaPerUnit":
+ common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
}
return err
}
diff --git a/model/redemption.go b/model/redemption.go
index 155e3cfd..d60eb649 100644
--- a/model/redemption.go
+++ b/model/redemption.go
@@ -66,7 +66,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
common.SysError("更新兑换码状态失败:" + err.Error())
}
- RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %d 点额度", redemption.Quota))
+ RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
}()
return redemption.Quota, nil
}
diff --git a/model/token.go b/model/token.go
index 64e52dcd..8744f582 100644
--- a/model/token.go
+++ b/model/token.go
@@ -36,8 +36,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
if key == "" {
return nil, errors.New("未提供 token")
}
- token = &Token{}
- err = DB.Where("`key` = ?", key).First(token).Error
+ token, err = CacheGetTokenByKey(key)
if err == nil {
if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该 token 状态不可用")
diff --git a/model/user.go b/model/user.go
index 5205662e..922b410f 100644
--- a/model/user.go
+++ b/model/user.go
@@ -93,16 +93,16 @@ func (user *User) Insert(inviterId int) error {
return result.Error
}
if common.QuotaForNewUser > 0 {
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %d 点额度", common.QuotaForNewUser))
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %d 点额度", common.QuotaForInvitee))
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
- RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %d 点额度", common.QuotaForInviter))
+ RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
}
}
return nil
@@ -256,6 +256,11 @@ func GetUserQuota(id int) (quota int, err error) {
return quota, err
}
+func GetUserUsedQuota(id int) (quota int, err error) {
+ err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error
+ return quota, err
+}
+
func GetUserEmail(id int) (email string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error
return email, err
diff --git a/web/src/App.js b/web/src/App.js
index 61477141..c967ce2c 100644
--- a/web/src/App.js
+++ b/web/src/App.js
@@ -48,6 +48,8 @@ function App() {
localStorage.setItem('system_name', data.system_name);
localStorage.setItem('logo', data.logo);
localStorage.setItem('footer_html', data.footer_html);
+ localStorage.setItem('quota_per_unit', data.quota_per_unit);
+ localStorage.setItem('display_in_currency', data.display_in_currency);
if (data.chat_link) {
localStorage.setItem('chat_link', data.chat_link);
} else {
diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js
index 3f3a4ab0..7f9b1ea5 100644
--- a/web/src/components/OperationSetting.js
+++ b/web/src/components/OperationSetting.js
@@ -13,9 +13,11 @@ const OperationSetting = () => {
GroupRatio: '',
TopUpLink: '',
ChatLink: '',
+ QuotaPerUnit: 0,
AutomaticDisableChannelEnabled: '',
ChannelDisableThreshold: 0,
- LogConsumeEnabled: ''
+ LogConsumeEnabled: '',
+ DisplayInCurrencyEnabled: ''
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
@@ -118,6 +120,9 @@ const OperationSetting = () => {
if (originInputs['ChatLink'] !== inputs.ChatLink) {
await updateOption('ChatLink', inputs.ChatLink);
}
+ if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
+ await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
+ }
break;
}
};
@@ -129,7 +134,7 @@ const OperationSetting = () => {