🔖 chore: migration constants

This commit is contained in:
MartialBE 2024-05-29 01:56:14 +08:00
parent ce12558ad6
commit 3d8a51e139
No known key found for this signature in database
GPG Key ID: 27C0267EC84B0A5C
91 changed files with 670 additions and 614 deletions

View File

@ -3,7 +3,8 @@ package cli
import ( import (
"flag" "flag"
"fmt" "fmt"
"one-api/common" "one-api/common/config"
"one-api/common/utils"
"os" "os"
"github.com/spf13/viper" "github.com/spf13/viper"
@ -18,11 +19,11 @@ var (
export = flag.Bool("export", false, "Exports prices to a JSON file.") export = flag.Bool("export", false, "Exports prices to a JSON file.")
) )
func FlagConfig() { func InitCli() {
flag.Parse() flag.Parse()
if *printVersion { if *printVersion {
fmt.Println(common.Version) fmt.Println(config.Version)
os.Exit(0) os.Exit(0)
} }
@ -44,10 +45,19 @@ func FlagConfig() {
os.Exit(0) os.Exit(0)
} }
if Config != nil && !utils.IsFileExist(*Config) {
panic("Config file not found")
}
viper.SetConfigFile(*Config)
if err := viper.ReadInConfig(); err != nil {
panic(err)
}
} }
func help() { func help() {
fmt.Println("One API " + common.Version + " - All in one API service for OpenAI API.") fmt.Println("One API " + config.Version + " - All in one API service for OpenAI API.")
fmt.Println("Copyright (C) 2024 MartialBE. All rights reserved.") fmt.Println("Copyright (C) 2024 MartialBE. All rights reserved.")
fmt.Println("Original copyright holder: JustSong") fmt.Println("Original copyright holder: JustSong")
fmt.Println("GitHub: https://github.com/MartialBE/one-api") fmt.Println("GitHub: https://github.com/MartialBE/one-api")

View File

@ -1,10 +1,13 @@
package common package common
import "fmt" import (
"fmt"
"one-api/common/config"
)
func LogQuota(quota int) string { func LogQuota(quota int) string {
if DisplayInCurrencyEnabled { if config.DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/QuotaPerUnit) return fmt.Sprintf("%.6f 额度", float64(quota)/config.QuotaPerUnit)
} else { } else {
return fmt.Sprintf("%d 点额度", quota) return fmt.Sprintf("%d 点额度", quota)
} }

View File

@ -5,37 +5,22 @@ import (
"strings" "strings"
"time" "time"
"one-api/cli"
"one-api/common"
"one-api/common/utils" "one-api/common/utils"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
func InitConf() { func InitConf() {
cli.FlagConfig()
defaultConfig() defaultConfig()
setConfigFile()
setEnv() setEnv()
if viper.GetBool("debug") { if viper.GetBool("debug") {
logger.SysLog("running in debug mode") logger.SysLog("running in debug mode")
} }
common.IsMasterNode = viper.GetString("node_type") != "slave" IsMasterNode = viper.GetString("node_type") != "slave"
common.RequestInterval = time.Duration(viper.GetInt("polling_interval")) * time.Second RequestInterval = time.Duration(viper.GetInt("polling_interval")) * time.Second
common.SessionSecret = utils.GetOrDefault("session_secret", common.SessionSecret) SessionSecret = utils.GetOrDefault("session_secret", SessionSecret)
}
func setConfigFile() {
if !utils.IsFileExist(*cli.Config) {
return
}
viper.SetConfigFile(*cli.Config)
if err := viper.ReadInConfig(); err != nil {
panic(err)
}
} }
func setEnv() { func setEnv() {

View File

@ -1,4 +1,4 @@
package common package config
import ( import (
"sync" "sync"

View File

@ -3,7 +3,7 @@ package channel
import ( import (
"context" "context"
"errors" "errors"
"one-api/common" "one-api/common/config"
"one-api/common/stmp" "one-api/common/stmp"
"github.com/gomarkdown/markdown" "github.com/gomarkdown/markdown"
@ -28,10 +28,10 @@ func (e *Email) Name() string {
func (e *Email) Send(ctx context.Context, title, message string) error { func (e *Email) Send(ctx context.Context, title, message string) error {
to := e.To to := e.To
if to == "" { if to == "" {
to = common.RootUserEmail to = config.RootUserEmail
} }
if common.SMTPServer == "" || common.SMTPAccount == "" || common.SMTPToken == "" || to == "" { if config.SMTPServer == "" || config.SMTPAccount == "" || config.SMTPToken == "" || to == "" {
return errors.New("smtp config is not set, skip send email notifier") return errors.New("smtp config is not set, skip send email notifier")
} }
@ -44,7 +44,7 @@ func (e *Email) Send(ctx context.Context, title, message string) error {
body := markdown.Render(doc, renderer) body := markdown.Render(doc, renderer)
emailClient := stmp.NewStmp(common.SMTPServer, common.SMTPPort, common.SMTPAccount, common.SMTPToken, common.SMTPFrom) emailClient := stmp.NewStmp(config.SMTPServer, config.SMTPPort, config.SMTPAccount, config.SMTPToken, config.SMTPFrom)
return emailClient.Send(to, title, string(body)) return emailClient.Send(to, title, string(body))
} }

View File

@ -2,6 +2,7 @@ package common
import ( import (
"context" "context"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"time" "time"
@ -41,7 +42,7 @@ func InitRedisClient() (err error) {
} else { } else {
RedisEnabled = true RedisEnabled = true
// for compatibility with old versions // for compatibility with old versions
MemoryCacheEnabled = true config.MemoryCacheEnabled = true
} }
return err return err

View File

@ -3,6 +3,7 @@ package stmp
import ( import (
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/utils" "one-api/common/utils"
"strings" "strings"
@ -38,7 +39,7 @@ func (s *StmpConfig) Send(to, subject, body string) error {
message.Subject(subject) message.Subject(subject)
message.SetGenHeader("References", s.getReferences()) message.SetGenHeader("References", s.getReferences())
message.SetBodyString(mail.TypeTextHTML, body) message.SetBodyString(mail.TypeTextHTML, body)
message.SetUserAgent(fmt.Sprintf("One API %s // https://github.com/MartialBE/one-api", common.Version)) message.SetUserAgent(fmt.Sprintf("One API %s // https://github.com/MartialBE/one-api", config.Version))
client, err := mail.NewClient( client, err := mail.NewClient(
s.Host, s.Host,
@ -78,11 +79,11 @@ func (s *StmpConfig) Render(to, subject, content string) error {
} }
func GetSystemStmp() (*StmpConfig, error) { func GetSystemStmp() (*StmpConfig, error) {
if common.SMTPServer == "" || common.SMTPPort == 0 || common.SMTPAccount == "" || common.SMTPToken == "" { if config.SMTPServer == "" || config.SMTPPort == 0 || config.SMTPAccount == "" || config.SMTPToken == "" {
return nil, fmt.Errorf("SMTP 信息未配置") return nil, fmt.Errorf("SMTP 信息未配置")
} }
return NewStmp(common.SMTPServer, common.SMTPPort, common.SMTPAccount, common.SMTPToken, common.SMTPFrom), nil return NewStmp(config.SMTPServer, config.SMTPPort, config.SMTPAccount, config.SMTPToken, config.SMTPFrom), nil
} }
func SendPasswordResetEmail(userName, email, link string) error { func SendPasswordResetEmail(userName, email, link string) error {
@ -106,7 +107,7 @@ func SendPasswordResetEmail(userName, email, link string) error {
</p> </p>
<p style="color: #858585;">重置链接 %d 分钟内有效如果不是本人操作请忽略</p>` <p style="color: #858585;">重置链接 %d 分钟内有效如果不是本人操作请忽略</p>`
subject := fmt.Sprintf("%s密码重置", common.SystemName) subject := fmt.Sprintf("%s密码重置", config.SystemName)
content := fmt.Sprintf(contentTemp, userName, link, link, common.VerificationValidMinutes) content := fmt.Sprintf(contentTemp, userName, link, link, common.VerificationValidMinutes)
return stmp.Render(email, subject, content) return stmp.Render(email, subject, content)
@ -132,7 +133,7 @@ func SendVerificationCodeEmail(email, code string) error {
验证码 %d 分钟内有效如果不是本人操作请忽略 验证码 %d 分钟内有效如果不是本人操作请忽略
</p>` </p>`
subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName)
content := fmt.Sprintf(contentTemp, code, common.VerificationValidMinutes) content := fmt.Sprintf(contentTemp, code, common.VerificationValidMinutes)
return stmp.Render(email, subject, content) return stmp.Render(email, subject, content)
@ -162,7 +163,7 @@ func SendQuotaWarningCodeEmail(userName, email string, quota int, noMoreQuota bo
if noMoreQuota { if noMoreQuota {
subject = "您的额度已用尽" subject = "您的额度已用尽"
} }
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress)
content := fmt.Sprintf(contentTemp, userName, subject, quota, topUpLink, topUpLink) content := fmt.Sprintf(contentTemp, userName, subject, quota, topUpLink, topUpLink)

View File

@ -2,6 +2,7 @@ package stmp_test
import ( import (
"fmt" "fmt"
"one-api/common/config"
"testing" "testing"
"one-api/common" "one-api/common"
@ -56,7 +57,7 @@ func TestSend(t *testing.T) {
验证码 %d 分钟内有效如果不是本人操作请忽略 验证码 %d 分钟内有效如果不是本人操作请忽略
</p>` </p>`
subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName)
content := fmt.Sprintf(contentTemp, code, common.VerificationValidMinutes) content := fmt.Sprintf(contentTemp, code, common.VerificationValidMinutes)
err := stmpClient.Render(email, subject, content) err := stmpClient.Render(email, subject, content)

View File

@ -1,17 +1,17 @@
package stmp package stmp
import ( import (
"one-api/common" "one-api/common/config"
) )
func getLogo() string { func getLogo() string {
if common.Logo == "" { if config.Logo == "" {
return "" return ""
} }
return `<table class="logo" width="100%"> return `<table class="logo" width="100%">
<tr> <tr>
<td> <td>
<img src="` + common.Logo + `" width="130" style="max-width: 100%" <img src="` + config.Logo + `" width="130" style="max-width: 100%"
/> />
</td> </td>
</tr> </tr>
@ -19,11 +19,11 @@ func getLogo() string {
} }
func getSystemName() string { func getSystemName() string {
if common.SystemName == "" { if config.SystemName == "" {
return "One API" return "One API"
} }
return common.SystemName return config.SystemName
} }
func getDefaultTemplate(content string) string { func getDefaultTemplate(content string) string {

View File

@ -1,7 +1,7 @@
package telegram package telegram
import ( import (
"one-api/common" "one-api/common/config"
"one-api/common/utils" "one-api/common/utils"
"strings" "strings"
@ -24,8 +24,8 @@ func commandAffStart(b *gotgbot.Bot, ctx *ext.Context) error {
} }
messae := "您可以通过分享您的邀请码来邀请朋友,每次成功邀请将获得奖励。\n\n您的邀请码是: " + user.AffCode messae := "您可以通过分享您的邀请码来邀请朋友,每次成功邀请将获得奖励。\n\n您的邀请码是: " + user.AffCode
if common.ServerAddress != "" { if config.ServerAddress != "" {
serverAddress := strings.TrimSuffix(common.ServerAddress, "/") serverAddress := strings.TrimSuffix(config.ServerAddress, "/")
messae += "\n\n页面地址" + serverAddress + "/register?aff=" + user.AffCode messae += "\n\n页面地址" + serverAddress + "/register?aff=" + user.AffCode
} }

View File

@ -3,7 +3,7 @@ package telegram
import ( import (
"fmt" "fmt"
"net/url" "net/url"
"one-api/common" "one-api/common/config"
"one-api/model" "one-api/model"
"strings" "strings"
@ -56,7 +56,7 @@ func getApikeyList(userId, page int) (message string, pageParams *paginationPara
} }
chatUrlTmp := "" chatUrlTmp := ""
if common.ServerAddress != "" { if config.ServerAddress != "" {
chatUrlTmp = getChatUrl() chatUrlTmp = getChatUrl()
} }
@ -75,11 +75,11 @@ func getApikeyList(userId, page int) (message string, pageParams *paginationPara
} }
func getChatUrl() string { func getChatUrl() string {
serverAddress := strings.TrimSuffix(common.ServerAddress, "/") serverAddress := strings.TrimSuffix(config.ServerAddress, "/")
chatNextUrl := fmt.Sprintf(`{"key":"setToken","url":"%s"}`, serverAddress) chatNextUrl := fmt.Sprintf(`{"key":"setToken","url":"%s"}`, serverAddress)
chatNextUrl = "https://chat.oneapi.pro/#/?settings=" + url.QueryEscape(chatNextUrl) chatNextUrl = "https://chat.oneapi.pro/#/?settings=" + url.QueryEscape(chatNextUrl)
if common.ChatLink != "" { if config.ChatLink != "" {
chatLink := strings.TrimSuffix(common.ChatLink, "/") chatLink := strings.TrimSuffix(config.ChatLink, "/")
chatNextUrl = strings.ReplaceAll(chatNextUrl, `https://chat.oneapi.pro`, chatLink) chatNextUrl = strings.ReplaceAll(chatNextUrl, `https://chat.oneapi.pro`, chatLink)
} }

View File

@ -7,7 +7,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/model" "one-api/model"
"strings" "strings"
@ -56,13 +56,13 @@ func InitTelegramBot() {
func StartTelegramBot() { func StartTelegramBot() {
botWebhook := viper.GetString("tg.webhook_secret") botWebhook := viper.GetString("tg.webhook_secret")
if botWebhook != "" { if botWebhook != "" {
if common.ServerAddress == "" { if config.ServerAddress == "" {
logger.SysLog("Telegram bot is not enabled: Server address is not set") logger.SysLog("Telegram bot is not enabled: Server address is not set")
StopTelegramBot() StopTelegramBot()
return return
} }
TGWebHookSecret = botWebhook TGWebHookSecret = botWebhook
serverAddress := strings.TrimSuffix(common.ServerAddress, "/") serverAddress := strings.TrimSuffix(config.ServerAddress, "/")
urlPath := fmt.Sprintf("/api/telegram/%s", viper.GetString("tg.bot_api_key")) urlPath := fmt.Sprintf("/api/telegram/%s", viper.GetString("tg.bot_api_key"))
webHookOpts := &ext.AddWebhookOpts{ webHookOpts := &ext.AddWebhookOpts{

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"math" "math"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"strings" "strings"
@ -21,7 +22,7 @@ var gpt4oTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() { func InitTokenEncoders() {
if viper.GetBool("disable_token_encoders") { if viper.GetBool("disable_token_encoders") {
DISABLE_TOKEN_ENCODERS = true config.DISABLE_TOKEN_ENCODERS = true
logger.SysLog("token encoders disabled") logger.SysLog("token encoders disabled")
return return
} }
@ -46,7 +47,7 @@ func InitTokenEncoders() {
} }
func getTokenEncoder(model string) *tiktoken.Tiktoken { func getTokenEncoder(model string) *tiktoken.Tiktoken {
if DISABLE_TOKEN_ENCODERS { if config.DISABLE_TOKEN_ENCODERS {
return nil return nil
} }
@ -75,7 +76,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
} }
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if DISABLE_TOKEN_ENCODERS || ApproximateTokenEnabled { if config.DISABLE_TOKEN_ENCODERS || config.ApproximateTokenEnabled {
return int(float64(len(text)) * 0.38) return int(float64(len(text)) * 0.38)
} }
return len(tokenEncoder.Encode(text, nil, nil)) return len(tokenEncoder.Encode(text, nil, nil))

View File

@ -1,7 +1,7 @@
package controller package controller
import ( import (
"one-api/common" "one-api/common/config"
"one-api/model" "one-api/model"
"one-api/types" "one-api/types"
@ -14,7 +14,7 @@ func GetSubscription(c *gin.Context) {
var err error var err error
var token *model.Token var token *model.Token
var expiredTime int64 var expiredTime int64
if common.DisplayTokenStatEnabled { if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId) token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime expiredTime = token.ExpiredTime
@ -50,8 +50,8 @@ func GetSubscription(c *gin.Context) {
} }
quota := remainQuota + usedQuota quota := remainQuota + usedQuota
amount := float64(quota) amount := float64(quota)
if common.DisplayInCurrencyEnabled { if config.DisplayInCurrencyEnabled {
amount /= common.QuotaPerUnit amount /= config.QuotaPerUnit
} }
if token != nil && token.UnlimitedQuota { if token != nil && token.UnlimitedQuota {
amount = 100000000 amount = 100000000
@ -71,7 +71,7 @@ func GetUsage(c *gin.Context) {
var quota int var quota int
var err error var err error
var token *model.Token var token *model.Token
if common.DisplayTokenStatEnabled { if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId) token, err = model.GetTokenById(tokenId)
quota = token.UsedQuota quota = token.UsedQuota
@ -90,8 +90,8 @@ func GetUsage(c *gin.Context) {
return return
} }
amount := float64(quota) amount := float64(quota)
if common.DisplayInCurrencyEnabled { if config.DisplayInCurrencyEnabled {
amount /= common.QuotaPerUnit amount /= config.QuotaPerUnit
} }
usage := OpenAIUsageResponse{ usage := OpenAIUsageResponse{
Object: "list", Object: "list",

View File

@ -4,7 +4,7 @@ import (
"errors" "errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"one-api/common" "one-api/common/config"
"one-api/model" "one-api/model"
"one-api/providers" "one-api/providers"
providersBase "one-api/providers/base" providersBase "one-api/providers/base"
@ -109,11 +109,11 @@ func updateAllChannelsBalance() error {
return err return err
} }
for _, channel := range channels { for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled { if channel.Status != config.ChannelStatusEnabled {
continue continue
} }
// TODO: support Azure // TODO: support Azure
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { if channel.Type != config.ChannelTypeOpenAI && channel.Type != config.ChannelTypeCustom {
continue continue
} }
balance, err := updateChannelBalance(channel) balance, err := updateChannelBalance(channel)
@ -125,7 +125,7 @@ func updateAllChannelsBalance() error {
DisableChannel(channel.Id, channel.Name, "余额不足", true) DisableChannel(channel.Id, channel.Name, "余额不足", true)
} }
} }
time.Sleep(common.RequestInterval) time.Sleep(config.RequestInterval)
} }
return nil return nil
} }

View File

@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/common/notify" "one-api/common/notify"
"one-api/common/utils" "one-api/common/utils"
@ -145,16 +145,16 @@ func testAllChannels(isNotify bool) error {
if err != nil { if err != nil {
return err return err
} }
var disableThreshold = int64(common.ChannelDisableThreshold * 1000) var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
if disableThreshold == 0 { if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value disableThreshold = 10000000 // a impossible value
} }
go func() { go func() {
var sendMessage string var sendMessage string
for _, channel := range channels { for _, channel := range channels {
time.Sleep(common.RequestInterval) time.Sleep(config.RequestInterval)
isChannelEnabled := channel.Status == common.ChannelStatusEnabled isChannelEnabled := channel.Status == config.ChannelStatusEnabled
sendMessage += fmt.Sprintf("**通道 %s - #%d - %s** : \n\n", utils.EscapeMarkdownText(channel.Name), channel.Id, channel.StatusToStr()) sendMessage += fmt.Sprintf("**通道 %s - #%d - %s** : \n\n", utils.EscapeMarkdownText(channel.Name), channel.Id, channel.StatusToStr())
tik := time.Now() tik := time.Now()
err, openaiErr := testChannel(channel, "") err, openaiErr := testChannel(channel, "")
@ -173,7 +173,7 @@ func testAllChannels(isNotify bool) error {
// 如果已被禁用,但是请求成功,需要判断是否需要恢复 // 如果已被禁用,但是请求成功,需要判断是否需要恢复
// 手动禁用的通道,不会自动恢复 // 手动禁用的通道,不会自动恢复
if shouldEnableChannel(err, openaiErr) { if shouldEnableChannel(err, openaiErr) {
if channel.Status == common.ChannelStatusAutoDisabled { if channel.Status == config.ChannelStatusAutoDisabled {
EnableChannel(channel.Id, channel.Name, false) EnableChannel(channel.Id, channel.Name, false)
sendMessage += "- 已被启用 \n\n" sendMessage += "- 已被启用 \n\n"
} else { } else {

View File

@ -3,7 +3,7 @@ package controller
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/notify" "one-api/common/notify"
"one-api/model" "one-api/model"
"one-api/types" "one-api/types"
@ -13,7 +13,7 @@ import (
) )
func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool { func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
if !common.AutomaticEnableChannelEnabled { if !config.AutomaticEnableChannelEnabled {
return false return false
} }
if err != nil { if err != nil {
@ -26,7 +26,7 @@ func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
} }
func ShouldDisableChannel(err *types.OpenAIError, statusCode int) bool { func ShouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
if !common.AutomaticDisableChannelEnabled { if !config.AutomaticDisableChannelEnabled {
return false return false
} }
@ -74,7 +74,7 @@ func ShouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
// disable & notify // disable & notify
func DisableChannel(channelId int, channelName string, reason string, sendNotify bool) { func DisableChannel(channelId int, channelName string, reason string, sendNotify bool) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) model.UpdateChannelStatusById(channelId, config.ChannelStatusAutoDisabled)
if !sendNotify { if !sendNotify {
return return
} }
@ -86,7 +86,7 @@ func DisableChannel(channelId int, channelName string, reason string, sendNotify
// enable & notify // enable & notify
func EnableChannel(channelId int, channelName string, sendNotify bool) { func EnableChannel(channelId int, channelName string, sendNotify bool) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) model.UpdateChannelStatusById(channelId, config.ChannelStatusEnabled)
if !sendNotify { if !sendNotify {
return return
} }

View File

@ -6,7 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/common/utils" "one-api/common/utils"
"one-api/model" "one-api/model"
@ -33,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
if code == "" { if code == "" {
return nil, errors.New("无效的参数") return nil, errors.New("无效的参数")
} }
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code} values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code}
jsonData, err := json.Marshal(values) jsonData, err := json.Marshal(values)
if err != nil { if err != nil {
return nil, err return nil, err
@ -96,7 +96,7 @@ func GitHubOAuth(c *gin.Context) {
return return
} }
if !common.GitHubOAuthEnabled { if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "管理员未开启通过 GitHub 登录以及注册", "message": "管理员未开启通过 GitHub 登录以及注册",
@ -125,7 +125,7 @@ func GitHubOAuth(c *gin.Context) {
return return
} }
} else { } else {
if common.RegisterEnabled { if config.RegisterEnabled {
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" { if githubUser.Name != "" {
user.DisplayName = githubUser.Name user.DisplayName = githubUser.Name
@ -133,8 +133,8 @@ func GitHubOAuth(c *gin.Context) {
user.DisplayName = "GitHub User" user.DisplayName = "GitHub User"
} }
user.Email = githubUser.Email user.Email = githubUser.Email
user.Role = common.RoleCommonUser user.Role = config.RoleCommonUser
user.Status = common.UserStatusEnabled user.Status = config.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -152,7 +152,7 @@ func GitHubOAuth(c *gin.Context) {
} }
} }
if user.Status != common.UserStatusEnabled { if user.Status != config.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁", "message": "用户已被封禁",
"success": false, "success": false,
@ -163,7 +163,7 @@ func GitHubOAuth(c *gin.Context) {
} }
func GitHubBind(c *gin.Context) { func GitHubBind(c *gin.Context) {
if !common.GitHubOAuthEnabled { if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "管理员未开启通过 GitHub 登录以及注册", "message": "管理员未开启通过 GitHub 登录以及注册",

View File

@ -6,7 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/model" "one-api/model"
"strconv" "strconv"
@ -41,8 +41,8 @@ type LarkUser struct {
func getLarkAppAccessToken() (string, error) { func getLarkAppAccessToken() (string, error) {
values := map[string]string{ values := map[string]string{
"app_id": common.LarkClientId, "app_id": config.LarkClientId,
"app_secret": common.LarkClientSecret, "app_secret": config.LarkClientSecret,
} }
jsonData, err := json.Marshal(values) jsonData, err := json.Marshal(values)
if err != nil { if err != nil {
@ -148,7 +148,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
} }
func LarkOAuth(c *gin.Context) { func LarkOAuth(c *gin.Context) {
if !common.LarkAuthEnabled { if !config.LarkAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过飞书登录以及注册", "message": "管理员未开启通过飞书登录以及注册",
"success": false, "success": false,
@ -191,15 +191,15 @@ func LarkOAuth(c *gin.Context) {
return return
} }
} else { } else {
if common.RegisterEnabled { if config.RegisterEnabled {
user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1)
if larkUser.Data.Name != "" { if larkUser.Data.Name != "" {
user.DisplayName = larkUser.Data.Name user.DisplayName = larkUser.Data.Name
} else { } else {
user.DisplayName = "Lark User" user.DisplayName = "Lark User"
} }
user.Role = common.RoleCommonUser user.Role = config.RoleCommonUser
user.Status = common.UserStatusEnabled user.Status = config.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -217,7 +217,7 @@ func LarkOAuth(c *gin.Context) {
} }
} }
if user.Status != common.UserStatusEnabled { if user.Status != config.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁", "message": "用户已被封禁",
"success": false, "success": false,
@ -228,7 +228,7 @@ func LarkOAuth(c *gin.Context) {
} }
func LarkBind(c *gin.Context) { func LarkBind(c *gin.Context) {
if !common.LarkAuthEnabled { if !config.LarkAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过飞书登录以及注册", "message": "管理员未开启通过飞书登录以及注册",
"success": false, "success": false,

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/stmp" "one-api/common/stmp"
"one-api/common/telegram" "one-api/common/telegram"
"one-api/model" "one-api/model"
@ -23,60 +24,60 @@ func GetStatus(c *gin.Context) {
"success": true, "success": true,
"message": "", "message": "",
"data": gin.H{ "data": gin.H{
"version": common.Version, "version": config.Version,
"start_time": common.StartTime, "start_time": config.StartTime,
"email_verification": common.EmailVerificationEnabled, "email_verification": config.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled, "github_oauth": config.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId, "github_client_id": config.GitHubClientId,
"lark_login": common.LarkAuthEnabled, "lark_login": config.LarkAuthEnabled,
"lark_client_id": common.LarkClientId, "lark_client_id": config.LarkClientId,
"system_name": common.SystemName, "system_name": config.SystemName,
"logo": common.Logo, "logo": config.Logo,
"footer_html": common.Footer, "footer_html": config.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_qrcode": config.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled, "wechat_login": config.WeChatAuthEnabled,
"server_address": common.ServerAddress, "server_address": config.ServerAddress,
"turnstile_check": common.TurnstileCheckEnabled, "turnstile_check": config.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey, "turnstile_site_key": config.TurnstileSiteKey,
"top_up_link": common.TopUpLink, "top_up_link": config.TopUpLink,
"chat_link": common.ChatLink, "chat_link": config.ChatLink,
"quota_per_unit": common.QuotaPerUnit, "quota_per_unit": config.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled, "display_in_currency": config.DisplayInCurrencyEnabled,
"telegram_bot": telegram_bot, "telegram_bot": telegram_bot,
"mj_notify_enabled": common.MjNotifyEnabled, "mj_notify_enabled": config.MjNotifyEnabled,
"chat_cache_enabled": common.ChatCacheEnabled, "chat_cache_enabled": config.ChatCacheEnabled,
"chat_links": common.ChatLinks, "chat_links": config.ChatLinks,
}, },
}) })
} }
func GetNotice(c *gin.Context) { func GetNotice(c *gin.Context) {
common.OptionMapRWMutex.RLock() config.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock() defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": common.OptionMap["Notice"], "data": config.OptionMap["Notice"],
}) })
} }
func GetAbout(c *gin.Context) { func GetAbout(c *gin.Context) {
common.OptionMapRWMutex.RLock() config.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock() defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": common.OptionMap["About"], "data": config.OptionMap["About"],
}) })
} }
func GetHomePageContent(c *gin.Context) { func GetHomePageContent(c *gin.Context) {
common.OptionMapRWMutex.RLock() config.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock() defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": common.OptionMap["HomePageContent"], "data": config.OptionMap["HomePageContent"],
}) })
} }
@ -89,9 +90,9 @@ func SendEmailVerification(c *gin.Context) {
}) })
return return
} }
if common.EmailDomainRestrictionEnabled { if config.EmailDomainRestrictionEnabled {
allowed := false allowed := false
for _, domain := range common.EmailDomainWhitelist { for _, domain := range config.EmailDomainWhitelist {
if strings.HasSuffix(email, "@"+domain) { if strings.HasSuffix(email, "@"+domain) {
allowed = true allowed = true
break break
@ -157,7 +158,7 @@ func SendPasswordResetEmail(c *gin.Context) {
code := common.GenerateVerificationCode(0) code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code)
err := stmp.SendPasswordResetEmail(userName, email, link) err := stmp.SendPasswordResetEmail(userName, email, link)
if err != nil { if err != nil {

View File

@ -3,7 +3,7 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/utils" "one-api/common/utils"
"one-api/model" "one-api/model"
"strings" "strings"
@ -13,8 +13,8 @@ import (
func GetOptions(c *gin.Context) { func GetOptions(c *gin.Context) {
var options []*model.Option var options []*model.Option
common.OptionMapRWMutex.Lock() config.OptionMapRWMutex.Lock()
for k, v := range common.OptionMap { for k, v := range config.OptionMap {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
continue continue
} }
@ -23,7 +23,7 @@ func GetOptions(c *gin.Context) {
Value: utils.Interface2String(v), Value: utils.Interface2String(v),
}) })
} }
common.OptionMapRWMutex.Unlock() config.OptionMapRWMutex.Unlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
@ -44,7 +44,7 @@ func UpdateOption(c *gin.Context) {
} }
switch option.Key { switch option.Key {
case "GitHubOAuthEnabled": case "GitHubOAuthEnabled":
if option.Value == "true" && common.GitHubClientId == "" { if option.Value == "true" && config.GitHubClientId == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用 GitHub OAuth请先填入 GitHub Client Id 以及 GitHub Client Secret", "message": "无法启用 GitHub OAuth请先填入 GitHub Client Id 以及 GitHub Client Secret",
@ -52,7 +52,7 @@ func UpdateOption(c *gin.Context) {
return return
} }
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", "message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
@ -60,7 +60,7 @@ func UpdateOption(c *gin.Context) {
return return
} }
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
if option.Value == "true" && common.WeChatServerAddress == "" { if option.Value == "true" && config.WeChatServerAddress == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用微信登录,请先填入微信登录相关配置信息!", "message": "无法启用微信登录,请先填入微信登录相关配置信息!",
@ -68,7 +68,7 @@ func UpdateOption(c *gin.Context) {
return return
} }
case "TurnstileCheckEnabled": case "TurnstileCheckEnabled":
if option.Value == "true" && common.TurnstileSiteKey == "" { if option.Value == "true" && config.TurnstileSiteKey == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",

View File

@ -3,6 +3,7 @@ package controller
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/utils" "one-api/common/utils"
"one-api/model" "one-api/model"
"strconv" "strconv"
@ -199,15 +200,15 @@ func UpdateToken(c *gin.Context) {
}) })
return return
} }
if token.Status == common.TokenStatusEnabled { if token.Status == config.TokenStatusEnabled {
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= utils.GetTimestamp() && cleanToken.ExpiredTime != -1 { if cleanToken.Status == config.TokenStatusExpired && cleanToken.ExpiredTime <= utils.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
}) })
return return
} }
if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { if cleanToken.Status == config.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/utils" "one-api/common/utils"
"one-api/model" "one-api/model"
"strconv" "strconv"
@ -20,7 +21,7 @@ type LoginRequest struct {
} }
func Login(c *gin.Context) { func Login(c *gin.Context) {
if !common.PasswordLoginEnabled { if !config.PasswordLoginEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了密码登录", "message": "管理员关闭了密码登录",
"success": false, "success": false,
@ -107,14 +108,14 @@ func Logout(c *gin.Context) {
} }
func Register(c *gin.Context) { func Register(c *gin.Context) {
if !common.RegisterEnabled { if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册", "message": "管理员关闭了新用户注册",
"success": false, "success": false,
}) })
return return
} }
if !common.PasswordRegisterEnabled { if !config.PasswordRegisterEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", "message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
"success": false, "success": false,
@ -137,7 +138,7 @@ func Register(c *gin.Context) {
}) })
return return
} }
if common.EmailVerificationEnabled { if config.EmailVerificationEnabled {
if user.Email == "" || user.VerificationCode == "" { if user.Email == "" || user.VerificationCode == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -161,7 +162,7 @@ func Register(c *gin.Context) {
DisplayName: user.Username, DisplayName: user.Username,
InviterId: inviterId, InviterId: inviterId,
} }
if common.EmailVerificationEnabled { if config.EmailVerificationEnabled {
cleanUser.Email = user.Email cleanUser.Email = user.Email
} }
if err := cleanUser.Insert(inviterId); err != nil { if err := cleanUser.Insert(inviterId); err != nil {
@ -214,7 +215,7 @@ func GetUser(c *gin.Context) {
return return
} }
myRole := c.GetInt("role") myRole := c.GetInt("role")
if myRole <= user.Role && myRole != common.RoleRootUser { if myRole <= user.Role && myRole != config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权获取同级或更高等级用户的信息", "message": "无权获取同级或更高等级用户的信息",
@ -360,14 +361,14 @@ func UpdateUser(c *gin.Context) {
return return
} }
myRole := c.GetInt("role") myRole := c.GetInt("role")
if myRole <= originUser.Role && myRole != common.RoleRootUser { if myRole <= originUser.Role && myRole != config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息", "message": "无权更新同权限等级或更高权限等级的用户信息",
}) })
return return
} }
if myRole <= updatedUser.Role && myRole != common.RoleRootUser { if myRole <= updatedUser.Role && myRole != config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权将其他用户权限等级提升到大于等于自己的权限等级", "message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
@ -479,7 +480,7 @@ func DeleteSelf(c *gin.Context) {
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)
if user.Role == common.RoleRootUser { if user.Role == config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "不能删除超级管理员账户", "message": "不能删除超级管理员账户",
@ -579,7 +580,7 @@ func ManageUser(c *gin.Context) {
return return
} }
myRole := c.GetInt("role") myRole := c.GetInt("role")
if myRole <= user.Role && myRole != common.RoleRootUser { if myRole <= user.Role && myRole != config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息", "message": "无权更新同权限等级或更高权限等级的用户信息",
@ -588,8 +589,8 @@ func ManageUser(c *gin.Context) {
} }
switch req.Action { switch req.Action {
case "disable": case "disable":
user.Status = common.UserStatusDisabled user.Status = config.UserStatusDisabled
if user.Role == common.RoleRootUser { if user.Role == config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法禁用超级管理员用户", "message": "无法禁用超级管理员用户",
@ -597,9 +598,9 @@ func ManageUser(c *gin.Context) {
return return
} }
case "enable": case "enable":
user.Status = common.UserStatusEnabled user.Status = config.UserStatusEnabled
case "delete": case "delete":
if user.Role == common.RoleRootUser { if user.Role == config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法删除超级管理员用户", "message": "无法删除超级管理员用户",
@ -614,37 +615,37 @@ func ManageUser(c *gin.Context) {
return return
} }
case "promote": case "promote":
if myRole != common.RoleRootUser { if myRole != config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "普通管理员用户无法提升其他用户为管理员", "message": "普通管理员用户无法提升其他用户为管理员",
}) })
return return
} }
if user.Role >= common.RoleAdminUser { if user.Role >= config.RoleAdminUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "该用户已经是管理员", "message": "该用户已经是管理员",
}) })
return return
} }
user.Role = common.RoleAdminUser user.Role = config.RoleAdminUser
case "demote": case "demote":
if user.Role == common.RoleRootUser { if user.Role == config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法降级超级管理员用户", "message": "无法降级超级管理员用户",
}) })
return return
} }
if user.Role == common.RoleCommonUser { if user.Role == config.RoleCommonUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "该用户已经是普通用户", "message": "该用户已经是普通用户",
}) })
return return
} }
user.Role = common.RoleCommonUser user.Role = config.RoleCommonUser
} }
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {

View File

@ -5,7 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/model" "one-api/model"
"strconv" "strconv"
"time" "time"
@ -23,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) {
if code == "" { if code == "" {
return "", errors.New("无效的参数") return "", errors.New("无效的参数")
} }
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil) req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil)
if err != nil { if err != nil {
return "", err return "", err
} }
req.Header.Set("Authorization", common.WeChatServerToken) req.Header.Set("Authorization", config.WeChatServerToken)
client := http.Client{ client := http.Client{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
} }
@ -51,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) {
} }
func WeChatAuth(c *gin.Context) { func WeChatAuth(c *gin.Context) {
if !common.WeChatAuthEnabled { if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册", "message": "管理员未开启通过微信登录以及注册",
"success": false, "success": false,
@ -80,11 +80,11 @@ func WeChatAuth(c *gin.Context) {
return return
} }
} else { } else {
if common.RegisterEnabled { if config.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = "WeChat User" user.DisplayName = "WeChat User"
user.Role = common.RoleCommonUser user.Role = config.RoleCommonUser
user.Status = common.UserStatusEnabled user.Status = config.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -102,7 +102,7 @@ func WeChatAuth(c *gin.Context) {
} }
} }
if user.Status != common.UserStatusEnabled { if user.Status != config.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁", "message": "用户已被封禁",
"success": false, "success": false,
@ -113,7 +113,7 @@ func WeChatAuth(c *gin.Context) {
} }
func WeChatBind(c *gin.Context) { func WeChatBind(c *gin.Context) {
if !common.WeChatAuthEnabled { if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册", "message": "管理员未开启通过微信登录以及注册",
"success": false, "success": false,

12
main.go
View File

@ -3,6 +3,7 @@ package main
import ( import (
"embed" "embed"
"fmt" "fmt"
"one-api/cli"
"one-api/common" "one-api/common"
"one-api/common/config" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
@ -31,9 +32,10 @@ var buildFS embed.FS
var indexPage []byte var indexPage []byte
func main() { func main() {
cli.InitCli()
config.InitConf() config.InitConf()
logger.SetupLogger() logger.SetupLogger()
logger.SysLog("One API " + common.Version + " started") logger.SysLog("One API " + config.Version + " started")
// Initialize SQL Database // Initialize SQL Database
model.SetupDB() model.SetupDB()
defer model.CloseDB() defer model.CloseDB()
@ -60,10 +62,10 @@ func main() {
func initMemoryCache() { func initMemoryCache() {
if viper.GetBool("memory_cache_enabled") { if viper.GetBool("memory_cache_enabled") {
common.MemoryCacheEnabled = true config.MemoryCacheEnabled = true
} }
if !common.MemoryCacheEnabled { if !config.MemoryCacheEnabled {
return return
} }
@ -91,7 +93,7 @@ func initHttpServer() {
server.Use(middleware.RequestId()) server.Use(middleware.RequestId())
middleware.SetUpLogger(server) middleware.SetUpLogger(server)
store := cookie.NewStore([]byte(common.SessionSecret)) store := cookie.NewStore([]byte(config.SessionSecret))
server.Use(sessions.Sessions("session", store)) server.Use(sessions.Sessions("session", store))
router.SetRouter(server, buildFS, indexPage) router.SetRouter(server, buildFS, indexPage)
@ -105,7 +107,7 @@ func initHttpServer() {
func SyncChannelCache(frequency int) { func SyncChannelCache(frequency int) {
// 只有 从 服务器端获取数据的时候才会用到 // 只有 从 服务器端获取数据的时候才会用到
if common.IsMasterNode { if config.IsMasterNode {
logger.SysLog("master node does't synchronize the channel") logger.SysLog("master node does't synchronize the channel")
return return
} }

View File

@ -2,7 +2,7 @@ package middleware
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/utils" "one-api/common/utils"
"one-api/model" "one-api/model"
"strings" "strings"
@ -44,7 +44,7 @@ func authHelper(c *gin.Context, minRole int) {
return return
} }
} }
if status.(int) == common.UserStatusDisabled { if status.(int) == config.UserStatusDisabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "用户已被封禁", "message": "用户已被封禁",
@ -68,19 +68,19 @@ func authHelper(c *gin.Context, minRole int) {
func UserAuth() func(c *gin.Context) { func UserAuth() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
authHelper(c, common.RoleCommonUser) authHelper(c, config.RoleCommonUser)
} }
} }
func AdminAuth() func(c *gin.Context) { func AdminAuth() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
authHelper(c, common.RoleAdminUser) authHelper(c, config.RoleAdminUser)
} }
} }
func RootAuth() func(c *gin.Context) { func RootAuth() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
authHelper(c, common.RoleRootUser) authHelper(c, config.RoleRootUser)
} }
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/utils" "one-api/common/utils"
"time" "time"
@ -45,7 +46,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
} }
if listLength < int64(maxRequestNum) { if listLength < int64(maxRequestNum) {
rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
} else { } else {
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr) oldTime, err := time.Parse(timeFormat, oldTimeStr)
@ -64,14 +65,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
// time.Since will return negative number! // time.Since will return negative number!
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
if int64(nowTime.Sub(oldTime).Seconds()) < duration { if int64(nowTime.Sub(oldTime).Seconds()) < duration {
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
c.Status(http.StatusTooManyRequests) c.Status(http.StatusTooManyRequests)
c.Abort() c.Abort()
return return
} else { } else {
rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
} }
} }
} }
@ -92,7 +93,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
} }
} else { } else {
// It's safe to call multi times. // It's safe to call multi times.
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration)
return func(c *gin.Context) { return func(c *gin.Context) {
memoryRateLimiter(c, maxRequestNum, duration, mark) memoryRateLimiter(c, maxRequestNum, duration, mark)
} }

View File

@ -6,7 +6,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"net/url" "net/url"
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
) )
@ -16,7 +16,7 @@ type turnstileCheckResponse struct {
func TurnstileCheck() gin.HandlerFunc { func TurnstileCheck() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if common.TurnstileCheckEnabled { if config.TurnstileCheckEnabled {
session := sessions.Default(c) session := sessions.Default(c)
turnstileChecked := session.Get("turnstile") turnstileChecked := session.Get("turnstile")
if turnstileChecked != nil { if turnstileChecked != nil {
@ -33,7 +33,7 @@ func TurnstileCheck() gin.HandlerFunc {
return return
} }
rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
"secret": {common.TurnstileSecretKey}, "secret": {config.TurnstileSecretKey},
"response": {response}, "response": {response},
"remoteip": {c.ClientIP()}, "remoteip": {c.ClientIP()},
}) })

View File

@ -2,6 +2,7 @@ package model
import ( import (
"one-api/common" "one-api/common"
"one-api/common/config"
"strings" "strings"
) )
@ -66,7 +67,7 @@ func (channel *Channel) AddAbilities() error {
Group: group, Group: group,
Model: model, Model: model,
ChannelId: channel.Id, ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled, Enabled: channel.Status == config.ChannelStatusEnabled,
Priority: channel.Priority, Priority: channel.Priority,
Weight: channel.Weight, Weight: channel.Weight,
} }

View File

@ -3,7 +3,7 @@ package model
import ( import (
"errors" "errors"
"math/rand" "math/rand"
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/common/utils" "one-api/common/utils"
"strings" "strings"
@ -38,7 +38,7 @@ func FilterOnlyChat() ChannelsFilterFunc {
} }
func (cc *ChannelsChooser) Cooldowns(channelId int) bool { func (cc *ChannelsChooser) Cooldowns(channelId int) bool {
if common.RetryCooldownSeconds == 0 { if config.RetryCooldownSeconds == 0 {
return false return false
} }
cc.Lock() cc.Lock()
@ -47,7 +47,7 @@ func (cc *ChannelsChooser) Cooldowns(channelId int) bool {
return false return false
} }
cc.Channels[channelId].CooldownsTime = time.Now().Unix() + int64(common.RetryCooldownSeconds) cc.Channels[channelId].CooldownsTime = time.Now().Unix() + int64(config.RetryCooldownSeconds)
return true return true
} }
@ -159,7 +159,7 @@ var ChannelGroup = ChannelsChooser{}
func (cc *ChannelsChooser) Load() { func (cc *ChannelsChooser) Load() {
var channels []*Channel var channels []*Channel
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) DB.Where("status = ?", config.ChannelStatusEnabled).Find(&channels)
abilities, err := GetAbilityChannelGroup() abilities, err := GetAbilityChannelGroup()
if err != nil { if err != nil {
@ -173,7 +173,7 @@ func (cc *ChannelsChooser) Load() {
for _, channel := range channels { for _, channel := range channels {
if *channel.Weight == 0 { if *channel.Weight == 0 {
channel.Weight = &common.DefaultChannelWeight channel.Weight = &config.DefaultChannelWeight
} }
newChannels[channel.Id] = &ChannelChoice{ newChannels[channel.Id] = &ChannelChoice{
Channel: channel, Channel: channel,

View File

@ -1,7 +1,7 @@
package model package model
import ( import (
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/common/utils" "one-api/common/utils"
"strings" "strings"
@ -270,11 +270,11 @@ func (channel *Channel) Delete() error {
func (channel *Channel) StatusToStr() string { func (channel *Channel) StatusToStr() string {
switch channel.Status { switch channel.Status {
case common.ChannelStatusEnabled: case config.ChannelStatusEnabled:
return "启用" return "启用"
case common.ChannelStatusAutoDisabled: case config.ChannelStatusAutoDisabled:
return "自动禁用" return "自动禁用"
case common.ChannelStatusManuallyDisabled: case config.ChannelStatusManuallyDisabled:
return "手动禁用" return "手动禁用"
} }
@ -282,7 +282,7 @@ func (channel *Channel) StatusToStr() string {
} }
func UpdateChannelStatusById(id int, status int) { func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) err := UpdateAbilityStatus(id, status == config.ChannelStatusEnabled)
if err != nil { if err != nil {
logger.SysError("failed to update ability status: " + err.Error()) logger.SysError("failed to update ability status: " + err.Error())
} }
@ -298,7 +298,7 @@ func UpdateChannelStatusById(id int, status int) {
} }
func UpdateChannelUsedQuota(id int, quota int) { func UpdateChannelUsedQuota(id int, quota int) {
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
return return
} }
@ -318,7 +318,7 @@ func DeleteChannelByStatus(status int64) (int64, error) {
} }
func DeleteDisabledChannel() (int64, error) { func DeleteDisabledChannel() (int64, error) {
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) result := DB.Where("status = ? or status = ?", config.ChannelStatusAutoDisabled, config.ChannelStatusManuallyDisabled).Delete(&Channel{})
// 同时删除Ability // 同时删除Ability
DB.Where("enabled = ?", false).Delete(&Ability{}) DB.Where("enabled = ?", false).Delete(&Ability{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error

View File

@ -3,6 +3,7 @@ package model
import ( import (
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/config"
"strings" "strings"
"gorm.io/gorm" "gorm.io/gorm"
@ -43,11 +44,11 @@ func PaginateAndOrder[T modelable](db *gorm.DB, params *PaginationParams, result
params.Page = 1 params.Page = 1
} }
if params.Size < 1 { if params.Size < 1 {
params.Size = common.ItemsPerPage params.Size = config.ItemsPerPage
} }
if params.Size > common.MaxRecentItems { if params.Size > config.MaxRecentItems {
return nil, fmt.Errorf("size 参数不能超过 %d", common.MaxRecentItems) return nil, fmt.Errorf("size 参数不能超过 %d", config.MaxRecentItems)
} }
offset := (params.Page - 1) * params.Size offset := (params.Page - 1) * params.Size

View File

@ -3,7 +3,7 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/common/utils" "one-api/common/utils"
@ -37,7 +37,7 @@ const (
) )
func RecordLog(userId int, logType int, content string) { func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !common.LogConsumeEnabled { if logType == LogTypeConsume && !config.LogConsumeEnabled {
return return
} }
log := &Log{ log := &Log{
@ -55,7 +55,7 @@ func RecordLog(userId int, logType int, content string) {
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, requestTime int) { func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, requestTime int) {
logger.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) logger.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled { if !config.LogConsumeEnabled {
return return
} }
log := &Log{ log := &Log{
@ -156,12 +156,12 @@ func GetUserLogsList(userId int, params *LogsListParams) (*DataResult[Log], erro
} }
func SearchAllLogs(keyword string) (logs []*Log, err error) { func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
return logs, err return logs, err
} }
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err return logs, err
} }

View File

@ -3,6 +3,7 @@ package model
import ( import (
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/common/utils" "one-api/common/utils"
"strconv" "strconv"
@ -24,12 +25,12 @@ func SetupDB() {
logger.FatalLog("failed to initialize database: " + err.Error()) logger.FatalLog("failed to initialize database: " + err.Error())
} }
ChannelGroup.Load() ChannelGroup.Load()
common.RootUserEmail = GetRootUserEmail() config.RootUserEmail = GetRootUserEmail()
if viper.GetBool("batch_update_enabled") { if viper.GetBool("batch_update_enabled") {
common.BatchUpdateEnabled = true config.BatchUpdateEnabled = true
common.BatchUpdateInterval = utils.GetOrDefault("batch_update_interval", 5) config.BatchUpdateInterval = utils.GetOrDefault("batch_update_interval", 5)
logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s")
InitBatchUpdater() InitBatchUpdater()
} }
} }
@ -46,8 +47,8 @@ func createRootAccountIfNeed() error {
rootUser := User{ rootUser := User{
Username: "root", Username: "root",
Password: hashedPassword, Password: hashedPassword,
Role: common.RoleRootUser, Role: config.RoleRootUser,
Status: common.UserStatusEnabled, Status: config.UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: utils.GetUUID(), AccessToken: utils.GetUUID(),
Quota: 100000000, Quota: 100000000,
@ -102,7 +103,7 @@ func InitDB() (err error) {
sqlDB.SetMaxOpenConns(utils.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) sqlDB.SetMaxOpenConns(utils.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(utils.GetOrDefault("SQL_MAX_LIFETIME", 60))) sqlDB.SetConnMaxLifetime(time.Second * time.Duration(utils.GetOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode { if !config.IsMasterNode {
return nil return nil
} }
logger.SysLog("database migration started") logger.SysLog("database migration started")

View File

@ -2,6 +2,7 @@ package model
import ( import (
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"strconv" "strconv"
"strings" "strings"
@ -26,63 +27,63 @@ func GetOption(key string) (option Option, err error) {
} }
func InitOptionMap() { func InitOptionMap() {
common.OptionMapRWMutex.Lock() config.OptionMapRWMutex.Lock()
common.OptionMap = make(map[string]string) config.OptionMap = make(map[string]string)
common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled)
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
common.OptionMap["LarkAuthEnabled"] = strconv.FormatBool(common.LarkAuthEnabled) config.OptionMap["LarkAuthEnabled"] = strconv.FormatBool(config.LarkAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled)
common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled)
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled)
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",")
common.OptionMap["SMTPServer"] = "" config.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = "" config.OptionMap["SMTPFrom"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort)
common.OptionMap["SMTPAccount"] = "" config.OptionMap["SMTPAccount"] = ""
common.OptionMap["SMTPToken"] = "" config.OptionMap["SMTPToken"] = ""
common.OptionMap["Notice"] = "" config.OptionMap["Notice"] = ""
common.OptionMap["About"] = "" config.OptionMap["About"] = ""
common.OptionMap["HomePageContent"] = "" config.OptionMap["HomePageContent"] = ""
common.OptionMap["Footer"] = common.Footer config.OptionMap["Footer"] = config.Footer
common.OptionMap["SystemName"] = common.SystemName config.OptionMap["SystemName"] = config.SystemName
common.OptionMap["Logo"] = common.Logo config.OptionMap["Logo"] = config.Logo
common.OptionMap["ServerAddress"] = "" config.OptionMap["ServerAddress"] = ""
common.OptionMap["GitHubClientId"] = "" config.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = "" config.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["WeChatServerAddress"] = "" config.OptionMap["WeChatServerAddress"] = ""
common.OptionMap["WeChatServerToken"] = "" config.OptionMap["WeChatServerToken"] = ""
common.OptionMap["WeChatAccountQRCodeImageURL"] = "" config.OptionMap["WeChatAccountQRCodeImageURL"] = ""
common.OptionMap["TurnstileSiteKey"] = "" config.OptionMap["TurnstileSiteKey"] = ""
common.OptionMap["TurnstileSecretKey"] = "" config.OptionMap["TurnstileSecretKey"] = ""
common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser)
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota)
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink config.OptionMap["TopUpLink"] = config.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink config.OptionMap["ChatLink"] = config.ChatLink
common.OptionMap["ChatLinks"] = common.ChatLinks config.OptionMap["ChatLinks"] = config.ChatLinks
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes)
common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds) config.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(config.RetryCooldownSeconds)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled) config.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(config.MjNotifyEnabled)
common.OptionMap["ChatCacheEnabled"] = strconv.FormatBool(common.ChatCacheEnabled) config.OptionMap["ChatCacheEnabled"] = strconv.FormatBool(config.ChatCacheEnabled)
common.OptionMap["ChatCacheExpireMinute"] = strconv.Itoa(common.ChatCacheExpireMinute) config.OptionMap["ChatCacheExpireMinute"] = strconv.Itoa(config.ChatCacheExpireMinute)
common.OptionMapRWMutex.Unlock() config.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase() loadOptionsFromDatabase()
} }
@ -121,64 +122,64 @@ func UpdateOption(key string, value string) error {
} }
var optionIntMap = map[string]*int{ var optionIntMap = map[string]*int{
"SMTPPort": &common.SMTPPort, "SMTPPort": &config.SMTPPort,
"QuotaForNewUser": &common.QuotaForNewUser, "QuotaForNewUser": &config.QuotaForNewUser,
"QuotaForInviter": &common.QuotaForInviter, "QuotaForInviter": &config.QuotaForInviter,
"QuotaForInvitee": &common.QuotaForInvitee, "QuotaForInvitee": &config.QuotaForInvitee,
"QuotaRemindThreshold": &common.QuotaRemindThreshold, "QuotaRemindThreshold": &config.QuotaRemindThreshold,
"PreConsumedQuota": &common.PreConsumedQuota, "PreConsumedQuota": &config.PreConsumedQuota,
"RetryTimes": &common.RetryTimes, "RetryTimes": &config.RetryTimes,
"RetryCooldownSeconds": &common.RetryCooldownSeconds, "RetryCooldownSeconds": &config.RetryCooldownSeconds,
"ChatCacheExpireMinute": &common.ChatCacheExpireMinute, "ChatCacheExpireMinute": &config.ChatCacheExpireMinute,
} }
var optionBoolMap = map[string]*bool{ var optionBoolMap = map[string]*bool{
"PasswordRegisterEnabled": &common.PasswordRegisterEnabled, "PasswordRegisterEnabled": &config.PasswordRegisterEnabled,
"PasswordLoginEnabled": &common.PasswordLoginEnabled, "PasswordLoginEnabled": &config.PasswordLoginEnabled,
"EmailVerificationEnabled": &common.EmailVerificationEnabled, "EmailVerificationEnabled": &config.EmailVerificationEnabled,
"GitHubOAuthEnabled": &common.GitHubOAuthEnabled, "GitHubOAuthEnabled": &config.GitHubOAuthEnabled,
"WeChatAuthEnabled": &common.WeChatAuthEnabled, "WeChatAuthEnabled": &config.WeChatAuthEnabled,
"LarkAuthEnabled": &common.LarkAuthEnabled, "LarkAuthEnabled": &config.LarkAuthEnabled,
"TurnstileCheckEnabled": &common.TurnstileCheckEnabled, "TurnstileCheckEnabled": &config.TurnstileCheckEnabled,
"RegisterEnabled": &common.RegisterEnabled, "RegisterEnabled": &config.RegisterEnabled,
"EmailDomainRestrictionEnabled": &common.EmailDomainRestrictionEnabled, "EmailDomainRestrictionEnabled": &config.EmailDomainRestrictionEnabled,
"AutomaticDisableChannelEnabled": &common.AutomaticDisableChannelEnabled, "AutomaticDisableChannelEnabled": &config.AutomaticDisableChannelEnabled,
"AutomaticEnableChannelEnabled": &common.AutomaticEnableChannelEnabled, "AutomaticEnableChannelEnabled": &config.AutomaticEnableChannelEnabled,
"ApproximateTokenEnabled": &common.ApproximateTokenEnabled, "ApproximateTokenEnabled": &config.ApproximateTokenEnabled,
"LogConsumeEnabled": &common.LogConsumeEnabled, "LogConsumeEnabled": &config.LogConsumeEnabled,
"DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled, "DisplayInCurrencyEnabled": &config.DisplayInCurrencyEnabled,
"DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled, "DisplayTokenStatEnabled": &config.DisplayTokenStatEnabled,
"MjNotifyEnabled": &common.MjNotifyEnabled, "MjNotifyEnabled": &config.MjNotifyEnabled,
"ChatCacheEnabled": &common.ChatCacheEnabled, "ChatCacheEnabled": &config.ChatCacheEnabled,
} }
var optionStringMap = map[string]*string{ var optionStringMap = map[string]*string{
"SMTPServer": &common.SMTPServer, "SMTPServer": &config.SMTPServer,
"SMTPAccount": &common.SMTPAccount, "SMTPAccount": &config.SMTPAccount,
"SMTPFrom": &common.SMTPFrom, "SMTPFrom": &config.SMTPFrom,
"SMTPToken": &common.SMTPToken, "SMTPToken": &config.SMTPToken,
"ServerAddress": &common.ServerAddress, "ServerAddress": &config.ServerAddress,
"GitHubClientId": &common.GitHubClientId, "GitHubClientId": &config.GitHubClientId,
"GitHubClientSecret": &common.GitHubClientSecret, "GitHubClientSecret": &config.GitHubClientSecret,
"Footer": &common.Footer, "Footer": &config.Footer,
"SystemName": &common.SystemName, "SystemName": &config.SystemName,
"Logo": &common.Logo, "Logo": &config.Logo,
"WeChatServerAddress": &common.WeChatServerAddress, "WeChatServerAddress": &config.WeChatServerAddress,
"WeChatServerToken": &common.WeChatServerToken, "WeChatServerToken": &config.WeChatServerToken,
"WeChatAccountQRCodeImageURL": &common.WeChatAccountQRCodeImageURL, "WeChatAccountQRCodeImageURL": &config.WeChatAccountQRCodeImageURL,
"TurnstileSiteKey": &common.TurnstileSiteKey, "TurnstileSiteKey": &config.TurnstileSiteKey,
"TurnstileSecretKey": &common.TurnstileSecretKey, "TurnstileSecretKey": &config.TurnstileSecretKey,
"TopUpLink": &common.TopUpLink, "TopUpLink": &config.TopUpLink,
"ChatLink": &common.ChatLink, "ChatLink": &config.ChatLink,
"ChatLinks": &common.ChatLinks, "ChatLinks": &config.ChatLinks,
"LarkClientId": &common.LarkClientId, "LarkClientId": &config.LarkClientId,
"LarkClientSecret": &common.LarkClientSecret, "LarkClientSecret": &config.LarkClientSecret,
} }
func updateOptionMap(key string, value string) (err error) { func updateOptionMap(key string, value string) (err error) {
common.OptionMapRWMutex.Lock() config.OptionMapRWMutex.Lock()
defer common.OptionMapRWMutex.Unlock() defer config.OptionMapRWMutex.Unlock()
common.OptionMap[key] = value config.OptionMap[key] = value
if ptr, ok := optionIntMap[key]; ok { if ptr, ok := optionIntMap[key]; ok {
*ptr, _ = strconv.Atoi(value) *ptr, _ = strconv.Atoi(value)
return return
@ -196,13 +197,13 @@ func updateOptionMap(key string, value string) (err error) {
switch key { switch key {
case "EmailDomainWhitelist": case "EmailDomainWhitelist":
common.EmailDomainWhitelist = strings.Split(value, ",") config.EmailDomainWhitelist = strings.Split(value, ",")
case "GroupRatio": case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value) err = common.UpdateGroupRatioByJSONString(value)
case "ChannelDisableThreshold": case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit": case "QuotaPerUnit":
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
} }
return err return err
} }

View File

@ -1,7 +1,7 @@
package model package model
import ( import (
"one-api/common" "one-api/common/config"
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
"gorm.io/gorm" "gorm.io/gorm"
@ -114,211 +114,211 @@ type ModelType struct {
func GetDefaultPrice() []*Price { func GetDefaultPrice() []*Price {
ModelTypes := map[string]ModelType{ ModelTypes := map[string]ModelType{
// $0.03 / 1K tokens $0.06 / 1K tokens // $0.03 / 1K tokens $0.06 / 1K tokens
"gpt-4": {[]float64{15, 30}, common.ChannelTypeOpenAI}, "gpt-4": {[]float64{15, 30}, config.ChannelTypeOpenAI},
"gpt-4-0314": {[]float64{15, 30}, common.ChannelTypeOpenAI}, "gpt-4-0314": {[]float64{15, 30}, config.ChannelTypeOpenAI},
"gpt-4-0613": {[]float64{15, 30}, common.ChannelTypeOpenAI}, "gpt-4-0613": {[]float64{15, 30}, config.ChannelTypeOpenAI},
// $0.06 / 1K tokens $0.12 / 1K tokens // $0.06 / 1K tokens $0.12 / 1K tokens
"gpt-4-32k": {[]float64{30, 60}, common.ChannelTypeOpenAI}, "gpt-4-32k": {[]float64{30, 60}, config.ChannelTypeOpenAI},
"gpt-4-32k-0314": {[]float64{30, 60}, common.ChannelTypeOpenAI}, "gpt-4-32k-0314": {[]float64{30, 60}, config.ChannelTypeOpenAI},
"gpt-4-32k-0613": {[]float64{30, 60}, common.ChannelTypeOpenAI}, "gpt-4-32k-0613": {[]float64{30, 60}, config.ChannelTypeOpenAI},
// $0.01 / 1K tokens $0.03 / 1K tokens // $0.01 / 1K tokens $0.03 / 1K tokens
"gpt-4-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, "gpt-4-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI},
"gpt-4-turbo": {[]float64{5, 15}, common.ChannelTypeOpenAI}, "gpt-4-turbo": {[]float64{5, 15}, config.ChannelTypeOpenAI},
"gpt-4-turbo-2024-04-09": {[]float64{5, 15}, common.ChannelTypeOpenAI}, "gpt-4-turbo-2024-04-09": {[]float64{5, 15}, config.ChannelTypeOpenAI},
"gpt-4-1106-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, "gpt-4-1106-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI},
"gpt-4-0125-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, "gpt-4-0125-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI},
"gpt-4-turbo-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, "gpt-4-turbo-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI},
"gpt-4-vision-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, "gpt-4-vision-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI},
// $0.005 / 1K tokens $0.015 / 1K tokens // $0.005 / 1K tokens $0.015 / 1K tokens
"gpt-4o": {[]float64{2.5, 7.5}, common.ChannelTypeOpenAI}, "gpt-4o": {[]float64{2.5, 7.5}, config.ChannelTypeOpenAI},
// $0.0005 / 1K tokens $0.0015 / 1K tokens // $0.0005 / 1K tokens $0.0015 / 1K tokens
"gpt-3.5-turbo": {[]float64{0.25, 0.75}, common.ChannelTypeOpenAI}, "gpt-3.5-turbo": {[]float64{0.25, 0.75}, config.ChannelTypeOpenAI},
"gpt-3.5-turbo-0125": {[]float64{0.25, 0.75}, common.ChannelTypeOpenAI}, "gpt-3.5-turbo-0125": {[]float64{0.25, 0.75}, config.ChannelTypeOpenAI},
// $0.0015 / 1K tokens $0.002 / 1K tokens // $0.0015 / 1K tokens $0.002 / 1K tokens
"gpt-3.5-turbo-0301": {[]float64{0.75, 1}, common.ChannelTypeOpenAI}, "gpt-3.5-turbo-0301": {[]float64{0.75, 1}, config.ChannelTypeOpenAI},
"gpt-3.5-turbo-0613": {[]float64{0.75, 1}, common.ChannelTypeOpenAI}, "gpt-3.5-turbo-0613": {[]float64{0.75, 1}, config.ChannelTypeOpenAI},
"gpt-3.5-turbo-instruct": {[]float64{0.75, 1}, common.ChannelTypeOpenAI}, "gpt-3.5-turbo-instruct": {[]float64{0.75, 1}, config.ChannelTypeOpenAI},
// $0.003 / 1K tokens $0.004 / 1K tokens // $0.003 / 1K tokens $0.004 / 1K tokens
"gpt-3.5-turbo-16k": {[]float64{1.5, 2}, common.ChannelTypeOpenAI}, "gpt-3.5-turbo-16k": {[]float64{1.5, 2}, config.ChannelTypeOpenAI},
"gpt-3.5-turbo-16k-0613": {[]float64{1.5, 2}, common.ChannelTypeOpenAI}, "gpt-3.5-turbo-16k-0613": {[]float64{1.5, 2}, config.ChannelTypeOpenAI},
// $0.001 / 1K tokens $0.002 / 1K tokens // $0.001 / 1K tokens $0.002 / 1K tokens
"gpt-3.5-turbo-1106": {[]float64{0.5, 1}, common.ChannelTypeOpenAI}, "gpt-3.5-turbo-1106": {[]float64{0.5, 1}, config.ChannelTypeOpenAI},
// $0.0020 / 1K tokens // $0.0020 / 1K tokens
"davinci-002": {[]float64{1, 1}, common.ChannelTypeOpenAI}, "davinci-002": {[]float64{1, 1}, config.ChannelTypeOpenAI},
// $0.0004 / 1K tokens // $0.0004 / 1K tokens
"babbage-002": {[]float64{0.2, 0.2}, common.ChannelTypeOpenAI}, "babbage-002": {[]float64{0.2, 0.2}, config.ChannelTypeOpenAI},
// $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"whisper-1": {[]float64{15, 15}, common.ChannelTypeOpenAI}, "whisper-1": {[]float64{15, 15}, config.ChannelTypeOpenAI},
// $0.015 / 1K characters // $0.015 / 1K characters
"tts-1": {[]float64{7.5, 7.5}, common.ChannelTypeOpenAI}, "tts-1": {[]float64{7.5, 7.5}, config.ChannelTypeOpenAI},
"tts-1-1106": {[]float64{7.5, 7.5}, common.ChannelTypeOpenAI}, "tts-1-1106": {[]float64{7.5, 7.5}, config.ChannelTypeOpenAI},
// $0.030 / 1K characters // $0.030 / 1K characters
"tts-1-hd": {[]float64{15, 15}, common.ChannelTypeOpenAI}, "tts-1-hd": {[]float64{15, 15}, config.ChannelTypeOpenAI},
"tts-1-hd-1106": {[]float64{15, 15}, common.ChannelTypeOpenAI}, "tts-1-hd-1106": {[]float64{15, 15}, config.ChannelTypeOpenAI},
"text-embedding-ada-002": {[]float64{0.05, 0.05}, common.ChannelTypeOpenAI}, "text-embedding-ada-002": {[]float64{0.05, 0.05}, config.ChannelTypeOpenAI},
// $0.00002 / 1K tokens // $0.00002 / 1K tokens
"text-embedding-3-small": {[]float64{0.01, 0.01}, common.ChannelTypeOpenAI}, "text-embedding-3-small": {[]float64{0.01, 0.01}, config.ChannelTypeOpenAI},
// $0.00013 / 1K tokens // $0.00013 / 1K tokens
"text-embedding-3-large": {[]float64{0.065, 0.065}, common.ChannelTypeOpenAI}, "text-embedding-3-large": {[]float64{0.065, 0.065}, config.ChannelTypeOpenAI},
"text-moderation-stable": {[]float64{0.1, 0.1}, common.ChannelTypeOpenAI}, "text-moderation-stable": {[]float64{0.1, 0.1}, config.ChannelTypeOpenAI},
"text-moderation-latest": {[]float64{0.1, 0.1}, common.ChannelTypeOpenAI}, "text-moderation-latest": {[]float64{0.1, 0.1}, config.ChannelTypeOpenAI},
// $0.016 - $0.020 / image // $0.016 - $0.020 / image
"dall-e-2": {[]float64{8, 8}, common.ChannelTypeOpenAI}, "dall-e-2": {[]float64{8, 8}, config.ChannelTypeOpenAI},
// $0.040 - $0.120 / image // $0.040 - $0.120 / image
"dall-e-3": {[]float64{20, 20}, common.ChannelTypeOpenAI}, "dall-e-3": {[]float64{20, 20}, config.ChannelTypeOpenAI},
// $0.80/million tokens $2.40/million tokens // $0.80/million tokens $2.40/million tokens
"claude-instant-1.2": {[]float64{0.4, 1.2}, common.ChannelTypeAnthropic}, "claude-instant-1.2": {[]float64{0.4, 1.2}, config.ChannelTypeAnthropic},
// $8.00/million tokens $24.00/million tokens // $8.00/million tokens $24.00/million tokens
"claude-2.0": {[]float64{4, 12}, common.ChannelTypeAnthropic}, "claude-2.0": {[]float64{4, 12}, config.ChannelTypeAnthropic},
"claude-2.1": {[]float64{4, 12}, common.ChannelTypeAnthropic}, "claude-2.1": {[]float64{4, 12}, config.ChannelTypeAnthropic},
// $15 / M $75 / M // $15 / M $75 / M
"claude-3-opus-20240229": {[]float64{7.5, 22.5}, common.ChannelTypeAnthropic}, "claude-3-opus-20240229": {[]float64{7.5, 22.5}, config.ChannelTypeAnthropic},
// $3 / M $15 / M // $3 / M $15 / M
"claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, common.ChannelTypeAnthropic}, "claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, config.ChannelTypeAnthropic},
// $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens // $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens
"claude-3-haiku-20240307": {[]float64{0.125, 0.625}, common.ChannelTypeAnthropic}, "claude-3-haiku-20240307": {[]float64{0.125, 0.625}, config.ChannelTypeAnthropic},
// ¥0.004 / 1k tokens ¥0.008 / 1k tokens // ¥0.004 / 1k tokens ¥0.008 / 1k tokens
"ERNIE-Speed": {[]float64{0.2857, 0.5714}, common.ChannelTypeBaidu}, "ERNIE-Speed": {[]float64{0.2857, 0.5714}, config.ChannelTypeBaidu},
// ¥0.012 / 1k tokens ¥0.012 / 1k tokens // ¥0.012 / 1k tokens ¥0.012 / 1k tokens
"ERNIE-Bot": {[]float64{0.8572, 0.8572}, common.ChannelTypeBaidu}, "ERNIE-Bot": {[]float64{0.8572, 0.8572}, config.ChannelTypeBaidu},
"ERNIE-3.5-8K": {[]float64{0.8572, 0.8572}, common.ChannelTypeBaidu}, "ERNIE-3.5-8K": {[]float64{0.8572, 0.8572}, config.ChannelTypeBaidu},
// 0.024元/千tokens 0.048元/千tokens // 0.024元/千tokens 0.048元/千tokens
"ERNIE-Bot-8k": {[]float64{1.7143, 3.4286}, common.ChannelTypeBaidu}, "ERNIE-Bot-8k": {[]float64{1.7143, 3.4286}, config.ChannelTypeBaidu},
// ¥0.008 / 1k tokens ¥0.008 / 1k tokens // ¥0.008 / 1k tokens ¥0.008 / 1k tokens
"ERNIE-Bot-turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeBaidu}, "ERNIE-Bot-turbo": {[]float64{0.5715, 0.5715}, config.ChannelTypeBaidu},
// ¥0.12 / 1k tokens ¥0.12 / 1k tokens // ¥0.12 / 1k tokens ¥0.12 / 1k tokens
"ERNIE-Bot-4": {[]float64{8.572, 8.572}, common.ChannelTypeBaidu}, "ERNIE-Bot-4": {[]float64{8.572, 8.572}, config.ChannelTypeBaidu},
"ERNIE-4.0": {[]float64{8.572, 8.572}, common.ChannelTypeBaidu}, "ERNIE-4.0": {[]float64{8.572, 8.572}, config.ChannelTypeBaidu},
// ¥0.002 / 1k tokens // ¥0.002 / 1k tokens
"Embedding-V1": {[]float64{0.1429, 0.1429}, common.ChannelTypeBaidu}, "Embedding-V1": {[]float64{0.1429, 0.1429}, config.ChannelTypeBaidu},
// ¥0.004 / 1k tokens // ¥0.004 / 1k tokens
"BLOOMZ-7B": {[]float64{0.2857, 0.2857}, common.ChannelTypeBaidu}, "BLOOMZ-7B": {[]float64{0.2857, 0.2857}, config.ChannelTypeBaidu},
"PaLM-2": {[]float64{1, 1}, common.ChannelTypePaLM}, "PaLM-2": {[]float64{1, 1}, config.ChannelTypePaLM},
// $0.50 / 1 million tokens $1.50 / 1 million tokens // $0.50 / 1 million tokens $1.50 / 1 million tokens
// 0.0005$ / 1k tokens 0.0015$ / 1k tokens // 0.0005$ / 1k tokens 0.0015$ / 1k tokens
"gemini-pro": {[]float64{0.25, 0.75}, common.ChannelTypeGemini}, "gemini-pro": {[]float64{0.25, 0.75}, config.ChannelTypeGemini},
"gemini-pro-vision": {[]float64{0.25, 0.75}, common.ChannelTypeGemini}, "gemini-pro-vision": {[]float64{0.25, 0.75}, config.ChannelTypeGemini},
"gemini-1.0-pro": {[]float64{0.25, 0.75}, common.ChannelTypeGemini}, "gemini-1.0-pro": {[]float64{0.25, 0.75}, config.ChannelTypeGemini},
// $7 / 1 million tokens $21 / 1 million tokens // $7 / 1 million tokens $21 / 1 million tokens
"gemini-1.5-pro": {[]float64{1.75, 5.25}, common.ChannelTypeGemini}, "gemini-1.5-pro": {[]float64{1.75, 5.25}, config.ChannelTypeGemini},
"gemini-1.5-pro-latest": {[]float64{1.75, 5.25}, common.ChannelTypeGemini}, "gemini-1.5-pro-latest": {[]float64{1.75, 5.25}, config.ChannelTypeGemini},
"gemini-1.5-flash": {[]float64{0.175, 0.265}, common.ChannelTypeGemini}, "gemini-1.5-flash": {[]float64{0.175, 0.265}, config.ChannelTypeGemini},
"gemini-1.5-flash-latest": {[]float64{0.175, 0.265}, common.ChannelTypeGemini}, "gemini-1.5-flash-latest": {[]float64{0.175, 0.265}, config.ChannelTypeGemini},
"gemini-ultra": {[]float64{1, 1}, common.ChannelTypeGemini}, "gemini-ultra": {[]float64{1, 1}, config.ChannelTypeGemini},
// ¥0.005 / 1k tokens // ¥0.005 / 1k tokens
"glm-3-turbo": {[]float64{0.3572, 0.3572}, common.ChannelTypeZhipu}, "glm-3-turbo": {[]float64{0.3572, 0.3572}, config.ChannelTypeZhipu},
// ¥0.1 / 1k tokens // ¥0.1 / 1k tokens
"glm-4": {[]float64{7.143, 7.143}, common.ChannelTypeZhipu}, "glm-4": {[]float64{7.143, 7.143}, config.ChannelTypeZhipu},
"glm-4v": {[]float64{7.143, 7.143}, common.ChannelTypeZhipu}, "glm-4v": {[]float64{7.143, 7.143}, config.ChannelTypeZhipu},
// ¥0.0005 / 1k tokens // ¥0.0005 / 1k tokens
"embedding-2": {[]float64{0.0357, 0.0357}, common.ChannelTypeZhipu}, "embedding-2": {[]float64{0.0357, 0.0357}, config.ChannelTypeZhipu},
// ¥0.25 / 1张图片 // ¥0.25 / 1张图片
"cogview-3": {[]float64{17.8571, 17.8571}, common.ChannelTypeZhipu}, "cogview-3": {[]float64{17.8571, 17.8571}, config.ChannelTypeZhipu},
// ¥0.008 / 1k tokens // ¥0.008 / 1k tokens
"qwen-turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeAli}, "qwen-turbo": {[]float64{0.5715, 0.5715}, config.ChannelTypeAli},
// ¥0.02 / 1k tokens // ¥0.02 / 1k tokens
"qwen-plus": {[]float64{1.4286, 1.4286}, common.ChannelTypeAli}, "qwen-plus": {[]float64{1.4286, 1.4286}, config.ChannelTypeAli},
"qwen-vl-max": {[]float64{1.4286, 1.4286}, common.ChannelTypeAli}, "qwen-vl-max": {[]float64{1.4286, 1.4286}, config.ChannelTypeAli},
// 0.12元/1,000tokens // 0.12元/1,000tokens
"qwen-max": {[]float64{8.5714, 8.5714}, common.ChannelTypeAli}, "qwen-max": {[]float64{8.5714, 8.5714}, config.ChannelTypeAli},
"qwen-max-longcontext": {[]float64{8.5714, 8.5714}, common.ChannelTypeAli}, "qwen-max-longcontext": {[]float64{8.5714, 8.5714}, config.ChannelTypeAli},
// 0.008元/1,000tokens // 0.008元/1,000tokens
"qwen-vl-plus": {[]float64{0.5715, 0.5715}, common.ChannelTypeAli}, "qwen-vl-plus": {[]float64{0.5715, 0.5715}, config.ChannelTypeAli},
// ¥0.0007 / 1k tokens // ¥0.0007 / 1k tokens
"text-embedding-v1": {[]float64{0.05, 0.05}, common.ChannelTypeAli}, "text-embedding-v1": {[]float64{0.05, 0.05}, config.ChannelTypeAli},
// ¥0.018 / 1k tokens // ¥0.018 / 1k tokens
"SparkDesk": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, "SparkDesk": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei},
"SparkDesk-v1.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, "SparkDesk-v1.1": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei},
"SparkDesk-v2.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, "SparkDesk-v2.1": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei},
"SparkDesk-v3.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, "SparkDesk-v3.1": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei},
"SparkDesk-v3.5": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, "SparkDesk-v3.5": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei},
// ¥0.012 / 1k tokens // ¥0.012 / 1k tokens
"360GPT_S2_V9": {[]float64{0.8572, 0.8572}, common.ChannelType360}, "360GPT_S2_V9": {[]float64{0.8572, 0.8572}, config.ChannelType360},
// ¥0.001 / 1k tokens // ¥0.001 / 1k tokens
"embedding-bert-512-v1": {[]float64{0.0715, 0.0715}, common.ChannelType360}, "embedding-bert-512-v1": {[]float64{0.0715, 0.0715}, config.ChannelType360},
"embedding_s1_v1": {[]float64{0.0715, 0.0715}, common.ChannelType360}, "embedding_s1_v1": {[]float64{0.0715, 0.0715}, config.ChannelType360},
"semantic_similarity_s1_v1": {[]float64{0.0715, 0.0715}, common.ChannelType360}, "semantic_similarity_s1_v1": {[]float64{0.0715, 0.0715}, config.ChannelType360},
// ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"hunyuan": {[]float64{7.143, 7.143}, common.ChannelTypeTencent}, "hunyuan": {[]float64{7.143, 7.143}, config.ChannelTypeTencent},
// https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// ¥0.01 / 1k tokens // ¥0.01 / 1k tokens
"ChatStd": {[]float64{0.7143, 0.7143}, common.ChannelTypeTencent}, "ChatStd": {[]float64{0.7143, 0.7143}, config.ChannelTypeTencent},
//¥0.1 / 1k tokens //¥0.1 / 1k tokens
"ChatPro": {[]float64{7.143, 7.143}, common.ChannelTypeTencent}, "ChatPro": {[]float64{7.143, 7.143}, config.ChannelTypeTencent},
"Baichuan2-Turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeBaichuan}, // ¥0.008 / 1k tokens "Baichuan2-Turbo": {[]float64{0.5715, 0.5715}, config.ChannelTypeBaichuan}, // ¥0.008 / 1k tokens
"Baichuan2-Turbo-192k": {[]float64{1.143, 1.143}, common.ChannelTypeBaichuan}, // ¥0.016 / 1k tokens "Baichuan2-Turbo-192k": {[]float64{1.143, 1.143}, config.ChannelTypeBaichuan}, // ¥0.016 / 1k tokens
"Baichuan2-53B": {[]float64{1.4286, 1.4286}, common.ChannelTypeBaichuan}, // ¥0.02 / 1k tokens "Baichuan2-53B": {[]float64{1.4286, 1.4286}, config.ChannelTypeBaichuan}, // ¥0.02 / 1k tokens
"Baichuan-Text-Embedding": {[]float64{0.0357, 0.0357}, common.ChannelTypeBaichuan}, // ¥0.0005 / 1k tokens "Baichuan-Text-Embedding": {[]float64{0.0357, 0.0357}, config.ChannelTypeBaichuan}, // ¥0.0005 / 1k tokens
"abab5.5s-chat": {[]float64{0.3572, 0.3572}, common.ChannelTypeMiniMax}, // ¥0.005 / 1k tokens "abab5.5s-chat": {[]float64{0.3572, 0.3572}, config.ChannelTypeMiniMax}, // ¥0.005 / 1k tokens
"abab5.5-chat": {[]float64{1.0714, 1.0714}, common.ChannelTypeMiniMax}, // ¥0.015 / 1k tokens "abab5.5-chat": {[]float64{1.0714, 1.0714}, config.ChannelTypeMiniMax}, // ¥0.015 / 1k tokens
"abab6-chat": {[]float64{14.2857, 14.2857}, common.ChannelTypeMiniMax}, // ¥0.2 / 1k tokens "abab6-chat": {[]float64{14.2857, 14.2857}, config.ChannelTypeMiniMax}, // ¥0.2 / 1k tokens
"embo-01": {[]float64{0.0357, 0.0357}, common.ChannelTypeMiniMax}, // ¥0.0005 / 1k tokens "embo-01": {[]float64{0.0357, 0.0357}, config.ChannelTypeMiniMax}, // ¥0.0005 / 1k tokens
"deepseek-coder": {[]float64{0.75, 0.75}, common.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens "deepseek-coder": {[]float64{0.75, 0.75}, config.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
"deepseek-chat": {[]float64{0.75, 0.75}, common.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens "deepseek-chat": {[]float64{0.75, 0.75}, config.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
"moonshot-v1-8k": {[]float64{0.8572, 0.8572}, common.ChannelTypeMoonshot}, // ¥0.012 / 1K tokens "moonshot-v1-8k": {[]float64{0.8572, 0.8572}, config.ChannelTypeMoonshot}, // ¥0.012 / 1K tokens
"moonshot-v1-32k": {[]float64{1.7143, 1.7143}, common.ChannelTypeMoonshot}, // ¥0.024 / 1K tokens "moonshot-v1-32k": {[]float64{1.7143, 1.7143}, config.ChannelTypeMoonshot}, // ¥0.024 / 1K tokens
"moonshot-v1-128k": {[]float64{4.2857, 4.2857}, common.ChannelTypeMoonshot}, // ¥0.06 / 1K tokens "moonshot-v1-128k": {[]float64{4.2857, 4.2857}, config.ChannelTypeMoonshot}, // ¥0.06 / 1K tokens
"open-mistral-7b": {[]float64{0.125, 0.125}, common.ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens "open-mistral-7b": {[]float64{0.125, 0.125}, config.ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens
"open-mixtral-8x7b": {[]float64{0.35, 0.35}, common.ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens "open-mixtral-8x7b": {[]float64{0.35, 0.35}, config.ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens
"mistral-small-latest": {[]float64{1, 3}, common.ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens "mistral-small-latest": {[]float64{1, 3}, config.ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens
"mistral-medium-latest": {[]float64{1.35, 4.05}, common.ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens "mistral-medium-latest": {[]float64{1.35, 4.05}, config.ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens
"mistral-large-latest": {[]float64{4, 12}, common.ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens "mistral-large-latest": {[]float64{4, 12}, config.ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens
"mistral-embed": {[]float64{0.05, 0.05}, common.ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens "mistral-embed": {[]float64{0.05, 0.05}, config.ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens
// $0.70/$0.80 /1M Tokens 0.0007$ / 1k tokens // $0.70/$0.80 /1M Tokens 0.0007$ / 1k tokens
"llama2-70b-4096": {[]float64{0.35, 0.4}, common.ChannelTypeGroq}, "llama2-70b-4096": {[]float64{0.35, 0.4}, config.ChannelTypeGroq},
// $0.10/$0.10 /1M Tokens 0.0001$ / 1k tokens // $0.10/$0.10 /1M Tokens 0.0001$ / 1k tokens
"llama2-7b-2048": {[]float64{0.05, 0.05}, common.ChannelTypeGroq}, "llama2-7b-2048": {[]float64{0.05, 0.05}, config.ChannelTypeGroq},
"gemma-7b-it": {[]float64{0.05, 0.05}, common.ChannelTypeGroq}, "gemma-7b-it": {[]float64{0.05, 0.05}, config.ChannelTypeGroq},
// $0.27/$0.27 /1M Tokens 0.00027$ / 1k tokens // $0.27/$0.27 /1M Tokens 0.00027$ / 1k tokens
"mixtral-8x7b-32768": {[]float64{0.135, 0.135}, common.ChannelTypeGroq}, "mixtral-8x7b-32768": {[]float64{0.135, 0.135}, config.ChannelTypeGroq},
// 2.5 元 / 1M tokens 0.0025 / 1k tokens // 2.5 元 / 1M tokens 0.0025 / 1k tokens
"yi-34b-chat-0205": {[]float64{0.1786, 0.1786}, common.ChannelTypeLingyi}, "yi-34b-chat-0205": {[]float64{0.1786, 0.1786}, config.ChannelTypeLingyi},
// 12 元 / 1M tokens 0.012 / 1k tokens // 12 元 / 1M tokens 0.012 / 1k tokens
"yi-34b-chat-200k": {[]float64{0.8571, 0.8571}, common.ChannelTypeLingyi}, "yi-34b-chat-200k": {[]float64{0.8571, 0.8571}, config.ChannelTypeLingyi},
// 6 元 / 1M tokens 0.006 / 1k tokens // 6 元 / 1M tokens 0.006 / 1k tokens
"yi-vl-plus": {[]float64{0.4286, 0.4286}, common.ChannelTypeLingyi}, "yi-vl-plus": {[]float64{0.4286, 0.4286}, config.ChannelTypeLingyi},
"@cf/stabilityai/stable-diffusion-xl-base-1.0": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, "@cf/stabilityai/stable-diffusion-xl-base-1.0": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
"@cf/lykon/dreamshaper-8-lcm": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, "@cf/lykon/dreamshaper-8-lcm": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
"@cf/bytedance/stable-diffusion-xl-lightning": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, "@cf/bytedance/stable-diffusion-xl-lightning": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
"@cf/qwen/qwen1.5-7b-chat-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, "@cf/qwen/qwen1.5-7b-chat-awq": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
"@cf/qwen/qwen1.5-14b-chat-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, "@cf/qwen/qwen1.5-14b-chat-awq": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
"@hf/thebloke/deepseek-coder-6.7b-base-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, "@hf/thebloke/deepseek-coder-6.7b-base-awq": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
"@hf/google/gemma-7b-it": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, "@hf/google/gemma-7b-it": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
"@hf/thebloke/llama-2-13b-chat-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, "@hf/thebloke/llama-2-13b-chat-awq": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
"@cf/openai/whisper": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, "@cf/openai/whisper": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
//$0.50 /1M TOKENS $1.50/1M TOKENS //$0.50 /1M TOKENS $1.50/1M TOKENS
"command-r": {[]float64{0.25, 0.75}, common.ChannelTypeCohere}, "command-r": {[]float64{0.25, 0.75}, config.ChannelTypeCohere},
//$3 /1M TOKENS $15/1M TOKENS //$3 /1M TOKENS $15/1M TOKENS
"command-r-plus": {[]float64{1.5, 7.5}, common.ChannelTypeCohere}, "command-r-plus": {[]float64{1.5, 7.5}, config.ChannelTypeCohere},
// 0.065 // 0.065
"sd3": {[]float64{32.5, 32.5}, common.ChannelTypeStabilityAI}, "sd3": {[]float64{32.5, 32.5}, config.ChannelTypeStabilityAI},
// 0.04 // 0.04
"sd3-turbo": {[]float64{20, 20}, common.ChannelTypeStabilityAI}, "sd3-turbo": {[]float64{20, 20}, config.ChannelTypeStabilityAI},
// 0.03 // 0.03
"stable-image-core": {[]float64{15, 15}, common.ChannelTypeStabilityAI}, "stable-image-core": {[]float64{15, 15}, config.ChannelTypeStabilityAI},
// hunyuan // hunyuan
"hunyuan-lite": {[]float64{0, 0}, common.ChannelTypeHunyuan}, "hunyuan-lite": {[]float64{0, 0}, config.ChannelTypeHunyuan},
"hunyuan-standard": {[]float64{0.3214, 0.3571}, common.ChannelTypeHunyuan}, "hunyuan-standard": {[]float64{0.3214, 0.3571}, config.ChannelTypeHunyuan},
"hunyuan-standard-256k": {[]float64{1.0714, 4.2857}, common.ChannelTypeHunyuan}, "hunyuan-standard-256k": {[]float64{1.0714, 4.2857}, config.ChannelTypeHunyuan},
"hunyuan-pro": {[]float64{2.1429, 7.1429}, common.ChannelTypeHunyuan}, "hunyuan-pro": {[]float64{2.1429, 7.1429}, config.ChannelTypeHunyuan},
} }
var prices []*Price var prices []*Price
@ -355,7 +355,7 @@ func GetDefaultPrice() []*Price {
prices = append(prices, &Price{ prices = append(prices, &Price{
Model: model, Model: model,
Type: TimesPriceType, Type: TimesPriceType,
ChannelType: common.ChannelTypeMidjourney, ChannelType: config.ChannelTypeMidjourney,
Input: mjPrice, Input: mjPrice,
Output: mjPrice, Output: mjPrice,
}) })

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/utils" "one-api/common/utils"
"gorm.io/gorm" "gorm.io/gorm"
@ -69,7 +70,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil { if err != nil {
return errors.New("无效的兑换码") return errors.New("无效的兑换码")
} }
if redemption.Status != common.RedemptionCodeStatusEnabled { if redemption.Status != config.RedemptionCodeStatusEnabled {
return errors.New("该兑换码已被使用") return errors.New("该兑换码已被使用")
} }
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
@ -77,7 +78,7 @@ func Redeem(key string, userId int) (quota int, err error) {
return err return err
} }
redemption.RedeemedTime = utils.GetTimestamp() redemption.RedeemedTime = utils.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed redemption.Status = config.RedemptionCodeStatusUsed
err = tx.Save(redemption).Error err = tx.Save(redemption).Error
return err return err
}) })

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/common/stmp" "one-api/common/stmp"
"one-api/common/utils" "one-api/common/utils"
@ -49,7 +50,7 @@ func GetUserTokensList(userId int, params *GenericParams) (*DataResult[Token], e
// 获取状态为可用的令牌 // 获取状态为可用的令牌
func GetUserEnabledTokens(userId int) (tokens []*Token, err error) { func GetUserEnabledTokens(userId int) (tokens []*Token, err error) {
err = DB.Where("user_id = ? and status = ?", userId, common.TokenStatusEnabled).Find(&tokens).Error err = DB.Where("user_id = ? and status = ?", userId, config.TokenStatusEnabled).Find(&tokens).Error
return tokens, err return tokens, err
} }
@ -65,17 +66,17 @@ func ValidateUserToken(key string) (token *Token, err error) {
} }
return nil, errors.New("令牌验证失败") return nil, errors.New("令牌验证失败")
} }
if token.Status == common.TokenStatusExhausted { if token.Status == config.TokenStatusExhausted {
return nil, errors.New("该令牌额度已用尽") return nil, errors.New("该令牌额度已用尽")
} else if token.Status == common.TokenStatusExpired { } else if token.Status == config.TokenStatusExpired {
return nil, errors.New("该令牌已过期") return nil, errors.New("该令牌已过期")
} }
if token.Status != common.TokenStatusEnabled { if token.Status != config.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用") return nil, errors.New("该令牌状态不可用")
} }
if token.ExpiredTime != -1 && token.ExpiredTime < utils.GetTimestamp() { if token.ExpiredTime != -1 && token.ExpiredTime < utils.GetTimestamp() {
if !common.RedisEnabled { if !common.RedisEnabled {
token.Status = common.TokenStatusExpired token.Status = config.TokenStatusExpired
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
logger.SysError("failed to update token status" + err.Error()) logger.SysError("failed to update token status" + err.Error())
@ -86,7 +87,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled { if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted // in this case, we can make sure the token is exhausted
token.Status = common.TokenStatusExhausted token.Status = config.TokenStatusExhausted
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
logger.SysError("failed to update token status" + err.Error()) logger.SysError("failed to update token status" + err.Error())
@ -128,7 +129,7 @@ func GetTokenByName(name string, user_id int) (*Token, error) {
} }
func (token *Token) Insert() error { func (token *Token) Insert() error {
if token.ChatCache && !common.ChatCacheEnabled { if token.ChatCache && !config.ChatCacheEnabled {
token.ChatCache = false token.ChatCache = false
} }
@ -138,7 +139,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values // Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error { func (token *Token) Update() error {
if token.ChatCache && !common.ChatCacheEnabled { if token.ChatCache && !config.ChatCacheEnabled {
token.ChatCache = false token.ChatCache = false
} }
@ -178,7 +179,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota) addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil return nil
} }
@ -200,7 +201,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil return nil
} }
@ -236,7 +237,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
if userQuota < quota { if userQuota < quota {
return errors.New("用户额度不足") return errors.New("用户额度不足")
} }
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold
noMoreQuota := userQuota-quota <= 0 noMoreQuota := userQuota-quota <= 0
if quotaTooLow || noMoreQuota { if quotaTooLow || noMoreQuota {
go sendQuotaWarningEmail(token.UserId, userQuota, noMoreQuota) go sendQuotaWarningEmail(token.UserId, userQuota, noMoreQuota)

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/common/utils" "one-api/common/utils"
"strings" "strings"
@ -116,7 +117,7 @@ func (user *User) Insert(inviterId int) error {
return err return err
} }
} }
user.Quota = common.QuotaForNewUser user.Quota = config.QuotaForNewUser
user.AccessToken = utils.GetUUID() user.AccessToken = utils.GetUUID()
user.AffCode = utils.GetRandomString(4) user.AffCode = utils.GetRandomString(4)
user.CreatedTime = utils.GetTimestamp() user.CreatedTime = utils.GetTimestamp()
@ -124,17 +125,17 @@ func (user *User) Insert(inviterId int) error {
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
if common.QuotaForNewUser > 0 { if config.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
} }
if inviterId != 0 { if inviterId != 0 {
if common.QuotaForInvitee > 0 { if config.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
} }
if common.QuotaForInviter > 0 { if config.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, common.QuotaForInviter) _ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
} }
} }
return nil return nil
@ -150,8 +151,8 @@ func (user *User) Update(updatePassword bool) error {
} }
err = DB.Model(user).Updates(user).Error err = DB.Model(user).Updates(user).Error
if err == nil && user.Role == common.RoleRootUser { if err == nil && user.Role == config.RoleRootUser {
common.RootUserEmail = user.Email config.RootUserEmail = user.Email
} }
return err return err
@ -196,7 +197,7 @@ func (user *User) ValidateAndFill() (err error) {
} }
} }
okay := common.ValidatePasswordAndHash(password, user.Password) okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled { if !okay || user.Status != config.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁") return errors.New("用户名或密码错误,或用户已被封禁")
} }
return nil return nil
@ -310,7 +311,7 @@ func IsAdmin(userId int) bool {
logger.SysError("no such user " + err.Error()) logger.SysError("no such user " + err.Error())
return false return false
} }
return user.Role >= common.RoleAdminUser return user.Role >= config.RoleAdminUser
} }
func IsUserEnabled(userId int) (bool, error) { func IsUserEnabled(userId int) (bool, error) {
@ -322,7 +323,7 @@ func IsUserEnabled(userId int) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
return user.Status == common.UserStatusEnabled, nil return user.Status == config.UserStatusEnabled, nil
} }
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (user *User) {
@ -366,7 +367,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota) addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil return nil
} }
@ -382,7 +383,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota) addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil return nil
} }
@ -395,12 +396,12 @@ func decreaseUserQuota(id int, quota int) (err error) {
} }
func GetRootUserEmail() (email string) { func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) DB.Model(&User{}).Where("role = ?", config.RoleRootUser).Select("email").Find(&email)
return email return email
} }
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuota, id, quota) addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1) addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return return

View File

@ -1,7 +1,7 @@
package model package model
import ( import (
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"sync" "sync"
"time" "time"
@ -29,7 +29,7 @@ func init() {
func InitBatchUpdater() { func InitBatchUpdater() {
go func() { go func() {
for { for {
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second)
batchUpdate() batchUpdate()
} }
}() }()

View File

@ -2,7 +2,7 @@ package ali_test
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/test" "one-api/common/test"
"one-api/model" "one-api/model"
) )
@ -20,5 +20,5 @@ func setupAliTestServer() (baseUrl string, server *test.ServerTest, teardown fun
} }
func getAliChannel(baseUrl string) model.Channel { func getAliChannel(baseUrl string) model.Channel {
return test.GetChannel(common.ChannelTypeAli, baseUrl, "", "", "") return test.GetChannel(config.ChannelTypeAli, baseUrl, "", "", "")
} }

View File

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils" "one-api/common/utils"
"one-api/types" "one-api/types"
@ -55,7 +56,7 @@ func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRe
} }
func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -3,11 +3,12 @@ package ali
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/types" "one-api/types"
) )
func (p *AliProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { func (p *AliProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/providers/openai" "one-api/providers/openai"
"one-api/types" "one-api/types"
"time" "time"
@ -14,7 +15,7 @@ func (p *AzureProvider) CreateImageGenerations(request *types.ImageRequest) (*ty
return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest) return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest)
} }
req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeImagesGenerations, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/types" "one-api/types"
"strings" "strings"
) )
@ -82,7 +83,7 @@ func (p *AzureSpeechProvider) getRequestBody(request *types.SpeechAudioRequest)
} }
func (p *AzureSpeechProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) { func (p *AzureSpeechProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeAudioSpeech) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeAudioSpeech)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -2,7 +2,7 @@ package baichuan
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/providers/openai" "one-api/providers/openai"
"one-api/types" "one-api/types"
@ -11,7 +11,7 @@ import (
func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request) requestBody := p.getChatRequestBody(request)
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, requestBody) req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, requestBody)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
@ -51,7 +51,7 @@ func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatComplet
request.StreamOptions = nil request.StreamOptions = nil
} }
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -5,6 +5,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/types" "one-api/types"
"strings" "strings"
@ -54,7 +55,7 @@ func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletion
} }
func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -3,11 +3,12 @@ package baidu
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/types" "one-api/types"
) )
func (p *BaiduProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { func (p *BaiduProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/model" "one-api/model"
"one-api/types" "one-api/types"
@ -110,25 +111,25 @@ func (p *BaseProvider) ModelMappingHandler(modelName string) (string, error) {
func (p *BaseProvider) GetAPIUri(relayMode int) string { func (p *BaseProvider) GetAPIUri(relayMode int) string {
switch relayMode { switch relayMode {
case common.RelayModeChatCompletions: case config.RelayModeChatCompletions:
return p.Config.ChatCompletions return p.Config.ChatCompletions
case common.RelayModeCompletions: case config.RelayModeCompletions:
return p.Config.Completions return p.Config.Completions
case common.RelayModeEmbeddings: case config.RelayModeEmbeddings:
return p.Config.Embeddings return p.Config.Embeddings
case common.RelayModeAudioSpeech: case config.RelayModeAudioSpeech:
return p.Config.AudioSpeech return p.Config.AudioSpeech
case common.RelayModeAudioTranscription: case config.RelayModeAudioTranscription:
return p.Config.AudioTranscriptions return p.Config.AudioTranscriptions
case common.RelayModeAudioTranslation: case config.RelayModeAudioTranslation:
return p.Config.AudioTranslations return p.Config.AudioTranslations
case common.RelayModeModerations: case config.RelayModeModerations:
return p.Config.Moderation return p.Config.Moderation
case common.RelayModeImagesGenerations: case config.RelayModeImagesGenerations:
return p.Config.ImagesGenerations return p.Config.ImagesGenerations
case common.RelayModeImagesEdits: case config.RelayModeImagesEdits:
return p.Config.ImagesEdit return p.Config.ImagesEdit
case common.RelayModeImagesVariations: case config.RelayModeImagesVariations:
return p.Config.ImagesVariations return p.Config.ImagesVariations
default: default:
return "" return ""

View File

@ -3,6 +3,7 @@ package bedrock
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/providers/bedrock/category" "one-api/providers/bedrock/category"
"one-api/types" "one-api/types"
@ -48,7 +49,7 @@ func (p *BedrockProvider) getChatRequest(request *types.ChatCompletionRequest) (
return nil, common.StringErrorWrapper("bedrock provider not found", "bedrock_err", http.StatusInternalServerError) return nil, common.StringErrorWrapper("bedrock provider not found", "bedrock_err", http.StatusInternalServerError)
} }
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -6,6 +6,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/image" "one-api/common/image"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils" "one-api/common/utils"
@ -64,7 +65,7 @@ func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletio
} }
func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils" "one-api/common/utils"
"one-api/providers/base" "one-api/providers/base"
@ -56,7 +57,7 @@ func (p *CohereProvider) CreateChatCompletionStream(request *types.ChatCompletio
} }
func (p *CohereProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *CohereProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -6,6 +6,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils" "one-api/common/utils"
"one-api/types" "one-api/types"
@ -56,7 +57,7 @@ func (p *CozeProvider) CreateChatCompletionStream(request *types.ChatCompletionR
} }
func (p *CozeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *CozeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -2,7 +2,7 @@ package groq
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/providers/openai" "one-api/providers/openai"
"one-api/types" "one-api/types"
@ -11,7 +11,7 @@ import (
func (p *GroqProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { func (p *GroqProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
p.getChatRequestBody(request) p.getChatRequestBody(request)
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
@ -51,7 +51,7 @@ func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionR
request.StreamOptions = nil request.StreamOptions = nil
} }
p.getChatRequestBody(request) p.getChatRequestBody(request)
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/types" "one-api/types"
"strings" "strings"
@ -53,7 +54,7 @@ func (p *HunyuanProvider) CreateChatCompletionStream(request *types.ChatCompleti
} }
func (p *HunyuanProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *HunyuanProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
action, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) action, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -7,7 +7,7 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/common/requester" "one-api/common/requester"
"one-api/model" "one-api/model"
@ -48,7 +48,7 @@ func (p *MidjourneyProvider) Send(timeout int, requestURL string) (*MidjourneyRe
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
} }
delete(mapResult, "accountFilter") delete(mapResult, "accountFilter")
if !common.MjNotifyEnabled { if !config.MjNotifyEnabled {
delete(mapResult, "notifyHook") delete(mapResult, "notifyHook")
} }
} }

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/types" "one-api/types"
"strings" "strings"
@ -55,7 +56,7 @@ func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompleti
} }
func (p *MiniMaxProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *MiniMaxProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -3,11 +3,12 @@ package minimax
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/types" "one-api/types"
) )
func (p *MiniMaxProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { func (p *MiniMaxProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -5,6 +5,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/types" "one-api/types"
"strings" "strings"
@ -56,7 +57,7 @@ func (p *MistralProvider) CreateChatCompletionStream(request *types.ChatCompleti
} }
func (p *MistralProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *MistralProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -3,11 +3,12 @@ package mistral
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/types" "one-api/types"
) )
func (p *MistralProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { func (p *MistralProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -3,6 +3,7 @@ package moonshot
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/providers/openai" "one-api/providers/openai"
"one-api/types" "one-api/types"
@ -10,7 +11,7 @@ import (
func (p *MoonshotProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { func (p *MoonshotProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
request.ClearEmptyMessages() request.ClearEmptyMessages()
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
@ -51,7 +52,7 @@ func (p *MoonshotProvider) CreateChatCompletion(request *types.ChatCompletionReq
func (p *MoonshotProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { func (p *MoonshotProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
request.ClearEmptyMessages() request.ClearEmptyMessages()
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/image" "one-api/common/image"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils" "one-api/common/utils"
@ -56,7 +57,7 @@ func (p *OllamaProvider) CreateChatCompletionStream(request *types.ChatCompletio
} }
func (p *OllamaProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *OllamaProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -3,11 +3,12 @@ package ollama
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/types" "one-api/types"
) )
func (p *OllamaProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { func (p *OllamaProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/model" "one-api/model"
"one-api/types" "one-api/types"
@ -32,11 +33,11 @@ func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInter
// 创建 OpenAIProvider // 创建 OpenAIProvider
// https://platform.openai.com/docs/api-reference/introduction // https://platform.openai.com/docs/api-reference/introduction
func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider { func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider {
config := getOpenAIConfig(baseURL) openaiConfig := getOpenAIConfig(baseURL)
OpenAIProvider := &OpenAIProvider{ OpenAIProvider := &OpenAIProvider{
BaseProvider: base.BaseProvider{ BaseProvider: base.BaseProvider{
Config: config, Config: openaiConfig,
Channel: channel, Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, RequestErrorHandle), Requester: requester.NewHTTPRequester(*channel.Proxy, RequestErrorHandle),
}, },
@ -44,7 +45,7 @@ func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvide
BalanceAction: true, BalanceAction: true,
} }
if channel.Type == common.ChannelTypeOpenAI { if channel.Type == config.ChannelTypeOpenAI {
OpenAIProvider.SupportStreamOptions = true OpenAIProvider.SupportStreamOptions = true
} }
@ -109,7 +110,7 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)
requestURL = fmt.Sprintf("/openai%s?api-version=%s", requestURL, apiVersion) requestURL = fmt.Sprintf("/openai%s?api-version=%s", requestURL, apiVersion)
} }
} else if p.Channel.Type == common.ChannelTypeCustom && p.Channel.Other != "" { } else if p.Channel.Type == config.ChannelTypeCustom && p.Channel.Other != "" {
requestURL = strings.Replace(requestURL, "v1", p.Channel.Other, 1) requestURL = strings.Replace(requestURL, "v1", p.Channel.Other, 1)
} }

View File

@ -5,6 +5,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/types" "one-api/types"
"strings" "strings"
@ -17,7 +18,7 @@ type OpenAIStreamHandler struct {
} }
func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
@ -67,7 +68,7 @@ func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletio
// 避免误传导致报错 // 避免误传导致报错
request.StreamOptions = nil request.StreamOptions = nil
} }
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -5,13 +5,14 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/types" "one-api/types"
"strings" "strings"
) )
func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (openaiResponse *types.CompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (openaiResponse *types.CompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
@ -50,7 +51,7 @@ func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest
// 避免误传导致报错 // 避免误传导致报错
request.StreamOptions = nil request.StreamOptions = nil
} }
req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -2,12 +2,12 @@ package openai
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/types" "one-api/types"
) )
func (p *OpenAIProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeEmbeddings, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeEmbeddings, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -5,12 +5,13 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/types" "one-api/types"
) )
func (p *OpenAIProvider) CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getRequestImageBody(common.RelayModeEdits, request.Model, request) req, errWithCode := p.getRequestImageBody(config.RelayModeEdits, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -3,6 +3,7 @@ package openai
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/types" "one-api/types"
) )
@ -11,7 +12,7 @@ func (p *OpenAIProvider) CreateImageGenerations(request *types.ImageRequest) (*t
return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest) return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest)
} }
req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeImagesGenerations, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -2,12 +2,12 @@ package openai
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/types" "one-api/types"
) )
func (p *OpenAIProvider) CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getRequestImageBody(common.RelayModeImagesVariations, request.Model, request) req, errWithCode := p.getRequestImageBody(config.RelayModeImagesVariations, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -2,13 +2,13 @@ package openai
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/types" "one-api/types"
) )
func (p *OpenAIProvider) CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeModerations, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeModerations, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -2,13 +2,13 @@ package openai
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/types" "one-api/types"
) )
func (p *OpenAIProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeAudioSpeech, request.Model, request) req, errWithCode := p.GetRequestTextBody(config.RelayModeAudioSpeech, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/types" "one-api/types"
"regexp" "regexp"
@ -15,7 +16,7 @@ import (
) )
func (p *OpenAIProvider) CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranscription, request.Model, request) req, errWithCode := p.getRequestAudioBody(config.RelayModeAudioTranscription, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -4,11 +4,12 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/types" "one-api/types"
) )
func (p *OpenAIProvider) CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranslation, request.Model, request) req, errWithCode := p.getRequestAudioBody(config.RelayModeAudioTranslation, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils" "one-api/common/utils"
"one-api/types" "one-api/types"
@ -55,7 +56,7 @@ func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionR
} }
func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -1,7 +1,7 @@
package providers package providers
import ( import (
"one-api/common" "one-api/common/config"
"one-api/model" "one-api/model"
"one-api/providers/ali" "one-api/providers/ali"
"one-api/providers/azure" "one-api/providers/azure"
@ -44,32 +44,32 @@ var providerFactories = make(map[int]ProviderFactory)
// 在程序启动时,添加所有的供应商工厂 // 在程序启动时,添加所有的供应商工厂
func init() { func init() {
providerFactories[common.ChannelTypeOpenAI] = openai.OpenAIProviderFactory{} providerFactories[config.ChannelTypeOpenAI] = openai.OpenAIProviderFactory{}
providerFactories[common.ChannelTypeAzure] = azure.AzureProviderFactory{} providerFactories[config.ChannelTypeAzure] = azure.AzureProviderFactory{}
providerFactories[common.ChannelTypeAli] = ali.AliProviderFactory{} providerFactories[config.ChannelTypeAli] = ali.AliProviderFactory{}
providerFactories[common.ChannelTypeTencent] = tencent.TencentProviderFactory{} providerFactories[config.ChannelTypeTencent] = tencent.TencentProviderFactory{}
providerFactories[common.ChannelTypeBaidu] = baidu.BaiduProviderFactory{} providerFactories[config.ChannelTypeBaidu] = baidu.BaiduProviderFactory{}
providerFactories[common.ChannelTypeAnthropic] = claude.ClaudeProviderFactory{} providerFactories[config.ChannelTypeAnthropic] = claude.ClaudeProviderFactory{}
providerFactories[common.ChannelTypePaLM] = palm.PalmProviderFactory{} providerFactories[config.ChannelTypePaLM] = palm.PalmProviderFactory{}
providerFactories[common.ChannelTypeZhipu] = zhipu.ZhipuProviderFactory{} providerFactories[config.ChannelTypeZhipu] = zhipu.ZhipuProviderFactory{}
providerFactories[common.ChannelTypeXunfei] = xunfei.XunfeiProviderFactory{} providerFactories[config.ChannelTypeXunfei] = xunfei.XunfeiProviderFactory{}
providerFactories[common.ChannelTypeAzureSpeech] = azurespeech.AzureSpeechProviderFactory{} providerFactories[config.ChannelTypeAzureSpeech] = azurespeech.AzureSpeechProviderFactory{}
providerFactories[common.ChannelTypeGemini] = gemini.GeminiProviderFactory{} providerFactories[config.ChannelTypeGemini] = gemini.GeminiProviderFactory{}
providerFactories[common.ChannelTypeBaichuan] = baichuan.BaichuanProviderFactory{} providerFactories[config.ChannelTypeBaichuan] = baichuan.BaichuanProviderFactory{}
providerFactories[common.ChannelTypeMiniMax] = minimax.MiniMaxProviderFactory{} providerFactories[config.ChannelTypeMiniMax] = minimax.MiniMaxProviderFactory{}
providerFactories[common.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{} providerFactories[config.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{}
providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{} providerFactories[config.ChannelTypeMistral] = mistral.MistralProviderFactory{}
providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{} providerFactories[config.ChannelTypeGroq] = groq.GroqProviderFactory{}
providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{} providerFactories[config.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{} providerFactories[config.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{} providerFactories[config.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{} providerFactories[config.ChannelTypeCohere] = cohere.CohereProviderFactory{}
providerFactories[common.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{} providerFactories[config.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{}
providerFactories[common.ChannelTypeCoze] = coze.CozeProviderFactory{} providerFactories[config.ChannelTypeCoze] = coze.CozeProviderFactory{}
providerFactories[common.ChannelTypeOllama] = ollama.OllamaProviderFactory{} providerFactories[config.ChannelTypeOllama] = ollama.OllamaProviderFactory{}
providerFactories[common.ChannelTypeMoonshot] = moonshot.MoonshotProviderFactory{} providerFactories[config.ChannelTypeMoonshot] = moonshot.MoonshotProviderFactory{}
providerFactories[common.ChannelTypeLingyi] = lingyi.LingyiProviderFactory{} providerFactories[config.ChannelTypeLingyi] = lingyi.LingyiProviderFactory{}
providerFactories[common.ChannelTypeHunyuan] = hunyuan.HunyuanProviderFactory{} providerFactories[config.ChannelTypeHunyuan] = hunyuan.HunyuanProviderFactory{}
} }
@ -79,7 +79,7 @@ func GetProvider(channel *model.Channel, c *gin.Context) base.ProviderInterface
var provider base.ProviderInterface var provider base.ProviderInterface
if !ok { if !ok {
// 处理未找到的供应商工厂 // 处理未找到的供应商工厂
baseURL := common.ChannelBaseURLs[channel.Type] baseURL := config.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" { if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL() baseURL = channel.GetBaseURL()
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/base64" "encoding/base64"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/storage" "one-api/common/storage"
"one-api/common/utils" "one-api/common/utils"
"one-api/types" "one-api/types"
@ -20,7 +21,7 @@ func convertModelName(modelName string) string {
} }
func (p *StabilityAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) { func (p *StabilityAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeImagesGenerations) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeImagesGenerations)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils" "one-api/common/utils"
"one-api/types" "one-api/types"
@ -55,7 +56,7 @@ func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompleti
} }
func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils" "one-api/common/utils"
"one-api/types" "one-api/types"
@ -59,7 +60,7 @@ func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletio
} }
func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) { func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -5,6 +5,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/requester" "one-api/common/requester"
"one-api/types" "one-api/types"
"strings" "strings"
@ -54,7 +55,7 @@ func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletion
} }
func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -3,11 +3,12 @@ package zhipu
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/types" "one-api/types"
) )
func (p *ZhipuProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { func (p *ZhipuProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -3,12 +3,13 @@ package zhipu
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/types" "one-api/types"
"time" "time"
) )
func (p *ZhipuProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) { func (p *ZhipuProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeImagesGenerations) url, errWithCode := p.GetSupportedAPIUri(config.RelayModeImagesGenerations)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }

View File

@ -8,6 +8,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils" "one-api/common/utils"
@ -93,7 +94,7 @@ func fetchChannelById(channelId int) (*model.Channel, error) {
if err != nil { if err != nil {
return nil, errors.New("无效的渠道 Id") return nil, errors.New("无效的渠道 Id")
} }
if channel.Status != common.ChannelStatusEnabled { if channel.Status != config.ChannelStatusEnabled {
return nil, errors.New("该渠道已被禁用") return nil, errors.New("该渠道已被禁用")
} }

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/model" "one-api/model"
"one-api/relay/relay_util" "one-api/relay/relay_util"
@ -50,7 +51,7 @@ func Relay(c *gin.Context) {
channel := relay.getProvider().GetChannel() channel := relay.getProvider().GetChannel()
go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr) go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr)
retryTimes := common.RetryTimes retryTimes := config.RetryTimes
if done || !shouldRetry(c, apiErr.StatusCode) { if done || !shouldRetry(c, apiErr.StatusCode) {
logger.LogError(c.Request.Context(), fmt.Sprintf("relay error happen, status code is %d, won't retry in this case", apiErr.StatusCode)) logger.LogError(c.Request.Context(), fmt.Sprintf("relay error happen, status code is %d, won't retry in this case", apiErr.StatusCode))
retryTimes = 0 retryTimes = 0

View File

@ -10,6 +10,7 @@ import (
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/controller" "one-api/controller"
"one-api/model" "one-api/model"
provider "one-api/providers/midjourney" provider "one-api/providers/midjourney"
@ -112,7 +113,7 @@ func coverMidjourneyTaskDto(originTask *model.Midjourney) (midjourneyTask provid
midjourneyTask.FinishTime = originTask.FinishTime midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = "" midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" { if originTask.ImageUrl != "" {
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId midjourneyTask.ImageUrl = config.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" { if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
} }

View File

@ -3,6 +3,7 @@ package relay
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/model" "one-api/model"
"one-api/providers/azure" "one-api/providers/azure"
"one-api/providers/openai" "one-api/providers/openai"
@ -20,7 +21,7 @@ func RelayOnly(c *gin.Context) {
} }
channel := provider.GetChannel() channel := provider.GetChannel()
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeAzure { if channel.Type != config.ChannelTypeOpenAI && channel.Type != config.ChannelTypeAzure {
common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type azureopenai or openai") common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type azureopenai or openai")
return return
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/utils" "one-api/common/utils"
"one-api/model" "one-api/model"
@ -57,7 +58,7 @@ func NewChatCacheProps(c *gin.Context, allow bool) *ChatCacheProps {
return props return props
} }
if common.ChatCacheEnabled && c.GetBool("chat_cache") { if config.ChatCacheEnabled && c.GetBool("chat_cache") {
props.Cache = true props.Cache = true
} }
@ -113,7 +114,7 @@ func (p *ChatCacheProps) StoreCache(channelId, promptTokens, completionTokens in
p.CompletionTokens = completionTokens p.CompletionTokens = completionTokens
p.ModelName = modelName p.ModelName = modelName
return p.Driver.Set(p.getHash(), p, int64(common.ChatCacheExpireMinute)) return p.Driver.Set(p.getHash(), p, int64(config.ChatCacheExpireMinute))
} }
func (p *ChatCacheProps) GetCache() *ChatCacheProps { func (p *ChatCacheProps) GetCache() *ChatCacheProps {
@ -125,7 +126,7 @@ func (p *ChatCacheProps) GetCache() *ChatCacheProps {
} }
func (p *ChatCacheProps) needCache() bool { func (p *ChatCacheProps) needCache() bool {
return common.ChatCacheEnabled && p.Cache return config.ChatCacheEnabled && p.Cache
} }
func (p *ChatCacheProps) getHash() string { func (p *ChatCacheProps) getHash() string {

View File

@ -3,7 +3,7 @@ package relay_util
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/common/utils" "one-api/common/utils"
"one-api/model" "one-api/model"
@ -107,7 +107,7 @@ func (p *Pricing) GetPrice(modelName string) *model.Price {
return &model.Price{ return &model.Price{
Type: model.TokensPriceType, Type: model.TokensPriceType,
ChannelType: common.ChannelTypeUnknown, ChannelType: config.ChannelTypeUnknown,
Input: model.DefaultPrice, Input: model.DefaultPrice,
Output: model.DefaultPrice, Output: model.DefaultPrice,
} }

View File

@ -7,6 +7,7 @@ import (
"math" "math"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/model" "one-api/model"
"one-api/types" "one-api/types"
@ -45,7 +46,7 @@ func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *type
if quota.price.Type == model.TimesPriceType { if quota.price.Type == model.TimesPriceType {
quota.preConsumedQuota = int(1000 * quota.inputRatio) quota.preConsumedQuota = int(1000 * quota.inputRatio)
} else { } else {
quota.preConsumedQuota = int(float64(quota.promptTokens+common.PreConsumedQuota) * quota.inputRatio) quota.preConsumedQuota = int(float64(quota.promptTokens+config.PreConsumedQuota) * quota.inputRatio)
} }
errWithCode := quota.preQuotaConsumption() errWithCode := quota.preQuotaConsumption()

View File

@ -1,35 +1,37 @@
package relay_util package relay_util
import "one-api/common" import (
"one-api/common/config"
)
var UnknownOwnedBy = "未知" var UnknownOwnedBy = "未知"
var ModelOwnedBy map[int]string var ModelOwnedBy map[int]string
func init() { func init() {
ModelOwnedBy = map[int]string{ ModelOwnedBy = map[int]string{
common.ChannelTypeOpenAI: "OpenAI", config.ChannelTypeOpenAI: "OpenAI",
common.ChannelTypeAnthropic: "Anthropic", config.ChannelTypeAnthropic: "Anthropic",
common.ChannelTypeBaidu: "Baidu", config.ChannelTypeBaidu: "Baidu",
common.ChannelTypePaLM: "Google PaLM", config.ChannelTypePaLM: "Google PaLM",
common.ChannelTypeGemini: "Google Gemini", config.ChannelTypeGemini: "Google Gemini",
common.ChannelTypeZhipu: "Zhipu", config.ChannelTypeZhipu: "Zhipu",
common.ChannelTypeAli: "Ali", config.ChannelTypeAli: "Ali",
common.ChannelTypeXunfei: "Xunfei", config.ChannelTypeXunfei: "Xunfei",
common.ChannelType360: "360", config.ChannelType360: "360",
common.ChannelTypeTencent: "Tencent", config.ChannelTypeTencent: "Tencent",
common.ChannelTypeBaichuan: "Baichuan", config.ChannelTypeBaichuan: "Baichuan",
common.ChannelTypeMiniMax: "MiniMax", config.ChannelTypeMiniMax: "MiniMax",
common.ChannelTypeDeepseek: "Deepseek", config.ChannelTypeDeepseek: "Deepseek",
common.ChannelTypeMoonshot: "Moonshot", config.ChannelTypeMoonshot: "Moonshot",
common.ChannelTypeMistral: "Mistral", config.ChannelTypeMistral: "Mistral",
common.ChannelTypeGroq: "Groq", config.ChannelTypeGroq: "Groq",
common.ChannelTypeLingyi: "Lingyiwanwu", config.ChannelTypeLingyi: "Lingyiwanwu",
common.ChannelTypeMidjourney: "Midjourney", config.ChannelTypeMidjourney: "Midjourney",
common.ChannelTypeCloudflareAI: "Cloudflare AI", config.ChannelTypeCloudflareAI: "Cloudflare AI",
common.ChannelTypeCohere: "Cohere", config.ChannelTypeCohere: "Cohere",
common.ChannelTypeStabilityAI: "Stability AI", config.ChannelTypeStabilityAI: "Stability AI",
common.ChannelTypeCoze: "Coze", config.ChannelTypeCoze: "Coze",
common.ChannelTypeOllama: "Ollama", config.ChannelTypeOllama: "Ollama",
common.ChannelTypeHunyuan: "Hunyuan", config.ChannelTypeHunyuan: "Hunyuan",
} }
} }

View File

@ -4,7 +4,7 @@ import (
"embed" "embed"
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"strings" "strings"
@ -17,7 +17,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
SetDashboardRouter(router) SetDashboardRouter(router)
SetRelayRouter(router) SetRelayRouter(router)
frontendBaseUrl := viper.GetString("frontend_base_url") frontendBaseUrl := viper.GetString("frontend_base_url")
if common.IsMasterNode && frontendBaseUrl != "" { if config.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = "" frontendBaseUrl = ""
logger.SysLog("FRONTEND_BASE_URL is ignored on master node") logger.SysLog("FRONTEND_BASE_URL is ignored on master node")
} }