diff --git a/common/constants.go b/common/constants.go
index 7a1694c5..643c2b34 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -15,6 +15,8 @@ var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
+var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
+var DisplayInCurrencyEnabled = false
var UsingSQLite = false
diff --git a/common/logger.go b/common/logger.go
index 0b8b2cfb..3658dbdb 100644
--- a/common/logger.go
+++ b/common/logger.go
@@ -42,3 +42,11 @@ func FatalLog(v ...any) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}
+
+func LogQuota(quota int) string {
+ if DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
+ } else {
+ return fmt.Sprintf("%d 点额度", quota)
+ }
+}
diff --git a/controller/billing.go b/controller/billing.go
index 2f0d90fe..18a34dc9 100644
--- a/controller/billing.go
+++ b/controller/billing.go
@@ -2,6 +2,7 @@ package controller
import (
"github.com/gin-gonic/gin"
+ "one-api/common"
"one-api/model"
)
@@ -18,23 +19,38 @@ func GetSubscription(c *gin.Context) {
})
return
}
+ amount := float64(quota)
+ if common.DisplayInCurrencyEnabled {
+ amount /= common.QuotaPerUnit
+ }
subscription := OpenAISubscriptionResponse{
Object: "billing_subscription",
HasPaymentMethod: true,
- SoftLimitUSD: float64(quota),
- HardLimitUSD: float64(quota),
- SystemHardLimitUSD: float64(quota),
+ SoftLimitUSD: amount,
+ HardLimitUSD: amount,
+ SystemHardLimitUSD: amount,
}
c.JSON(200, subscription)
return
}
func GetUsage(c *gin.Context) {
- //userId := c.GetInt("id")
- // TODO: get usage from database
+ userId := c.GetInt("id")
+ quota, err := model.GetUserUsedQuota(userId)
+ if err != nil {
+ openAIError := OpenAIError{
+ Message: err.Error(),
+ Type: "one_api_error",
+ }
+ c.JSON(200, gin.H{
+ "error": openAIError,
+ })
+ return
+ }
+ amount := float64(quota)
usage := OpenAIUsageResponse{
Object: "list",
- TotalUsage: 0,
+ TotalUsage: amount,
}
c.JSON(200, usage)
return
diff --git a/controller/misc.go b/controller/misc.go
index 10f1f99e..755ccbd4 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -14,21 +14,23 @@ func GetStatus(c *gin.Context) {
"success": true,
"message": "",
"data": gin.H{
- "version": common.Version,
- "start_time": common.StartTime,
- "email_verification": common.EmailVerificationEnabled,
- "github_oauth": common.GitHubOAuthEnabled,
- "github_client_id": common.GitHubClientId,
- "system_name": common.SystemName,
- "logo": common.Logo,
- "footer_html": common.Footer,
- "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
- "wechat_login": common.WeChatAuthEnabled,
- "server_address": common.ServerAddress,
- "turnstile_check": common.TurnstileCheckEnabled,
- "turnstile_site_key": common.TurnstileSiteKey,
- "top_up_link": common.TopUpLink,
- "chat_link": common.ChatLink,
+ "version": common.Version,
+ "start_time": common.StartTime,
+ "email_verification": common.EmailVerificationEnabled,
+ "github_oauth": common.GitHubOAuthEnabled,
+ "github_client_id": common.GitHubClientId,
+ "system_name": common.SystemName,
+ "logo": common.Logo,
+ "footer_html": common.Footer,
+ "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
+ "wechat_login": common.WeChatAuthEnabled,
+ "server_address": common.ServerAddress,
+ "turnstile_check": common.TurnstileCheckEnabled,
+ "turnstile_site_key": common.TurnstileSiteKey,
+ "top_up_link": common.TopUpLink,
+ "chat_link": common.ChatLink,
+ "quota_per_unit": common.QuotaPerUnit,
+ "display_in_currency": common.DisplayInCurrencyEnabled,
},
})
return
diff --git a/controller/relay-text.go b/controller/relay-text.go
index f8b41812..e47c517b 100644
--- a/controller/relay-text.go
+++ b/controller/relay-text.go
@@ -138,7 +138,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
tokenName := c.GetString("token_name")
userId := c.GetInt("id")
- model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("通过令牌「%s」使用模型 %s 消耗 %d 点额度(模型倍率 %.2f,分组倍率 %.2f)", tokenName, textRequest.Model, quota, modelRatio, groupRatio))
+ model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("通过令牌「%s」使用模型 %s 消耗 %s(模型倍率 %.2f,分组倍率 %.2f)", tokenName, textRequest.Model, common.LogQuota(quota), modelRatio, groupRatio))
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
diff --git a/controller/user.go b/controller/user.go
index c861651a..3060edd4 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -384,7 +384,7 @@ func UpdateUser(c *gin.Context) {
return
}
if originUser.Quota != updatedUser.Quota {
- model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %d 点修改为 %d 点", originUser.Quota, updatedUser.Quota))
+ model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
}
c.JSON(http.StatusOK, gin.H{
"success": true,
diff --git a/model/cache.go b/model/cache.go
index a405c77b..4ca60496 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -16,12 +16,12 @@ const (
func CacheGetTokenByKey(key string) (*Token, error) {
var token Token
if !common.RedisEnabled {
- err := DB.Where("`key` = ?", key).First(token).Error
+ err := DB.Where("`key` = ?", key).First(&token).Error
return &token, err
}
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
- err := DB.Where("`key` = ?", key).First(token).Error
+ err := DB.Where("`key` = ?", key).First(&token).Error
if err != nil {
return nil, err
}
diff --git a/model/option.go b/model/option.go
index b53b172d..12b1ce70 100644
--- a/model/option.go
+++ b/model/option.go
@@ -35,6 +35,7 @@ func InitOptionMap() {
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
+ common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = ""
@@ -64,6 +65,7 @@ func InitOptionMap() {
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
+ common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@@ -140,6 +142,8 @@ func updateOptionMap(key string, value string) (err error) {
common.AutomaticDisableChannelEnabled = boolValue
case "LogConsumeEnabled":
common.LogConsumeEnabled = boolValue
+ case "DisplayInCurrencyEnabled":
+ common.DisplayInCurrencyEnabled = boolValue
}
}
switch key {
@@ -196,6 +200,8 @@ func updateOptionMap(key string, value string) (err error) {
common.ChatLink = value
case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
+ case "QuotaPerUnit":
+ common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
}
return err
}
diff --git a/model/redemption.go b/model/redemption.go
index 155e3cfd..d60eb649 100644
--- a/model/redemption.go
+++ b/model/redemption.go
@@ -66,7 +66,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
common.SysError("更新兑换码状态失败:" + err.Error())
}
- RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %d 点额度", redemption.Quota))
+ RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
}()
return redemption.Quota, nil
}
diff --git a/model/user.go b/model/user.go
index 5205662e..922b410f 100644
--- a/model/user.go
+++ b/model/user.go
@@ -93,16 +93,16 @@ func (user *User) Insert(inviterId int) error {
return result.Error
}
if common.QuotaForNewUser > 0 {
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %d 点额度", common.QuotaForNewUser))
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %d 点额度", common.QuotaForInvitee))
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
- RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %d 点额度", common.QuotaForInviter))
+ RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
}
}
return nil
@@ -256,6 +256,11 @@ func GetUserQuota(id int) (quota int, err error) {
return quota, err
}
+func GetUserUsedQuota(id int) (quota int, err error) {
+ err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error
+ return quota, err
+}
+
func GetUserEmail(id int) (email string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error
return email, err
diff --git a/web/src/App.js b/web/src/App.js
index 61477141..c967ce2c 100644
--- a/web/src/App.js
+++ b/web/src/App.js
@@ -48,6 +48,8 @@ function App() {
localStorage.setItem('system_name', data.system_name);
localStorage.setItem('logo', data.logo);
localStorage.setItem('footer_html', data.footer_html);
+ localStorage.setItem('quota_per_unit', data.quota_per_unit);
+ localStorage.setItem('display_in_currency', data.display_in_currency);
if (data.chat_link) {
localStorage.setItem('chat_link', data.chat_link);
} else {
diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js
index 3f3a4ab0..7f9b1ea5 100644
--- a/web/src/components/OperationSetting.js
+++ b/web/src/components/OperationSetting.js
@@ -13,9 +13,11 @@ const OperationSetting = () => {
GroupRatio: '',
TopUpLink: '',
ChatLink: '',
+ QuotaPerUnit: 0,
AutomaticDisableChannelEnabled: '',
ChannelDisableThreshold: 0,
- LogConsumeEnabled: ''
+ LogConsumeEnabled: '',
+ DisplayInCurrencyEnabled: ''
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
@@ -118,6 +120,9 @@ const OperationSetting = () => {
if (originInputs['ChatLink'] !== inputs.ChatLink) {
await updateOption('ChatLink', inputs.ChatLink);
}
+ if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
+ await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
+ }
break;
}
};
@@ -129,7 +134,7 @@ const OperationSetting = () => {