From 675847bf98588138138318272924876188413d86 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 22 Jul 2023 17:12:13 +0800 Subject: [PATCH] refactor: refactor claude related code --- controller/relay-claude.go | 64 ++++++++++++++++++++++++++++++++++++++ controller/relay-text.go | 57 +++++---------------------------- 2 files changed, 71 insertions(+), 50 deletions(-) diff --git a/controller/relay-claude.go b/controller/relay-claude.go index 2b4f0c87..99f472e4 100644 --- a/controller/relay-claude.go +++ b/controller/relay-claude.go @@ -1,5 +1,11 @@ package controller +import ( + "fmt" + "one-api/common" + "strings" +) + type ClaudeMetadata struct { UserId string `json:"user_id"` } @@ -38,3 +44,61 @@ func stopReasonClaude2OpenAI(reason string) string { return reason } } + +func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { + claudeRequest := ClaudeRequest{ + Model: textRequest.Model, + Prompt: "", + MaxTokensToSample: textRequest.MaxTokens, + StopSequences: nil, + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + Stream: textRequest.Stream, + } + if claudeRequest.MaxTokensToSample == 0 { + claudeRequest.MaxTokensToSample = 1000000 + } + prompt := "" + for _, message := range textRequest.Messages { + if message.Role == "user" { + prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) + } else if message.Role == "assistant" { + prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) + } else { + // ignore other roles + } + prompt += "\n\nAssistant:" + } + claudeRequest.Prompt = prompt + return &claudeRequest +} + +func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = claudeResponse.Completion + choice.FinishReason = stopReasonClaude2OpenAI(claudeResponse.StopReason) + var response ChatCompletionsStreamResponse + response.Object = "chat.completion.chunk" + response.Model = claudeResponse.Model + response.Choices = []ChatCompletionsStreamResponseChoice{choice} + return &response +} + +func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: strings.TrimPrefix(claudeResponse.Completion, " "), + Name: nil, + }, + FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), + } + fullTextResponse := OpenAITextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []OpenAITextResponseChoice{choice}, + } + return &fullTextResponse +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 18f50966..a3d0c801 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -159,30 +159,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } switch apiType { case APITypeClaude: - claudeRequest := ClaudeRequest{ - Model: textRequest.Model, - Prompt: "", - MaxTokensToSample: textRequest.MaxTokens, - StopSequences: nil, - Temperature: textRequest.Temperature, - TopP: textRequest.TopP, - Stream: textRequest.Stream, - } - if claudeRequest.MaxTokensToSample == 0 { - claudeRequest.MaxTokensToSample = 1000000 - } - prompt := "" - for _, message := range textRequest.Messages { - if message.Role == "user" { - prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) - } else if message.Role == "assistant" { - prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) - } else { - // ignore other roles - } - prompt += "\n\nAssistant:" - } - claudeRequest.Prompt = prompt + claudeRequest := requestOpenAI2Claude(textRequest) jsonStr, err := json.Marshal(claudeRequest) if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) @@ -441,15 +418,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return true } streamResponseText += claudeResponse.Completion - var choice ChatCompletionsStreamResponseChoice - choice.Delta.Content = claudeResponse.Completion - choice.FinishReason = stopReasonClaude2OpenAI(claudeResponse.StopReason) - var response ChatCompletionsStreamResponse + response := streamResponseClaude2OpenAI(&claudeResponse) response.Id = responseId response.Created = createdTime - response.Object = "chat.completion.chunk" - response.Model = textRequest.Model - response.Choices = []ChatCompletionsStreamResponseChoice{choice} jsonStr, err := json.Marshal(response) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) @@ -492,26 +463,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { StatusCode: resp.StatusCode, } } - choice := OpenAITextResponseChoice{ - Index: 0, - Message: Message{ - Role: "assistant", - Content: strings.TrimPrefix(claudeResponse.Completion, " "), - Name: nil, - }, - FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), - } + fullTextResponse := responseClaude2OpenAI(&claudeResponse) completionTokens := countTokenText(claudeResponse.Completion, textRequest.Model) - fullTextResponse := OpenAITextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), - Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, - Usage: Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + promptTokens, - }, + fullTextResponse.Usage = Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, } textResponse.Usage = fullTextResponse.Usage jsonResponse, err := json.Marshal(fullTextResponse)