fix: fix baidu system prompt (close #1079)

This commit is contained in:
JustSong 2024-03-13 22:56:54 +08:00
parent e99150bdb9
commit 79d0cd378a
8 changed files with 92 additions and 76 deletions

View File

@ -1,7 +1,7 @@
package config package config
import ( import (
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/env"
"os" "os"
"strconv" "strconv"
"sync" "sync"
@ -94,16 +94,16 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false var BatchUpdateEnabled = false
var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var Theme = helper.GetOrDefaultEnvString("THEME", "default") var Theme = env.String("THEME", "default")
var ValidThemes = map[string]bool{ var ValidThemes = map[string]bool{
"default": true, "default": true,
"berry": true, "berry": true,
@ -112,10 +112,10 @@ var ValidThemes = map[string]bool{
// All duration's unit is seconds // All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration // Shouldn't larger then RateLimitKeyExpirationDuration
var ( var (
GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60 GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60 GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10 UploadRateLimitNum = 10
@ -130,8 +130,8 @@ var (
var RateLimitKeyExpirationDuration = 20 * time.Minute var RateLimitKeyExpirationDuration = 20 * time.Minute
var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false) var EnableMetric = env.Bool("ENABLE_METRIC", false)
var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10) var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024) var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128) var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)

View File

@ -1,10 +1,12 @@
package common package common
import "github.com/songquanpeng/one-api/common/helper" import (
"github.com/songquanpeng/one-api/common/env"
)
var UsingSQLite = false var UsingSQLite = false
var UsingPostgreSQL = false var UsingPostgreSQL = false
var UsingMySQL = false var UsingMySQL = false
var SQLitePath = "one-api.db" var SQLitePath = "one-api.db"
var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000)

42
common/env/helper.go vendored Normal file
View File

@ -0,0 +1,42 @@
package env
import (
"os"
"strconv"
)
func Bool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env) == "true"
}
func Int(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
return defaultValue
}
return num
}
func Float64(env string, defaultValue float64) float64 {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.ParseFloat(os.Getenv(env), 64)
if err != nil {
return defaultValue
}
return num
}
func String(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}

View File

@ -3,12 +3,10 @@ package helper
import ( import (
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/songquanpeng/one-api/common/logger"
"html/template" "html/template"
"log" "log"
"math/rand" "math/rand"
"net" "net"
"os"
"os/exec" "os/exec"
"runtime" "runtime"
"strconv" "strconv"
@ -195,44 +193,6 @@ func Max(a int, b int) int {
} }
} }
func GetOrDefaultEnvBool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env) == "true"
}
func GetOrDefaultEnvInt(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.ParseFloat(os.Getenv(env), 64)
if err != nil {
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultEnvString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func AssignOrDefault(value string, defaultValue string) string { func AssignOrDefault(value string, defaultValue string) string {
if len(value) != 0 { if len(value) != 0 {
return value return value

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"io" "io"
"log" "log"
"os" "os"
@ -54,7 +55,9 @@ func SysError(s string) {
} }
func Debug(ctx context.Context, msg string) { func Debug(ctx context.Context, msg string) {
logHelper(ctx, loggerDEBUG, msg) if config.DebugEnabled {
logHelper(ctx, loggerDEBUG, msg)
}
} }
func Info(ctx context.Context, msg string) { func Info(ctx context.Context, msg string) {

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/env"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
@ -81,9 +82,9 @@ func InitDB() (err error) {
if err != nil { if err != nil {
return err return err
} }
sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
if !config.IsMasterNode { if !config.IsMasterNode {
return nil return nil

View File

@ -32,9 +32,16 @@ type Message struct {
} }
type ChatRequest struct { type ChatRequest struct {
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
Stream bool `json:"stream"` Temperature float64 `json:"temperature,omitempty"`
UserId string `json:"user_id,omitempty"` TopP float64 `json:"top_p,omitempty"`
PenaltyScore float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
EnableCitation bool `json:"enable_citation,omitempty"`
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
UserId string `json:"user_id,omitempty"`
} }
type Error struct { type Error struct {
@ -45,28 +52,28 @@ type Error struct {
var baiduTokenStore sync.Map var baiduTokenStore sync.Map
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages)) baiduRequest := ChatRequest{
Messages: make([]Message, 0, len(request.Messages)),
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
MaxOutputTokens: request.MaxTokens,
UserId: request.User,
}
for _, message := range request.Messages { for _, message := range request.Messages {
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, Message{ baiduRequest.System = message.StringContent()
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, Message{
Role: "assistant",
Content: "Okay",
})
} else { } else {
messages = append(messages, Message{ baiduRequest.Messages = append(baiduRequest.Messages, Message{
Role: message.Role, Role: message.Role,
Content: message.StringContent(), Content: message.StringContent(),
}) })
} }
} }
return &ChatRequest{ return &baiduRequest
Messages: messages,
Stream: request.Stream,
}
} }
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {

View File

@ -74,6 +74,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
} }
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData) requestBody = bytes.NewBuffer(jsonData)
} }