diff --git a/common/config/config.go b/common/config/config.go index c62a6ac6..83cfa933 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -1,7 +1,7 @@ package config import ( - "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/env" "os" "strconv" "sync" @@ -94,16 +94,16 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var RequestInterval = time.Duration(requestInterval) * time.Second -var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second +var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second var BatchUpdateEnabled = false -var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) +var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5) -var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second +var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second -var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") +var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE") -var Theme = helper.GetOrDefaultEnvString("THEME", "default") +var Theme = env.String("THEME", "default") var ValidThemes = map[string]bool{ "default": true, "berry": true, @@ -112,10 +112,10 @@ var ValidThemes = map[string]bool{ // All duration's unit is seconds // Shouldn't larger then RateLimitKeyExpirationDuration var ( - GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) + GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitDuration int64 = 3 * 60 - GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) + GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration int64 = 3 * 60 UploadRateLimitNum = 10 @@ -130,8 +130,8 @@ var ( var RateLimitKeyExpirationDuration = 20 * time.Minute -var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false) -var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10) -var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) -var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024) -var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128) +var EnableMetric = env.Bool("ENABLE_METRIC", false) +var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10) +var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) +var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024) +var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) diff --git a/common/database.go b/common/database.go index df60bdd5..f2db759f 100644 --- a/common/database.go +++ b/common/database.go @@ -1,10 +1,12 @@ package common -import "github.com/songquanpeng/one-api/common/helper" +import ( + "github.com/songquanpeng/one-api/common/env" +) var UsingSQLite = false var UsingPostgreSQL = false var UsingMySQL = false var SQLitePath = "one-api.db" -var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) +var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/env/helper.go b/common/env/helper.go new file mode 100644 index 00000000..fdb9f827 --- /dev/null +++ b/common/env/helper.go @@ -0,0 +1,42 @@ +package env + +import ( + "os" + "strconv" +) + +func Bool(env string, defaultValue bool) bool { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) == "true" +} + +func Int(env string, defaultValue int) int { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.Atoi(os.Getenv(env)) + if err != nil { + return defaultValue + } + return num +} + +func Float64(env string, defaultValue float64) float64 { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.ParseFloat(os.Getenv(env), 64) + if err != nil { + return defaultValue + } + return num +} + +func String(env string, defaultValue string) string { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) +} diff --git a/common/helper/helper.go b/common/helper/helper.go index 23578842..76db5042 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -3,12 +3,10 @@ package helper import ( "fmt" "github.com/google/uuid" - "github.com/songquanpeng/one-api/common/logger" "html/template" "log" "math/rand" "net" - "os" "os/exec" "runtime" "strconv" @@ -195,44 +193,6 @@ func Max(a int, b int) int { } } -func GetOrDefaultEnvBool(env string, defaultValue bool) bool { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - return os.Getenv(env) == "true" -} - -func GetOrDefaultEnvInt(env string, defaultValue int) int { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - num, err := strconv.Atoi(os.Getenv(env)) - if err != nil { - logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) - return defaultValue - } - return num -} - -func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - num, err := strconv.ParseFloat(os.Getenv(env), 64) - if err != nil { - logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue)) - return defaultValue - } - return num -} - -func GetOrDefaultEnvString(env string, defaultValue string) string { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - return os.Getenv(env) -} - func AssignOrDefault(value string, defaultValue string) string { if len(value) != 0 { return value diff --git a/common/logger/logger.go b/common/logger/logger.go index 41b98ca3..ad0a0bea 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "io" "log" "os" @@ -54,7 +55,9 @@ func SysError(s string) { } func Debug(ctx context.Context, msg string) { - logHelper(ctx, loggerDEBUG, msg) + if config.DebugEnabled { + logHelper(ctx, loggerDEBUG, msg) + } } func Info(ctx context.Context, msg string) { diff --git a/model/main.go b/model/main.go index f27cdb6f..05150fd9 100644 --- a/model/main.go +++ b/model/main.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/env" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/driver/mysql" @@ -81,9 +82,9 @@ func InitDB() (err error) { if err != nil { return err } - sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) + sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) if !config.IsMasterNode { return nil diff --git a/relay/channel/baidu/main.go b/relay/channel/baidu/main.go index 4f2b13fc..9ca9e47d 100644 --- a/relay/channel/baidu/main.go +++ b/relay/channel/baidu/main.go @@ -32,9 +32,16 @@ type Message struct { } type ChatRequest struct { - Messages []Message `json:"messages"` - Stream bool `json:"stream"` - UserId string `json:"user_id,omitempty"` + Messages []Message `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + PenaltyScore float64 `json:"penalty_score,omitempty"` + Stream bool `json:"stream,omitempty"` + System string `json:"system,omitempty"` + DisableSearch bool `json:"disable_search,omitempty"` + EnableCitation bool `json:"enable_citation,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + UserId string `json:"user_id,omitempty"` } type Error struct { @@ -45,28 +52,28 @@ type Error struct { var baiduTokenStore sync.Map func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { - messages := make([]Message, 0, len(request.Messages)) + baiduRequest := ChatRequest{ + Messages: make([]Message, 0, len(request.Messages)), + Temperature: request.Temperature, + TopP: request.TopP, + PenaltyScore: request.FrequencyPenalty, + Stream: request.Stream, + DisableSearch: false, + EnableCitation: false, + MaxOutputTokens: request.MaxTokens, + UserId: request.User, + } for _, message := range request.Messages { if message.Role == "system" { - messages = append(messages, Message{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, Message{ - Role: "assistant", - Content: "Okay", - }) + baiduRequest.System = message.StringContent() } else { - messages = append(messages, Message{ + baiduRequest.Messages = append(baiduRequest.Messages, Message{ Role: message.Role, Content: message.StringContent(), }) } } - return &ChatRequest{ - Messages: messages, - Stream: request.Stream, - } + return &baiduRequest } func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { diff --git a/relay/controller/text.go b/relay/controller/text.go index 781170f4..ba008713 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -74,6 +74,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { if err != nil { return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) } + logger.Debugf(ctx, "converted request: \n%s", string(jsonData)) requestBody = bytes.NewBuffer(jsonData) }