feat: support claude-3 (close #1080, close #1094)

This commit is contained in:
JustSong 2024-03-09 01:12:47 +08:00
parent 4fb22ad4ce
commit bf2e26a48f
5 changed files with 204 additions and 81 deletions

View File

@ -63,12 +63,15 @@ var ModelRatio = map[string]float64{
"text-search-ada-doc-001": 10, "text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1, "text-moderation-stable": 0.1,
"text-moderation-latest": 0.1, "text-moderation-latest": 0.1,
"dall-e-2": 8, // $0.016 - $0.020 / image "dall-e-2": 8, // $0.016 - $0.020 / image
"dall-e-3": 20, // $0.040 - $0.120 / image "dall-e-3": 20, // $0.040 - $0.120 / image
"claude-instant-1": 0.815, // $1.63 / 1M tokens // https://www.anthropic.com/api#pricing
"claude-2": 5.51, // $11.02 / 1M tokens "claude-instant-1.2": 0.8 / 1000 * USD,
"claude-2.0": 5.51, // $11.02 / 1M tokens "claude-2.0": 8.0 / 1000 * USD,
"claude-2.1": 5.51, // $11.02 / 1M tokens "claude-2.1": 8.0 / 1000 * USD,
"claude-3-haiku-20240229": 0.25 / 1000 * USD,
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
@ -214,11 +217,11 @@ func GetCompletionRatio(name string) float64 {
} }
return 2 return 2
} }
if strings.HasPrefix(name, "claude-instant-1") { if strings.HasPrefix(name, "claude-3") {
return 3.38 return 5
} }
if strings.HasPrefix(name, "claude-2") { if strings.HasPrefix(name, "claude-") {
return 2.965517 return 3
} }
if strings.HasPrefix(name, "mistral-") { if strings.HasPrefix(name, "mistral-") {
return 3 return 3

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io" "io"
@ -20,7 +19,7 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
@ -31,6 +30,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut
anthropicVersion = "2023-06-01" anthropicVersion = "2023-06-01"
} }
req.Header.Set("anthropic-version", anthropicVersion) req.Header.Set("anthropic-version", anthropicVersion)
req.Header.Set("anthropic-beta", "messages-2023-12-15")
return nil return nil
} }
@ -47,9 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
var responseText string err, usage = StreamHandler(c, resp)
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else { } else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
} }

View File

@ -1,5 +1,8 @@
package anthropic package anthropic
var ModelList = []string{ var ModelList = []string{
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", "claude-instant-1.2", "claude-2.0", "claude-2.1",
"claude-3-haiku-20240229",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
@ -15,73 +16,135 @@ import (
"strings" "strings"
) )
func stopReasonClaude2OpenAI(reason string) string { func stopReasonClaude2OpenAI(reason *string) string {
switch reason { if reason == nil {
return ""
}
switch *reason {
case "end_turn":
return "stop"
case "stop_sequence": case "stop_sequence":
return "stop" return "stop"
case "max_tokens": case "max_tokens":
return "length" return "length"
default: default:
return reason return *reason
} }
} }
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeRequest := Request{ claudeRequest := Request{
Model: textRequest.Model, Model: textRequest.Model,
Prompt: "", MaxTokens: textRequest.MaxTokens,
MaxTokensToSample: textRequest.MaxTokens, Temperature: textRequest.Temperature,
StopSequences: nil, TopP: textRequest.TopP,
Temperature: textRequest.Temperature, Stream: textRequest.Stream,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
} }
if claudeRequest.MaxTokensToSample == 0 { if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokensToSample = 1000000 claudeRequest.MaxTokens = 4096
}
// legacy model name mapping
if claudeRequest.Model == "claude-instant-1" {
claudeRequest.Model = "claude-instant-1.1"
} else if claudeRequest.Model == "claude-2" {
claudeRequest.Model = "claude-2.1"
} }
prompt := ""
for _, message := range textRequest.Messages { for _, message := range textRequest.Messages {
if message.Role == "user" { if message.Role == "system" && claudeRequest.System == "" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) claudeRequest.System = message.StringContent()
} else if message.Role == "assistant" { continue
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
if prompt == "" {
prompt = message.StringContent()
}
} }
claudeMessage := Message{
Role: message.Role,
}
var content Content
if message.IsStringContent() {
content.Type = "text"
content.Text = message.StringContent()
claudeMessage.Content = append(claudeMessage.Content, content)
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
continue
}
var contents []Content
openaiContent := message.ParseContent()
for _, part := range openaiContent {
var content Content
if part.Type == model.ContentTypeText {
content.Type = "text"
content.Text = part.Text
} else if part.Type == model.ContentTypeImageURL {
content.Type = "image"
content.Source = &ImageSource{
Type: "base64",
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
content.Source.MediaType = mimeType
content.Source.Data = data
}
contents = append(contents, content)
}
claudeMessage.Content = contents
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
} }
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest return &claudeRequest
} }
func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse { // https://docs.anthropic.com/claude/reference/messages-streaming
func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response
var responseText string
var stopReason string
switch claudeResponse.Type {
case "message_start":
return nil, claudeResponse.Message
case "content_block_start":
if claudeResponse.ContentBlock != nil {
responseText = claudeResponse.ContentBlock.Text
}
case "content_block_delta":
if claudeResponse.Delta != nil {
responseText = claudeResponse.Delta.Text
}
case "message_delta":
if claudeResponse.Usage != nil {
response = &Response{
Usage: *claudeResponse.Usage,
}
}
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
stopReason = *claudeResponse.Delta.StopReason
}
}
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion choice.Delta.Content = responseText
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) choice.Delta.Role = "assistant"
finishReason := stopReasonClaude2OpenAI(&stopReason)
if finishReason != "null" { if finishReason != "null" {
choice.FinishReason = &finishReason choice.FinishReason = &finishReason
} }
var response openai.ChatCompletionsStreamResponse var openaiResponse openai.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk" openaiResponse.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} return &openaiResponse, response
return &response
} }
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
var responseText string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
choice := openai.TextResponseChoice{ choice := openai.TextResponseChoice{
Index: 0, Index: 0,
Message: model.Message{ Message: model.Message{
Role: "assistant", Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "), Content: responseText,
Name: nil, Name: nil,
}, },
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
} }
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id),
Model: claudeResponse.Model,
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice}, Choices: []openai.TextResponseChoice{choice},
@ -89,17 +152,15 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
return &fullTextResponse return &fullTextResponse
} }
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
return 0, nil, nil return 0, nil, nil
} }
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 4, data[0:i], nil return i + 1, data[0:i], nil
} }
if atEOF { if atEOF {
return len(data), data, nil return len(data), data, nil
@ -111,29 +172,45 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
go func() { go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") { if len(data) < 6 {
continue continue
} }
data = strings.TrimPrefix(data, "event: completion\r\ndata: ") if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data dataChan <- data
} }
stopChan <- true stopChan <- true
}() }()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
var usage model.Usage
var modelName string
var id string
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
// some implementations may add \r at the end of data // some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r") data = strings.TrimSuffix(data, "\r")
var claudeResponse Response var claudeResponse StreamResponse
err := json.Unmarshal([]byte(data), &claudeResponse) err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true return true
} }
responseText += claudeResponse.Completion response, meta := streamResponseClaude2OpenAI(&claudeResponse)
response := streamResponseClaude2OpenAI(&claudeResponse) if meta != nil {
response.Id = responseId usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
modelName = meta.Model
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
}
if response == nil {
return true
}
response.Id = id
response.Model = modelName
response.Created = createdTime response.Created = createdTime
jsonStr, err := json.Marshal(response) jsonStr, err := json.Marshal(response)
if err != nil { if err != nil {
@ -147,11 +224,8 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return false return false
} }
}) })
err := resp.Body.Close() _ = resp.Body.Close()
if err != nil { return nil, &usage
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
} }
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
@ -181,11 +255,10 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
} }
fullTextResponse := responseClaude2OpenAI(&claudeResponse) fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = modelName fullTextResponse.Model = modelName
completionTokens := openai.CountTokenText(claudeResponse.Completion, modelName)
usage := model.Usage{ usage := model.Usage{
PromptTokens: promptTokens, PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: completionTokens, CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: promptTokens + completionTokens, TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
} }
fullTextResponse.Usage = usage fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)

View File

@ -1,19 +1,44 @@
package anthropic package anthropic
// https://docs.anthropic.com/claude/reference/messages_post
type Metadata struct { type Metadata struct {
UserId string `json:"user_id"` UserId string `json:"user_id"`
} }
type ImageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data string `json:"data"`
}
type Content struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *ImageSource `json:"source,omitempty"`
}
type Message struct {
Role string `json:"role"`
Content []Content `json:"content"`
}
type Request struct { type Request struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Messages []Message `json:"messages"`
MaxTokensToSample int `json:"max_tokens_to_sample"` System string `json:"system,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
TopP float64 `json:"top_p,omitempty"` Stream bool `json:"stream,omitempty"`
TopK int `json:"top_k,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//Metadata `json:"metadata,omitempty"` //Metadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"` }
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} }
type Error struct { type Error struct {
@ -22,8 +47,29 @@ type Error struct {
} }
type Response struct { type Response struct {
Completion string `json:"completion"` Id string `json:"id"`
StopReason string `json:"stop_reason"` Type string `json:"type"`
Model string `json:"model"` Role string `json:"role"`
Error Error `json:"error"` Content []Content `json:"content"`
Model string `json:"model"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
Usage Usage `json:"usage"`
Error Error `json:"error"`
}
type Delta struct {
Type string `json:"type"`
Text string `json:"text"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
}
type StreamResponse struct {
Type string `json:"type"`
Message *Response `json:"message"`
Index int `json:"index"`
ContentBlock *Content `json:"content_block"`
Delta *Delta `json:"delta"`
Usage *Usage `json:"usage"`
} }