feat: add cohere support (#1355)

* support cohere

* chore: tiny improvements

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
This commit is contained in:
Ghostz 2024-04-24 21:50:01 +08:00 committed by GitHub
parent cb33e8aad5
commit 24f026d18e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 476 additions and 1 deletions

View File

@ -83,6 +83,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [零一万物](https://platform.lingyiwanwu.com/)
+ [x] [阶跃星辰](https://platform.stepfun.com/)
+ [x] [Coze](https://www.coze.com/)
+ [x] [Cohere](https://cohere.com/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。

View File

@ -7,6 +7,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws"
"github.com/songquanpeng/one-api/relay/adaptor/baidu"
"github.com/songquanpeng/one-api/relay/adaptor/cohere"
"github.com/songquanpeng/one-api/relay/adaptor/coze"
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
"github.com/songquanpeng/one-api/relay/adaptor/ollama"
@ -46,6 +47,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
return &ollama.Adaptor{}
case apitype.Coze:
return &coze.Adaptor{}
case apitype.Cohere:
return &cohere.Adaptor{}
}
return nil
}

View File

@ -0,0 +1,64 @@
package cohere
import (
"errors"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
type Adaptor struct{}
// ConvertImageRequest implements adaptor.Adaptor.
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertImageRequest implements adaptor.Adaptor.
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v1/chat", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "Cohere"
}

View File

@ -0,0 +1,7 @@
package cohere
var ModelList = []string{
"command", "command-nightly",
"command-light", "command-light-nightly",
"command-r", "command-r-plus",
}

View File

@ -0,0 +1,233 @@
package cohere
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
)
func stopReasonCohere2OpenAI(reason *string) string {
if reason == nil {
return ""
}
switch *reason {
case "COMPLETE":
return "stop"
default:
return *reason
}
}
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
cohereRequest := Request{
Model: textRequest.Model,
Message: "",
MaxTokens: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
P: textRequest.TopP,
K: textRequest.TopK,
Stream: textRequest.Stream,
FrequencyPenalty: textRequest.FrequencyPenalty,
PresencePenalty: textRequest.FrequencyPenalty,
Seed: int(textRequest.Seed),
}
if cohereRequest.Model == "" {
cohereRequest.Model = "command-r"
}
for _, message := range textRequest.Messages {
if message.Role == "user" {
cohereRequest.Message = message.Content.(string)
} else {
var role string
if message.Role == "assistant" {
role = "CHATBOT"
} else if message.Role == "system" {
role = "SYSTEM"
} else {
role = "USER"
}
cohereRequest.ChatHistory = append(cohereRequest.ChatHistory, ChatMessage{
Role: role,
Message: message.Content.(string),
})
}
}
return &cohereRequest
}
func StreamResponseCohere2OpenAI(cohereResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response
var responseText string
var finishReason string
switch cohereResponse.EventType {
case "stream-start":
return nil, nil
case "text-generation":
responseText += cohereResponse.Text
case "stream-end":
usage := cohereResponse.Response.Meta.Tokens
response = &Response{
Meta: Meta{
Tokens: Usage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
},
},
}
finishReason = *cohereResponse.Response.FinishReason
default:
return nil, nil
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText
choice.Delta.Role = "assistant"
if finishReason != "" {
choice.FinishReason = &finishReason
}
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &openaiResponse, response
}
func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: cohereResponse.Text,
Name: nil,
},
FinishReason: stopReasonCohere2OpenAI(cohereResponse.FinishReason),
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", cohereResponse.ResponseID),
Model: "model",
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
var usage model.Usage
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var cohereResponse StreamResponse
err := json.Unmarshal([]byte(data), &cohereResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
if meta != nil {
usage.PromptTokens += meta.Meta.Tokens.InputTokens
usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
return true
}
if response == nil {
return true
}
response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
response.Model = c.GetString("original_model")
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
_ = resp.Body.Close()
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cohereResponse Response
err = json.Unmarshal(responseBody, &cohereResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if cohereResponse.ResponseID == "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: cohereResponse.Message,
Type: cohereResponse.Message,
Param: "",
Code: resp.StatusCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := ResponseCohere2OpenAI(&cohereResponse)
fullTextResponse.Model = modelName
usage := model.Usage{
PromptTokens: cohereResponse.Meta.Tokens.InputTokens,
CompletionTokens: cohereResponse.Meta.Tokens.OutputTokens,
TotalTokens: cohereResponse.Meta.Tokens.InputTokens + cohereResponse.Meta.Tokens.OutputTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@ -0,0 +1,147 @@
package cohere
type Request struct {
Message string `json:"message" required:"true"`
Model string `json:"model,omitempty"` // 默认值为"command-r"
Stream bool `json:"stream,omitempty"` // 默认值为false
Preamble string `json:"preamble,omitempty"`
ChatHistory []ChatMessage `json:"chat_history,omitempty"`
ConversationID string `json:"conversation_id,omitempty"`
PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO"
Connectors []Connector `json:"connectors,omitempty"`
Documents []Document `json:"documents,omitempty"`
Temperature float64 `json:"temperature,omitempty"` // 默认值为0.3
MaxTokens int `json:"max_tokens,omitempty"`
MaxInputTokens int `json:"max_input_tokens,omitempty"`
K int `json:"k,omitempty"` // 默认值为0
P float64 `json:"p,omitempty"` // 默认值为0.75
Seed int `json:"seed,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0
PresencePenalty float64 `json:"presence_penalty,omitempty"` // 默认值为0.0
Tools []Tool `json:"tools,omitempty"`
ToolResults []ToolResult `json:"tool_results,omitempty"`
}
type ChatMessage struct {
Role string `json:"role" required:"true"`
Message string `json:"message" required:"true"`
}
type Tool struct {
Name string `json:"name" required:"true"`
Description string `json:"description" required:"true"`
ParameterDefinitions map[string]ParameterSpec `json:"parameter_definitions"`
}
type ParameterSpec struct {
Description string `json:"description"`
Type string `json:"type" required:"true"`
Required bool `json:"required"`
}
type ToolResult struct {
Call ToolCall `json:"call"`
Outputs []map[string]interface{} `json:"outputs"`
}
type ToolCall struct {
Name string `json:"name" required:"true"`
Parameters map[string]interface{} `json:"parameters" required:"true"`
}
type StreamResponse struct {
IsFinished bool `json:"is_finished"`
EventType string `json:"event_type"`
GenerationID string `json:"generation_id,omitempty"`
SearchQueries []*SearchQuery `json:"search_queries,omitempty"`
SearchResults []*SearchResult `json:"search_results,omitempty"`
Documents []*Document `json:"documents,omitempty"`
Text string `json:"text,omitempty"`
Citations []*Citation `json:"citations,omitempty"`
Response *Response `json:"response,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
}
type SearchQuery struct {
Text string `json:"text"`
GenerationID string `json:"generation_id"`
}
type SearchResult struct {
SearchQuery *SearchQuery `json:"search_query"`
DocumentIDs []string `json:"document_ids"`
Connector *Connector `json:"connector"`
}
type Connector struct {
ID string `json:"id"`
}
type Document struct {
ID string `json:"id"`
Snippet string `json:"snippet"`
Timestamp string `json:"timestamp"`
Title string `json:"title"`
URL string `json:"url"`
}
type Citation struct {
Start int `json:"start"`
End int `json:"end"`
Text string `json:"text"`
DocumentIDs []string `json:"document_ids"`
}
type Response struct {
ResponseID string `json:"response_id"`
Text string `json:"text"`
GenerationID string `json:"generation_id"`
ChatHistory []*Message `json:"chat_history"`
FinishReason *string `json:"finish_reason"`
Meta Meta `json:"meta"`
Citations []*Citation `json:"citations"`
Documents []*Document `json:"documents"`
SearchResults []*SearchResult `json:"search_results"`
SearchQueries []*SearchQuery `json:"search_queries"`
Message string `json:"message"`
}
type Message struct {
Role string `json:"role"`
Message string `json:"message"`
}
type Version struct {
Version string `json:"version"`
}
type Units struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type ChatEntry struct {
Role string `json:"role"`
Message string `json:"message"`
}
type Meta struct {
APIVersion APIVersion `json:"api_version"`
BilledUnits BilledUnits `json:"billed_units"`
Tokens Usage `json:"tokens"`
}
type APIVersion struct {
Version string `json:"version"`
}
type BilledUnits struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}

View File

@ -14,6 +14,7 @@ const (
Ollama
AwsClaude
Coze
Cohere
Dummy // this one is only for count, do not add any channel after this
)

View File

@ -2,8 +2,9 @@ package ratio
import (
"encoding/json"
"github.com/songquanpeng/one-api/common/logger"
"strings"
"github.com/songquanpeng/one-api/common/logger"
)
const (
@ -162,6 +163,13 @@ var ModelRatio = map[string]float64{
"step-1v-32k": 0.024 * RMB,
"step-1-32k": 0.024 * RMB,
"step-1-200k": 0.15 * RMB,
// https://cohere.com/pricing
"command": 0.5,
"command-nightly": 0.5,
"command-light": 0.5,
"command-light-nightly": 0.5,
"command-r": 0.5 / 1000 * USD,
"command-r-plus ": 3.0 / 1000 * USD,
}
var CompletionRatio = map[string]float64{}
@ -284,6 +292,12 @@ func GetCompletionRatio(name string) float64 {
return 2
case "llama3-70b-8192":
return 0.79 / 0.59
case "command", "command-light", "command-nightly", "command-light-nightly":
return 2
case "command-r":
return 3
case "command-r-plus":
return 5
}
return 1
}

View File

@ -36,6 +36,7 @@ const (
StepFun
AwsClaude
Coze
Cohere
Dummy
)

View File

@ -29,6 +29,8 @@ func ToAPIType(channelType int) int {
apiType = apitype.AwsClaude
case Coze:
apiType = apitype.Coze
case Cohere:
apiType = apitype.Cohere
}
return apiType

View File

@ -36,6 +36,7 @@ var ChannelBaseURLs = []string{
"https://api.stepfun.com", // 32
"", // 33
"https://api.coze.com", // 34
"https://api.cohere.ai", //35
}
func init() {

View File

@ -20,6 +20,7 @@ export const CHANNEL_OPTIONS = [
{ key: 31, text: '零一万物', value: 31, color: 'green' },
{ key: 32, text: '阶跃星辰', value: 32, color: 'blue' },
{ key: 34, text: 'Coze', value: 34, color: 'blue' },
{ key: 35, text: 'Cohere', value: 35, color: 'blue' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },