diff --git a/README.md b/README.md index ff9e0bc0..b53936c4 100644 --- a/README.md +++ b/README.md @@ -366,6 +366,7 @@ graph LR + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 +17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/constants.go b/common/constants.go index 60700ec8..e4cbf8bf 100644 --- a/common/constants.go +++ b/common/constants.go @@ -98,6 +98,8 @@ var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second +var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") + const ( RequestIdKey = "X-Oneapi-Request-Id" ) diff --git a/common/image/image.go b/common/image/image.go index eae76286..de8fefd3 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -15,6 +15,9 @@ import ( _ "golang.org/x/image/webp" ) +// Regex to match data URL pattern +var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) + func IsImageUrl(url string) (bool, error) { resp, err := http.Head(url) if err != nil { @@ -44,9 +47,13 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) { } func GetImageFromUrl(url string) (mimeType string, data string, err error) { - // openai's image_url support base64 encoded image - if strings.HasPrefix(url, "data:image/jpeg;base64,") { - return "image/jpeg", strings.TrimPrefix(url, "data:image/jpeg;base64,"), nil + // Check if the URL is a data URL + matches := dataURLPattern.FindStringSubmatch(url) + if len(matches) == 3 { + // URL is a data URL + mimeType = "image/" + matches[1] + data = matches[2] + return } isImage, err := IsImageUrl(url) diff --git a/common/model-ratio.go b/common/model-ratio.go index fa2adaa1..2908be17 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -52,6 +52,8 @@ var ModelRatio = map[string]float64{ "gpt-3.5-turbo-16k-0613": 1.5, "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens + "davinci-002" 1, // $0.002 / 1K tokens + "babbage-002" 0.2, // $0.0004 / 1K tokens "text-ada-001": 0.2, "text-babbage-001": 0.25, "text-curie-001": 1, diff --git a/common/utils.go b/common/utils.go index 21bec8f5..9a7038e2 100644 --- a/common/utils.go +++ b/common/utils.go @@ -196,6 +196,13 @@ func GetOrDefault(env string, defaultValue int) int { return num } +func GetOrDefaultString(env string, defaultValue string) string { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) +} + func MessageWithRequestId(message string, id string) string { return fmt.Sprintf("%s (request id: %s)", message, id) } diff --git a/controller/model.go b/controller/model.go index 6a759b63..6cb530db 100644 --- a/controller/model.go +++ b/controller/model.go @@ -342,6 +342,24 @@ func init() { Root: "code-davinci-edit-001", Parent: nil, }, + { + Id: "davinci-002", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "davinci-002", + Parent: nil, + }, + { + Id: "babbage-002", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "babbage-002", + Parent: nil, + }, { Id: "claude-instant-1", Object: "model", diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index ec55d4b6..d8ab58d6 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -63,24 +63,24 @@ type GeminiChatGenerationConfig struct { func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { geminiRequest := GeminiChatRequest{ Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), - //SafetySettings: []GeminiChatSafetySettings{ - // { - // Category: "HARM_CATEGORY_HARASSMENT", - // Threshold: "BLOCK_ONLY_HIGH", - // }, - // { - // Category: "HARM_CATEGORY_HATE_SPEECH", - // Threshold: "BLOCK_ONLY_HIGH", - // }, - // { - // Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", - // Threshold: "BLOCK_ONLY_HIGH", - // }, - // { - // Category: "HARM_CATEGORY_DANGEROUS_CONTENT", - // Threshold: "BLOCK_ONLY_HIGH", - // }, - //}, + SafetySettings: []GeminiChatSafetySettings{ + { + Category: "HARM_CATEGORY_HARASSMENT", + Threshold: common.GeminiSafetySetting, + }, + { + Category: "HARM_CATEGORY_HATE_SPEECH", + Threshold: common.GeminiSafetySetting, + }, + { + Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + Threshold: common.GeminiSafetySetting, + }, + { + Category: "HARM_CATEGORY_DANGEROUS_CONTENT", + Threshold: common.GeminiSafetySetting, + }, + }, GenerationConfig: GeminiChatGenerationConfig{ Temperature: textRequest.Temperature, TopP: textRequest.TopP, diff --git a/controller/relay-image.go b/controller/relay-image.go index 7e1fed39..14a2983b 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -168,6 +168,9 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode var textResponse ImageResponse defer func(ctx context.Context) { + if resp.StatusCode != http.StatusOK { + return + } err := model.PostConsumeTokenQuota(tokenId, quota) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) diff --git a/i18n/en.json b/i18n/en.json index b0deb83a..7b51909b 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -526,5 +526,7 @@ "模型版本": "Model version", "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", "点击查看": "click to view", - "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!" + "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!", + "测试所有渠道": "Test all channels", + "更新已启用渠道余额": "Update the balance of enabled channels" } diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 5d68e2da..a2adfd32 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -523,10 +523,10 @@ const ChannelsTable = () => { 添加新的渠道 + loading={loading || updatingBalance}>更新已启用渠道余额