ai-gateway/relay/channel/anthropic/main.go

273 lines
7.9 KiB
Go
Raw Normal View History

2024-01-14 11:21:03 +00:00
package anthropic
2023-07-22 08:18:03 +00:00
2023-07-22 09:12:13 +00:00
import (
2023-07-22 09:36:40 +00:00
"bufio"
"encoding/json"
2023-07-22 09:12:13 +00:00
"fmt"
2023-07-22 09:36:40 +00:00
"github.com/gin-gonic/gin"
2024-01-28 11:38:58 +00:00
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
2024-01-28 11:38:58 +00:00
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
2023-07-22 09:36:40 +00:00
"io"
"net/http"
2023-07-22 09:12:13 +00:00
"strings"
)
func stopReasonClaude2OpenAI(reason *string) string {
if reason == nil {
return ""
}
switch *reason {
case "end_turn":
return "stop"
2023-07-22 08:18:03 +00:00
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return *reason
2023-07-22 08:18:03 +00:00
}
}
2023-07-22 09:12:13 +00:00
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
2024-01-14 11:21:03 +00:00
claudeRequest := Request{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
// legacy model name mapping
if claudeRequest.Model == "claude-instant-1" {
claudeRequest.Model = "claude-instant-1.1"
} else if claudeRequest.Model == "claude-2" {
claudeRequest.Model = "claude-2.1"
}
2023-07-22 09:12:13 +00:00
for _, message := range textRequest.Messages {
if message.Role == "system" && claudeRequest.System == "" {
claudeRequest.System = message.StringContent()
continue
}
claudeMessage := Message{
Role: message.Role,
}
var content Content
if message.IsStringContent() {
content.Type = "text"
content.Text = message.StringContent()
claudeMessage.Content = append(claudeMessage.Content, content)
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
continue
}
var contents []Content
openaiContent := message.ParseContent()
for _, part := range openaiContent {
var content Content
if part.Type == model.ContentTypeText {
content.Type = "text"
content.Text = part.Text
} else if part.Type == model.ContentTypeImageURL {
content.Type = "image"
content.Source = &ImageSource{
Type: "base64",
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
content.Source.MediaType = mimeType
content.Source.Data = data
2023-11-24 13:39:44 +00:00
}
contents = append(contents, content)
2023-07-22 09:12:13 +00:00
}
claudeMessage.Content = contents
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
2023-07-22 09:12:13 +00:00
}
return &claudeRequest
}
// https://docs.anthropic.com/claude/reference/messages-streaming
func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response
var responseText string
var stopReason string
switch claudeResponse.Type {
case "message_start":
return nil, claudeResponse.Message
case "content_block_start":
if claudeResponse.ContentBlock != nil {
responseText = claudeResponse.ContentBlock.Text
}
case "content_block_delta":
if claudeResponse.Delta != nil {
responseText = claudeResponse.Delta.Text
}
case "message_delta":
if claudeResponse.Usage != nil {
response = &Response{
Usage: *claudeResponse.Usage,
}
}
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
stopReason = *claudeResponse.Delta.StopReason
}
}
2024-01-14 11:21:03 +00:00
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText
choice.Delta.Role = "assistant"
finishReason := stopReasonClaude2OpenAI(&stopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &openaiResponse, response
2023-07-22 09:12:13 +00:00
}
2024-01-14 11:21:03 +00:00
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
var responseText string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
2024-01-14 11:21:03 +00:00
choice := openai.TextResponseChoice{
2023-07-22 09:12:13 +00:00
Index: 0,
Message: model.Message{
2023-07-22 09:12:13 +00:00
Role: "assistant",
Content: responseText,
2023-07-22 09:12:13 +00:00
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
2024-01-14 11:21:03 +00:00
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id),
Model: claudeResponse.Model,
2023-07-22 09:12:13 +00:00
Object: "chat.completion",
Created: helper.GetTimestamp(),
2024-01-14 11:21:03 +00:00
Choices: []openai.TextResponseChoice{choice},
2023-07-22 09:12:13 +00:00
}
return &fullTextResponse
}
2023-07-22 09:36:40 +00:00
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
createdTime := helper.GetTimestamp()
2023-07-22 09:36:40 +00:00
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
2023-07-22 09:36:40 +00:00
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 {
2023-07-22 09:36:40 +00:00
continue
}
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
2023-07-22 09:36:40 +00:00
dataChan <- data
}
stopChan <- true
}()
2024-01-14 11:21:03 +00:00
common.SetEventStreamHeaders(c)
var usage model.Usage
var modelName string
var id string
2023-07-22 09:36:40 +00:00
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse StreamResponse
2023-07-22 09:36:40 +00:00
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
2023-07-22 09:36:40 +00:00
return true
}
response, meta := streamResponseClaude2OpenAI(&claudeResponse)
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
modelName = meta.Model
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
}
if response == nil {
return true
}
response.Id = id
response.Model = modelName
2023-07-22 09:36:40 +00:00
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
2023-07-22 09:36:40 +00:00
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
_ = resp.Body.Close()
return nil, &usage
2023-07-22 09:36:40 +00:00
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
2023-07-22 09:36:40 +00:00
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
2023-07-22 09:36:40 +00:00
}
err = resp.Body.Close()
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
2023-07-22 09:36:40 +00:00
}
2024-01-14 11:21:03 +00:00
var claudeResponse Response
2023-07-22 09:36:40 +00:00
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
2023-07-22 09:36:40 +00:00
}
if claudeResponse.Error.Type != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
2023-07-22 09:36:40 +00:00
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = modelName
usage := model.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
2023-07-22 09:36:40 +00:00
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
2023-07-22 09:36:40 +00:00
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}