fix: fix baidu system prompt (close #1079)
This commit is contained in:
parent
e99150bdb9
commit
79d0cd378a
@ -1,7 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/env"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
@ -94,16 +94,16 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||
|
||||
var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second
|
||||
var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
|
||||
|
||||
var BatchUpdateEnabled = false
|
||||
var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5)
|
||||
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)
|
||||
|
||||
var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second
|
||||
var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second
|
||||
|
||||
var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
|
||||
var Theme = helper.GetOrDefaultEnvString("THEME", "default")
|
||||
var Theme = env.String("THEME", "default")
|
||||
var ValidThemes = map[string]bool{
|
||||
"default": true,
|
||||
"berry": true,
|
||||
@ -112,10 +112,10 @@ var ValidThemes = map[string]bool{
|
||||
// All duration's unit is seconds
|
||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||
var (
|
||||
GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180)
|
||||
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180)
|
||||
GlobalApiRateLimitDuration int64 = 3 * 60
|
||||
|
||||
GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitDuration int64 = 3 * 60
|
||||
|
||||
UploadRateLimitNum = 10
|
||||
@ -130,8 +130,8 @@ var (
|
||||
|
||||
var RateLimitKeyExpirationDuration = 20 * time.Minute
|
||||
|
||||
var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false)
|
||||
var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10)
|
||||
var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
|
||||
var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024)
|
||||
var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128)
|
||||
var EnableMetric = env.Bool("ENABLE_METRIC", false)
|
||||
var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
|
||||
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
|
||||
var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
|
||||
var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
|
||||
|
@ -1,10 +1,12 @@
|
||||
package common
|
||||
|
||||
import "github.com/songquanpeng/one-api/common/helper"
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/common/env"
|
||||
)
|
||||
|
||||
var UsingSQLite = false
|
||||
var UsingPostgreSQL = false
|
||||
var UsingMySQL = false
|
||||
|
||||
var SQLitePath = "one-api.db"
|
||||
var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000)
|
||||
var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000)
|
||||
|
42
common/env/helper.go
vendored
Normal file
42
common/env/helper.go
vendored
Normal file
@ -0,0 +1,42 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func Bool(env string, defaultValue bool) bool {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return os.Getenv(env) == "true"
|
||||
}
|
||||
|
||||
func Int(env string, defaultValue int) int {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
num, err := strconv.Atoi(os.Getenv(env))
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
func Float64(env string, defaultValue float64) float64 {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
num, err := strconv.ParseFloat(os.Getenv(env), 64)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
func String(env string, defaultValue string) string {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return os.Getenv(env)
|
||||
}
|
@ -3,12 +3,10 @@ package helper
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"html/template"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@ -195,44 +193,6 @@ func Max(a int, b int) int {
|
||||
}
|
||||
}
|
||||
|
||||
func GetOrDefaultEnvBool(env string, defaultValue bool) bool {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return os.Getenv(env) == "true"
|
||||
}
|
||||
|
||||
func GetOrDefaultEnvInt(env string, defaultValue int) int {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
num, err := strconv.Atoi(os.Getenv(env))
|
||||
if err != nil {
|
||||
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
|
||||
return defaultValue
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
num, err := strconv.ParseFloat(os.Getenv(env), 64)
|
||||
if err != nil {
|
||||
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue))
|
||||
return defaultValue
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
func GetOrDefaultEnvString(env string, defaultValue string) string {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return os.Getenv(env)
|
||||
}
|
||||
|
||||
func AssignOrDefault(value string, defaultValue string) string {
|
||||
if len(value) != 0 {
|
||||
return value
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
@ -54,7 +55,9 @@ func SysError(s string) {
|
||||
}
|
||||
|
||||
func Debug(ctx context.Context, msg string) {
|
||||
if config.DebugEnabled {
|
||||
logHelper(ctx, loggerDEBUG, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func Info(ctx context.Context, msg string) {
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/env"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"gorm.io/driver/mysql"
|
||||
@ -81,9 +82,9 @@ func InitDB() (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100))
|
||||
sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000))
|
||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60)))
|
||||
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
|
||||
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
|
||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
|
||||
|
||||
if !config.IsMasterNode {
|
||||
return nil
|
||||
|
@ -33,7 +33,14 @@ type Message struct {
|
||||
|
||||
type ChatRequest struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
DisableSearch bool `json:"disable_search,omitempty"`
|
||||
EnableCitation bool `json:"enable_citation,omitempty"`
|
||||
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
@ -45,28 +52,28 @@ type Error struct {
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
baiduRequest := ChatRequest{
|
||||
Messages: make([]Message, 0, len(request.Messages)),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
PenaltyScore: request.FrequencyPenalty,
|
||||
Stream: request.Stream,
|
||||
DisableSearch: false,
|
||||
EnableCitation: false,
|
||||
MaxOutputTokens: request.MaxTokens,
|
||||
UserId: request.User,
|
||||
}
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, Message{
|
||||
Role: "user",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, Message{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
baiduRequest.System = message.StringContent()
|
||||
} else {
|
||||
messages = append(messages, Message{
|
||||
baiduRequest.Messages = append(baiduRequest.Messages, Message{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &ChatRequest{
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
}
|
||||
return &baiduRequest
|
||||
}
|
||||
|
||||
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
|
@ -74,6 +74,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user