🔖 chore: migration logger package
This commit is contained in:
parent
79524108a3
commit
ce12558ad6
@ -2,7 +2,7 @@ package cli
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
"one-api/relay/relay_util"
|
"one-api/relay/relay_util"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
@ -12,7 +12,7 @@ func ExportPrices() {
|
|||||||
prices := relay_util.GetPricesList("default")
|
prices := relay_util.GetPricesList("default")
|
||||||
|
|
||||||
if len(prices) == 0 {
|
if len(prices) == 0 {
|
||||||
common.SysError("No prices found")
|
logger.SysError("No prices found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,22 +27,22 @@ func ExportPrices() {
|
|||||||
// 导出到当前目录下的 prices.json 文件
|
// 导出到当前目录下的 prices.json 文件
|
||||||
file, err := os.Create("prices.json")
|
file, err := os.Create("prices.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Failed to create file: " + err.Error())
|
logger.SysError("Failed to create file: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
jsonData, err := json.MarshalIndent(prices, "", " ")
|
jsonData, err := json.MarshalIndent(prices, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Failed to encode prices: " + err.Error())
|
logger.SysError("Failed to encode prices: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = file.Write(jsonData)
|
_, err = file.Write(jsonData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Failed to write to file: " + err.Error())
|
logger.SysError("Failed to write to file: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
common.SysLog("Prices exported to prices.json")
|
logger.SysLog("Prices exported to prices.json")
|
||||||
}
|
}
|
||||||
|
11
common/common.go
Normal file
11
common/common.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func LogQuota(quota int) string {
|
||||||
|
if DisplayInCurrencyEnabled {
|
||||||
|
return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("%d 点额度", quota)
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"one-api/common/logger"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -18,7 +19,7 @@ func InitConf() {
|
|||||||
setEnv()
|
setEnv()
|
||||||
|
|
||||||
if viper.GetBool("debug") {
|
if viper.GetBool("debug") {
|
||||||
common.SysLog("running in debug mode")
|
logger.SysLog("running in debug mode")
|
||||||
}
|
}
|
||||||
|
|
||||||
common.IsMasterNode = viper.GetString("node_type") != "slave"
|
common.IsMasterNode = viper.GetString("node_type") != "slave"
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -65,7 +66,7 @@ func AbortWithMessage(c *gin.Context, statusCode int, message string) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
c.Abort()
|
c.Abort()
|
||||||
LogError(c.Request.Context(), message)
|
logger.LogError(c.Request.Context(), message)
|
||||||
}
|
}
|
||||||
|
|
||||||
func APIRespondWithError(c *gin.Context, status int, err error) {
|
func APIRespondWithError(c *gin.Context, status int, err error) {
|
||||||
|
@ -2,6 +2,7 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"one-api/common/logger"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -9,7 +10,7 @@ func SafeGoroutine(f func()) {
|
|||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
|
logger.SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
f()
|
f()
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import "encoding/json"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"one-api/common/logger"
|
||||||
|
)
|
||||||
|
|
||||||
var GroupRatio = map[string]float64{
|
var GroupRatio = map[string]float64{
|
||||||
"default": 1,
|
"default": 1,
|
||||||
@ -11,7 +14,7 @@ var GroupRatio = map[string]float64{
|
|||||||
func GroupRatio2JSONString() string {
|
func GroupRatio2JSONString() string {
|
||||||
jsonBytes, err := json.Marshal(GroupRatio)
|
jsonBytes, err := json.Marshal(GroupRatio)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SysError("error marshalling model ratio: " + err.Error())
|
logger.SysError("error marshalling model ratio: " + err.Error())
|
||||||
}
|
}
|
||||||
return string(jsonBytes)
|
return string(jsonBytes)
|
||||||
}
|
}
|
||||||
@ -24,7 +27,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error {
|
|||||||
func GetGroupRatio(name string) float64 {
|
func GetGroupRatio(name string) float64 {
|
||||||
ratio, ok := GroupRatio[name]
|
ratio, ok := GroupRatio[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
SysError("group ratio not found: " + name)
|
logger.SysError("group ratio not found: " + name)
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
return ratio
|
return ratio
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package common
|
package logger
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -129,11 +129,3 @@ func FatalLog(v ...any) {
|
|||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LogQuota(quota int) string {
|
|
||||||
if DisplayInCurrencyEnabled {
|
|
||||||
return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf("%d 点额度", quota)
|
|
||||||
}
|
|
||||||
}
|
|
@ -2,7 +2,7 @@ package notify
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
"one-api/common/notify/channel"
|
"one-api/common/notify/channel"
|
||||||
|
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
@ -23,13 +23,13 @@ func InitNotifier() {
|
|||||||
|
|
||||||
func InitEmailNotifier() {
|
func InitEmailNotifier() {
|
||||||
if viper.GetBool("notify.email.disable") {
|
if viper.GetBool("notify.email.disable") {
|
||||||
common.SysLog("email notifier disabled")
|
logger.SysLog("email notifier disabled")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
smtp_to := viper.GetString("notify.email.smtp_to")
|
smtp_to := viper.GetString("notify.email.smtp_to")
|
||||||
emailNotifier := channel.NewEmail(smtp_to)
|
emailNotifier := channel.NewEmail(smtp_to)
|
||||||
AddNotifiers(emailNotifier)
|
AddNotifiers(emailNotifier)
|
||||||
common.SysLog("email notifier enable")
|
logger.SysLog("email notifier enable")
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitDingTalkNotifier() {
|
func InitDingTalkNotifier() {
|
||||||
@ -49,7 +49,7 @@ func InitDingTalkNotifier() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
AddNotifiers(dingTalkNotifier)
|
AddNotifiers(dingTalkNotifier)
|
||||||
common.SysLog("dingtalk notifier enable")
|
logger.SysLog("dingtalk notifier enable")
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitLarkNotifier() {
|
func InitLarkNotifier() {
|
||||||
@ -69,7 +69,7 @@ func InitLarkNotifier() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
AddNotifiers(larkNotifier)
|
AddNotifiers(larkNotifier)
|
||||||
common.SysLog("lark notifier enable")
|
logger.SysLog("lark notifier enable")
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitPushdeerNotifier() {
|
func InitPushdeerNotifier() {
|
||||||
@ -81,7 +81,7 @@ func InitPushdeerNotifier() {
|
|||||||
pushdeerNotifier := channel.NewPushdeer(pushkey, viper.GetString("notify.pushdeer.url"))
|
pushdeerNotifier := channel.NewPushdeer(pushkey, viper.GetString("notify.pushdeer.url"))
|
||||||
|
|
||||||
AddNotifiers(pushdeerNotifier)
|
AddNotifiers(pushdeerNotifier)
|
||||||
common.SysLog("pushdeer notifier enable")
|
logger.SysLog("pushdeer notifier enable")
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitTelegramNotifier() {
|
func InitTelegramNotifier() {
|
||||||
@ -95,5 +95,5 @@ func InitTelegramNotifier() {
|
|||||||
telegramNotifier := channel.NewTelegram(bot_token, chat_id, httpProxy)
|
telegramNotifier := channel.NewTelegram(bot_token, chat_id, httpProxy)
|
||||||
|
|
||||||
AddNotifiers(telegramNotifier)
|
AddNotifiers(telegramNotifier)
|
||||||
common.SysLog("telegram notifier enable")
|
logger.SysLog("telegram notifier enable")
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ package notify
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (n *Notify) Send(ctx context.Context, title, message string) {
|
func (n *Notify) Send(ctx context.Context, title, message string) {
|
||||||
@ -17,14 +17,14 @@ func (n *Notify) Send(ctx context.Context, title, message string) {
|
|||||||
}
|
}
|
||||||
err := channel.Send(ctx, title, message)
|
err := channel.Send(ctx, title, message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("%s err: %s", channelName, err.Error()))
|
logger.LogError(ctx, fmt.Sprintf("%s err: %s", channelName, err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Send(title, message string) {
|
func Send(title, message string) {
|
||||||
//lint:ignore SA1029 reason: 需要使用该类型作为错误处理
|
//lint:ignore SA1029 reason: 需要使用该类型作为错误处理
|
||||||
ctx := context.WithValue(context.Background(), common.RequestIdKey, "NotifyTask")
|
ctx := context.WithValue(context.Background(), logger.RequestIdKey, "NotifyTask")
|
||||||
|
|
||||||
notifyChannels.Send(ctx, title, message)
|
notifyChannels.Send(ctx, title, message)
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"one-api/common/logger"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
@ -16,17 +17,17 @@ func InitRedisClient() (err error) {
|
|||||||
redisConn := viper.GetString("redis_conn_string")
|
redisConn := viper.GetString("redis_conn_string")
|
||||||
|
|
||||||
if redisConn == "" {
|
if redisConn == "" {
|
||||||
SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
|
logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if viper.GetInt("sync_frequency") == 0 {
|
if viper.GetInt("sync_frequency") == 0 {
|
||||||
SysLog("SYNC_FREQUENCY not set, Redis is disabled")
|
logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
SysLog("Redis is enabled")
|
logger.SysLog("Redis is enabled")
|
||||||
opt, err := redis.ParseURL(redisConn)
|
opt, err := redis.ParseURL(redisConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
FatalLog("failed to parse Redis connection string: " + err.Error())
|
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
RDB = redis.NewClient(opt)
|
RDB = redis.NewClient(opt)
|
||||||
@ -36,7 +37,7 @@ func InitRedisClient() (err error) {
|
|||||||
|
|
||||||
_, err = RDB.Ping(ctx).Result()
|
_, err = RDB.Ping(ctx).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
FatalLog("Redis ping test failed: " + err.Error())
|
logger.FatalLog("Redis ping test failed: " + err.Error())
|
||||||
} else {
|
} else {
|
||||||
RedisEnabled = true
|
RedisEnabled = true
|
||||||
// for compatibility with old versions
|
// for compatibility with old versions
|
||||||
@ -49,7 +50,7 @@ func InitRedisClient() (err error) {
|
|||||||
func ParseRedisOption() *redis.Options {
|
func ParseRedisOption() *redis.Options {
|
||||||
opt, err := redis.ParseURL(viper.GetString("redis_conn_string"))
|
opt, err := redis.ParseURL(viper.GetString("redis_conn_string"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
FatalLog("failed to parse Redis connection string: " + err.Error())
|
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
|
||||||
}
|
}
|
||||||
return opt
|
return opt
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
"one-api/common/utils"
|
"one-api/common/utils"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ func GetWSClient(proxyAddr string) *websocket.Dialer {
|
|||||||
if proxyAddr != "" {
|
if proxyAddr != "" {
|
||||||
err := setWSProxy(dialer, proxyAddr)
|
err := setWSProxy(dialer, proxyAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(err.Error())
|
logger.SysError(err.Error())
|
||||||
return dialer
|
return dialer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ package storage
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Storage) Upload(ctx context.Context, data []byte, fileName string) string {
|
func (s *Storage) Upload(ctx context.Context, data []byte, fileName string) string {
|
||||||
@ -17,7 +17,7 @@ func (s *Storage) Upload(ctx context.Context, data []byte, fileName string) stri
|
|||||||
}
|
}
|
||||||
url, err := drive.Upload(data, fileName)
|
url, err := drive.Upload(data, fileName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("%s err: %s", driveName, err.Error()))
|
logger.LogError(ctx, fmt.Sprintf("%s err: %s", driveName, err.Error()))
|
||||||
} else {
|
} else {
|
||||||
return url
|
return url
|
||||||
}
|
}
|
||||||
@ -28,7 +28,7 @@ func (s *Storage) Upload(ctx context.Context, data []byte, fileName string) stri
|
|||||||
|
|
||||||
func Upload(data []byte, fileName string) string {
|
func Upload(data []byte, fileName string) string {
|
||||||
//lint:ignore SA1029 reason: 需要使用该类型作为错误处理
|
//lint:ignore SA1029 reason: 需要使用该类型作为错误处理
|
||||||
ctx := context.WithValue(context.Background(), common.RequestIdKey, "Upload")
|
ctx := context.WithValue(context.Background(), logger.RequestIdKey, "Upload")
|
||||||
|
|
||||||
return storageDrives.Upload(ctx, data, fileName)
|
return storageDrives.Upload(ctx, data, fileName)
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -29,20 +30,20 @@ var TGEnabled = false
|
|||||||
|
|
||||||
func InitTelegramBot() {
|
func InitTelegramBot() {
|
||||||
if TGEnabled {
|
if TGEnabled {
|
||||||
common.SysLog("Telegram bot has been started")
|
logger.SysLog("Telegram bot has been started")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
botKey := viper.GetString("tg.bot_api_key")
|
botKey := viper.GetString("tg.bot_api_key")
|
||||||
if botKey == "" {
|
if botKey == "" {
|
||||||
common.SysLog("Telegram bot is not enabled")
|
logger.SysLog("Telegram bot is not enabled")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
TGBot, err = gotgbot.NewBot(botKey, getBotOpts())
|
TGBot, err = gotgbot.NewBot(botKey, getBotOpts())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("failed to create new telegram bot: " + err.Error())
|
logger.SysLog("failed to create new telegram bot: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,7 +57,7 @@ func StartTelegramBot() {
|
|||||||
botWebhook := viper.GetString("tg.webhook_secret")
|
botWebhook := viper.GetString("tg.webhook_secret")
|
||||||
if botWebhook != "" {
|
if botWebhook != "" {
|
||||||
if common.ServerAddress == "" {
|
if common.ServerAddress == "" {
|
||||||
common.SysLog("Telegram bot is not enabled: Server address is not set")
|
logger.SysLog("Telegram bot is not enabled: Server address is not set")
|
||||||
StopTelegramBot()
|
StopTelegramBot()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -70,7 +71,7 @@ func StartTelegramBot() {
|
|||||||
|
|
||||||
err := TGupdater.AddWebhook(TGBot, urlPath, webHookOpts)
|
err := TGupdater.AddWebhook(TGBot, urlPath, webHookOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("Telegram bot failed to add webhook:" + err.Error())
|
logger.SysLog("Telegram bot failed to add webhook:" + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,7 +81,7 @@ func StartTelegramBot() {
|
|||||||
SecretToken: TGWebHookSecret,
|
SecretToken: TGWebHookSecret,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("Telegram bot failed to set webhook:" + err.Error())
|
logger.SysLog("Telegram bot failed to set webhook:" + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -96,13 +97,13 @@ func StartTelegramBot() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("Telegram bot failed to start polling:" + err.Error())
|
logger.SysLog("Telegram bot failed to start polling:" + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Idle, to keep updates coming in, and avoid bot stopping.
|
// Idle, to keep updates coming in, and avoid bot stopping.
|
||||||
go TGupdater.Idle()
|
go TGupdater.Idle()
|
||||||
common.SysLog(fmt.Sprintf("Telegram bot %s has been started...:", TGBot.User.Username))
|
logger.SysLog(fmt.Sprintf("Telegram bot %s has been started...:", TGBot.User.Username))
|
||||||
TGEnabled = true
|
TGEnabled = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,7 +136,7 @@ func setDispatcher() *ext.Dispatcher {
|
|||||||
dispatcher := ext.NewDispatcher(&ext.DispatcherOpts{
|
dispatcher := ext.NewDispatcher(&ext.DispatcherOpts{
|
||||||
// If an error is returned by a handler, log it and continue going.
|
// If an error is returned by a handler, log it and continue going.
|
||||||
Error: func(b *gotgbot.Bot, ctx *ext.Context, err error) ext.DispatcherAction {
|
Error: func(b *gotgbot.Bot, ctx *ext.Context, err error) ext.DispatcherAction {
|
||||||
common.SysLog("telegram an error occurred while handling update: " + err.Error())
|
logger.SysLog("telegram an error occurred while handling update: " + err.Error())
|
||||||
return ext.DispatcherActionNoop
|
return ext.DispatcherActionNoop
|
||||||
},
|
},
|
||||||
MaxRoutines: ext.DefaultMaxRoutines,
|
MaxRoutines: ext.DefaultMaxRoutines,
|
||||||
@ -173,7 +174,7 @@ func getMenu() []gotgbot.BotCommand {
|
|||||||
customMenu, err := model.GetTelegramMenus()
|
customMenu, err := model.GetTelegramMenus()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("Failed to get custom menu, error: " + err.Error())
|
logger.SysLog("Failed to get custom menu, error: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(customMenu) > 0 {
|
if len(customMenu) > 0 {
|
||||||
@ -234,7 +235,7 @@ func getHttpClient() (httpClient *http.Client) {
|
|||||||
|
|
||||||
proxyURL, err := url.Parse(proxyAddr)
|
proxyURL, err := url.Parse(proxyAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("failed to parse TG proxy URL: " + err.Error())
|
logger.SysLog("failed to parse TG proxy URL: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch proxyURL.Scheme {
|
switch proxyURL.Scheme {
|
||||||
@ -247,7 +248,7 @@ func getHttpClient() (httpClient *http.Client) {
|
|||||||
case "socks5":
|
case "socks5":
|
||||||
dialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
dialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("failed to create TG SOCKS5 dialer: " + err.Error())
|
logger.SysLog("failed to create TG SOCKS5 dialer: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
httpClient = &http.Client{
|
httpClient = &http.Client{
|
||||||
@ -258,7 +259,7 @@ func getHttpClient() (httpClient *http.Client) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
common.SysLog("unknown TG proxy type: " + proxyAddr)
|
logger.SysLog("unknown TG proxy type: " + proxyAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"one-api/common/logger"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"one-api/common/image"
|
"one-api/common/image"
|
||||||
@ -21,27 +22,27 @@ var gpt4oTokenEncoder *tiktoken.Tiktoken
|
|||||||
func InitTokenEncoders() {
|
func InitTokenEncoders() {
|
||||||
if viper.GetBool("disable_token_encoders") {
|
if viper.GetBool("disable_token_encoders") {
|
||||||
DISABLE_TOKEN_ENCODERS = true
|
DISABLE_TOKEN_ENCODERS = true
|
||||||
SysLog("token encoders disabled")
|
logger.SysLog("token encoders disabled")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
SysLog("initializing token encoders")
|
logger.SysLog("initializing token encoders")
|
||||||
var err error
|
var err error
|
||||||
gpt35TokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
|
gpt35TokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
||||||
}
|
}
|
||||||
|
|
||||||
gpt4TokenEncoder, err = tiktoken.EncodingForModel("gpt-4")
|
gpt4TokenEncoder, err = tiktoken.EncodingForModel("gpt-4")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||||
}
|
}
|
||||||
|
|
||||||
gpt4oTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
|
gpt4oTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
|
logger.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
|
||||||
}
|
}
|
||||||
|
|
||||||
SysLog("token encoders initialized")
|
logger.SysLog("token encoders initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||||
@ -64,7 +65,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
|||||||
var err error
|
var err error
|
||||||
tokenEncoder, err = tiktoken.EncodingForModel(model)
|
tokenEncoder, err = tiktoken.EncodingForModel(model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||||
tokenEncoder = gpt35TokenEncoder
|
tokenEncoder = gpt35TokenEncoder
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -119,7 +120,7 @@ func CountTokenMessages(messages []types.ChatCompletionMessage, model string) in
|
|||||||
imageTokens, err := countImageTokens(url, detail)
|
imageTokens, err := countImageTokens(url, detail)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//Due to the excessive length of the error information, only extract and record the most critical part.
|
//Due to the excessive length of the error information, only extract and record the most critical part.
|
||||||
SysError("error counting image tokens: " + err.Error())
|
logger.SysError("error counting image tokens: " + err.Error())
|
||||||
} else {
|
} else {
|
||||||
tokenNum += imageTokens
|
tokenNum += imageTokens
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/common/notify"
|
"one-api/common/notify"
|
||||||
"one-api/common/utils"
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
@ -70,7 +71,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
|
|
||||||
// 转换为JSON字符串
|
// 转换为JSON字符串
|
||||||
jsonBytes, _ := json.Marshal(response)
|
jsonBytes, _ := json.Marshal(response)
|
||||||
common.SysLog(fmt.Sprintf("测试渠道 %s : %s 返回内容为:%s", channel.Name, request.Model, string(jsonBytes)))
|
logger.SysLog(fmt.Sprintf("测试渠道 %s : %s 返回内容为:%s", channel.Name, request.Model, string(jsonBytes)))
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -233,8 +234,8 @@ func AutomaticallyTestChannels(frequency int) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||||
common.SysLog("testing all channels")
|
logger.SysLog("testing all channels")
|
||||||
_ = testAllChannels(false)
|
_ = testAllChannels(false)
|
||||||
common.SysLog("channel test finished")
|
logger.SysLog("channel test finished")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/common/utils"
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -48,7 +49,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
|||||||
}
|
}
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog(err.Error())
|
logger.SysLog(err.Error())
|
||||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
@ -64,7 +65,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
|||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||||
res2, err := client.Do(req)
|
res2, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog(err.Error())
|
logger.SysLog(err.Error())
|
||||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||||
}
|
}
|
||||||
defer res2.Body.Close()
|
defer res2.Body.Close()
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@ -58,7 +59,7 @@ func getLarkAppAccessToken() (string, error) {
|
|||||||
}
|
}
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog(err.Error())
|
logger.SysLog(err.Error())
|
||||||
return "", errors.New("无法连接至飞书服务器,请稍后重试!")
|
return "", errors.New("无法连接至飞书服务器,请稍后重试!")
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
@ -100,7 +101,7 @@ func getLarkUserAccessToken(code string) (string, error) {
|
|||||||
}
|
}
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog(err.Error())
|
logger.SysLog(err.Error())
|
||||||
return "", errors.New("无法连接至飞书服务器,请稍后重试!")
|
return "", errors.New("无法连接至飞书服务器,请稍后重试!")
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
@ -135,7 +136,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
|
|||||||
}
|
}
|
||||||
res2, err := client.Do(req)
|
res2, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog(err.Error())
|
logger.SysLog(err.Error())
|
||||||
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
|
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
|
||||||
}
|
}
|
||||||
var larkUser LarkUser
|
var larkUser LarkUser
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
provider "one-api/providers/midjourney"
|
provider "one-api/providers/midjourney"
|
||||||
@ -45,9 +46,9 @@ func ActivateUpdateMidjourneyTaskBulk() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateMidjourneyTaskBulk() {
|
func UpdateMidjourneyTaskBulk() {
|
||||||
ctx := context.WithValue(context.Background(), common.RequestIdKey, "MidjourneyTask")
|
ctx := context.WithValue(context.Background(), logger.RequestIdKey, "MidjourneyTask")
|
||||||
for {
|
for {
|
||||||
common.LogInfo(ctx, "running")
|
logger.LogInfo(ctx, "running")
|
||||||
|
|
||||||
tasks := model.GetAllUnFinishTasks()
|
tasks := model.GetAllUnFinishTasks()
|
||||||
|
|
||||||
@ -56,11 +57,11 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
for len(activeMidjourneyTask) > 0 {
|
for len(activeMidjourneyTask) > 0 {
|
||||||
<-activeMidjourneyTask
|
<-activeMidjourneyTask
|
||||||
}
|
}
|
||||||
common.LogInfo(ctx, "no tasks, waiting...")
|
logger.LogInfo(ctx, "no tasks, waiting...")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||||
taskChannelM := make(map[int][]string)
|
taskChannelM := make(map[int][]string)
|
||||||
taskM := make(map[string]*model.Midjourney)
|
taskM := make(map[string]*model.Midjourney)
|
||||||
nullTaskIds := make([]int, 0)
|
nullTaskIds := make([]int, 0)
|
||||||
@ -79,9 +80,9 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
|
||||||
} else {
|
} else {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
|
logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(taskChannelM) == 0 {
|
if len(taskChannelM) == 0 {
|
||||||
@ -89,7 +90,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for channelId, taskIds := range taskChannelM {
|
for channelId, taskIds := range taskChannelM {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||||
if len(taskIds) == 0 {
|
if len(taskIds) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -100,7 +101,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
"status": "FAILURE",
|
"status": "FAILURE",
|
||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
|
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
|
||||||
@ -110,7 +111,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
})
|
})
|
||||||
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 设置超时时间
|
// 设置超时时间
|
||||||
@ -122,22 +123,22 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||||
resp, err := requester.HTTPClient.Do(req)
|
resp, err := requester.HTTPClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var responseItems []provider.MidjourneyDto
|
var responseItems []provider.MidjourneyDto
|
||||||
err = json.Unmarshal(responseBody, &responseItems)
|
err = json.Unmarshal(responseBody, &responseItems)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
@ -176,17 +177,17 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
||||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
err = model.CacheUpdateUserQuota(task.UserId)
|
err = model.CacheUpdateUserQuota(task.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
logger.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||||
} else {
|
} else {
|
||||||
quota := task.Quota
|
quota := task.Quota
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
err = model.IncreaseUserQuota(task.UserId, quota)
|
err = model.IncreaseUserQuota(task.UserId, quota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||||
}
|
}
|
||||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
@ -195,7 +196,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
}
|
}
|
||||||
err = task.Update()
|
err = task.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package cron
|
package cron
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -11,7 +11,7 @@ import (
|
|||||||
func InitCron() {
|
func InitCron() {
|
||||||
scheduler, err := gocron.NewScheduler()
|
scheduler, err := gocron.NewScheduler()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("Cron scheduler error: " + err.Error())
|
logger.SysLog("Cron scheduler error: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,12 +24,12 @@ func InitCron() {
|
|||||||
)),
|
)),
|
||||||
gocron.NewTask(func() {
|
gocron.NewTask(func() {
|
||||||
model.RemoveChatCache(time.Now().Unix())
|
model.RemoveChatCache(time.Now().Unix())
|
||||||
common.SysLog("删除过期缓存数据")
|
logger.SysLog("删除过期缓存数据")
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("Cron job error: " + err.Error())
|
logger.SysLog("Cron job error: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
15
main.go
15
main.go
@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/config"
|
"one-api/common/config"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/common/notify"
|
"one-api/common/notify"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
"one-api/common/storage"
|
"one-api/common/storage"
|
||||||
@ -31,8 +32,8 @@ var indexPage []byte
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
config.InitConf()
|
config.InitConf()
|
||||||
common.SetupLogger()
|
logger.SetupLogger()
|
||||||
common.SysLog("One API " + common.Version + " started")
|
logger.SysLog("One API " + common.Version + " started")
|
||||||
// Initialize SQL Database
|
// Initialize SQL Database
|
||||||
model.SetupDB()
|
model.SetupDB()
|
||||||
defer model.CloseDB()
|
defer model.CloseDB()
|
||||||
@ -69,8 +70,8 @@ func initMemoryCache() {
|
|||||||
syncFrequency := viper.GetInt("sync_frequency")
|
syncFrequency := viper.GetInt("sync_frequency")
|
||||||
model.TokenCacheSeconds = syncFrequency
|
model.TokenCacheSeconds = syncFrequency
|
||||||
|
|
||||||
common.SysLog("memory cache enabled")
|
logger.SysLog("memory cache enabled")
|
||||||
common.SysError(fmt.Sprintf("sync frequency: %d seconds", syncFrequency))
|
logger.SysError(fmt.Sprintf("sync frequency: %d seconds", syncFrequency))
|
||||||
go model.SyncOptions(syncFrequency)
|
go model.SyncOptions(syncFrequency)
|
||||||
go SyncChannelCache(syncFrequency)
|
go SyncChannelCache(syncFrequency)
|
||||||
}
|
}
|
||||||
@ -98,19 +99,19 @@ func initHttpServer() {
|
|||||||
|
|
||||||
err := server.Run(":" + port)
|
err := server.Run(":" + port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.FatalLog("failed to start HTTP server: " + err.Error())
|
logger.FatalLog("failed to start HTTP server: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SyncChannelCache(frequency int) {
|
func SyncChannelCache(frequency int) {
|
||||||
// 只有 从 服务器端获取数据的时候才会用到
|
// 只有 从 服务器端获取数据的时候才会用到
|
||||||
if common.IsMasterNode {
|
if common.IsMasterNode {
|
||||||
common.SysLog("master node does't synchronize the channel")
|
logger.SysLog("master node does't synchronize the channel")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Second)
|
time.Sleep(time.Duration(frequency) * time.Second)
|
||||||
common.SysLog("syncing channels from database")
|
logger.SysLog("syncing channels from database")
|
||||||
model.ChannelGroup.Load()
|
model.ChannelGroup.Load()
|
||||||
relay_util.PricingInstance.Init()
|
relay_util.PricingInstance.Init()
|
||||||
}
|
}
|
||||||
|
@ -3,14 +3,14 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetUpLogger(server *gin.Engine) {
|
func SetUpLogger(server *gin.Engine) {
|
||||||
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||||
var requestID string
|
var requestID string
|
||||||
if param.Keys != nil {
|
if param.Keys != nil {
|
||||||
requestID = param.Keys[common.RequestIdKey].(string)
|
requestID = param.Keys[logger.RequestIdKey].(string)
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
|
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
|
||||||
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||||
|
@ -3,7 +3,7 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -13,8 +13,8 @@ func RelayPanicRecover() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
common.SysError(fmt.Sprintf("panic detected: %v", err))
|
logger.SysError(fmt.Sprintf("panic detected: %v", err))
|
||||||
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/MartialBE/one-api", err),
|
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/MartialBE/one-api", err),
|
||||||
|
@ -2,7 +2,7 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
"one-api/common/utils"
|
"one-api/common/utils"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -12,11 +12,11 @@ import (
|
|||||||
func RequestId() func(c *gin.Context) {
|
func RequestId() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
id := utils.GetTimeString() + utils.GetRandomString(8)
|
id := utils.GetTimeString() + utils.GetRandomString(8)
|
||||||
c.Set(common.RequestIdKey, id)
|
c.Set(logger.RequestIdKey, id)
|
||||||
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
|
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
|
||||||
ctx = context.WithValue(ctx, "requestStartTime", time.Now())
|
ctx = context.WithValue(ctx, "requestStartTime", time.Now())
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
c.Header(common.RequestIdKey, id)
|
c.Header(logger.RequestIdKey, id)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
type turnstileCheckResponse struct {
|
type turnstileCheckResponse struct {
|
||||||
@ -37,7 +38,7 @@ func TurnstileCheck() gin.HandlerFunc {
|
|||||||
"remoteip": {c.ClientIP()},
|
"remoteip": {c.ClientIP()},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(err.Error())
|
logger.SysError(err.Error())
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc {
|
|||||||
var res turnstileCheckResponse
|
var res turnstileCheckResponse
|
||||||
err = json.NewDecoder(rawRes.Body).Decode(&res)
|
err = json.NewDecoder(rawRes.Body).Decode(&res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(err.Error())
|
logger.SysError(err.Error())
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
"one-api/common/utils"
|
"one-api/common/utils"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -10,10 +10,10 @@ import (
|
|||||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||||
c.JSON(statusCode, gin.H{
|
c.JSON(statusCode, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": utils.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
"message": utils.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)),
|
||||||
"type": "one_api_error",
|
"type": "one_api_error",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
c.Abort()
|
c.Abort()
|
||||||
common.LogError(c.Request.Context(), message)
|
logger.LogError(c.Request.Context(), message)
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/common/utils"
|
"one-api/common/utils"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -162,7 +163,7 @@ func (cc *ChannelsChooser) Load() {
|
|||||||
|
|
||||||
abilities, err := GetAbilityChannelGroup()
|
abilities, err := GetAbilityChannelGroup()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("get enabled abilities failed: " + err.Error())
|
logger.SysLog("get enabled abilities failed: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -216,5 +217,5 @@ func (cc *ChannelsChooser) Load() {
|
|||||||
cc.Channels = newChannels
|
cc.Channels = newChannels
|
||||||
cc.Match = newMatchList
|
cc.Match = newMatchList
|
||||||
cc.Unlock()
|
cc.Unlock()
|
||||||
common.SysLog("channels Load success")
|
logger.SysLog("channels Load success")
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -34,7 +35,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
|
|||||||
}
|
}
|
||||||
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
|
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Redis set token error: " + err.Error())
|
logger.SysError("Redis set token error: " + err.Error())
|
||||||
}
|
}
|
||||||
return &token, nil
|
return &token, nil
|
||||||
}
|
}
|
||||||
@ -54,7 +55,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
|
|||||||
}
|
}
|
||||||
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(TokenCacheSeconds)*time.Second)
|
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(TokenCacheSeconds)*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Redis set user group error: " + err.Error())
|
logger.SysError("Redis set user group error: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return group, err
|
return group, err
|
||||||
@ -72,7 +73,7 @@ func CacheGetUserQuota(id int) (quota int, err error) {
|
|||||||
}
|
}
|
||||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(TokenCacheSeconds)*time.Second)
|
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(TokenCacheSeconds)*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Redis set user quota error: " + err.Error())
|
logger.SysError("Redis set user quota error: " + err.Error())
|
||||||
}
|
}
|
||||||
return quota, err
|
return quota, err
|
||||||
}
|
}
|
||||||
@ -119,7 +120,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
|||||||
}
|
}
|
||||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(TokenCacheSeconds)*time.Second)
|
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(TokenCacheSeconds)*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Redis set user enabled error: " + err.Error())
|
logger.SysError("Redis set user enabled error: " + err.Error())
|
||||||
}
|
}
|
||||||
return userEnabled, err
|
return userEnabled, err
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/common/utils"
|
"one-api/common/utils"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -240,7 +241,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
|||||||
ResponseTime: int(responseTime),
|
ResponseTime: int(responseTime),
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update response time: " + err.Error())
|
logger.SysError("failed to update response time: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -250,7 +251,7 @@ func (channel *Channel) UpdateBalance(balance float64) {
|
|||||||
Balance: balance,
|
Balance: balance,
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update balance: " + err.Error())
|
logger.SysError("failed to update balance: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -283,11 +284,11 @@ func (channel *Channel) StatusToStr() string {
|
|||||||
func UpdateChannelStatusById(id int, status int) {
|
func UpdateChannelStatusById(id int, status int) {
|
||||||
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update ability status: " + err.Error())
|
logger.SysError("failed to update ability status: " + err.Error())
|
||||||
}
|
}
|
||||||
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
|
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update channel status: " + err.Error())
|
logger.SysError("failed to update channel status: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -307,7 +308,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
|
|||||||
func updateChannelUsedQuota(id int, quota int) {
|
func updateChannelUsedQuota(id int, quota int) {
|
||||||
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update channel used quota: " + err.Error())
|
logger.SysError("failed to update channel used quota: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/common/utils"
|
"one-api/common/utils"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@ -48,12 +49,12 @@ func RecordLog(userId int, logType int, content string) {
|
|||||||
}
|
}
|
||||||
err := DB.Create(log).Error
|
err := DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to record log: " + err.Error())
|
logger.SysError("failed to record log: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, requestTime int) {
|
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, requestTime int) {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
logger.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||||
if !common.LogConsumeEnabled {
|
if !common.LogConsumeEnabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -73,7 +74,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
|||||||
}
|
}
|
||||||
err := DB.Create(log).Error
|
err := DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "failed to record log: "+err.Error())
|
logger.LogError(ctx, "failed to record log: "+err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ package model
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/common/utils"
|
"one-api/common/utils"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -20,7 +21,7 @@ var DB *gorm.DB
|
|||||||
func SetupDB() {
|
func SetupDB() {
|
||||||
err := InitDB()
|
err := InitDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.FatalLog("failed to initialize database: " + err.Error())
|
logger.FatalLog("failed to initialize database: " + err.Error())
|
||||||
}
|
}
|
||||||
ChannelGroup.Load()
|
ChannelGroup.Load()
|
||||||
common.RootUserEmail = GetRootUserEmail()
|
common.RootUserEmail = GetRootUserEmail()
|
||||||
@ -28,7 +29,7 @@ func SetupDB() {
|
|||||||
if viper.GetBool("batch_update_enabled") {
|
if viper.GetBool("batch_update_enabled") {
|
||||||
common.BatchUpdateEnabled = true
|
common.BatchUpdateEnabled = true
|
||||||
common.BatchUpdateInterval = utils.GetOrDefault("batch_update_interval", 5)
|
common.BatchUpdateInterval = utils.GetOrDefault("batch_update_interval", 5)
|
||||||
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||||
InitBatchUpdater()
|
InitBatchUpdater()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -37,7 +38,7 @@ func createRootAccountIfNeed() error {
|
|||||||
var user User
|
var user User
|
||||||
//if user.Status != common.UserStatusEnabled {
|
//if user.Status != common.UserStatusEnabled {
|
||||||
if err := DB.First(&user).Error; err != nil {
|
if err := DB.First(&user).Error; err != nil {
|
||||||
common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
|
logger.SysLog("no user exists, create a root user for you: username is root, password is 123456")
|
||||||
hashedPassword, err := common.Password2Hash("123456")
|
hashedPassword, err := common.Password2Hash("123456")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -61,7 +62,7 @@ func chooseDB() (*gorm.DB, error) {
|
|||||||
dsn := viper.GetString("sql_dsn")
|
dsn := viper.GetString("sql_dsn")
|
||||||
if strings.HasPrefix(dsn, "postgres://") {
|
if strings.HasPrefix(dsn, "postgres://") {
|
||||||
// Use PostgreSQL
|
// Use PostgreSQL
|
||||||
common.SysLog("using PostgreSQL as database")
|
logger.SysLog("using PostgreSQL as database")
|
||||||
common.UsingPostgreSQL = true
|
common.UsingPostgreSQL = true
|
||||||
return gorm.Open(postgres.New(postgres.Config{
|
return gorm.Open(postgres.New(postgres.Config{
|
||||||
DSN: dsn,
|
DSN: dsn,
|
||||||
@ -71,13 +72,13 @@ func chooseDB() (*gorm.DB, error) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
// Use MySQL
|
// Use MySQL
|
||||||
common.SysLog("using MySQL as database")
|
logger.SysLog("using MySQL as database")
|
||||||
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||||
PrepareStmt: true, // precompile SQL
|
PrepareStmt: true, // precompile SQL
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
// Use SQLite
|
// Use SQLite
|
||||||
common.SysLog("SQL_DSN not set, using SQLite as database")
|
logger.SysLog("SQL_DSN not set, using SQLite as database")
|
||||||
common.UsingSQLite = true
|
common.UsingSQLite = true
|
||||||
config := fmt.Sprintf("?_busy_timeout=%d", utils.GetOrDefault("sqlite_busy_timeout", 3000))
|
config := fmt.Sprintf("?_busy_timeout=%d", utils.GetOrDefault("sqlite_busy_timeout", 3000))
|
||||||
return gorm.Open(sqlite.Open(viper.GetString("sqlite_path")+config), &gorm.Config{
|
return gorm.Open(sqlite.Open(viper.GetString("sqlite_path")+config), &gorm.Config{
|
||||||
@ -104,7 +105,7 @@ func InitDB() (err error) {
|
|||||||
if !common.IsMasterNode {
|
if !common.IsMasterNode {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
common.SysLog("database migration started")
|
logger.SysLog("database migration started")
|
||||||
|
|
||||||
migration(DB)
|
migration(DB)
|
||||||
|
|
||||||
@ -152,11 +153,11 @@ func InitDB() (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
common.SysLog("database migrated")
|
logger.SysLog("database migrated")
|
||||||
err = createRootAccountIfNeed()
|
err = createRootAccountIfNeed()
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
common.FatalLog(err)
|
logger.FatalLog(err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
|
|
||||||
"github.com/go-gormigrate/gormigrate/v2"
|
"github.com/go-gormigrate/gormigrate/v2"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@ -22,7 +22,7 @@ func removeKeyIndexMigration() *gormigrate.Migration {
|
|||||||
|
|
||||||
err := tx.Migrator().DropIndex(&Channel{}, "idx_channels_key")
|
err := tx.Migrator().DropIndex(&Channel{}, "idx_channels_key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("remove idx_channels_key Failure: " + err.Error())
|
logger.SysLog("remove idx_channels_key Failure: " + err.Error())
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -2,6 +2,7 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -90,7 +91,7 @@ func loadOptionsFromDatabase() {
|
|||||||
for _, option := range options {
|
for _, option := range options {
|
||||||
err := updateOptionMap(option.Key, option.Value)
|
err := updateOptionMap(option.Key, option.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update option map: " + err.Error())
|
logger.SysError("failed to update option map: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -98,7 +99,7 @@ func loadOptionsFromDatabase() {
|
|||||||
func SyncOptions(frequency int) {
|
func SyncOptions(frequency int) {
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Second)
|
time.Sleep(time.Duration(frequency) * time.Second)
|
||||||
common.SysLog("syncing options from database")
|
logger.SysLog("syncing options from database")
|
||||||
loadOptionsFromDatabase()
|
loadOptionsFromDatabase()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/common/stmp"
|
"one-api/common/stmp"
|
||||||
"one-api/common/utils"
|
"one-api/common/utils"
|
||||||
|
|
||||||
@ -58,7 +59,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
}
|
}
|
||||||
token, err = CacheGetTokenByKey(key)
|
token, err = CacheGetTokenByKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("CacheGetTokenByKey failed: " + err.Error())
|
logger.SysError("CacheGetTokenByKey failed: " + err.Error())
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, errors.New("无效的令牌")
|
return nil, errors.New("无效的令牌")
|
||||||
}
|
}
|
||||||
@ -77,7 +78,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
token.Status = common.TokenStatusExpired
|
token.Status = common.TokenStatusExpired
|
||||||
err := token.SelectUpdate()
|
err := token.SelectUpdate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update token status" + err.Error())
|
logger.SysError("failed to update token status" + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("该令牌已过期")
|
return nil, errors.New("该令牌已过期")
|
||||||
@ -88,7 +89,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
token.Status = common.TokenStatusExhausted
|
token.Status = common.TokenStatusExhausted
|
||||||
err := token.SelectUpdate()
|
err := token.SelectUpdate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update token status" + err.Error())
|
logger.SysError("failed to update token status" + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("该令牌额度已用尽")
|
return nil, errors.New("该令牌额度已用尽")
|
||||||
@ -254,12 +255,12 @@ func sendQuotaWarningEmail(userId int, userQuota int, noMoreQuota bool) {
|
|||||||
user := User{Id: userId}
|
user := User{Id: userId}
|
||||||
|
|
||||||
if err := user.FillUserById(); err != nil {
|
if err := user.FillUserById(); err != nil {
|
||||||
common.SysError("failed to fetch user email: " + err.Error())
|
logger.SysError("failed to fetch user email: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.Email == "" {
|
if user.Email == "" {
|
||||||
common.SysError("user email is empty")
|
logger.SysError("user email is empty")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -271,7 +272,7 @@ func sendQuotaWarningEmail(userId int, userQuota int, noMoreQuota bool) {
|
|||||||
err := stmp.SendQuotaWarningCodeEmail(userName, user.Email, userQuota, noMoreQuota)
|
err := stmp.SendQuotaWarningCodeEmail(userName, user.Email, userQuota, noMoreQuota)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to send email" + err.Error())
|
logger.SysError("failed to send email" + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/common/utils"
|
"one-api/common/utils"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -306,7 +307,7 @@ func IsAdmin(userId int) bool {
|
|||||||
var user User
|
var user User
|
||||||
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
|
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("no such user " + err.Error())
|
logger.SysError("no such user " + err.Error())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return user.Role >= common.RoleAdminUser
|
return user.Role >= common.RoleAdminUser
|
||||||
@ -415,7 +416,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
|||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update user used quota and request count: " + err.Error())
|
logger.SysError("failed to update user used quota and request count: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -426,14 +427,14 @@ func updateUserUsedQuota(id int, quota int) {
|
|||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update user used quota: " + err.Error())
|
logger.SysError("failed to update user used quota: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateUserRequestCount(id int, count int) {
|
func updateUserRequestCount(id int, count int) {
|
||||||
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
|
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update user request count: " + err.Error())
|
logger.SysError("failed to update user request count: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -45,7 +46,7 @@ func addNewRecord(type_ int, id int, value int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func batchUpdate() {
|
func batchUpdate() {
|
||||||
common.SysLog("batch update started")
|
logger.SysLog("batch update started")
|
||||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
batchUpdateLocks[i].Lock()
|
batchUpdateLocks[i].Lock()
|
||||||
store := batchUpdateStores[i]
|
store := batchUpdateStores[i]
|
||||||
@ -57,12 +58,12 @@ func batchUpdate() {
|
|||||||
case BatchUpdateTypeUserQuota:
|
case BatchUpdateTypeUserQuota:
|
||||||
err := increaseUserQuota(key, value)
|
err := increaseUserQuota(key, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to batch update user quota: " + err.Error())
|
logger.SysError("failed to batch update user quota: " + err.Error())
|
||||||
}
|
}
|
||||||
case BatchUpdateTypeTokenQuota:
|
case BatchUpdateTypeTokenQuota:
|
||||||
err := increaseTokenQuota(key, value)
|
err := increaseTokenQuota(key, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to batch update token quota: " + err.Error())
|
logger.SysError("failed to batch update token quota: " + err.Error())
|
||||||
}
|
}
|
||||||
case BatchUpdateTypeUsedQuota:
|
case BatchUpdateTypeUsedQuota:
|
||||||
updateUserUsedQuota(key, value)
|
updateUserUsedQuota(key, value)
|
||||||
@ -73,5 +74,5 @@ func batchUpdate() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
common.SysLog("batch update finished")
|
logger.SysLog("batch update finished")
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/providers/base"
|
"one-api/providers/base"
|
||||||
@ -71,7 +72,7 @@ func (p *MidjourneyProvider) Send(timeout int, requestURL string) (*MidjourneyRe
|
|||||||
|
|
||||||
resp, errWith := p.Requester.SendRequestRaw(req)
|
resp, errWith := p.Requester.SendRequestRaw(req)
|
||||||
if errWith != nil {
|
if errWith != nil {
|
||||||
common.SysError("do request failed: " + errWith.Error())
|
logger.SysError("do request failed: " + errWith.Error())
|
||||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
|
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||||
}
|
}
|
||||||
statusCode := resp.StatusCode
|
statusCode := resp.StatusCode
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/providers/base"
|
"one-api/providers/base"
|
||||||
@ -94,7 +94,7 @@ func (p *XunfeiProvider) getAPIVersion(modelName string) string {
|
|||||||
}
|
}
|
||||||
apiVersion = "v1.1"
|
apiVersion = "v1.1"
|
||||||
|
|
||||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
logger.SysLog("api_version not found, use default: " + apiVersion)
|
||||||
return apiVersion
|
return apiVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,7 +130,7 @@ func (p *XunfeiProvider) buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret st
|
|||||||
}
|
}
|
||||||
ul, err := url.Parse(hostUrl)
|
ul, err := url.Parse(hostUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("url parse error: " + err.Error())
|
logger.SysError("url parse error: " + err.Error())
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
date := time.Now().UTC().Format(time.RFC1123)
|
date := time.Now().UTC().Format(time.RFC1123)
|
||||||
|
@ -4,7 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/providers/base"
|
"one-api/providers/base"
|
||||||
@ -95,7 +95,7 @@ func (p *ZhipuProvider) getZhipuToken() string {
|
|||||||
|
|
||||||
split := strings.Split(apikey, ".")
|
split := strings.Split(apikey, ".")
|
||||||
if len(split) != 2 {
|
if len(split) != 2 {
|
||||||
common.SysError("invalid zhipu key: " + apikey)
|
logger.SysError("invalid zhipu key: " + apikey)
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
"one-api/common/utils"
|
"one-api/common/utils"
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
@ -115,7 +116,7 @@ func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, erro
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName)
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName)
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
message = "数据库一致性已被破坏,请联系管理员"
|
||||||
}
|
}
|
||||||
return nil, errors.New(message)
|
return nil, errors.New(message)
|
||||||
@ -250,14 +251,14 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *types.OpenAIErrorWithStatusCode) {
|
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *types.OpenAIErrorWithStatusCode) {
|
||||||
common.LogError(ctx, fmt.Sprintf("relay error (channel #%d(%s)): %s", channelId, channelName, err.Message))
|
logger.LogError(ctx, fmt.Sprintf("relay error (channel #%d(%s)): %s", channelId, channelName, err.Message))
|
||||||
if controller.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
if controller.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
||||||
controller.DisableChannel(channelId, channelName, err.Message, true)
|
controller.DisableChannel(channelId, channelName, err.Message, true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func relayResponseWithErr(c *gin.Context, err *types.OpenAIErrorWithStatusCode) {
|
func relayResponseWithErr(c *gin.Context, err *types.OpenAIErrorWithStatusCode) {
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(logger.RequestIdKey)
|
||||||
err.OpenAIError.Message = utils.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
err.OpenAIError.Message = utils.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
||||||
c.JSON(err.StatusCode, gin.H{
|
c.JSON(err.StatusCode, gin.H{
|
||||||
"error": err.OpenAIError,
|
"error": err.OpenAIError,
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay/relay_util"
|
"one-api/relay/relay_util"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
@ -51,7 +52,7 @@ func Relay(c *gin.Context) {
|
|||||||
|
|
||||||
retryTimes := common.RetryTimes
|
retryTimes := common.RetryTimes
|
||||||
if done || !shouldRetry(c, apiErr.StatusCode) {
|
if done || !shouldRetry(c, apiErr.StatusCode) {
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error happen, status code is %d, won't retry in this case", apiErr.StatusCode))
|
logger.LogError(c.Request.Context(), fmt.Sprintf("relay error happen, status code is %d, won't retry in this case", apiErr.StatusCode))
|
||||||
retryTimes = 0
|
retryTimes = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,7 +64,7 @@ func Relay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
channel = relay.getProvider().GetChannel()
|
channel = relay.getProvider().GetChannel()
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("using channel #%d(%s) to retry (remain times %d)", channel.Id, channel.Name, i))
|
logger.LogError(c.Request.Context(), fmt.Sprintf("using channel #%d(%s) to retry (remain times %d)", channel.Id, channel.Name, i))
|
||||||
apiErr, done = RelayHandler(relay)
|
apiErr, done = RelayHandler(relay)
|
||||||
if apiErr == nil {
|
if apiErr == nil {
|
||||||
return
|
return
|
||||||
|
@ -6,7 +6,7 @@ package midjourney
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common/logger"
|
||||||
provider "one-api/providers/midjourney"
|
provider "one-api/providers/midjourney"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ func RelayMidjourney(c *gin.Context) {
|
|||||||
"code": err.Code,
|
"code": err.Code,
|
||||||
})
|
})
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
logger.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/common/utils"
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"sort"
|
"sort"
|
||||||
@ -30,7 +31,7 @@ type BatchPrices struct {
|
|||||||
|
|
||||||
// NewPricing creates a new Pricing instance
|
// NewPricing creates a new Pricing instance
|
||||||
func NewPricing() {
|
func NewPricing() {
|
||||||
common.SysLog("Initializing Pricing")
|
logger.SysLog("Initializing Pricing")
|
||||||
|
|
||||||
PricingInstance = &Pricing{
|
PricingInstance = &Pricing{
|
||||||
Prices: make(map[string]*model.Price),
|
Prices: make(map[string]*model.Price),
|
||||||
@ -40,16 +41,16 @@ func NewPricing() {
|
|||||||
err := PricingInstance.Init()
|
err := PricingInstance.Init()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Failed to initialize Pricing:" + err.Error())
|
logger.SysError("Failed to initialize Pricing:" + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 初始化时,需要检测是否有更新
|
// 初始化时,需要检测是否有更新
|
||||||
if viper.GetBool("auto_price_updates") || len(PricingInstance.Prices) == 0 {
|
if viper.GetBool("auto_price_updates") || len(PricingInstance.Prices) == 0 {
|
||||||
common.SysLog("Checking for pricing updates")
|
logger.SysLog("Checking for pricing updates")
|
||||||
prices := model.GetDefaultPrice()
|
prices := model.GetDefaultPrice()
|
||||||
PricingInstance.SyncPricing(prices, false)
|
PricingInstance.SyncPricing(prices, false)
|
||||||
common.SysLog("Pricing initialized")
|
logger.SysLog("Pricing initialized")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"time"
|
"time"
|
||||||
@ -154,7 +155,7 @@ func (q *Quota) Undo(c *gin.Context) {
|
|||||||
// return pre-consumed quota
|
// return pre-consumed quota
|
||||||
err := model.PostConsumeTokenQuota(tokenId, -q.preConsumedQuota)
|
err := model.PostConsumeTokenQuota(tokenId, -q.preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
logger.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||||
}
|
}
|
||||||
}(c.Request.Context())
|
}(c.Request.Context())
|
||||||
}
|
}
|
||||||
@ -166,7 +167,7 @@ func (q *Quota) Consume(c *gin.Context, usage *types.Usage) {
|
|||||||
go func(ctx context.Context) {
|
go func(ctx context.Context) {
|
||||||
err := q.completedQuotaConsumption(usage, tokenName, ctx)
|
err := q.completedQuotaConsumption(usage, tokenName, ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, err.Error())
|
logger.LogError(ctx, err.Error())
|
||||||
}
|
}
|
||||||
}(c.Request.Context())
|
}(c.Request.Context())
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/logger"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -18,7 +19,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
|||||||
frontendBaseUrl := viper.GetString("frontend_base_url")
|
frontendBaseUrl := viper.GetString("frontend_base_url")
|
||||||
if common.IsMasterNode && frontendBaseUrl != "" {
|
if common.IsMasterNode && frontendBaseUrl != "" {
|
||||||
frontendBaseUrl = ""
|
frontendBaseUrl = ""
|
||||||
common.SysLog("FRONTEND_BASE_URL is ignored on master node")
|
logger.SysLog("FRONTEND_BASE_URL is ignored on master node")
|
||||||
}
|
}
|
||||||
if frontendBaseUrl == "" {
|
if frontendBaseUrl == "" {
|
||||||
SetWebRouter(router, buildFS, indexPage)
|
SetWebRouter(router, buildFS, indexPage)
|
||||||
|
Loading…
Reference in New Issue
Block a user