refactor: refactor claude related code

This commit is contained in:
JustSong 2023-07-22 17:12:13 +08:00
parent 2ff15baf66
commit 675847bf98
2 changed files with 71 additions and 50 deletions

View File

@ -1,5 +1,11 @@
package controller
import (
"fmt"
"one-api/common"
"strings"
)
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
@ -38,3 +44,61 @@ func stopReasonClaude2OpenAI(reason string) string {
return reason
}
}
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
MaxTokensToSample: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 1000000
}
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else {
// ignore other roles
}
prompt += "\n\nAssistant:"
}
claudeRequest.Prompt = prompt
return &claudeRequest
}
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
choice.FinishReason = stopReasonClaude2OpenAI(claudeResponse.StopReason)
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response
}
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
}
return &fullTextResponse
}

View File

@ -159,30 +159,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
switch apiType {
case APITypeClaude:
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
MaxTokensToSample: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 1000000
}
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else {
// ignore other roles
}
prompt += "\n\nAssistant:"
}
claudeRequest.Prompt = prompt
claudeRequest := requestOpenAI2Claude(textRequest)
jsonStr, err := json.Marshal(claudeRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
@ -441,15 +418,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return true
}
streamResponseText += claudeResponse.Completion
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
choice.FinishReason = stopReasonClaude2OpenAI(claudeResponse.StopReason)
var response ChatCompletionsStreamResponse
response := streamResponseClaude2OpenAI(&claudeResponse)
response.Id = responseId
response.Created = createdTime
response.Object = "chat.completion.chunk"
response.Model = textRequest.Model
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
@ -492,26 +463,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
StatusCode: resp.StatusCode,
}
}
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
completionTokens := countTokenText(claudeResponse.Completion, textRequest.Model)
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Usage: Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + promptTokens,
},
fullTextResponse.Usage = Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
textResponse.Usage = fullTextResponse.Usage
jsonResponse, err := json.Marshal(fullTextResponse)