From e64e7707a0193f764eefbcd498069c80c8f3a51f Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 27 Apr 2024 00:06:43 +0800 Subject: [PATCH] feat: support cohere's web search --- relay/adaptor/cohere/constant.go | 7 +++++++ relay/adaptor/cohere/main.go | 8 ++++++++ relay/billing/ratio/model.go | 3 +++ 3 files changed, 18 insertions(+) diff --git a/relay/adaptor/cohere/constant.go b/relay/adaptor/cohere/constant.go index 3ff4d655..9e70652c 100644 --- a/relay/adaptor/cohere/constant.go +++ b/relay/adaptor/cohere/constant.go @@ -5,3 +5,10 @@ var ModelList = []string{ "command-light", "command-light-nightly", "command-r", "command-r-plus", } + +func init() { + num := len(ModelList) + for i := 0; i < num; i++ { + ModelList = append(ModelList, ModelList[i]+"-internet") + } +} diff --git a/relay/adaptor/cohere/main.go b/relay/adaptor/cohere/main.go index 81277b07..4bc3fa8d 100644 --- a/relay/adaptor/cohere/main.go +++ b/relay/adaptor/cohere/main.go @@ -17,6 +17,10 @@ import ( "github.com/songquanpeng/one-api/relay/model" ) +var ( + WebSearchConnector = Connector{ID: "web-search"} +) + func stopReasonCohere2OpenAI(reason *string) string { if reason == nil { return "" @@ -45,6 +49,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { if cohereRequest.Model == "" { cohereRequest.Model = "command-r" } + if strings.HasSuffix(cohereRequest.Model, "-internet") { + cohereRequest.Model = strings.TrimSuffix(cohereRequest.Model, "-internet") + cohereRequest.Connectors = append(cohereRequest.Connectors, WebSearchConnector) + } for _, message := range textRequest.Messages { if message.Role == "user" { cohereRequest.Message = message.Content.(string) diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index c6fdf4b4..f6cc233a 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -228,6 +228,9 @@ func GetModelRatio(name string) float64 { if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } + if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { + name = strings.TrimSuffix(name, "-internet") + } ratio, ok := ModelRatio[name] if !ok { ratio, ok = DefaultModelRatio[name]