From b747cdbc6fde8eab634183bf2abae86e696a63cc Mon Sep 17 00:00:00 2001 From: JustSong Date: Mon, 26 Feb 2024 22:52:16 +0800 Subject: [PATCH 01/11] fix: fix getAndValidateTextRequest failed: unexpected end of JSON input (close #1043) --- common/gin.go | 18 +++++++++++++++--- controller/relay.go | 14 +++++++++++--- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/common/gin.go b/common/gin.go index bed2c2b1..b6ef96a6 100644 --- a/common/gin.go +++ b/common/gin.go @@ -8,12 +8,24 @@ import ( "strings" ) -func UnmarshalBodyReusable(c *gin.Context, v any) error { +const KeyRequestBody = "key_request_body" + +func GetRequestBody(c *gin.Context) ([]byte, error) { + requestBody, _ := c.Get(KeyRequestBody) + if requestBody != nil { + return requestBody.([]byte), nil + } requestBody, err := io.ReadAll(c.Request.Body) if err != nil { - return err + return nil, err } - err = c.Request.Body.Close() + _ = c.Request.Body.Close() + c.Set(KeyRequestBody, requestBody) + return requestBody.([]byte), nil +} + +func UnmarshalBodyReusable(c *gin.Context, v any) error { + requestBody, err := GetRequestBody(c) if err != nil { return err } diff --git a/controller/relay.go b/controller/relay.go index 240042b6..499e8ddc 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,9 +1,11 @@ package controller import ( + "bytes" "context" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" @@ -13,6 +15,7 @@ import ( "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" + "io" "net/http" ) @@ -50,8 +53,8 @@ func Relay(c *gin.Context) { go processChannelRelayError(ctx, channelId, channelName, bizErr) requestId := c.GetString(logger.RequestIdKey) retryTimes := config.RetryTimes - if !shouldRetry(bizErr.StatusCode) { - logger.Errorf(ctx, "relay error happen, but status code is %d, won't retry in this case", bizErr.StatusCode) + if !shouldRetry(c, bizErr.StatusCode) { + logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) retryTimes = 0 } for i := retryTimes; i > 0; i-- { @@ -65,6 +68,8 @@ func Relay(c *gin.Context) { continue } middleware.SetupContextForSelectedChannel(c, channel, originalModel) + requestBody, err := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) bizErr = relay(c, relayMode) if bizErr == nil { return @@ -85,7 +90,10 @@ func Relay(c *gin.Context) { } } -func shouldRetry(statusCode int) bool { +func shouldRetry(c *gin.Context, statusCode int) bool { + if _, ok := c.Get("specific_channel_id"); ok { + return false + } if statusCode == http.StatusTooManyRequests { return true } From eac6a0b9aada6a3d59d070cd0eddf741a063d2dd Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 2 Mar 2024 00:03:29 +0800 Subject: [PATCH 02/11] fix: fix version is blank --- .github/workflows/linux-release.yml | 2 +- .github/workflows/macos-release.yml | 2 +- .github/workflows/windows-release.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml index 98edc471..e81ab09f 100644 --- a/.github/workflows/linux-release.yml +++ b/.github/workflows/linux-release.yml @@ -38,7 +38,7 @@ jobs: - name: Build Backend (amd64) run: | go mod download - go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api + go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api - name: Build Backend (arm64) run: | diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml index 9142609f..13415276 100644 --- a/.github/workflows/macos-release.yml +++ b/.github/workflows/macos-release.yml @@ -38,7 +38,7 @@ jobs: - name: Build Backend run: | go mod download - go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos + go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos - name: Release uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml index c058f41d..8b1160b4 100644 --- a/.github/workflows/windows-release.yml +++ b/.github/workflows/windows-release.yml @@ -41,7 +41,7 @@ jobs: - name: Build Backend run: | go mod download - go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe + go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe - name: Release uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') From 614c2e044266330b79c7a6d07b338b4fb5e14a80 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 2 Mar 2024 00:55:48 +0800 Subject: [PATCH 03/11] feat: support baichuan's models now (close #1057) --- README.md | 1 + common/constants.go | 2 ++ common/logger/logger.go | 9 +++++++++ controller/model.go | 12 ++++++++++++ controller/relay.go | 4 ++++ relay/channel/baichuan/constants.go | 7 +++++++ relay/channel/openai/adaptor.go | 5 +++++ relay/controller/text.go | 3 ++- web/berry/src/constants/ChannelConstants.js | 6 ++++++ web/berry/src/views/Channel/type/Config.js | 12 ++++++++++++ web/default/src/constants/channel.constants.js | 1 + web/default/src/pages/Channel/EditChannel.js | 3 +++ 12 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 relay/channel/baichuan/constants.go diff --git a/README.md b/README.md index ff1fffd2..ae2ffac5 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [360 智脑](https://ai.360.cn) + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) + [x] [Moonshot AI](https://platform.moonshot.cn/) + + [x] [百川大模型](https://platform.baichuan-ai.com) + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) + [ ] [MINIMAX](https://api.minimax.chat/) (WIP) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 diff --git a/common/constants.go b/common/constants.go index ccaa3560..85fcca18 100644 --- a/common/constants.go +++ b/common/constants.go @@ -64,6 +64,7 @@ const ( ChannelTypeTencent = 23 ChannelTypeGemini = 24 ChannelTypeMoonshot = 25 + ChannelTypeBaichuan = 26 ) var ChannelBaseURLs = []string{ @@ -93,6 +94,7 @@ var ChannelBaseURLs = []string{ "https://hunyuan.cloud.tencent.com", // 23 "https://generativelanguage.googleapis.com", // 24 "https://api.moonshot.cn", // 25 + "https://api.baichuan-ai.com", // 26 } const ( diff --git a/common/logger/logger.go b/common/logger/logger.go index f970ee61..8232b2fc 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -13,6 +13,7 @@ import ( ) const ( + loggerDEBUG = "DEBUG" loggerINFO = "INFO" loggerWarn = "WARN" loggerError = "ERR" @@ -55,6 +56,10 @@ func SysError(s string) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } +func Debug(ctx context.Context, msg string) { + logHelper(ctx, loggerDEBUG, msg) +} + func Info(ctx context.Context, msg string) { logHelper(ctx, loggerINFO, msg) } @@ -67,6 +72,10 @@ func Error(ctx context.Context, msg string) { logHelper(ctx, loggerError, msg) } +func Debugf(ctx context.Context, format string, a ...any) { + Debug(ctx, fmt.Sprintf(format, a...)) +} + func Infof(ctx context.Context, format string, a ...any) { Info(ctx, fmt.Sprintf(format, a...)) } diff --git a/controller/model.go b/controller/model.go index f5760901..42ebb598 100644 --- a/controller/model.go +++ b/controller/model.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel/ai360" + "github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" @@ -98,6 +99,17 @@ func init() { Parent: nil, }) } + for _, modelName := range baichuan.ModelList { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "baichuan", + Permission: permission, + Root: modelName, + Parent: nil, + }) + } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { openAIModelsMap[model.Id] = model diff --git a/controller/relay.go b/controller/relay.go index 499e8ddc..278c0b32 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -41,6 +41,10 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { func Relay(c *gin.Context) { ctx := c.Request.Context() relayMode := constant.Path2RelayMode(c.Request.URL.Path) + if config.DebugEnabled { + requestBody, _ := common.GetRequestBody(c) + logger.Debugf(ctx, "request body: %s", string(requestBody)) + } bizErr := relay(c, relayMode) if bizErr == nil { return diff --git a/relay/channel/baichuan/constants.go b/relay/channel/baichuan/constants.go new file mode 100644 index 00000000..cb20a1ff --- /dev/null +++ b/relay/channel/baichuan/constants.go @@ -0,0 +1,7 @@ +package baichuan + +var ModelList = []string{ + "Baichuan2-Turbo", + "Baichuan2-Turbo-192k", + "Baichuan-Text-Embedding", +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 1313e317..0b727d2e 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -7,6 +7,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/ai360" + "github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" @@ -84,6 +85,8 @@ func (a *Adaptor) GetModelList() []string { return ai360.ModelList case common.ChannelTypeMoonshot: return moonshot.ModelList + case common.ChannelTypeBaichuan: + return baichuan.ModelList default: return ModelList } @@ -97,6 +100,8 @@ func (a *Adaptor) GetChannelName() string { return "360" case common.ChannelTypeMoonshot: return "moonshot" + case common.ChannelTypeBaichuan: + return "baichuan" default: return "openai" } diff --git a/relay/controller/text.go b/relay/controller/text.go index cc460511..59c5f637 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -55,7 +55,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { var requestBody io.Reader if meta.APIType == constant.APITypeOpenAI { // no need to convert request for openai - if isModelMapped { + shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan + if shouldResetRequestBody { jsonStr, err := json.Marshal(textRequest) if err != nil { return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index aeff5190..86c3e3aa 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -71,6 +71,12 @@ export const CHANNEL_OPTIONS = { value: 23, color: 'default' }, + 26: { + key: 26, + text: '百川大模型', + value: 23, + color: 'default' + }, 8: { key: 8, text: '自定义渠道', diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index a091c8d6..c7e759b5 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -145,6 +145,18 @@ const typeConfig = { }, modelGroup: "google gemini", }, + 25: { + input: { + models: ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k'], + }, + modelGroup: "moonshot", + }, + 26: { + input: { + models: ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding'], + }, + modelGroup: "baichuan", + }, }; export { defaultConfig, typeConfig }; diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index 16da1b97..0cf06327 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -11,6 +11,7 @@ export const CHANNEL_OPTIONS = [ { key: 19, text: '360 智脑', value: 19, color: 'blue' }, { key: 25, text: 'Moonshot AI', value: 25, color: 'black' }, { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, + { key: 26, text: '百川大模型', value: 26, color: 'orange' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, diff --git a/web/default/src/pages/Channel/EditChannel.js b/web/default/src/pages/Channel/EditChannel.js index 4f4633ff..7a33b47f 100644 --- a/web/default/src/pages/Channel/EditChannel.js +++ b/web/default/src/pages/Channel/EditChannel.js @@ -102,6 +102,9 @@ const EditChannel = () => { case 25: localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k']; break; + case 26: + localModels = ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding']; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); } From df1fd9aa81084e435da96b41398ef950e1697ca7 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 2 Mar 2024 01:24:28 +0800 Subject: [PATCH 04/11] feat: support minimax's models now (close #354) --- common/constants.go | 2 ++ common/model-ratio.go | 8 ++++++++ controller/model.go | 12 ++++++++++++ relay/channel/minimax/constants.go | 7 +++++++ relay/channel/minimax/main.go | 14 ++++++++++++++ relay/channel/openai/adaptor.go | 13 +++++++++++-- web/berry/src/constants/ChannelConstants.js | 8 +++++++- web/berry/src/views/Channel/type/Config.js | 6 ++++++ web/default/src/constants/channel.constants.js | 1 + web/default/src/pages/Channel/EditChannel.js | 3 +++ 10 files changed, 71 insertions(+), 3 deletions(-) create mode 100644 relay/channel/minimax/constants.go create mode 100644 relay/channel/minimax/main.go diff --git a/common/constants.go b/common/constants.go index 85fcca18..f67dc146 100644 --- a/common/constants.go +++ b/common/constants.go @@ -65,6 +65,7 @@ const ( ChannelTypeGemini = 24 ChannelTypeMoonshot = 25 ChannelTypeBaichuan = 26 + ChannelTypeMinimax = 27 ) var ChannelBaseURLs = []string{ @@ -95,6 +96,7 @@ var ChannelBaseURLs = []string{ "https://generativelanguage.googleapis.com", // 24 "https://api.moonshot.cn", // 25 "https://api.baichuan-ai.com", // 26 + "https://api.minimax.chat", // 27 } const ( diff --git a/common/model-ratio.go b/common/model-ratio.go index 2e7aae71..d1f70de8 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -127,6 +127,14 @@ var ModelRatio = map[string]float64{ "moonshot-v1-8k": 0.012 * RMB, "moonshot-v1-32k": 0.024 * RMB, "moonshot-v1-128k": 0.06 * RMB, + // https://platform.baichuan-ai.com/price + "Baichuan2-Turbo": 0.008 * RMB, + "Baichuan2-Turbo-192k": 0.016 * RMB, + "Baichuan2-53B": 0.02 * RMB, + // https://api.minimax.chat/document/price + "abab6": 0.1 * RMB, + "abab5.5": 0.015 * RMB, + "abab5.5s": 0.005 * RMB, } func ModelRatio2JSONString() string { diff --git a/controller/model.go b/controller/model.go index 42ebb598..0f33f919 100644 --- a/controller/model.go +++ b/controller/model.go @@ -5,6 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/relay/channel/baichuan" + "github.com/songquanpeng/one-api/relay/channel/minimax" "github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" @@ -110,6 +111,17 @@ func init() { Parent: nil, }) } + for _, modelName := range minimax.ModelList { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "minimax", + Permission: permission, + Root: modelName, + Parent: nil, + }) + } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { openAIModelsMap[model.Id] = model diff --git a/relay/channel/minimax/constants.go b/relay/channel/minimax/constants.go new file mode 100644 index 00000000..c3da5b2d --- /dev/null +++ b/relay/channel/minimax/constants.go @@ -0,0 +1,7 @@ +package minimax + +var ModelList = []string{ + "abab5.5s-chat", + "abab5.5-chat", + "abab6-chat", +} diff --git a/relay/channel/minimax/main.go b/relay/channel/minimax/main.go new file mode 100644 index 00000000..a01821c2 --- /dev/null +++ b/relay/channel/minimax/main.go @@ -0,0 +1,14 @@ +package minimax + +import ( + "fmt" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/util" +) + +func GetRequestURL(meta *util.RelayMeta) (string, error) { + if meta.Mode == constant.RelayModeChatCompletions { + return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil + } + return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode) +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 0b727d2e..6afe2b2f 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -8,6 +8,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/relay/channel/baichuan" + "github.com/songquanpeng/one-api/relay/channel/minimax" "github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" @@ -25,7 +26,8 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { } func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - if meta.ChannelType == common.ChannelTypeAzure { + switch meta.ChannelType { + case common.ChannelTypeAzure: // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api requestURL := strings.Split(meta.RequestURLPath, "?")[0] requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) @@ -39,8 +41,11 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil + case common.ChannelTypeMinimax: + return minimax.GetRequestURL(meta) + default: + return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil } - return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { @@ -87,6 +92,8 @@ func (a *Adaptor) GetModelList() []string { return moonshot.ModelList case common.ChannelTypeBaichuan: return baichuan.ModelList + case common.ChannelTypeMinimax: + return minimax.ModelList default: return ModelList } @@ -102,6 +109,8 @@ func (a *Adaptor) GetChannelName() string { return "moonshot" case common.ChannelTypeBaichuan: return "baichuan" + case common.ChannelTypeMinimax: + return "minimax" default: return "openai" } diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index 86c3e3aa..98ceaebf 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -74,7 +74,13 @@ export const CHANNEL_OPTIONS = { 26: { key: 26, text: '百川大模型', - value: 23, + value: 26, + color: 'default' + }, + 27: { + key: 27, + text: 'MiniMax', + value: 27, color: 'default' }, 8: { diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index c7e759b5..0e89868b 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -157,6 +157,12 @@ const typeConfig = { }, modelGroup: "baichuan", }, + 27: { + input: { + models: ['abab5.5s-chat', 'abab5.5-chat', 'abab6-chat'], + }, + modelGroup: "minimax", + }, }; export { defaultConfig, typeConfig }; diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index 0cf06327..beb0adb1 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -12,6 +12,7 @@ export const CHANNEL_OPTIONS = [ { key: 25, text: 'Moonshot AI', value: 25, color: 'black' }, { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, { key: 26, text: '百川大模型', value: 26, color: 'orange' }, + { key: 27, text: 'MiniMax', value: 27, color: 'red' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, diff --git a/web/default/src/pages/Channel/EditChannel.js b/web/default/src/pages/Channel/EditChannel.js index 7a33b47f..b9214fd8 100644 --- a/web/default/src/pages/Channel/EditChannel.js +++ b/web/default/src/pages/Channel/EditChannel.js @@ -105,6 +105,9 @@ const EditChannel = () => { case 26: localModels = ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding']; break; + case 27: + localModels = ['abab5.5s-chat', 'abab5.5-chat', 'abab6-chat']; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); } From 76467285e86c34e95aa4c4ea8038083a658ee5ce Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 2 Mar 2024 01:25:21 +0800 Subject: [PATCH 05/11] docs: update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ae2ffac5..a92142ae 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [Moonshot AI](https://platform.moonshot.cn/) + [x] [百川大模型](https://platform.baichuan-ai.com) + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) - + [ ] [MINIMAX](https://api.minimax.chat/) (WIP) + + [x] [MINIMAX](https://api.minimax.chat/) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 From f9490bb72e22003148dc39716d16c24e64c47962 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 2 Mar 2024 01:32:04 +0800 Subject: [PATCH 06/11] fix: able to use updated default ratio --- common/model-ratio.go | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index d1f70de8..3be9118d 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -132,9 +132,18 @@ var ModelRatio = map[string]float64{ "Baichuan2-Turbo-192k": 0.016 * RMB, "Baichuan2-53B": 0.02 * RMB, // https://api.minimax.chat/document/price - "abab6": 0.1 * RMB, - "abab5.5": 0.015 * RMB, - "abab5.5s": 0.005 * RMB, + "abab6-chat": 0.1 * RMB, + "abab5.5-chat": 0.015 * RMB, + "abab5.5s-chat": 0.005 * RMB, +} + +var DefaultModelRatio map[string]float64 + +func init() { + DefaultModelRatio = make(map[string]float64) + for k, v := range ModelRatio { + DefaultModelRatio[k] = v + } } func ModelRatio2JSONString() string { @@ -155,6 +164,9 @@ func GetModelRatio(name string) float64 { name = strings.TrimSuffix(name, "-internet") } ratio, ok := ModelRatio[name] + if !ok { + ratio, ok = DefaultModelRatio[name] + } if !ok { logger.SysError("model ratio not found: " + name) return 30 From 1d0b7fb5ae73a9eefe28da5b2cf5f6b6af335c02 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 2 Mar 2024 03:05:25 +0800 Subject: [PATCH 07/11] feat: support chatglm-4 (close #1045, close #952, close #952, close #943) --- common/model-ratio.go | 20 +++++++----- relay/channel/openai/adaptor.go | 2 +- relay/channel/openai/main.go | 10 ++++-- relay/channel/openai/model.go | 1 + relay/channel/tencent/main.go | 1 + relay/channel/zhipu/adaptor.go | 34 ++++++++++++++++++++ relay/channel/zhipu/constants.go | 1 + web/berry/src/views/Channel/type/Config.js | 2 +- web/default/src/pages/Channel/EditChannel.js | 2 +- 9 files changed, 59 insertions(+), 14 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 3be9118d..1594b534 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -94,14 +94,18 @@ var ModelRatio = map[string]float64{ "claude-2.0": 5.51, // $11.02 / 1M tokens "claude-2.1": 5.51, // $11.02 / 1M tokens // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 - "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens - "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens - "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens - "ERNIE-Bot-8k": 0.024 * RMB, - "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens - "PaLM-2": 1, - "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens + "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens + "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens + "ERNIE-Bot-8k": 0.024 * RMB, + "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens + "PaLM-2": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + // https://open.bigmodel.cn/pricing + "glm-4": 0.1 * RMB, + "glm-4v": 0.1 * RMB, + "glm-3-turbo": 0.005 * RMB, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 6afe2b2f..27d0fc27 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -76,7 +76,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string - err, responseText = StreamHandler(c, resp, meta.Mode) + err, responseText, _ = StreamHandler(c, resp, meta.Mode) usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) } else { err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go index fbe55cf9..d47cd164 100644 --- a/relay/channel/openai/main.go +++ b/relay/channel/openai/main.go @@ -14,7 +14,7 @@ import ( "strings" ) -func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { responseText := "" scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -31,6 +31,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E }) dataChan := make(chan string) stopChan := make(chan bool) + var usage *model.Usage go func() { for scanner.Scan() { data := scanner.Text() @@ -54,6 +55,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E for _, choice := range streamResponse.Choices { responseText += choice.Delta.Content } + if streamResponse.Usage != nil { + usage = streamResponse.Usage + } case constant.RelayModeCompletions: var streamResponse CompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) @@ -86,9 +90,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E }) err := resp.Body.Close() if err != nil { - return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil } - return nil, responseText + return nil, responseText, usage } func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go index b24485a8..6c0b2c53 100644 --- a/relay/channel/openai/model.go +++ b/relay/channel/openai/model.go @@ -132,6 +132,7 @@ type ChatCompletionsStreamResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionsStreamResponseChoice `json:"choices"` + Usage *model.Usage `json:"usage"` } type CompletionsStreamResponse struct { diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go index 05edac20..fa26651b 100644 --- a/relay/channel/tencent/main.go +++ b/relay/channel/tencent/main.go @@ -81,6 +81,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { response := openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: "tencent-hunyuan", diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 7a822853..90cc79d3 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -5,20 +5,35 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" + "strings" ) type Adaptor struct { + APIVersion string } func (a *Adaptor) Init(meta *util.RelayMeta) { } +func (a *Adaptor) SetVersionByModeName(modelName string) { + if strings.HasPrefix(modelName, "glm-") { + a.APIVersion = "v4" + } else { + a.APIVersion = "v3" + } +} + func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + a.SetVersionByModeName(meta.ActualModelName) + if a.APIVersion == "v4" { + return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil + } method := "invoke" if meta.IsStream { method = "sse-invoke" @@ -37,6 +52,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } + if request.TopP >= 1 { + request.TopP = 0.99 + } + a.SetVersionByModeName(request.Model) + if a.APIVersion == "v4" { + return request, nil + } return ConvertRequest(*request), nil } @@ -44,7 +66,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io return channel.DoRequestHelper(a, c, meta, requestBody) } +func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, _, usage = openai.StreamHandler(c, resp, meta.Mode) + } else { + err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if a.APIVersion == "v4" { + return a.DoResponseV4(c, resp, meta) + } if meta.IsStream { err, usage = StreamHandler(c, resp) } else { diff --git a/relay/channel/zhipu/constants.go b/relay/channel/zhipu/constants.go index f0367b82..1655a59d 100644 --- a/relay/channel/zhipu/constants.go +++ b/relay/channel/zhipu/constants.go @@ -2,4 +2,5 @@ package zhipu var ModelList = []string{ "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", + "glm-4", "glm-4v", "glm-3-turbo", } diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index 0e89868b..4dec33de 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -67,7 +67,7 @@ const typeConfig = { }, 16: { input: { - models: ["chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite"], + models: ["glm-4", "glm-4v", "glm-3-turbo", "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite"], }, modelGroup: "zhipu", }, diff --git a/web/default/src/pages/Channel/EditChannel.js b/web/default/src/pages/Channel/EditChannel.js index b9214fd8..693242f9 100644 --- a/web/default/src/pages/Channel/EditChannel.js +++ b/web/default/src/pages/Channel/EditChannel.js @@ -79,7 +79,7 @@ const EditChannel = () => { localModels = [...localModels, ...withInternetVersion]; break; case 16: - localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; + localModels = ["glm-4", "glm-4v", "glm-3-turbo",'chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; break; case 18: localModels = [ From de18d6fe16e41a18465db908b9851426e6340721 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Mar 2024 19:30:11 +0800 Subject: [PATCH 08/11] refactor: refactor image relay (close #1068) --- common/model-ratio.go | 23 ------- relay/constant/image.go | 24 +++++++ relay/controller/helper.go | 59 +++++++++++++++++ relay/controller/image.go | 132 +++++++++++-------------------------- 4 files changed, 122 insertions(+), 116 deletions(-) create mode 100644 relay/constant/image.go diff --git a/common/model-ratio.go b/common/model-ratio.go index 1594b534..ab0ad748 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -7,29 +7,6 @@ import ( "time" ) -var DalleSizeRatios = map[string]map[string]float64{ - "dall-e-2": { - "256x256": 1, - "512x512": 1.125, - "1024x1024": 1.25, - }, - "dall-e-3": { - "1024x1024": 1, - "1024x1792": 2, - "1792x1024": 2, - }, -} - -var DalleGenerationImageAmounts = map[string][2]int{ - "dall-e-2": {1, 10}, - "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. -} - -var DalleImagePromptLengthLimitations = map[string]int{ - "dall-e-2": 1000, - "dall-e-3": 4000, -} - const ( USD2RMB = 7 USD = 500 // $0.002 = 1 -> $1 = 500 diff --git a/relay/constant/image.go b/relay/constant/image.go new file mode 100644 index 00000000..5e04895f --- /dev/null +++ b/relay/constant/image.go @@ -0,0 +1,24 @@ +package constant + +var DalleSizeRatios = map[string]map[string]float64{ + "dall-e-2": { + "256x256": 1, + "512x512": 1.125, + "1024x1024": 1.25, + }, + "dall-e-3": { + "1024x1024": 1, + "1024x1792": 2, + "1792x1024": 2, + }, +} + +var DalleGenerationImageAmounts = map[string][2]int{ + "dall-e-2": {1, 10}, + "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. +} + +var DalleImagePromptLengthLimitations = map[string]int{ + "dall-e-2": 1000, + "dall-e-3": 4000, +} diff --git a/relay/controller/helper.go b/relay/controller/helper.go index a06b2768..d5078304 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -36,6 +36,65 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener return textRequest, nil } +func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) { + imageRequest := &openai.ImageRequest{} + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err + } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" + } + return imageRequest, nil +} + +func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode { + // model validation + _, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] + if !hasValidSize { + return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + } + // check prompt length + if imageRequest.Prompt == "" { + return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) + } + if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] { + return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) + } + // Number of generated images validation + if !isWithinRange(imageRequest.Model, imageRequest.N) { + // channel not azure + if meta.ChannelType != common.ChannelTypeAzure { + return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + } + } + return nil +} + +func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) { + if imageRequest == nil { + return 0, errors.New("imageRequest is nil") + } + imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] + if !hasValidSize { + return 0, fmt.Errorf("size not supported for this image model: %s", imageRequest.Size) + } + if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { + if imageRequest.Size == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + return imageCostRatio, nil +} + func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case constant.RelayModeChatCompletions: diff --git a/relay/controller/image.go b/relay/controller/image.go index 6ec368f5..3ce3809b 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -10,6 +10,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/constant" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -20,120 +21,65 @@ import ( ) func isWithinRange(element string, value int) bool { - if _, ok := common.DalleGenerationImageAmounts[element]; !ok { + if _, ok := constant.DalleGenerationImageAmounts[element]; !ok { return false } - min := common.DalleGenerationImageAmounts[element][0] - max := common.DalleGenerationImageAmounts[element][1] + min := constant.DalleGenerationImageAmounts[element][0] + max := constant.DalleGenerationImageAmounts[element][1] return value >= min && value <= max } func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { - imageModel := "dall-e-2" - imageSize := "1024x1024" - - tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - userId := c.GetInt("id") - group := c.GetString("group") - - var imageRequest openai.ImageRequest - err := common.UnmarshalBodyReusable(c, &imageRequest) + ctx := c.Request.Context() + meta := util.GetRelayMeta(c) + imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { - return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - if imageRequest.N == 0 { - imageRequest.N = 1 - } - - // Size validation - if imageRequest.Size != "" { - imageSize = imageRequest.Size - } - - // Model validation - if imageRequest.Model != "" { - imageModel = imageRequest.Model - } - - imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] - - // Check if model is supported - if hasValidSize { - if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { - if imageSize == "1024x1024" { - imageCostRatio *= 2 - } else { - imageCostRatio *= 1.5 - } - } - } else { - return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) - } - - // Prompt validation - if imageRequest.Prompt == "" { - return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) - } - - // Check prompt length - if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { - return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) - } - - // Number of generated images validation - if !isWithinRange(imageModel, imageRequest.N) { - // channel not azure - if channelType != common.ChannelTypeAzure { - return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) - } + logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } // map model name - modelMapping := c.GetString("model_mapping") - isModelMapped := false - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[imageModel] != "" { - imageModel = modelMap[imageModel] - isModelMapped = true - } + var isModelMapped bool + meta.OriginModelName = imageRequest.Model + imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping) + meta.ActualModelName = imageRequest.Model + + // model validation + bizErr := validateImageRequest(imageRequest, meta) + if bizErr != nil { + return bizErr } - baseURL := common.ChannelBaseURLs[channelType] + + imageCostRatio, err := getImageCostRatio(imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) + } + requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure { + fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) + if meta.ChannelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api apiVersion := util.GetAzureAPIVersion(c) // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion) } var requestBody io.Reader - if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body + if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { - return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) } else { requestBody = c.Request.Body } - modelRatio := common.GetModelRatio(imageModel) - groupRatio := common.GetGroupRatio(group) + modelRatio := common.GetModelRatio(imageRequest.Model) + groupRatio := common.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio - userQuota, err := model.CacheGetUserQuota(userId) + userQuota, err := model.CacheGetUserQuota(meta.UserId) quota := int(ratio*imageCostRatio*1000) * imageRequest.N @@ -146,7 +92,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } token := c.Request.Header.Get("Authorization") - if channelType == common.ChannelTypeAzure { // Azure authentication + if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication token = strings.TrimPrefix(token, "Bearer ") req.Header.Set("api-key", token) } else { @@ -169,25 +115,25 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - var textResponse openai.ImageResponse + var imageResponse openai.ImageResponse defer func(ctx context.Context) { if resp.StatusCode != http.StatusOK { return } - err := model.PostConsumeTokenQuota(tokenId, quota) + err := model.PostConsumeTokenQuota(meta.TokenId, quota) if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) } - err = model.CacheUpdateUserQuota(userId) + err = model.CacheUpdateUserQuota(meta.UserId) if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } @@ -202,7 +148,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - err = json.Unmarshal(responseBody, &textResponse) + err = json.Unmarshal(responseBody, &imageResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } From 82e916b5ff9c6c3f8325c91c04003fb6751f024c Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Mar 2024 20:51:28 +0800 Subject: [PATCH 09/11] fix: fix azure test (close #1069) --- controller/channel-test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/controller/channel-test.go b/controller/channel-test.go index b498f4f1..485d7702 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -8,6 +8,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" @@ -51,6 +52,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error c.Request.Header.Set("Content-Type", "application/json") c.Set("channel", channel.Type) c.Set("base_url", channel.GetBaseURL()) + middleware.SetupContextForSelectedChannel(c, channel, "") meta := util.GetRelayMeta(c) apiType := constant.ChannelType2APIType(channel.Type) adaptor := helper.GetAdaptor(apiType) From b35f3523d305de758021286fff9f62bece08af02 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Mar 2024 21:03:04 +0800 Subject: [PATCH 10/11] feat: add gemini model alias (close #1064) --- relay/channel/gemini/constants.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/relay/channel/gemini/constants.go b/relay/channel/gemini/constants.go index 5bb0c168..4e7c57f9 100644 --- a/relay/channel/gemini/constants.go +++ b/relay/channel/gemini/constants.go @@ -1,6 +1,6 @@ package gemini var ModelList = []string{ - "gemini-pro", - "gemini-pro-vision", + "gemini-pro", "gemini-1.0-pro-001", + "gemini-pro-vision", "gemini-1.0-pro-vision-001", } From 9d8967f7d325999ab28616aef5b9eeaa1cccf6dc Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Mar 2024 21:46:45 +0800 Subject: [PATCH 11/11] feat: support Mistral's models now (close #1051) --- README.md | 1 + common/constants.go | 2 ++ common/model-ratio.go | 23 ++++++++++++++++--- controller/model.go | 12 ++++++++++ relay/channel/mistral/constants.go | 10 ++++++++ relay/channel/openai/adaptor.go | 5 ++++ relay/controller/helper.go | 10 ++++---- web/berry/src/constants/ChannelConstants.js | 6 +++++ .../src/constants/channel.constants.js | 1 + 9 files changed, 61 insertions(+), 9 deletions(-) create mode 100644 relay/channel/mistral/constants.go diff --git a/README.md b/README.md index a92142ae..69bb10ef 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude 系列模型](https://anthropic.com) + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + + [x] [Mistral 系列模型](https://mistral.ai/) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) diff --git a/common/constants.go b/common/constants.go index f67dc146..ac901139 100644 --- a/common/constants.go +++ b/common/constants.go @@ -66,6 +66,7 @@ const ( ChannelTypeMoonshot = 25 ChannelTypeBaichuan = 26 ChannelTypeMinimax = 27 + ChannelTypeMistral = 28 ) var ChannelBaseURLs = []string{ @@ -97,6 +98,7 @@ var ChannelBaseURLs = []string{ "https://api.moonshot.cn", // 25 "https://api.baichuan-ai.com", // 26 "https://api.minimax.chat", // 27 + "https://api.mistral.ai", // 28 } const ( diff --git a/common/model-ratio.go b/common/model-ratio.go index ab0ad748..2e66ac0d 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -17,7 +17,6 @@ const ( // https://platform.openai.com/docs/models/model-endpoint-compatibility // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf // https://openai.com/pricing -// TODO: when a new api is enabled, check the pricing here // 1 === $0.002 / 1K tokens // 1 === ¥0.014 / 1k tokens var ModelRatio = map[string]float64{ @@ -116,15 +115,29 @@ var ModelRatio = map[string]float64{ "abab6-chat": 0.1 * RMB, "abab5.5-chat": 0.015 * RMB, "abab5.5s-chat": 0.005 * RMB, + // https://docs.mistral.ai/platform/pricing/ + "open-mistral-7b": 0.25 / 1000 * USD, + "open-mixtral-8x7b": 0.7 / 1000 * USD, + "mistral-small-latest": 2.0 / 1000 * USD, + "mistral-medium-latest": 2.7 / 1000 * USD, + "mistral-large-latest": 8.0 / 1000 * USD, + "mistral-embed": 0.1 / 1000 * USD, } +var CompletionRatio = map[string]float64{} + var DefaultModelRatio map[string]float64 +var DefaultCompletionRatio map[string]float64 func init() { DefaultModelRatio = make(map[string]float64) for k, v := range ModelRatio { DefaultModelRatio[k] = v } + DefaultCompletionRatio = make(map[string]float64) + for k, v := range CompletionRatio { + DefaultCompletionRatio[k] = v + } } func ModelRatio2JSONString() string { @@ -155,8 +168,6 @@ func GetModelRatio(name string) float64 { return ratio } -var CompletionRatio = map[string]float64{} - func CompletionRatio2JSONString() string { jsonBytes, err := json.Marshal(CompletionRatio) if err != nil { @@ -174,6 +185,9 @@ func GetCompletionRatio(name string) float64 { if ratio, ok := CompletionRatio[name]; ok { return ratio } + if ratio, ok := DefaultCompletionRatio[name]; ok { + return ratio + } if strings.HasPrefix(name, "gpt-3.5") { if strings.HasSuffix(name, "0125") { // https://openai.com/blog/new-embedding-models-and-api-updates @@ -206,5 +220,8 @@ func GetCompletionRatio(name string) float64 { if strings.HasPrefix(name, "claude-2") { return 2.965517 } + if strings.HasPrefix(name, "mistral-") { + return 3 + } return 1 } diff --git a/controller/model.go b/controller/model.go index 0f33f919..0d0d2658 100644 --- a/controller/model.go +++ b/controller/model.go @@ -6,6 +6,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/minimax" + "github.com/songquanpeng/one-api/relay/channel/mistral" "github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" @@ -122,6 +123,17 @@ func init() { Parent: nil, }) } + for _, modelName := range mistral.ModelList { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "mistralai", + Permission: permission, + Root: modelName, + Parent: nil, + }) + } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { openAIModelsMap[model.Id] = model diff --git a/relay/channel/mistral/constants.go b/relay/channel/mistral/constants.go new file mode 100644 index 00000000..cdb157f5 --- /dev/null +++ b/relay/channel/mistral/constants.go @@ -0,0 +1,10 @@ +package mistral + +var ModelList = []string{ + "open-mistral-7b", + "open-mixtral-8x7b", + "mistral-small-latest", + "mistral-medium-latest", + "mistral-large-latest", + "mistral-embed", +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 27d0fc27..5a04a768 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -9,6 +9,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/minimax" + "github.com/songquanpeng/one-api/relay/channel/mistral" "github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" @@ -94,6 +95,8 @@ func (a *Adaptor) GetModelList() []string { return baichuan.ModelList case common.ChannelTypeMinimax: return minimax.ModelList + case common.ChannelTypeMistral: + return mistral.ModelList default: return ModelList } @@ -111,6 +114,8 @@ func (a *Adaptor) GetChannelName() string { return "baichuan" case common.ChannelTypeMinimax: return "minimax" + case common.ChannelTypeMistral: + return "mistralai" default: return "openai" } diff --git a/relay/controller/helper.go b/relay/controller/helper.go index d5078304..89fc69ce 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -172,10 +172,8 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R if err != nil { logger.Error(ctx, "error update user quota cache: "+err.Error()) } - if quota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) - model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) - model.UpdateChannelUsedQuota(meta.ChannelId, quota) - } + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) + model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) + model.UpdateChannelUsedQuota(meta.ChannelId, quota) } diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index 98ceaebf..31c45048 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -29,6 +29,12 @@ export const CHANNEL_OPTIONS = { value: 24, color: 'orange' }, + 28: { + key: 28, + text: 'Mistral AI', + value: 28, + color: 'orange' + }, 15: { key: 15, text: '百度文心千帆', diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index beb0adb1..b21bb15d 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -4,6 +4,7 @@ export const CHANNEL_OPTIONS = [ { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, + { key: 28, text: 'Mistral AI', value: 28, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },