diff --git a/README.md b/README.md index 0ab35893..2f81c10d 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [零一万物](https://platform.lingyiwanwu.com/) + [x] [阶跃星辰](https://platform.stepfun.com/) + [x] [Coze](https://www.coze.com/) + + [x] [Cohere](https://cohere.com/) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 diff --git a/relay/adaptor.go b/relay/adaptor.go index 24db9e89..293b6d79 100644 --- a/relay/adaptor.go +++ b/relay/adaptor.go @@ -7,6 +7,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/anthropic" "github.com/songquanpeng/one-api/relay/adaptor/aws" "github.com/songquanpeng/one-api/relay/adaptor/baidu" + "github.com/songquanpeng/one-api/relay/adaptor/cohere" "github.com/songquanpeng/one-api/relay/adaptor/coze" "github.com/songquanpeng/one-api/relay/adaptor/gemini" "github.com/songquanpeng/one-api/relay/adaptor/ollama" @@ -46,6 +47,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor { return &ollama.Adaptor{} case apitype.Coze: return &coze.Adaptor{} + case apitype.Cohere: + return &cohere.Adaptor{} } return nil } diff --git a/relay/adaptor/cohere/adaptor.go b/relay/adaptor/cohere/adaptor.go new file mode 100644 index 00000000..6fdb1b04 --- /dev/null +++ b/relay/adaptor/cohere/adaptor.go @@ -0,0 +1,64 @@ +package cohere + +import ( + "errors" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +type Adaptor struct{} + +// ConvertImageRequest implements adaptor.Adaptor. +func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertImageRequest implements adaptor.Adaptor. + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("%s/v1/chat", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return ConvertRequest(*request), nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "Cohere" +} diff --git a/relay/adaptor/cohere/constant.go b/relay/adaptor/cohere/constant.go new file mode 100644 index 00000000..3ff4d655 --- /dev/null +++ b/relay/adaptor/cohere/constant.go @@ -0,0 +1,7 @@ +package cohere + +var ModelList = []string{ + "command", "command-nightly", + "command-light", "command-light-nightly", + "command-r", "command-r-plus", +} diff --git a/relay/adaptor/cohere/main.go b/relay/adaptor/cohere/main.go new file mode 100644 index 00000000..81277b07 --- /dev/null +++ b/relay/adaptor/cohere/main.go @@ -0,0 +1,233 @@ +package cohere + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +func stopReasonCohere2OpenAI(reason *string) string { + if reason == nil { + return "" + } + switch *reason { + case "COMPLETE": + return "stop" + default: + return *reason + } +} + +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { + cohereRequest := Request{ + Model: textRequest.Model, + Message: "", + MaxTokens: textRequest.MaxTokens, + Temperature: textRequest.Temperature, + P: textRequest.TopP, + K: textRequest.TopK, + Stream: textRequest.Stream, + FrequencyPenalty: textRequest.FrequencyPenalty, + PresencePenalty: textRequest.FrequencyPenalty, + Seed: int(textRequest.Seed), + } + if cohereRequest.Model == "" { + cohereRequest.Model = "command-r" + } + for _, message := range textRequest.Messages { + if message.Role == "user" { + cohereRequest.Message = message.Content.(string) + } else { + var role string + if message.Role == "assistant" { + role = "CHATBOT" + } else if message.Role == "system" { + role = "SYSTEM" + } else { + role = "USER" + } + cohereRequest.ChatHistory = append(cohereRequest.ChatHistory, ChatMessage{ + Role: role, + Message: message.Content.(string), + }) + } + } + return &cohereRequest +} + +func StreamResponseCohere2OpenAI(cohereResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { + var response *Response + var responseText string + var finishReason string + + switch cohereResponse.EventType { + case "stream-start": + return nil, nil + case "text-generation": + responseText += cohereResponse.Text + case "stream-end": + usage := cohereResponse.Response.Meta.Tokens + response = &Response{ + Meta: Meta{ + Tokens: Usage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + }, + }, + } + finishReason = *cohereResponse.Response.FinishReason + default: + return nil, nil + } + + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = responseText + choice.Delta.Role = "assistant" + if finishReason != "" { + choice.FinishReason = &finishReason + } + var openaiResponse openai.ChatCompletionsStreamResponse + openaiResponse.Object = "chat.completion.chunk" + openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + return &openaiResponse, response +} + +func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse { + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: cohereResponse.Text, + Name: nil, + }, + FinishReason: stopReasonCohere2OpenAI(cohereResponse.FinishReason), + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", cohereResponse.ResponseID), + Model: "model", + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + createdTime := helper.GetTimestamp() + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, '\n'); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + dataChan <- data + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + var usage model.Usage + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + var cohereResponse StreamResponse + err := json.Unmarshal([]byte(data), &cohereResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + response, meta := StreamResponseCohere2OpenAI(&cohereResponse) + if meta != nil { + usage.PromptTokens += meta.Meta.Tokens.InputTokens + usage.CompletionTokens += meta.Meta.Tokens.OutputTokens + return true + } + if response == nil { + return true + } + response.Id = fmt.Sprintf("chatcmpl-%d", createdTime) + response.Model = c.GetString("original_model") + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + _ = resp.Body.Close() + return nil, &usage +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var cohereResponse Response + err = json.Unmarshal(responseBody, &cohereResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if cohereResponse.ResponseID == "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: cohereResponse.Message, + Type: cohereResponse.Message, + Param: "", + Code: resp.StatusCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := ResponseCohere2OpenAI(&cohereResponse) + fullTextResponse.Model = modelName + usage := model.Usage{ + PromptTokens: cohereResponse.Meta.Tokens.InputTokens, + CompletionTokens: cohereResponse.Meta.Tokens.OutputTokens, + TotalTokens: cohereResponse.Meta.Tokens.InputTokens + cohereResponse.Meta.Tokens.OutputTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/relay/adaptor/cohere/model.go b/relay/adaptor/cohere/model.go new file mode 100644 index 00000000..64fa9c94 --- /dev/null +++ b/relay/adaptor/cohere/model.go @@ -0,0 +1,147 @@ +package cohere + +type Request struct { + Message string `json:"message" required:"true"` + Model string `json:"model,omitempty"` // 默认值为"command-r" + Stream bool `json:"stream,omitempty"` // 默认值为false + Preamble string `json:"preamble,omitempty"` + ChatHistory []ChatMessage `json:"chat_history,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO" + Connectors []Connector `json:"connectors,omitempty"` + Documents []Document `json:"documents,omitempty"` + Temperature float64 `json:"temperature,omitempty"` // 默认值为0.3 + MaxTokens int `json:"max_tokens,omitempty"` + MaxInputTokens int `json:"max_input_tokens,omitempty"` + K int `json:"k,omitempty"` // 默认值为0 + P float64 `json:"p,omitempty"` // 默认值为0.75 + Seed int `json:"seed,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0 + PresencePenalty float64 `json:"presence_penalty,omitempty"` // 默认值为0.0 + Tools []Tool `json:"tools,omitempty"` + ToolResults []ToolResult `json:"tool_results,omitempty"` +} + +type ChatMessage struct { + Role string `json:"role" required:"true"` + Message string `json:"message" required:"true"` +} + +type Tool struct { + Name string `json:"name" required:"true"` + Description string `json:"description" required:"true"` + ParameterDefinitions map[string]ParameterSpec `json:"parameter_definitions"` +} + +type ParameterSpec struct { + Description string `json:"description"` + Type string `json:"type" required:"true"` + Required bool `json:"required"` +} + +type ToolResult struct { + Call ToolCall `json:"call"` + Outputs []map[string]interface{} `json:"outputs"` +} + +type ToolCall struct { + Name string `json:"name" required:"true"` + Parameters map[string]interface{} `json:"parameters" required:"true"` +} + +type StreamResponse struct { + IsFinished bool `json:"is_finished"` + EventType string `json:"event_type"` + GenerationID string `json:"generation_id,omitempty"` + SearchQueries []*SearchQuery `json:"search_queries,omitempty"` + SearchResults []*SearchResult `json:"search_results,omitempty"` + Documents []*Document `json:"documents,omitempty"` + Text string `json:"text,omitempty"` + Citations []*Citation `json:"citations,omitempty"` + Response *Response `json:"response,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` +} + +type SearchQuery struct { + Text string `json:"text"` + GenerationID string `json:"generation_id"` +} + +type SearchResult struct { + SearchQuery *SearchQuery `json:"search_query"` + DocumentIDs []string `json:"document_ids"` + Connector *Connector `json:"connector"` +} + +type Connector struct { + ID string `json:"id"` +} + +type Document struct { + ID string `json:"id"` + Snippet string `json:"snippet"` + Timestamp string `json:"timestamp"` + Title string `json:"title"` + URL string `json:"url"` +} + +type Citation struct { + Start int `json:"start"` + End int `json:"end"` + Text string `json:"text"` + DocumentIDs []string `json:"document_ids"` +} + +type Response struct { + ResponseID string `json:"response_id"` + Text string `json:"text"` + GenerationID string `json:"generation_id"` + ChatHistory []*Message `json:"chat_history"` + FinishReason *string `json:"finish_reason"` + Meta Meta `json:"meta"` + Citations []*Citation `json:"citations"` + Documents []*Document `json:"documents"` + SearchResults []*SearchResult `json:"search_results"` + SearchQueries []*SearchQuery `json:"search_queries"` + Message string `json:"message"` +} + +type Message struct { + Role string `json:"role"` + Message string `json:"message"` +} + +type Version struct { + Version string `json:"version"` +} + +type Units struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type ChatEntry struct { + Role string `json:"role"` + Message string `json:"message"` +} + +type Meta struct { + APIVersion APIVersion `json:"api_version"` + BilledUnits BilledUnits `json:"billed_units"` + Tokens Usage `json:"tokens"` +} + +type APIVersion struct { + Version string `json:"version"` +} + +type BilledUnits struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} diff --git a/relay/apitype/define.go b/relay/apitype/define.go index a3f2b98c..a1c8e6e1 100644 --- a/relay/apitype/define.go +++ b/relay/apitype/define.go @@ -14,6 +14,7 @@ const ( Ollama AwsClaude Coze + Cohere Dummy // this one is only for count, do not add any channel after this ) diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index b410df94..923d9c4f 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -2,8 +2,9 @@ package ratio import ( "encoding/json" - "github.com/songquanpeng/one-api/common/logger" "strings" + + "github.com/songquanpeng/one-api/common/logger" ) const ( @@ -162,6 +163,13 @@ var ModelRatio = map[string]float64{ "step-1v-32k": 0.024 * RMB, "step-1-32k": 0.024 * RMB, "step-1-200k": 0.15 * RMB, + // https://cohere.com/pricing + "command": 0.5, + "command-nightly": 0.5, + "command-light": 0.5, + "command-light-nightly": 0.5, + "command-r": 0.5 / 1000 * USD, + "command-r-plus ": 3.0 / 1000 * USD, } var CompletionRatio = map[string]float64{} @@ -284,6 +292,12 @@ func GetCompletionRatio(name string) float64 { return 2 case "llama3-70b-8192": return 0.79 / 0.59 + case "command", "command-light", "command-nightly", "command-light-nightly": + return 2 + case "command-r": + return 3 + case "command-r-plus": + return 5 } return 1 } diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go index 6975e492..4b37e566 100644 --- a/relay/channeltype/define.go +++ b/relay/channeltype/define.go @@ -36,6 +36,7 @@ const ( StepFun AwsClaude Coze + Cohere Dummy ) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go index d249e208..42b77891 100644 --- a/relay/channeltype/helper.go +++ b/relay/channeltype/helper.go @@ -29,6 +29,8 @@ func ToAPIType(channelType int) int { apiType = apitype.AwsClaude case Coze: apiType = apitype.Coze + case Cohere: + apiType = apitype.Cohere } return apiType diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go index 1f15dfe3..64fdcd0a 100644 --- a/relay/channeltype/url.go +++ b/relay/channeltype/url.go @@ -36,6 +36,7 @@ var ChannelBaseURLs = []string{ "https://api.stepfun.com", // 32 "", // 33 "https://api.coze.com", // 34 + "https://api.cohere.ai", //35 } func init() { diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index 0c1f4822..c21e19ed 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -20,6 +20,7 @@ export const CHANNEL_OPTIONS = [ { key: 31, text: '零一万物', value: 31, color: 'green' }, { key: 32, text: '阶跃星辰', value: 32, color: 'blue' }, { key: 34, text: 'Coze', value: 34, color: 'blue' }, + { key: 35, text: 'Cohere', value: 35, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },