From af8908db540d0ad4650e7595b79b0ccb066a9a38 Mon Sep 17 00:00:00 2001 From: Tisfeng Date: Mon, 1 Jan 2024 16:42:19 +0800 Subject: [PATCH] feat: able to change gemini safety setting (#867) * perf: adjust gemini safety settings, set BLOCK_NONE by default * feat: able to adjust by env variable --------- Co-authored-by: JustSong --- README.md | 1 + common/constants.go | 2 ++ common/utils.go | 7 +++++++ controller/relay-gemini.go | 36 ++++++++++++++++++------------------ 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index ff9e0bc0..b53936c4 100644 --- a/README.md +++ b/README.md @@ -366,6 +366,7 @@ graph LR + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 +17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/constants.go b/common/constants.go index 60700ec8..e4cbf8bf 100644 --- a/common/constants.go +++ b/common/constants.go @@ -98,6 +98,8 @@ var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second +var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") + const ( RequestIdKey = "X-Oneapi-Request-Id" ) diff --git a/common/utils.go b/common/utils.go index 21bec8f5..9a7038e2 100644 --- a/common/utils.go +++ b/common/utils.go @@ -196,6 +196,13 @@ func GetOrDefault(env string, defaultValue int) int { return num } +func GetOrDefaultString(env string, defaultValue string) string { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) +} + func MessageWithRequestId(message string, id string) string { return fmt.Sprintf("%s (request id: %s)", message, id) } diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index ec55d4b6..d8ab58d6 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -63,24 +63,24 @@ type GeminiChatGenerationConfig struct { func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { geminiRequest := GeminiChatRequest{ Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), - //SafetySettings: []GeminiChatSafetySettings{ - // { - // Category: "HARM_CATEGORY_HARASSMENT", - // Threshold: "BLOCK_ONLY_HIGH", - // }, - // { - // Category: "HARM_CATEGORY_HATE_SPEECH", - // Threshold: "BLOCK_ONLY_HIGH", - // }, - // { - // Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", - // Threshold: "BLOCK_ONLY_HIGH", - // }, - // { - // Category: "HARM_CATEGORY_DANGEROUS_CONTENT", - // Threshold: "BLOCK_ONLY_HIGH", - // }, - //}, + SafetySettings: []GeminiChatSafetySettings{ + { + Category: "HARM_CATEGORY_HARASSMENT", + Threshold: common.GeminiSafetySetting, + }, + { + Category: "HARM_CATEGORY_HATE_SPEECH", + Threshold: common.GeminiSafetySetting, + }, + { + Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + Threshold: common.GeminiSafetySetting, + }, + { + Category: "HARM_CATEGORY_DANGEROUS_CONTENT", + Threshold: common.GeminiSafetySetting, + }, + }, GenerationConfig: GeminiChatGenerationConfig{ Temperature: textRequest.Temperature, TopP: textRequest.TopP,