fix: fix baidu system prompt (close #1079)
This commit is contained in:
parent
e99150bdb9
commit
79d0cd378a
@ -1,7 +1,7 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/env"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
@ -94,16 +94,16 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
|||||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||||
|
|
||||||
var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second
|
var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
|
||||||
|
|
||||||
var BatchUpdateEnabled = false
|
var BatchUpdateEnabled = false
|
||||||
var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5)
|
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)
|
||||||
|
|
||||||
var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second
|
var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second
|
||||||
|
|
||||||
var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||||
|
|
||||||
var Theme = helper.GetOrDefaultEnvString("THEME", "default")
|
var Theme = env.String("THEME", "default")
|
||||||
var ValidThemes = map[string]bool{
|
var ValidThemes = map[string]bool{
|
||||||
"default": true,
|
"default": true,
|
||||||
"berry": true,
|
"berry": true,
|
||||||
@ -112,10 +112,10 @@ var ValidThemes = map[string]bool{
|
|||||||
// All duration's unit is seconds
|
// All duration's unit is seconds
|
||||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||||
var (
|
var (
|
||||||
GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180)
|
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180)
|
||||||
GlobalApiRateLimitDuration int64 = 3 * 60
|
GlobalApiRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60)
|
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||||
GlobalWebRateLimitDuration int64 = 3 * 60
|
GlobalWebRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
UploadRateLimitNum = 10
|
UploadRateLimitNum = 10
|
||||||
@ -130,8 +130,8 @@ var (
|
|||||||
|
|
||||||
var RateLimitKeyExpirationDuration = 20 * time.Minute
|
var RateLimitKeyExpirationDuration = 20 * time.Minute
|
||||||
|
|
||||||
var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false)
|
var EnableMetric = env.Bool("ENABLE_METRIC", false)
|
||||||
var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10)
|
var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
|
||||||
var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
|
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
|
||||||
var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024)
|
var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
|
||||||
var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128)
|
var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import "github.com/songquanpeng/one-api/common/helper"
|
import (
|
||||||
|
"github.com/songquanpeng/one-api/common/env"
|
||||||
|
)
|
||||||
|
|
||||||
var UsingSQLite = false
|
var UsingSQLite = false
|
||||||
var UsingPostgreSQL = false
|
var UsingPostgreSQL = false
|
||||||
var UsingMySQL = false
|
var UsingMySQL = false
|
||||||
|
|
||||||
var SQLitePath = "one-api.db"
|
var SQLitePath = "one-api.db"
|
||||||
var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000)
|
var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000)
|
||||||
|
42
common/env/helper.go
vendored
Normal file
42
common/env/helper.go
vendored
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Bool(env string, defaultValue bool) bool {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return os.Getenv(env) == "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
func Int(env string, defaultValue int) int {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
num, err := strconv.Atoi(os.Getenv(env))
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return num
|
||||||
|
}
|
||||||
|
|
||||||
|
func Float64(env string, defaultValue float64) float64 {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
num, err := strconv.ParseFloat(os.Getenv(env), 64)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return num
|
||||||
|
}
|
||||||
|
|
||||||
|
func String(env string, defaultValue string) string {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return os.Getenv(env)
|
||||||
|
}
|
@ -3,12 +3,10 @@ package helper
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
|
||||||
"html/template"
|
"html/template"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -195,44 +193,6 @@ func Max(a int, b int) int {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetOrDefaultEnvBool(env string, defaultValue bool) bool {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return os.Getenv(env) == "true"
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetOrDefaultEnvInt(env string, defaultValue int) int {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
num, err := strconv.Atoi(os.Getenv(env))
|
|
||||||
if err != nil {
|
|
||||||
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return num
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
num, err := strconv.ParseFloat(os.Getenv(env), 64)
|
|
||||||
if err != nil {
|
|
||||||
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue))
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return num
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetOrDefaultEnvString(env string, defaultValue string) string {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return os.Getenv(env)
|
|
||||||
}
|
|
||||||
|
|
||||||
func AssignOrDefault(value string, defaultValue string) string {
|
func AssignOrDefault(value string, defaultValue string) string {
|
||||||
if len(value) != 0 {
|
if len(value) != 0 {
|
||||||
return value
|
return value
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
@ -54,7 +55,9 @@ func SysError(s string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Debug(ctx context.Context, msg string) {
|
func Debug(ctx context.Context, msg string) {
|
||||||
logHelper(ctx, loggerDEBUG, msg)
|
if config.DebugEnabled {
|
||||||
|
logHelper(ctx, loggerDEBUG, msg)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Info(ctx context.Context, msg string) {
|
func Info(ctx context.Context, msg string) {
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/env"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"gorm.io/driver/mysql"
|
"gorm.io/driver/mysql"
|
||||||
@ -81,9 +82,9 @@ func InitDB() (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100))
|
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
|
||||||
sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000))
|
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
|
||||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60)))
|
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
|
||||||
|
|
||||||
if !config.IsMasterNode {
|
if !config.IsMasterNode {
|
||||||
return nil
|
return nil
|
||||||
|
@ -32,9 +32,16 @@ type Message struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
Stream bool `json:"stream"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
UserId string `json:"user_id,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
System string `json:"system,omitempty"`
|
||||||
|
DisableSearch bool `json:"disable_search,omitempty"`
|
||||||
|
EnableCitation bool `json:"enable_citation,omitempty"`
|
||||||
|
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
|
||||||
|
UserId string `json:"user_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
@ -45,28 +52,28 @@ type Error struct {
|
|||||||
var baiduTokenStore sync.Map
|
var baiduTokenStore sync.Map
|
||||||
|
|
||||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||||
messages := make([]Message, 0, len(request.Messages))
|
baiduRequest := ChatRequest{
|
||||||
|
Messages: make([]Message, 0, len(request.Messages)),
|
||||||
|
Temperature: request.Temperature,
|
||||||
|
TopP: request.TopP,
|
||||||
|
PenaltyScore: request.FrequencyPenalty,
|
||||||
|
Stream: request.Stream,
|
||||||
|
DisableSearch: false,
|
||||||
|
EnableCitation: false,
|
||||||
|
MaxOutputTokens: request.MaxTokens,
|
||||||
|
UserId: request.User,
|
||||||
|
}
|
||||||
for _, message := range request.Messages {
|
for _, message := range request.Messages {
|
||||||
if message.Role == "system" {
|
if message.Role == "system" {
|
||||||
messages = append(messages, Message{
|
baiduRequest.System = message.StringContent()
|
||||||
Role: "user",
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
messages = append(messages, Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "Okay",
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
messages = append(messages, Message{
|
baiduRequest.Messages = append(baiduRequest.Messages, Message{
|
||||||
Role: message.Role,
|
Role: message.Role,
|
||||||
Content: message.StringContent(),
|
Content: message.StringContent(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &ChatRequest{
|
return &baiduRequest
|
||||||
Messages: messages,
|
|
||||||
Stream: request.Stream,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
|
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||||
|
@ -74,6 +74,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user