diff --git a/providers/claude/base.go b/providers/claude/base.go index 58e1ba8c..58f9120c 100644 --- a/providers/claude/base.go +++ b/providers/claude/base.go @@ -81,3 +81,12 @@ func stopReasonClaude2OpenAI(reason string) string { return reason } } + +func convertRole(role string) string { + switch role { + case "user": + return types.ChatMessageRoleUser + default: + return types.ChatMessageRoleAssistant + } +} diff --git a/providers/claude/chat.go b/providers/claude/chat.go index ace6dafc..29bee010 100644 --- a/providers/claude/chat.go +++ b/providers/claude/chat.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/image" "one-api/common/requester" "one-api/types" "strings" @@ -71,7 +72,11 @@ func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (* headers["Accept"] = "text/event-stream" } - claudeRequest := convertFromChatOpenai(request) + claudeRequest, errWithCode := convertFromChatOpenai(request) + if errWithCode != nil { + return nil, errWithCode + } + // 创建请求 req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(claudeRequest), p.Requester.WithHeader(headers)) if err != nil { @@ -81,10 +86,10 @@ func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (* return req, nil } -func convertFromChatOpenai(request *types.ChatCompletionRequest) *ClaudeRequest { +func convertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest, *types.OpenAIErrorWithStatusCode) { claudeRequest := ClaudeRequest{ Model: request.Model, - Messages: nil, + Messages: []Message{}, System: "", MaxTokens: request.MaxTokens, StopSequences: nil, @@ -95,20 +100,46 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *ClaudeRequest if claudeRequest.MaxTokens == 0 { claudeRequest.MaxTokens = 4096 } - var messages []Message + for _, message := range request.Messages { - if message.Role != "system" { - messages = append(messages, Message{ - Role: message.Role, - Content: message.Content.(string), - }) - claudeRequest.Messages = messages - } else { + if message.Role == "system" { claudeRequest.System = message.Content.(string) + continue } + content := Message{ + Role: convertRole(message.Role), + Content: []MessageContent{}, + } + + openaiContent := message.ParseContent() + for _, part := range openaiContent { + if part.Type == types.ContentTypeText { + content.Content = append(content.Content, MessageContent{ + Type: "text", + Text: part.Text, + }) + continue + } + + if part.Type == types.ContentTypeImageURL { + mimeType, data, err := image.GetImageFromUrl(part.ImageURL.URL) + if err != nil { + return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest) + } + content.Content = append(content.Content, MessageContent{ + Type: "image", + Source: &ContentSource{ + Type: "base64", + MediaType: mimeType, + Data: data, + }, + }) + } + } + claudeRequest.Messages = append(claudeRequest.Messages, content) } - return &claudeRequest + return &claudeRequest, nil } func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { @@ -124,7 +155,7 @@ func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request * choice := types.ChatCompletionChoice{ Index: 0, Message: types.ChatCompletionMessage{ - Role: "assistant", + Role: response.Role, Content: strings.TrimPrefix(response.Content[0].Text, " "), Name: nil, }, @@ -135,7 +166,7 @@ func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request * Object: "chat.completion", Created: common.GetTimestamp(), Choices: []types.ChatCompletionChoice{choice}, - Model: response.Model, + Model: request.Model, Usage: &types.Usage{ CompletionTokens: 0, PromptTokens: 0, @@ -180,32 +211,39 @@ func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin return } + if claudeResponse.Type == "message_stop" { + errChan <- io.EOF + *rawLine = requester.StreamClosed + return + } + switch claudeResponse.Type { case "message_start": - h.Usage.PromptTokens = claudeResponse.Message.InputTokens + h.convertToOpenaiStream(&claudeResponse, dataChan) + h.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens case "message_delta": - h.convertToOpenaiStream(&claudeResponse, dataChan, errChan) + h.convertToOpenaiStream(&claudeResponse, dataChan) h.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens case "content_block_delta": - h.convertToOpenaiStream(&claudeResponse, dataChan, errChan) - - case "message_stop": - errChan <- io.EOF - *rawLine = requester.StreamClosed + h.convertToOpenaiStream(&claudeResponse, dataChan) default: return } } -func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStreamResponse, dataChan chan string, errChan chan error) { +func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStreamResponse, dataChan chan string) { choice := types.ChatCompletionStreamChoice{ Index: claudeResponse.Index, } + if claudeResponse.Message.Role != "" { + choice.Delta.Role = claudeResponse.Message.Role + } + if claudeResponse.Delta.Text != "" { choice.Delta.Content = claudeResponse.Delta.Text } diff --git a/providers/claude/type.go b/providers/claude/type.go index 8676c178..810641b7 100644 --- a/providers/claude/type.go +++ b/providers/claude/type.go @@ -14,9 +14,21 @@ type ResContent struct { Type string `json:"type"` } +type ContentSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +type MessageContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Source *ContentSource `json:"source,omitempty"` +} + type Message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content []MessageContent `json:"content"` } type ClaudeRequest struct { @@ -33,18 +45,19 @@ type ClaudeRequest struct { } type Usage struct { - InputTokens int `json:"input_tokens"` + InputTokens int `json:"input_tokens,omitempty"` OutputTokens int `json:"output_tokens,omitempty"` } type ClaudeResponse struct { - Content []ResContent `json:"content"` Id string `json:"id"` + Type string `json:"type"` Role string `json:"role"` - StopReason string `json:"stop_reason"` - StopSequence string `json:"stop_sequence,omitempty"` + Content []ResContent `json:"content"` Model string `json:"model"` - Usage `json:"usage,omitempty"` - Error ClaudeError `json:"error,omitempty"` + StopReason string `json:"stop_reason,omitempty"` + StopSequence string `json:"stop_sequence,omitempty"` + Usage Usage `json:"usage,omitempty"` + Error ClaudeError `json:"error,omitempty"` } type Delta struct {