✨ feat: Support Cohere
This commit is contained in:
parent
7e206c1e7e
commit
95854f5912
@ -173,6 +173,7 @@ const (
|
||||
ChannelTypeLingyi = 33
|
||||
ChannelTypeMidjourney = 34
|
||||
ChannelTypeCloudflareAI = 35
|
||||
ChannelTypeCohere = 36
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
@ -212,6 +213,7 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.lingyiwanwu.com", //33
|
||||
"", //34
|
||||
"", //35
|
||||
"https://api.cohere.ai/v1", //36
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -295,6 +295,10 @@ func GetDefaultPrice() []*Price {
|
||||
"@hf/google/gemma-7b-it": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
|
||||
"@hf/thebloke/llama-2-13b-chat-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
|
||||
"@cf/openai/whisper": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
|
||||
//$0.50 /1M TOKENS $1.50/1M TOKENS
|
||||
"command-r": {[]float64{0.25, 0.75}, common.ChannelTypeCohere},
|
||||
//$3 /1M TOKENS $15/1M TOKENS
|
||||
"command-r-plus": {[]float64{1.5, 7.5}, common.ChannelTypeCohere},
|
||||
}
|
||||
|
||||
var prices []*Price
|
||||
|
84
providers/cohere/base.go
Normal file
84
providers/cohere/base.go
Normal file
@ -0,0 +1,84 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type CohereProviderFactory struct{}
|
||||
|
||||
// 创建 CohereProvider
|
||||
func (f CohereProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &CohereProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type CohereProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://api.cohere.ai/v1",
|
||||
ChatCompletions: "/chat",
|
||||
}
|
||||
}
|
||||
|
||||
// 请求错误处理
|
||||
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
CohereError := &CohereError{}
|
||||
err := json.NewDecoder(resp.Body).Decode(CohereError)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorHandle(CohereError)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func errorHandle(CohereError *CohereError) *types.OpenAIError {
|
||||
if CohereError.Message == "" {
|
||||
return nil
|
||||
}
|
||||
return &types.OpenAIError{
|
||||
Message: CohereError.Message,
|
||||
Type: "Cohere error",
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *CohereProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key)
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
func (p *CohereProvider) GetFullRequestURL(requestURL string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
}
|
||||
|
||||
func convertRole(role string) string {
|
||||
switch role {
|
||||
case types.ChatMessageRoleSystem:
|
||||
return "SYSTEM"
|
||||
case types.ChatMessageRoleAssistant:
|
||||
return "CHATBOT"
|
||||
default:
|
||||
return "USER"
|
||||
}
|
||||
}
|
208
providers/cohere/chat.go
Normal file
208
providers/cohere/chat.go
Normal file
@ -0,0 +1,208 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
|
||||
)
|
||||
|
||||
type CohereStreamHandler struct {
|
||||
Usage *types.Usage
|
||||
Request *types.ChatCompletionRequest
|
||||
}
|
||||
|
||||
func (p *CohereProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
cohereResponse := &CohereResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, cohereResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return ConvertToChatOpenai(p, cohereResponse, request)
|
||||
}
|
||||
|
||||
func (p *CohereProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := &CohereStreamHandler{
|
||||
Usage: p.Usage,
|
||||
Request: request,
|
||||
}
|
||||
|
||||
eventstream.NewDecoder()
|
||||
|
||||
return requester.RequestStream(p.Requester, resp, chatHandler.HandlerStream)
|
||||
}
|
||||
|
||||
func (p *CohereProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url)
|
||||
if fullRequestURL == "" {
|
||||
return nil, common.ErrorWrapper(nil, "invalid_cohere_config", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
cohereRequest, errWithCode := ConvertFromChatOpenai(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(cohereRequest), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func ConvertFromChatOpenai(request *types.ChatCompletionRequest) (*CohereRequest, *types.OpenAIErrorWithStatusCode) {
|
||||
cohereRequest := CohereRequest{
|
||||
Model: request.Model,
|
||||
MaxTokens: request.MaxTokens,
|
||||
Temperature: request.Temperature,
|
||||
Stream: request.Stream,
|
||||
P: request.TopP,
|
||||
K: request.N,
|
||||
Seed: request.Seed,
|
||||
StopSequences: request.Stop,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
}
|
||||
|
||||
msgLen := len(request.Messages) - 1
|
||||
|
||||
for index, message := range request.Messages {
|
||||
if index == msgLen {
|
||||
cohereRequest.Message = message.StringContent()
|
||||
} else {
|
||||
cohereRequest.ChatHistory = append(cohereRequest.ChatHistory, ChatHistory{
|
||||
Role: convertRole(message.Role),
|
||||
Message: message.StringContent(),
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return &cohereRequest, nil
|
||||
}
|
||||
|
||||
func ConvertToChatOpenai(provider base.ProviderInterface, response *CohereResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
error := errorHandle(&response.CohereError)
|
||||
if error != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *error,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: types.ChatMessageRoleAssistant,
|
||||
Content: response.Text,
|
||||
},
|
||||
FinishReason: types.FinishReasonStop,
|
||||
}
|
||||
openaiResponse = &types.ChatCompletionResponse{
|
||||
ID: response.GenerationID,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []types.ChatCompletionChoice{choice},
|
||||
Model: request.Model,
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: response.Meta.BilledUnits.InputTokens,
|
||||
},
|
||||
}
|
||||
|
||||
openaiResponse.Usage.CompletionTokens = response.Meta.BilledUnits.OutputTokens + response.Meta.Tokens.SearchUnits + response.Meta.Tokens.Classifications
|
||||
openaiResponse.Usage.TotalTokens = openaiResponse.Usage.PromptTokens + openaiResponse.Usage.CompletionTokens
|
||||
|
||||
usage := provider.GetUsage()
|
||||
*usage = *openaiResponse.Usage
|
||||
|
||||
return openaiResponse, nil
|
||||
}
|
||||
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *CohereStreamHandler) HandlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "{") {
|
||||
*rawLine = nil
|
||||
return
|
||||
}
|
||||
|
||||
var cohereResponse CohereStreamResponse
|
||||
err := json.Unmarshal(*rawLine, &cohereResponse)
|
||||
if err != nil {
|
||||
errChan <- common.ErrorToOpenAIError(err)
|
||||
return
|
||||
}
|
||||
|
||||
if cohereResponse.EventType != "text-generation" && cohereResponse.EventType != "stream-end" {
|
||||
*rawLine = nil
|
||||
return
|
||||
}
|
||||
|
||||
h.convertToOpenaiStream(&cohereResponse, dataChan)
|
||||
}
|
||||
|
||||
func (h *CohereStreamHandler) convertToOpenaiStream(cohereResponse *CohereStreamResponse, dataChan chan string) {
|
||||
choice := types.ChatCompletionStreamChoice{
|
||||
Index: 0,
|
||||
}
|
||||
|
||||
if cohereResponse.EventType == "stream-end" {
|
||||
choice.FinishReason = types.FinishReasonStop
|
||||
} else {
|
||||
choice.Delta = types.ChatCompletionStreamChoiceDelta{
|
||||
Role: types.ChatMessageRoleAssistant,
|
||||
Content: cohereResponse.Text,
|
||||
}
|
||||
}
|
||||
|
||||
chatCompletion := types.ChatCompletionStreamResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: h.Request.Model,
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
|
||||
responseBody, _ := json.Marshal(chatCompletion)
|
||||
dataChan <- string(responseBody)
|
||||
}
|
85
providers/cohere/type.go
Normal file
85
providers/cohere/type.go
Normal file
@ -0,0 +1,85 @@
|
||||
package cohere
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
type ChatHistory struct {
|
||||
Role string `json:"role"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type CohereConnector struct {
|
||||
ID string `json:"id"`
|
||||
UserAccessToken string `json:"user_access_token,omitempty"`
|
||||
ContinueOnFailure bool `json:"continue_on_failure,omitempty"`
|
||||
Options any `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type CohereRequest struct {
|
||||
Message string `json:"message"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Preamble string `json:"preamble,omitempty"`
|
||||
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
|
||||
ConversationId string `json:"conversation_id,omitempty"`
|
||||
PromptTruncation string `json:"prompt_truncation,omitempty"`
|
||||
Connectors []CohereConnector `json:"connectors,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxInputTokens int `json:"max_input_tokens,omitempty"`
|
||||
K int `json:"k,omitempty"`
|
||||
P float64 `json:"p,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
StopSequences any `json:"stop_sequences,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
Tools []*types.ChatCompletionFunction `json:"tools,omitempty"`
|
||||
ToolResults any `json:"tool_results,omitempty"`
|
||||
// SearchQueriesOnly bool `json:"search_queries_only,omitempty"`
|
||||
}
|
||||
|
||||
type APIVersion struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type Tokens struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
SearchUnits int `json:"search_units,omitempty"`
|
||||
Classifications int `json:"classifications,omitempty"`
|
||||
}
|
||||
|
||||
type Meta struct {
|
||||
APIVersion APIVersion `json:"api_version"`
|
||||
BilledUnits Tokens `json:"billed_units"`
|
||||
Tokens Tokens `json:"tokens"`
|
||||
}
|
||||
|
||||
type CohereToolCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type CohereResponse struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
ResponseID string `json:"response_id,omitempty"`
|
||||
GenerationID string `json:"generation_id,omitempty"`
|
||||
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
|
||||
Meta Meta `json:"meta,omitempty"`
|
||||
CohereError
|
||||
}
|
||||
|
||||
type CohereError struct {
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type CohereStreamResponse struct {
|
||||
IsFinished bool `json:"is_finished"`
|
||||
EventType string `json:"event_type"`
|
||||
GenerationID string `json:"generation_id,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Response CohereResponse `json:"response,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
@ -12,6 +12,7 @@ import (
|
||||
"one-api/providers/bedrock"
|
||||
"one-api/providers/claude"
|
||||
"one-api/providers/cloudflareAI"
|
||||
"one-api/providers/cohere"
|
||||
"one-api/providers/deepseek"
|
||||
"one-api/providers/gemini"
|
||||
"one-api/providers/groq"
|
||||
@ -56,6 +57,7 @@ func init() {
|
||||
providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
|
||||
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
|
||||
providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
|
||||
providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{}
|
||||
|
||||
}
|
||||
|
||||
|
@ -26,5 +26,6 @@ func init() {
|
||||
common.ChannelTypeLingyi: "Lingyiwanwu",
|
||||
common.ChannelTypeMidjourney: "Midjourney",
|
||||
common.ChannelTypeCloudflareAI: "Cloudflare AI",
|
||||
common.ChannelTypeCohere: "Cohere",
|
||||
}
|
||||
}
|
||||
|
@ -146,6 +146,13 @@ export const CHANNEL_OPTIONS = {
|
||||
color: 'orange',
|
||||
url: ''
|
||||
},
|
||||
36: {
|
||||
key: 36,
|
||||
text: 'Cohere',
|
||||
value: 36,
|
||||
color: 'default',
|
||||
url: ''
|
||||
},
|
||||
24: {
|
||||
key: 24,
|
||||
text: 'Azure Speech',
|
||||
|
@ -283,6 +283,13 @@ const typeConfig = {
|
||||
base_url: ''
|
||||
},
|
||||
modelGroup: 'Cloudflare AI'
|
||||
},
|
||||
36: {
|
||||
input: {
|
||||
models: ['command-r', 'command-r-plus'],
|
||||
test_model: 'command-r'
|
||||
},
|
||||
modelGroup: 'Cohere'
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user