feat: Support Claude3 (#86)

* Claude改用messages API,支持Claude3

* 删除新API不支持的模型

* 忘了改请求地址

* 🐛 fix: fix the problem of return format and completion token not being obtained

---------

Co-authored-by: MartialBE <me@xiao5.info>
This commit is contained in:
moondie 2024-03-07 01:21:07 +08:00 committed by GitHub
parent 07c18dfb91
commit e1fcfae928
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 131 additions and 75 deletions

View File

@ -87,11 +87,12 @@ func init() {
"dall-e-3": {[]float64{20, 20}, ChannelTypeOpenAI}, "dall-e-3": {[]float64{20, 20}, ChannelTypeOpenAI},
// $0.80/million tokens $2.40/million tokens // $0.80/million tokens $2.40/million tokens
"claude-instant-1": {[]float64{0.4, 1.2}, ChannelTypeAnthropic}, "claude-instant-1.2": {[]float64{0.4, 1.2}, ChannelTypeAnthropic},
// $8.00/million tokens $24.00/million tokens // $8.00/million tokens $24.00/million tokens
"claude-2": {[]float64{4, 12}, ChannelTypeAnthropic}, "claude-2.0": {[]float64{4, 12}, ChannelTypeAnthropic},
"claude-2.0": {[]float64{4, 12}, ChannelTypeAnthropic}, "claude-2.1": {[]float64{4, 12}, ChannelTypeAnthropic},
"claude-2.1": {[]float64{4, 12}, ChannelTypeAnthropic}, "claude-3-opus-20240229": {[]float64{7.5, 22.5}, ChannelTypeAnthropic},
"claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, ChannelTypeAnthropic},
// ¥0.012 / 1k tokens ¥0.012 / 1k tokens // ¥0.012 / 1k tokens ¥0.012 / 1k tokens
"ERNIE-Bot": {[]float64{0.8572, 0.8572}, ChannelTypeBaidu}, "ERNIE-Bot": {[]float64{0.8572, 0.8572}, ChannelTypeBaidu},
@ -291,7 +292,7 @@ func GetCompletionRatio(name string) float64 {
} }
return 2 return 2
} }
if strings.HasPrefix(name, "claude-instant-1") { if strings.HasPrefix(name, "claude-instant-1.2") {
return 3.38 return 3.38
} }
if strings.HasPrefix(name, "claude-2") { if strings.HasPrefix(name, "claude-2") {

View File

@ -43,10 +43,11 @@
"text-moderation-latest": [0.1, 0.1], "text-moderation-latest": [0.1, 0.1],
"dall-e-2": [8, 8], "dall-e-2": [8, 8],
"dall-e-3": [20, 20], "dall-e-3": [20, 20],
"claude-instant-1": [0.4, 1.2], "claude-instant-1.2": [0.4, 1.2],
"claude-2": [4, 12],
"claude-2.0": [4, 12], "claude-2.0": [4, 12],
"claude-2.1": [4, 12], "claude-2.1": [4, 12],
"claude-3-opus-20240229": [7.5, 22.5],
"claude-3-sonnet-20240229": [1.3, 3.9],
"ERNIE-Bot": [0.8572, 0.8572], "ERNIE-Bot": [0.8572, 0.8572],
"ERNIE-Bot-8k": [1.7143, 3.4286], "ERNIE-Bot-8k": [1.7143, 3.4286],
"ERNIE-Bot-turbo": [0.5715, 0.5715], "ERNIE-Bot-turbo": [0.5715, 0.5715],

View File

@ -29,13 +29,13 @@ type ClaudeProvider struct {
func getConfig() base.ProviderConfig { func getConfig() base.ProviderConfig {
return base.ProviderConfig{ return base.ProviderConfig{
BaseURL: "https://api.anthropic.com", BaseURL: "https://api.anthropic.com",
ChatCompletions: "/v1/complete", ChatCompletions: "/v1/messages",
} }
} }
// 请求错误处理 // 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError { func requestErrorHandle(resp *http.Response) *types.OpenAIError {
claudeError := &ClaudeResponseError{} claudeError := &ClaudeError{}
err := json.NewDecoder(resp.Body).Decode(claudeError) err := json.NewDecoder(resp.Body).Decode(claudeError)
if err != nil { if err != nil {
return nil return nil
@ -45,14 +45,14 @@ func requestErrorHandle(resp *http.Response) *types.OpenAIError {
} }
// 错误处理 // 错误处理
func errorHandle(claudeError *ClaudeResponseError) *types.OpenAIError { func errorHandle(claudeError *ClaudeError) *types.OpenAIError {
if claudeError.Error.Type == "" { if claudeError.Type == "" {
return nil return nil
} }
return &types.OpenAIError{ return &types.OpenAIError{
Message: claudeError.Error.Message, Message: claudeError.Message,
Type: claudeError.Error.Type, Type: claudeError.Type,
Code: claudeError.Error.Type, Code: claudeError.Type,
} }
} }
@ -73,7 +73,7 @@ func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
func stopReasonClaude2OpenAI(reason string) string { func stopReasonClaude2OpenAI(reason string) string {
switch reason { switch reason {
case "stop_sequence": case "end_turn":
return types.FinishReasonStop return types.FinishReasonStop
case "max_tokens": case "max_tokens":
return types.FinishReasonLength return types.FinishReasonLength

View File

@ -83,36 +83,36 @@ func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*
func convertFromChatOpenai(request *types.ChatCompletionRequest) *ClaudeRequest { func convertFromChatOpenai(request *types.ChatCompletionRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{ claudeRequest := ClaudeRequest{
Model: request.Model, Model: request.Model,
Prompt: "", Messages: nil,
MaxTokensToSample: request.MaxTokens, System: "",
StopSequences: nil, MaxTokens: request.MaxTokens,
Temperature: request.Temperature, StopSequences: nil,
TopP: request.TopP, Temperature: request.Temperature,
Stream: request.Stream, TopP: request.TopP,
Stream: request.Stream,
} }
if claudeRequest.MaxTokensToSample == 0 { if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokensToSample = 1000000 claudeRequest.MaxTokens = 4096
} }
prompt := "" var messages []Message
for _, message := range request.Messages { for _, message := range request.Messages {
if message.Role == "user" { if message.Role != "system" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) messages = append(messages, Message{
} else if message.Role == "assistant" { Role: message.Role,
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) Content: message.Content.(string),
} else if message.Role == "system" { })
if prompt == "" { claudeRequest.Messages = messages
prompt = message.StringContent() } else {
} claudeRequest.System = message.Content.(string)
} }
} }
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest return &claudeRequest
} }
func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
error := errorHandle(&response.ClaudeResponseError) error := errorHandle(&response.Error)
if error != nil { if error != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{ errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *error, OpenAIError: *error,
@ -125,26 +125,33 @@ func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *
Index: 0, Index: 0,
Message: types.ChatCompletionMessage{ Message: types.ChatCompletionMessage{
Role: "assistant", Role: "assistant",
Content: strings.TrimPrefix(response.Completion, " "), Content: strings.TrimPrefix(response.Content[0].Text, " "),
Name: nil, Name: nil,
}, },
FinishReason: stopReasonClaude2OpenAI(response.StopReason), FinishReason: stopReasonClaude2OpenAI(response.StopReason),
} }
openaiResponse = &types.ChatCompletionResponse{ openaiResponse = &types.ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: response.Id,
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice}, Choices: []types.ChatCompletionChoice{choice},
Model: response.Model, Model: response.Model,
Usage: &types.Usage{
CompletionTokens: 0,
PromptTokens: 0,
TotalTokens: 0,
},
} }
completionTokens := common.CountTokenText(response.Completion, response.Model) completionTokens := response.Usage.OutputTokens
response.Usage.CompletionTokens = completionTokens
response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens
openaiResponse.Usage = response.Usage promptTokens := response.Usage.InputTokens
*p.Usage = *response.Usage openaiResponse.Usage.PromptTokens = promptTokens
openaiResponse.Usage.CompletionTokens = completionTokens
openaiResponse.Usage.TotalTokens = promptTokens + completionTokens
*p.Usage = *openaiResponse.Usage
return openaiResponse, nil return openaiResponse, nil
} }
@ -152,7 +159,7 @@ func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *
// 转换为OpenAI聊天流式请求体 // 转换为OpenAI聊天流式请求体
func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回 // 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), `data: {"type": "completion"`) { if !strings.HasPrefix(string(*rawLine), `data: {"type"`) {
*rawLine = nil *rawLine = nil
return return
} }
@ -160,43 +167,61 @@ func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin
// 去除前缀 // 去除前缀
*rawLine = (*rawLine)[6:] *rawLine = (*rawLine)[6:]
var claudeResponse *ClaudeResponse var claudeResponse ClaudeStreamResponse
err := json.Unmarshal(*rawLine, claudeResponse) err := json.Unmarshal(*rawLine, &claudeResponse)
if err != nil { if err != nil {
errChan <- common.ErrorToOpenAIError(err) errChan <- common.ErrorToOpenAIError(err)
return return
} }
error := errorHandle(&claudeResponse.ClaudeResponseError) error := errorHandle(&claudeResponse.Error)
if error != nil { if error != nil {
errChan <- error errChan <- error
return return
} }
if claudeResponse.StopReason == "stop_sequence" { switch claudeResponse.Type {
case "message_start":
h.Usage.PromptTokens = claudeResponse.Message.InputTokens
case "message_delta":
h.convertToOpenaiStream(&claudeResponse, dataChan, errChan)
h.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
case "content_block_delta":
h.convertToOpenaiStream(&claudeResponse, dataChan, errChan)
case "message_stop":
errChan <- io.EOF errChan <- io.EOF
*rawLine = requester.StreamClosed *rawLine = requester.StreamClosed
default:
return return
} }
h.convertToOpenaiStream(claudeResponse, dataChan, errChan)
} }
func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, dataChan chan string, errChan chan error) { func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStreamResponse, dataChan chan string, errChan chan error) {
var choice types.ChatCompletionStreamChoice choice := types.ChatCompletionStreamChoice{
choice.Delta.Content = claudeResponse.Completion Index: claudeResponse.Index,
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) }
if finishReason != "null" {
if claudeResponse.Delta.Text != "" {
choice.Delta.Content = claudeResponse.Delta.Text
}
finishReason := stopReasonClaude2OpenAI(claudeResponse.Delta.StopReason)
if finishReason != "" {
choice.FinishReason = &finishReason choice.FinishReason = &finishReason
} }
chatCompletion := types.ChatCompletionStreamResponse{ chatCompletion := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: h.Request.Model, Model: h.Request.Model,
Choices: []types.ChatCompletionStreamChoice{choice}, Choices: []types.ChatCompletionStreamChoice{choice},
} }
responseBody, _ := json.Marshal(chatCompletion) responseBody, _ := json.Marshal(chatCompletion)
dataChan <- string(responseBody) dataChan <- string(responseBody)
h.Usage.PromptTokens += common.CountTokenText(claudeResponse.Completion, h.Request.Model)
} }

View File

@ -1,7 +1,5 @@
package claude package claude
import "one-api/types"
type ClaudeError struct { type ClaudeError struct {
Type string `json:"type"` Type string `json:"type"`
Message string `json:"message"` Message string `json:"message"`
@ -11,25 +9,56 @@ type ClaudeMetadata struct {
UserId string `json:"user_id"` UserId string `json:"user_id"`
} }
type ResContent struct {
Text string `json:"text"`
Type string `json:"type"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ClaudeRequest struct { type ClaudeRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` System string `json:"system,omitempty"`
MaxTokensToSample int `json:"max_tokens_to_sample"` Messages []Message `json:"messages"`
StopSequences []string `json:"stop_sequences,omitempty"` MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
TopP float64 `json:"top_p,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"` //ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
} }
type ClaudeResponseError struct { type Usage struct {
Error ClaudeError `json:"error,omitempty"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens,omitempty"`
} }
type ClaudeResponse struct { type ClaudeResponse struct {
Completion string `json:"completion"` Content []ResContent `json:"content"`
StopReason string `json:"stop_reason"` Id string `json:"id"`
Model string `json:"model"` Role string `json:"role"`
Usage *types.Usage `json:"usage,omitempty"` StopReason string `json:"stop_reason"`
ClaudeResponseError StopSequence string `json:"stop_sequence,omitempty"`
Model string `json:"model"`
Usage `json:"usage,omitempty"`
Error ClaudeError `json:"error,omitempty"`
}
type Delta struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
type ClaudeStreamResponse struct {
Type string `json:"type"`
Message ClaudeResponse `json:"message,omitempty"`
Index int `json:"index,omitempty"`
Delta Delta `json:"delta,omitempty"`
Usage Usage `json:"usage,omitempty"`
Error ClaudeError `json:"error,omitempty"`
} }

View File

@ -59,8 +59,8 @@ const typeConfig = {
}, },
14: { 14: {
input: { input: {
models: ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1'], models: ['claude-instant-1.2', 'claude-2.0', 'claude-2.1','claude-3-opus-20240229','claude-3-sonnet-20240229'],
test_model: 'claude-2' test_model: 'claude-3-sonnet-20240229'
}, },
modelGroup: 'Anthropic' modelGroup: 'Anthropic'
}, },