diff --git a/common/gin.go b/common/gin.go index f5012688..bed2c2b1 100644 --- a/common/gin.go +++ b/common/gin.go @@ -31,3 +31,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return nil } + +func SetEventStreamHeaders(c *gin.Context) { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") +} diff --git a/controller/billing.go b/controller/billing.go index 42e86aea..e27fd614 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -4,6 +4,7 @@ import ( "github.com/gin-gonic/gin" "one-api/common" "one-api/model" + "one-api/relay/channel/openai" ) func GetSubscription(c *gin.Context) { @@ -27,12 +28,12 @@ func GetSubscription(c *gin.Context) { expiredTime = 0 } if err != nil { - openAIError := OpenAIError{ + Error := openai.Error{ Message: err.Error(), Type: "upstream_error", } c.JSON(200, gin.H{ - "error": openAIError, + "error": Error, }) return } @@ -69,12 +70,12 @@ func GetUsage(c *gin.Context) { quota, err = model.GetUserUsedQuota(userId) } if err != nil { - openAIError := OpenAIError{ + Error := openai.Error{ Message: err.Error(), Type: "one_api_error", } c.JSON(200, gin.H{ - "error": openAIError, + "error": Error, }) return } diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 6ddad7ea..29346cde 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/model" + "one-api/relay/util" "strconv" "time" @@ -92,7 +93,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He for k := range headers { req.Header.Add(k, headers.Get(k)) } - res, err := httpClient.Do(req) + res, err := util.HTTPClient.Do(req) if err != nil { return nil, err } diff --git a/controller/channel-test.go b/controller/channel-test.go index 3aaa4897..f64f0ee3 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -9,6 +9,8 @@ import ( "net/http" "one-api/common" "one-api/model" + "one-api/relay/channel/openai" + "one-api/relay/util" "strconv" "sync" "time" @@ -16,7 +18,7 @@ import ( "github.com/gin-gonic/gin" ) -func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { +func testChannel(channel *model.Channel, request openai.ChatRequest) (err error, openaiErr *openai.Error) { switch channel.Type { case common.ChannelTypePaLM: fallthrough @@ -46,13 +48,13 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } requestURL := common.ChannelBaseURLs[channel.Type] if channel.Type == common.ChannelTypeAzure { - requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) + requestURL = util.GetFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) } else { if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { requestURL = baseURL } - requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) + requestURL = util.GetFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) } jsonData, err := json.Marshal(request) if err != nil { @@ -68,12 +70,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai req.Header.Set("Authorization", "Bearer "+channel.Key) } req.Header.Set("Content-Type", "application/json") - resp, err := httpClient.Do(req) + resp, err := util.HTTPClient.Do(req) if err != nil { return err, nil } defer resp.Body.Close() - var response TextResponse + var response openai.SlimTextResponse body, err := io.ReadAll(resp.Body) if err != nil { return err, nil @@ -91,12 +93,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai return nil, nil } -func buildTestRequest() *ChatRequest { - testRequest := &ChatRequest{ +func buildTestRequest() *openai.ChatRequest { + testRequest := &openai.ChatRequest{ Model: "", // this will be set later MaxTokens: 1, } - testMessage := Message{ + testMessage := openai.Message{ Role: "user", Content: "hi", } @@ -204,10 +206,10 @@ func testAllChannels(notify bool) error { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) disableChannel(channel.Id, channel.Name, err.Error()) } - if isChannelEnabled && shouldDisableChannel(openaiErr, -1) { + if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { disableChannel(channel.Id, channel.Name, err.Error()) } - if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { + if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { enableChannel(channel.Id, channel.Name) } channel.UpdateResponseTime(milliseconds) diff --git a/controller/model.go b/controller/model.go index c12ccf34..b7ec1b6a 100644 --- a/controller/model.go +++ b/controller/model.go @@ -2,8 +2,8 @@ package controller import ( "fmt" - "github.com/gin-gonic/gin" + "one-api/relay/channel/openai" ) // https://platform.openai.com/docs/api-reference/models/list @@ -613,14 +613,14 @@ func RetrieveModel(c *gin.Context) { if model, ok := openAIModelsMap[modelId]; ok { c.JSON(200, model) } else { - openAIError := OpenAIError{ + Error := openai.Error{ Message: fmt.Sprintf("The model '%s' does not exist", modelId), Type: "invalid_request_error", Param: "model", Code: "model_not_found", } c.JSON(200, gin.H{ - "error": openAIError, + "error": Error, }) } } diff --git a/controller/relay.go b/controller/relay.go index e45fd3eb..198d6c9a 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -4,349 +4,53 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/relay/channel/openai" + "one-api/relay/constant" + "one-api/relay/controller" + "one-api/relay/util" "strconv" "strings" "github.com/gin-gonic/gin" ) -type Message struct { - Role string `json:"role"` - Content any `json:"content"` - Name *string `json:"name,omitempty"` -} - -type ImageURL struct { - Url string `json:"url,omitempty"` - Detail string `json:"detail,omitempty"` -} - -type TextContent struct { - Type string `json:"type,omitempty"` - Text string `json:"text,omitempty"` -} - -type ImageContent struct { - Type string `json:"type,omitempty"` - ImageURL *ImageURL `json:"image_url,omitempty"` -} - -const ( - ContentTypeText = "text" - ContentTypeImageURL = "image_url" -) - -type OpenAIMessageContent struct { - Type string `json:"type,omitempty"` - Text string `json:"text"` - ImageURL *ImageURL `json:"image_url,omitempty"` -} - -func (m Message) IsStringContent() bool { - _, ok := m.Content.(string) - return ok -} - -func (m Message) StringContent() string { - content, ok := m.Content.(string) - if ok { - return content - } - contentList, ok := m.Content.([]any) - if ok { - var contentStr string - for _, contentItem := range contentList { - contentMap, ok := contentItem.(map[string]any) - if !ok { - continue - } - if contentMap["type"] == ContentTypeText { - if subStr, ok := contentMap["text"].(string); ok { - contentStr += subStr - } - } - } - return contentStr - } - return "" -} - -func (m Message) ParseContent() []OpenAIMessageContent { - var contentList []OpenAIMessageContent - content, ok := m.Content.(string) - if ok { - contentList = append(contentList, OpenAIMessageContent{ - Type: ContentTypeText, - Text: content, - }) - return contentList - } - anyList, ok := m.Content.([]any) - if ok { - for _, contentItem := range anyList { - contentMap, ok := contentItem.(map[string]any) - if !ok { - continue - } - switch contentMap["type"] { - case ContentTypeText: - if subStr, ok := contentMap["text"].(string); ok { - contentList = append(contentList, OpenAIMessageContent{ - Type: ContentTypeText, - Text: subStr, - }) - } - case ContentTypeImageURL: - if subObj, ok := contentMap["image_url"].(map[string]any); ok { - contentList = append(contentList, OpenAIMessageContent{ - Type: ContentTypeImageURL, - ImageURL: &ImageURL{ - Url: subObj["url"].(string), - }, - }) - } - } - } - return contentList - } - return nil -} - -const ( - RelayModeUnknown = iota - RelayModeChatCompletions - RelayModeCompletions - RelayModeEmbeddings - RelayModeModerations - RelayModeImagesGenerations - RelayModeEdits - RelayModeAudioSpeech - RelayModeAudioTranscription - RelayModeAudioTranslation -) - // https://platform.openai.com/docs/api-reference/chat -type ResponseFormat struct { - Type string `json:"type,omitempty"` -} - -type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - ResponseFormat *ResponseFormat `json:"response_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` -} - -func (r GeneralOpenAIRequest) ParseInput() []string { - if r.Input == nil { - return nil - } - var input []string - switch r.Input.(type) { - case string: - input = []string{r.Input.(string)} - case []any: - input = make([]string, 0, len(r.Input.([]any))) - for _, item := range r.Input.([]any) { - if str, ok := item.(string); ok { - input = append(input, str) - } - } - } - return input -} - -type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - MaxTokens int `json:"max_tokens"` -} - -type TextRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Prompt string `json:"prompt"` - MaxTokens int `json:"max_tokens"` - //Stream bool `json:"stream"` -} - -// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create -type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - Quality string `json:"quality,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Style string `json:"style,omitempty"` - User string `json:"user,omitempty"` -} - -type WhisperJSONResponse struct { - Text string `json:"text,omitempty"` -} - -type WhisperVerboseJSONResponse struct { - Task string `json:"task,omitempty"` - Language string `json:"language,omitempty"` - Duration float64 `json:"duration,omitempty"` - Text string `json:"text,omitempty"` - Segments []Segment `json:"segments,omitempty"` -} - -type Segment struct { - Id int `json:"id"` - Seek int `json:"seek"` - Start float64 `json:"start"` - End float64 `json:"end"` - Text string `json:"text"` - Tokens []int `json:"tokens"` - Temperature float64 `json:"temperature"` - AvgLogprob float64 `json:"avg_logprob"` - CompressionRatio float64 `json:"compression_ratio"` - NoSpeechProb float64 `json:"no_speech_prob"` -} - -type TextToSpeechRequest struct { - Model string `json:"model" binding:"required"` - Input string `json:"input" binding:"required"` - Voice string `json:"voice" binding:"required"` - Speed float64 `json:"speed"` - ResponseFormat string `json:"response_format"` -} - -type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type OpenAIError struct { - Message string `json:"message"` - Type string `json:"type"` - Param string `json:"param"` - Code any `json:"code"` -} - -type OpenAIErrorWithStatusCode struct { - OpenAIError - StatusCode int `json:"status_code"` -} - -type TextResponse struct { - Choices []OpenAITextResponseChoice `json:"choices"` - Usage `json:"usage"` - Error OpenAIError `json:"error"` -} - -type OpenAITextResponseChoice struct { - Index int `json:"index"` - Message `json:"message"` - FinishReason string `json:"finish_reason"` -} - -type OpenAITextResponse struct { - Id string `json:"id"` - Model string `json:"model,omitempty"` - Object string `json:"object"` - Created int64 `json:"created"` - Choices []OpenAITextResponseChoice `json:"choices"` - Usage `json:"usage"` -} - -type OpenAIEmbeddingResponseItem struct { - Object string `json:"object"` - Index int `json:"index"` - Embedding []float64 `json:"embedding"` -} - -type OpenAIEmbeddingResponse struct { - Object string `json:"object"` - Data []OpenAIEmbeddingResponseItem `json:"data"` - Model string `json:"model"` - Usage `json:"usage"` -} - -type ImageResponse struct { - Created int `json:"created"` - Data []struct { - Url string `json:"url"` - } -} - -type ChatCompletionsStreamResponseChoice struct { - Delta struct { - Content string `json:"content"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` -} - -type ChatCompletionsStreamResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionsStreamResponseChoice `json:"choices"` -} - -type CompletionsStreamResponse struct { - Choices []struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` -} - func Relay(c *gin.Context) { - relayMode := RelayModeUnknown + relayMode := constant.RelayModeUnknown if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { - relayMode = RelayModeChatCompletions + relayMode = constant.RelayModeChatCompletions } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { - relayMode = RelayModeCompletions + relayMode = constant.RelayModeCompletions } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { - relayMode = RelayModeEmbeddings + relayMode = constant.RelayModeEmbeddings } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - relayMode = RelayModeEmbeddings + relayMode = constant.RelayModeEmbeddings } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - relayMode = RelayModeModerations + relayMode = constant.RelayModeModerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - relayMode = RelayModeImagesGenerations + relayMode = constant.RelayModeImagesGenerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { - relayMode = RelayModeEdits + relayMode = constant.RelayModeEdits } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { - relayMode = RelayModeAudioSpeech + relayMode = constant.RelayModeAudioSpeech } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - relayMode = RelayModeAudioTranscription + relayMode = constant.RelayModeAudioTranscription } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - relayMode = RelayModeAudioTranslation + relayMode = constant.RelayModeAudioTranslation } - var err *OpenAIErrorWithStatusCode + var err *openai.ErrorWithStatusCode switch relayMode { - case RelayModeImagesGenerations: - err = relayImageHelper(c, relayMode) - case RelayModeAudioSpeech: + case constant.RelayModeImagesGenerations: + err = controller.RelayImageHelper(c, relayMode) + case constant.RelayModeAudioSpeech: fallthrough - case RelayModeAudioTranslation: + case constant.RelayModeAudioTranslation: fallthrough - case RelayModeAudioTranscription: - err = relayAudioHelper(c, relayMode) + case constant.RelayModeAudioTranscription: + err = controller.RelayAudioHelper(c, relayMode) default: - err = relayTextHelper(c, relayMode) + err = controller.RelayTextHelper(c, relayMode) } if err != nil { requestId := c.GetString(common.RequestIdKey) @@ -359,17 +63,17 @@ func Relay(c *gin.Context) { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) } else { if err.StatusCode == http.StatusTooManyRequests { - err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" + err.Error.Message = "当前分组上游负载已饱和,请稍后再试" } - err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) + err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId) c.JSON(err.StatusCode, gin.H{ - "error": err.OpenAIError, + "error": err.Error, }) } channelId := c.GetInt("channel_id") common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors - if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { + if util.ShouldDisableChannel(&err.Error, err.StatusCode) { channelId := c.GetInt("channel_id") channelName := c.GetString("channel_name") disableChannel(channelId, channelName, err.Message) @@ -378,7 +82,7 @@ func Relay(c *gin.Context) { } func RelayNotImplemented(c *gin.Context) { - err := OpenAIError{ + err := openai.Error{ Message: "API not implemented", Type: "one_api_error", Param: "", @@ -390,7 +94,7 @@ func RelayNotImplemented(c *gin.Context) { } func RelayNotFound(c *gin.Context) { - err := OpenAIError{ + err := openai.Error{ Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), Type: "invalid_request_error", Param: "", diff --git a/main.go b/main.go index d871a548..28a41287 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "one-api/controller" "one-api/middleware" "one-api/model" + "one-api/relay/channel/openai" "one-api/router" "os" "strconv" @@ -80,7 +81,7 @@ func main() { common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") model.InitBatchUpdater() } - controller.InitTokenEncoders() + openai.InitTokenEncoders() // Initialize HTTP server server := gin.New() diff --git a/model/main.go b/model/main.go index bfd6888b..9723e638 100644 --- a/model/main.go +++ b/model/main.go @@ -16,7 +16,7 @@ var DB *gorm.DB func createRootAccountIfNeed() error { var user User - //if user.Status != common.UserStatusEnabled { + //if user.Status != util.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { common.SysLog("no user exists, create a root user for you: username is root, password is 123456") hashedPassword, err := common.Password2Hash("123456") diff --git a/model/user.go b/model/user.go index f08acd23..1c2c0a75 100644 --- a/model/user.go +++ b/model/user.go @@ -15,7 +15,7 @@ type User struct { Username string `json:"username" gorm:"unique;index" validate:"max=12"` Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` - Role int `json:"role" gorm:"type:int;default:1"` // admin, common + Role int `json:"role" gorm:"type:int;default:1"` // admin, util Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` diff --git a/controller/relay-aiproxy.go b/relay/channel/aiproxy/main.go similarity index 57% rename from controller/relay-aiproxy.go rename to relay/channel/aiproxy/main.go index 543954f7..bee4d9d3 100644 --- a/controller/relay-aiproxy.go +++ b/relay/channel/aiproxy/main.go @@ -1,4 +1,4 @@ -package controller +package aiproxy import ( "bufio" @@ -8,56 +8,27 @@ import ( "io" "net/http" "one-api/common" + "one-api/relay/channel/openai" + "one-api/relay/constant" "strconv" "strings" ) // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 -type AIProxyLibraryRequest struct { - Model string `json:"model"` - Query string `json:"query"` - LibraryId string `json:"libraryId"` - Stream bool `json:"stream"` -} - -type AIProxyLibraryError struct { - ErrCode int `json:"errCode"` - Message string `json:"message"` -} - -type AIProxyLibraryDocument struct { - Title string `json:"title"` - URL string `json:"url"` -} - -type AIProxyLibraryResponse struct { - Success bool `json:"success"` - Answer string `json:"answer"` - Documents []AIProxyLibraryDocument `json:"documents"` - AIProxyLibraryError -} - -type AIProxyLibraryStreamResponse struct { - Content string `json:"content"` - Finish bool `json:"finish"` - Model string `json:"model"` - Documents []AIProxyLibraryDocument `json:"documents"` -} - -func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { +func ConvertRequest(request openai.GeneralOpenAIRequest) *LibraryRequest { query := "" if len(request.Messages) != 0 { query = request.Messages[len(request.Messages)-1].StringContent() } - return &AIProxyLibraryRequest{ + return &LibraryRequest{ Model: request.Model, Stream: request.Stream, Query: query, } } -func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { +func aiProxyDocuments2Markdown(documents []LibraryDocument) string { if len(documents) == 0 { return "" } @@ -68,52 +39,52 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { return content } -func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { +func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse { content := response.Answer + aiProxyDocuments2Markdown(response.Documents) - choice := OpenAITextResponseChoice{ + choice := openai.TextResponseChoice{ Index: 0, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: content, }, FinishReason: "stop", } - fullTextResponse := OpenAITextResponse{ + fullTextResponse := openai.TextResponse{ Id: common.GetUUID(), Object: "chat.completion", Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, + Choices: []openai.TextResponseChoice{choice}, } return &fullTextResponse } -func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = aiProxyDocuments2Markdown(documents) - choice.FinishReason = &stopFinishReason - return &ChatCompletionsStreamResponse{ + choice.FinishReason = &constant.StopFinishReason + return &openai.ChatCompletionsStreamResponse{ Id: common.GetUUID(), Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } } -func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = response.Content - return &ChatCompletionsStreamResponse{ + return &openai.ChatCompletionsStreamResponse{ Id: common.GetUUID(), Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: response.Model, - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } } -func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage Usage +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var usage openai.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -143,12 +114,12 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr } stopChan <- true }() - setEventStreamHeaders(c) - var documents []AIProxyLibraryDocument + common.SetEventStreamHeaders(c) + var documents []LibraryDocument c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - var AIProxyLibraryResponse AIProxyLibraryStreamResponse + var AIProxyLibraryResponse LibraryStreamResponse err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -179,28 +150,28 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, &usage } -func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var AIProxyLibraryResponse AIProxyLibraryResponse +func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var AIProxyLibraryResponse LibraryResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if AIProxyLibraryResponse.ErrCode != 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: AIProxyLibraryResponse.Message, Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode), Code: AIProxyLibraryResponse.ErrCode, @@ -211,7 +182,7 @@ func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/aiproxy/model.go b/relay/channel/aiproxy/model.go new file mode 100644 index 00000000..39689b3d --- /dev/null +++ b/relay/channel/aiproxy/model.go @@ -0,0 +1,32 @@ +package aiproxy + +type LibraryRequest struct { + Model string `json:"model"` + Query string `json:"query"` + LibraryId string `json:"libraryId"` + Stream bool `json:"stream"` +} + +type LibraryError struct { + ErrCode int `json:"errCode"` + Message string `json:"message"` +} + +type LibraryDocument struct { + Title string `json:"title"` + URL string `json:"url"` +} + +type LibraryResponse struct { + Success bool `json:"success"` + Answer string `json:"answer"` + Documents []LibraryDocument `json:"documents"` + LibraryError +} + +type LibraryStreamResponse struct { + Content string `json:"content"` + Finish bool `json:"finish"` + Model string `json:"model"` + Documents []LibraryDocument `json:"documents"` +} diff --git a/controller/relay-ali.go b/relay/channel/ali/main.go similarity index 53% rename from controller/relay-ali.go rename to relay/channel/ali/main.go index df1cc084..f45a515a 100644 --- a/controller/relay-ali.go +++ b/relay/channel/ali/main.go @@ -1,4 +1,4 @@ -package controller +package ali import ( "bufio" @@ -7,112 +7,43 @@ import ( "io" "net/http" "one-api/common" + "one-api/relay/channel/openai" "strings" ) // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r -type AliMessage struct { - Content string `json:"content"` - Role string `json:"role"` -} +const EnableSearchModelSuffix = "-internet" -type AliInput struct { - //Prompt string `json:"prompt"` - Messages []AliMessage `json:"messages"` -} - -type AliParameters struct { - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Seed uint64 `json:"seed,omitempty"` - EnableSearch bool `json:"enable_search,omitempty"` - IncrementalOutput bool `json:"incremental_output,omitempty"` -} - -type AliChatRequest struct { - Model string `json:"model"` - Input AliInput `json:"input"` - Parameters AliParameters `json:"parameters,omitempty"` -} - -type AliEmbeddingRequest struct { - Model string `json:"model"` - Input struct { - Texts []string `json:"texts"` - } `json:"input"` - Parameters *struct { - TextType string `json:"text_type,omitempty"` - } `json:"parameters,omitempty"` -} - -type AliEmbedding struct { - Embedding []float64 `json:"embedding"` - TextIndex int `json:"text_index"` -} - -type AliEmbeddingResponse struct { - Output struct { - Embeddings []AliEmbedding `json:"embeddings"` - } `json:"output"` - Usage AliUsage `json:"usage"` - AliError -} - -type AliError struct { - Code string `json:"code"` - Message string `json:"message"` - RequestId string `json:"request_id"` -} - -type AliUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type AliOutput struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` -} - -type AliChatResponse struct { - Output AliOutput `json:"output"` - Usage AliUsage `json:"usage"` - AliError -} - -const AliEnableSearchModelSuffix = "-internet" - -func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { - messages := make([]AliMessage, 0, len(request.Messages)) +func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { + messages := make([]Message, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] - messages = append(messages, AliMessage{ + messages = append(messages, Message{ Content: message.StringContent(), Role: strings.ToLower(message.Role), }) } enableSearch := false aliModel := request.Model - if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) { + if strings.HasSuffix(aliModel, EnableSearchModelSuffix) { enableSearch = true - aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix) + aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) } - return &AliChatRequest{ + return &ChatRequest{ Model: aliModel, - Input: AliInput{ + Input: Input{ Messages: messages, }, - Parameters: AliParameters{ + Parameters: Parameters{ EnableSearch: enableSearch, IncrementalOutput: request.Stream, }, } } -func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { - return &AliEmbeddingRequest{ +func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ Model: "text-embedding-v1", Input: struct { Texts []string `json:"texts"` @@ -122,21 +53,21 @@ func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingReque } } -func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var aliResponse AliEmbeddingResponse +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var aliResponse EmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&aliResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } if aliResponse.Code != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: aliResponse.Message, Type: aliResponse.Code, Param: aliResponse.RequestId, @@ -149,7 +80,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) @@ -157,16 +88,16 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS return nil, &fullTextResponse.Usage } -func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { - openAIEmbeddingResponse := OpenAIEmbeddingResponse{ +func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ Object: "list", - Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)), Model: "text-embedding-v1", - Usage: Usage{TotalTokens: response.Usage.TotalTokens}, + Usage: openai.Usage{TotalTokens: response.Usage.TotalTokens}, } for _, item := range response.Output.Embeddings { - openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ Object: `embedding`, Index: item.TextIndex, Embedding: item.Embedding, @@ -175,21 +106,21 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin return &openAIEmbeddingResponse } -func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { - choice := OpenAITextResponseChoice{ +func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { + choice := openai.TextResponseChoice{ Index: 0, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: response.Output.Text, }, FinishReason: response.Output.FinishReason, } - fullTextResponse := OpenAITextResponse{ + fullTextResponse := openai.TextResponse{ Id: response.RequestId, Object: "chat.completion", Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, - Usage: Usage{ + Choices: []openai.TextResponseChoice{choice}, + Usage: openai.Usage{ PromptTokens: response.Usage.InputTokens, CompletionTokens: response.Usage.OutputTokens, TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, @@ -198,25 +129,25 @@ func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { return &fullTextResponse } -func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = aliResponse.Output.Text if aliResponse.Output.FinishReason != "null" { finishReason := aliResponse.Output.FinishReason choice.FinishReason = &finishReason } - response := ChatCompletionsStreamResponse{ + response := openai.ChatCompletionsStreamResponse{ Id: aliResponse.RequestId, Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "qwen", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } return &response } -func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage Usage +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var usage openai.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -246,12 +177,12 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat } stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) //lastResponseText := "" c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - var aliResponse AliChatResponse + var aliResponse ChatResponse err := json.Unmarshal([]byte(data), &aliResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -279,28 +210,28 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, &usage } -func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var aliResponse AliChatResponse +func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var aliResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &aliResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if aliResponse.Code != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: aliResponse.Message, Type: aliResponse.Code, Param: aliResponse.RequestId, @@ -313,7 +244,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode fullTextResponse.Model = "qwen" jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go new file mode 100644 index 00000000..54f13041 --- /dev/null +++ b/relay/channel/ali/model.go @@ -0,0 +1,71 @@ +package ali + +type Message struct { + Content string `json:"content"` + Role string `json:"role"` +} + +type Input struct { + //Prompt string `json:"prompt"` + Messages []Message `json:"messages"` +} + +type Parameters struct { + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` + IncrementalOutput bool `json:"incremental_output,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model"` + Input Input `json:"input"` + Parameters Parameters `json:"parameters,omitempty"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type Embedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type EmbeddingResponse struct { + Output struct { + Embeddings []Embedding `json:"embeddings"` + } `json:"output"` + Usage Usage `json:"usage"` + Error +} + +type Error struct { + Code string `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Output struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` +} + +type ChatResponse struct { + Output Output `json:"output"` + Usage Usage `json:"usage"` + Error +} diff --git a/controller/relay-claude.go b/relay/channel/anthropic/main.go similarity index 64% rename from controller/relay-claude.go rename to relay/channel/anthropic/main.go index ca7a701a..a4272d7b 100644 --- a/controller/relay-claude.go +++ b/relay/channel/anthropic/main.go @@ -1,4 +1,4 @@ -package controller +package anthropic import ( "bufio" @@ -8,37 +8,10 @@ import ( "io" "net/http" "one-api/common" + "one-api/relay/channel/openai" "strings" ) -type ClaudeMetadata struct { - UserId string `json:"user_id"` -} - -type ClaudeRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - MaxTokensToSample int `json:"max_tokens_to_sample"` - StopSequences []string `json:"stop_sequences,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - //ClaudeMetadata `json:"metadata,omitempty"` - Stream bool `json:"stream,omitempty"` -} - -type ClaudeError struct { - Type string `json:"type"` - Message string `json:"message"` -} - -type ClaudeResponse struct { - Completion string `json:"completion"` - StopReason string `json:"stop_reason"` - Model string `json:"model"` - Error ClaudeError `json:"error"` -} - func stopReasonClaude2OpenAI(reason string) string { switch reason { case "stop_sequence": @@ -50,8 +23,8 @@ func stopReasonClaude2OpenAI(reason string) string { } } -func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { - claudeRequest := ClaudeRequest{ +func ConvertRequest(textRequest openai.GeneralOpenAIRequest) *Request { + claudeRequest := Request{ Model: textRequest.Model, Prompt: "", MaxTokensToSample: textRequest.MaxTokens, @@ -80,40 +53,40 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { return &claudeRequest } -func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = claudeResponse.Completion finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) if finishReason != "null" { choice.FinishReason = &finishReason } - var response ChatCompletionsStreamResponse + var response openai.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = claudeResponse.Model - response.Choices = []ChatCompletionsStreamResponseChoice{choice} + response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} return &response } -func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { - choice := OpenAITextResponseChoice{ +func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { + choice := openai.TextResponseChoice{ Index: 0, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: strings.TrimPrefix(claudeResponse.Completion, " "), Name: nil, }, FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), } - fullTextResponse := OpenAITextResponse{ + fullTextResponse := openai.TextResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Object: "chat.completion", Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, + Choices: []openai.TextResponseChoice{choice}, } return &fullTextResponse } -func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { responseText := "" responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createdTime := common.GetTimestamp() @@ -143,13 +116,13 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS } stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: // some implementations may add \r at the end of data data = strings.TrimSuffix(data, "\r") - var claudeResponse ClaudeResponse + var claudeResponse Response err := json.Unmarshal([]byte(data), &claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -173,28 +146,28 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } -func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { +func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - var claudeResponse ClaudeResponse + var claudeResponse Response err = json.Unmarshal(responseBody, &claudeResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if claudeResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: claudeResponse.Error.Message, Type: claudeResponse.Error.Type, Param: "", @@ -205,8 +178,8 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model } fullTextResponse := responseClaude2OpenAI(&claudeResponse) fullTextResponse.Model = model - completionTokens := countTokenText(claudeResponse.Completion, model) - usage := Usage{ + completionTokens := openai.CountTokenText(claudeResponse.Completion, model) + usage := openai.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, @@ -214,7 +187,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model fullTextResponse.Usage = usage jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/anthropic/model.go b/relay/channel/anthropic/model.go new file mode 100644 index 00000000..70fc9430 --- /dev/null +++ b/relay/channel/anthropic/model.go @@ -0,0 +1,29 @@ +package anthropic + +type Metadata struct { + UserId string `json:"user_id"` +} + +type Request struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + MaxTokensToSample int `json:"max_tokens_to_sample"` + StopSequences []string `json:"stop_sequences,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + //Metadata `json:"metadata,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type Error struct { + Type string `json:"type"` + Message string `json:"message"` +} + +type Response struct { + Completion string `json:"completion"` + StopReason string `json:"stop_reason"` + Model string `json:"model"` + Error Error `json:"error"` +} diff --git a/controller/relay-baidu.go b/relay/channel/baidu/main.go similarity index 61% rename from controller/relay-baidu.go rename to relay/channel/baidu/main.go index dca30da1..47969492 100644 --- a/controller/relay-baidu.go +++ b/relay/channel/baidu/main.go @@ -1,4 +1,4 @@ -package controller +package baidu import ( "bufio" @@ -9,6 +9,9 @@ import ( "io" "net/http" "one-api/common" + "one-api/relay/channel/openai" + "one-api/relay/constant" + "one-api/relay/util" "strings" "sync" "time" @@ -37,53 +40,9 @@ type BaiduError struct { ErrorMsg string `json:"error_msg"` } -type BaiduChatResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Result string `json:"result"` - IsTruncated bool `json:"is_truncated"` - NeedClearHistory bool `json:"need_clear_history"` - Usage Usage `json:"usage"` - BaiduError -} - -type BaiduChatStreamResponse struct { - BaiduChatResponse - SentenceId int `json:"sentence_id"` - IsEnd bool `json:"is_end"` -} - -type BaiduEmbeddingRequest struct { - Input []string `json:"input"` -} - -type BaiduEmbeddingData struct { - Object string `json:"object"` - Embedding []float64 `json:"embedding"` - Index int `json:"index"` -} - -type BaiduEmbeddingResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Data []BaiduEmbeddingData `json:"data"` - Usage Usage `json:"usage"` - BaiduError -} - -type BaiduAccessToken struct { - AccessToken string `json:"access_token"` - Error string `json:"error,omitempty"` - ErrorDescription string `json:"error_description,omitempty"` - ExpiresIn int64 `json:"expires_in,omitempty"` - ExpiresAt time.Time `json:"-"` -} - var baiduTokenStore sync.Map -func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { +func ConvertRequest(request openai.GeneralOpenAIRequest) *BaiduChatRequest { messages := make([]BaiduMessage, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { @@ -108,56 +67,56 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { } } -func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { - choice := OpenAITextResponseChoice{ +func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { + choice := openai.TextResponseChoice{ Index: 0, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: response.Result, }, FinishReason: "stop", } - fullTextResponse := OpenAITextResponse{ + fullTextResponse := openai.TextResponse{ Id: response.Id, Object: "chat.completion", Created: response.Created, - Choices: []OpenAITextResponseChoice{choice}, + Choices: []openai.TextResponseChoice{choice}, Usage: response.Usage, } return &fullTextResponse } -func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = baiduResponse.Result if baiduResponse.IsEnd { - choice.FinishReason = &stopFinishReason + choice.FinishReason = &constant.StopFinishReason } - response := ChatCompletionsStreamResponse{ + response := openai.ChatCompletionsStreamResponse{ Id: baiduResponse.Id, Object: "chat.completion.chunk", Created: baiduResponse.Created, Model: "ernie-bot", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } return &response } -func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { - return &BaiduEmbeddingRequest{ +func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ Input: request.ParseInput(), } } -func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { - openAIEmbeddingResponse := OpenAIEmbeddingResponse{ +func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ Object: "list", - Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)), Model: "baidu-embedding", Usage: response.Usage, } for _, item := range response.Data { - openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ Object: item.Object, Index: item.Index, Embedding: item.Embedding, @@ -166,8 +125,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe return &openAIEmbeddingResponse } -func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage Usage +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var usage openai.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -194,11 +153,11 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt } stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - var baiduResponse BaiduChatStreamResponse + var baiduResponse ChatStreamResponse err := json.Unmarshal([]byte(data), &baiduResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -224,28 +183,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, &usage } -func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var baiduResponse BaiduChatResponse +func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var baiduResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if baiduResponse.ErrorMsg != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: baiduResponse.ErrorMsg, Type: "baidu_error", Param: "", @@ -258,7 +217,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo fullTextResponse.Model = "ernie-bot" jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) @@ -266,23 +225,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo return nil, &fullTextResponse.Usage } -func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var baiduResponse BaiduEmbeddingResponse +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var baiduResponse EmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if baiduResponse.ErrorMsg != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: baiduResponse.ErrorMsg, Type: "baidu_error", Param: "", @@ -294,7 +253,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) @@ -302,10 +261,10 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit return nil, &fullTextResponse.Usage } -func getBaiduAccessToken(apiKey string) (string, error) { +func GetAccessToken(apiKey string) (string, error) { if val, ok := baiduTokenStore.Load(apiKey); ok { - var accessToken BaiduAccessToken - if accessToken, ok = val.(BaiduAccessToken); ok { + var accessToken AccessToken + if accessToken, ok = val.(AccessToken); ok { // soon this will expire if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { go func() { @@ -320,12 +279,12 @@ func getBaiduAccessToken(apiKey string) (string, error) { return "", err } if accessToken == nil { - return "", errors.New("getBaiduAccessToken return a nil token") + return "", errors.New("GetAccessToken return a nil token") } return (*accessToken).AccessToken, nil } -func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { +func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) { parts := strings.Split(apiKey, "|") if len(parts) != 2 { return nil, errors.New("invalid baidu apikey") @@ -337,13 +296,13 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { } req.Header.Add("Content-Type", "application/json") req.Header.Add("Accept", "application/json") - res, err := impatientHTTPClient.Do(req) + res, err := util.ImpatientHTTPClient.Do(req) if err != nil { return nil, err } defer res.Body.Close() - var accessToken BaiduAccessToken + var accessToken AccessToken err = json.NewDecoder(res.Body).Decode(&accessToken) if err != nil { return nil, err diff --git a/relay/channel/baidu/model.go b/relay/channel/baidu/model.go new file mode 100644 index 00000000..caaebafb --- /dev/null +++ b/relay/channel/baidu/model.go @@ -0,0 +1,50 @@ +package baidu + +import ( + "one-api/relay/channel/openai" + "time" +) + +type ChatResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Result string `json:"result"` + IsTruncated bool `json:"is_truncated"` + NeedClearHistory bool `json:"need_clear_history"` + Usage openai.Usage `json:"usage"` + BaiduError +} + +type ChatStreamResponse struct { + ChatResponse + SentenceId int `json:"sentence_id"` + IsEnd bool `json:"is_end"` +} + +type EmbeddingRequest struct { + Input []string `json:"input"` +} + +type EmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +type EmbeddingResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Data []EmbeddingData `json:"data"` + Usage openai.Usage `json:"usage"` + BaiduError +} + +type AccessToken struct { + AccessToken string `json:"access_token"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + ExpiresAt time.Time `json:"-"` +} diff --git a/controller/relay-gemini.go b/relay/channel/google/gemini.go similarity index 69% rename from controller/relay-gemini.go rename to relay/channel/google/gemini.go index d8ab58d6..f49caadf 100644 --- a/controller/relay-gemini.go +++ b/relay/channel/google/gemini.go @@ -1,4 +1,4 @@ -package controller +package google import ( "bufio" @@ -8,6 +8,8 @@ import ( "net/http" "one-api/common" "one-api/common/image" + "one-api/relay/channel/openai" + "one-api/relay/constant" "strings" "github.com/gin-gonic/gin" @@ -19,48 +21,8 @@ const ( GeminiVisionMaxImageNum = 16 ) -type GeminiChatRequest struct { - Contents []GeminiChatContent `json:"contents"` - SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` - GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` - Tools []GeminiChatTools `json:"tools,omitempty"` -} - -type GeminiInlineData struct { - MimeType string `json:"mimeType"` - Data string `json:"data"` -} - -type GeminiPart struct { - Text string `json:"text,omitempty"` - InlineData *GeminiInlineData `json:"inlineData,omitempty"` -} - -type GeminiChatContent struct { - Role string `json:"role,omitempty"` - Parts []GeminiPart `json:"parts"` -} - -type GeminiChatSafetySettings struct { - Category string `json:"category"` - Threshold string `json:"threshold"` -} - -type GeminiChatTools struct { - FunctionDeclarations any `json:"functionDeclarations,omitempty"` -} - -type GeminiChatGenerationConfig struct { - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` -} - // Setting safety to the lowest possible values since Gemini is already powerless enough -func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { +func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest { geminiRequest := GeminiChatRequest{ Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), SafetySettings: []GeminiChatSafetySettings{ @@ -108,11 +70,11 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { var parts []GeminiPart imageNum := 0 for _, part := range openaiContent { - if part.Type == ContentTypeText { + if part.Type == openai.ContentTypeText { parts = append(parts, GeminiPart{ Text: part.Text, }) - } else if part.Type == ContentTypeImageURL { + } else if part.Type == openai.ContentTypeImageURL { imageNum += 1 if imageNum > GeminiVisionMaxImageNum { continue @@ -187,21 +149,21 @@ type GeminiChatPromptFeedback struct { SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` } -func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ +func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse { + fullTextResponse := openai.TextResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Object: "chat.completion", Created: common.GetTimestamp(), - Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), + Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { - choice := OpenAITextResponseChoice{ + choice := openai.TextResponseChoice{ Index: i, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: "", }, - FinishReason: stopFinishReason, + FinishReason: constant.StopFinishReason, } if len(candidate.Content.Parts) > 0 { choice.Message.Content = candidate.Content.Parts[0].Text @@ -211,18 +173,18 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse return &fullTextResponse } -func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = geminiResponse.GetResponseText() - choice.FinishReason = &stopFinishReason - var response ChatCompletionsStreamResponse + choice.FinishReason = &constant.StopFinishReason + var response openai.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = "gemini" - response.Choices = []ChatCompletionsStreamResponseChoice{choice} + response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} return &response } -func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { responseText := "" dataChan := make(chan string) stopChan := make(chan bool) @@ -252,7 +214,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW } stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -264,14 +226,14 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW var dummy dummyStruct err := json.Unmarshal([]byte(data), &dummy) responseText += dummy.Content - var choice ChatCompletionsStreamResponseChoice + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = dummy.Content - response := ChatCompletionsStreamResponse{ + response := openai.ChatCompletionsStreamResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "gemini-pro", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } jsonResponse, err := json.Marshal(response) if err != nil { @@ -287,28 +249,28 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } -func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { +func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } var geminiResponse GeminiChatResponse err = json.Unmarshal(responseBody, &geminiResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if len(geminiResponse.Candidates) == 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: "No candidates returned", Type: "server_error", Param: "", @@ -319,8 +281,8 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) fullTextResponse.Model = model - completionTokens := countTokenText(geminiResponse.GetResponseText(), model) - usage := Usage{ + completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model) + usage := openai.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, @@ -328,7 +290,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo fullTextResponse.Usage = usage jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/google/model.go b/relay/channel/google/model.go new file mode 100644 index 00000000..694c2dd1 --- /dev/null +++ b/relay/channel/google/model.go @@ -0,0 +1,80 @@ +package google + +import ( + "one-api/relay/channel/openai" +) + +type GeminiChatRequest struct { + Contents []GeminiChatContent `json:"contents"` + SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` + Tools []GeminiChatTools `json:"tools,omitempty"` +} + +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type GeminiPart struct { + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` +} + +type GeminiChatContent struct { + Role string `json:"role,omitempty"` + Parts []GeminiPart `json:"parts"` +} + +type GeminiChatSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type GeminiChatTools struct { + FunctionDeclarations any `json:"functionDeclarations,omitempty"` +} + +type GeminiChatGenerationConfig struct { + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} + +type PaLMChatMessage struct { + Author string `json:"author"` + Content string `json:"content"` +} + +type PaLMFilter struct { + Reason string `json:"reason"` + Message string `json:"message"` +} + +type PaLMPrompt struct { + Messages []PaLMChatMessage `json:"messages"` +} + +type PaLMChatRequest struct { + Prompt PaLMPrompt `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` +} + +type PaLMError struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` +} + +type PaLMChatResponse struct { + Candidates []PaLMChatMessage `json:"candidates"` + Messages []openai.Message `json:"messages"` + Filters []PaLMFilter `json:"filters"` + Error PaLMError `json:"error"` +} diff --git a/controller/relay-palm.go b/relay/channel/google/palm.go similarity index 62% rename from controller/relay-palm.go rename to relay/channel/google/palm.go index 0c1c8af6..77d8cbd6 100644 --- a/controller/relay-palm.go +++ b/relay/channel/google/palm.go @@ -1,4 +1,4 @@ -package controller +package google import ( "encoding/json" @@ -7,47 +7,14 @@ import ( "io" "net/http" "one-api/common" + "one-api/relay/channel/openai" + "one-api/relay/constant" ) // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body -type PaLMChatMessage struct { - Author string `json:"author"` - Content string `json:"content"` -} - -type PaLMFilter struct { - Reason string `json:"reason"` - Message string `json:"message"` -} - -type PaLMPrompt struct { - Messages []PaLMChatMessage `json:"messages"` -} - -type PaLMChatRequest struct { - Prompt PaLMPrompt `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` -} - -type PaLMError struct { - Code int `json:"code"` - Message string `json:"message"` - Status string `json:"status"` -} - -type PaLMChatResponse struct { - Candidates []PaLMChatMessage `json:"candidates"` - Messages []Message `json:"messages"` - Filters []PaLMFilter `json:"filters"` - Error PaLMError `json:"error"` -} - -func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { +func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatRequest { palmRequest := PaLMChatRequest{ Prompt: PaLMPrompt{ Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), @@ -71,14 +38,14 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { return &palmRequest } -func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ - Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), +func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse { + fullTextResponse := openai.TextResponse{ + Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { - choice := OpenAITextResponseChoice{ + choice := openai.TextResponseChoice{ Index: i, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: candidate.Content, }, @@ -89,20 +56,20 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { return &fullTextResponse } -func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice if len(palmResponse.Candidates) > 0 { choice.Delta.Content = palmResponse.Candidates[0].Content } - choice.FinishReason = &stopFinishReason - var response ChatCompletionsStreamResponse + choice.FinishReason = &constant.StopFinishReason + var response openai.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = "palm2" - response.Choices = []ChatCompletionsStreamResponseChoice{choice} + response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} return &response } -func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { +func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { responseText := "" responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createdTime := common.GetTimestamp() @@ -143,7 +110,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta dataChan <- string(jsonResponse) stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -156,28 +123,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } -func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { +func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: palmResponse.Error.Message, Type: palmResponse.Error.Status, Param: "", @@ -188,8 +155,8 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st } fullTextResponse := responsePaLM2OpenAI(&palmResponse) fullTextResponse.Model = model - completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) - usage := Usage{ + completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, model) + usage := openai.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, @@ -197,7 +164,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st fullTextResponse.Usage = usage jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/openai/constant.go b/relay/channel/openai/constant.go new file mode 100644 index 00000000..000f72ee --- /dev/null +++ b/relay/channel/openai/constant.go @@ -0,0 +1,6 @@ +package openai + +const ( + ContentTypeText = "text" + ContentTypeImageURL = "image_url" +) diff --git a/controller/relay-openai.go b/relay/channel/openai/main.go similarity index 76% rename from controller/relay-openai.go rename to relay/channel/openai/main.go index 37867843..848a6fa4 100644 --- a/controller/relay-openai.go +++ b/relay/channel/openai/main.go @@ -1,4 +1,4 @@ -package controller +package openai import ( "bufio" @@ -8,10 +8,11 @@ import ( "io" "net/http" "one-api/common" + "one-api/relay/constant" "strings" ) -func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWithStatusCode, string) { responseText := "" scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -41,7 +42,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O data = data[6:] if !strings.HasPrefix(data, "[DONE]") { switch relayMode { - case RelayModeChatCompletions: + case constant.RelayModeChatCompletions: var streamResponse ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) if err != nil { @@ -51,7 +52,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O for _, choice := range streamResponse.Choices { responseText += choice.Delta.Content } - case RelayModeCompletions: + case constant.RelayModeCompletions: var streamResponse CompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) if err != nil { @@ -66,7 +67,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O } stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -83,29 +84,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } -func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { - var textResponse TextResponse +func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*ErrorWithStatusCode, *Usage) { + var textResponse SlimTextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &textResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if textResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: textResponse.Error, - StatusCode: resp.StatusCode, + return &ErrorWithStatusCode{ + Error: textResponse.Error, + StatusCode: resp.StatusCode, }, nil } // Reset response body @@ -113,7 +114,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model // We shouldn't set the header before we parse the response body, because the parse part may fail. // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. + // So the HTTPClient will be confused by the response. // For example, Postman will report error, and we cannot check the response at all. for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) @@ -121,17 +122,17 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model c.Writer.WriteHeader(resp.StatusCode) _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } if textResponse.Usage.TotalTokens == 0 { completionTokens := 0 for _, choice := range textResponse.Choices { - completionTokens += countTokenText(choice.Message.StringContent(), model) + completionTokens += CountTokenText(choice.Message.StringContent(), model) } textResponse.Usage = Usage{ PromptTokens: promptTokens, diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go new file mode 100644 index 00000000..c831ce19 --- /dev/null +++ b/relay/channel/openai/model.go @@ -0,0 +1,283 @@ +package openai + +type Message struct { + Role string `json:"role"` + Content any `json:"content"` + Name *string `json:"name,omitempty"` +} + +type ImageURL struct { + Url string `json:"url,omitempty"` + Detail string `json:"detail,omitempty"` +} + +type TextContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text,omitempty"` +} + +type ImageContent struct { + Type string `json:"type,omitempty"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} + +type OpenAIMessageContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} + +func (m Message) IsStringContent() bool { + _, ok := m.Content.(string) + return ok +} + +func (m Message) StringContent() string { + content, ok := m.Content.(string) + if ok { + return content + } + contentList, ok := m.Content.([]any) + if ok { + var contentStr string + for _, contentItem := range contentList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + if contentMap["type"] == ContentTypeText { + if subStr, ok := contentMap["text"].(string); ok { + contentStr += subStr + } + } + } + return contentStr + } + return "" +} + +func (m Message) ParseContent() []OpenAIMessageContent { + var contentList []OpenAIMessageContent + content, ok := m.Content.(string) + if ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeText, + Text: content, + }) + return contentList + } + anyList, ok := m.Content.([]any) + if ok { + for _, contentItem := range anyList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + switch contentMap["type"] { + case ContentTypeText: + if subStr, ok := contentMap["text"].(string); ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeText, + Text: subStr, + }) + } + case ContentTypeImageURL: + if subObj, ok := contentMap["image_url"].(map[string]any); ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeImageURL, + ImageURL: &ImageURL{ + Url: subObj["url"].(string), + }, + }) + } + } + } + return contentList + } + return nil +} + +type ResponseFormat struct { + Type string `json:"type,omitempty"` +} + +type GeneralOpenAIRequest struct { + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Stream bool `json:"stream,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` + Functions any `json:"functions,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` +} + +func (r GeneralOpenAIRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input +} + +type ChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens"` +} + +type TextRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Prompt string `json:"prompt"` + MaxTokens int `json:"max_tokens"` + //Stream bool `json:"stream"` +} + +// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + User string `json:"user,omitempty"` +} + +type WhisperJSONResponse struct { + Text string `json:"text,omitempty"` +} + +type WhisperVerboseJSONResponse struct { + Task string `json:"task,omitempty"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` + Text string `json:"text,omitempty"` + Segments []Segment `json:"segments,omitempty"` +} + +type Segment struct { + Id int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} + +type TextToSpeechRequest struct { + Model string `json:"model" binding:"required"` + Input string `json:"input" binding:"required"` + Voice string `json:"voice" binding:"required"` + Speed float64 `json:"speed"` + ResponseFormat string `json:"response_format"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Error struct { + Message string `json:"message"` + Type string `json:"type"` + Param string `json:"param"` + Code any `json:"code"` +} + +type ErrorWithStatusCode struct { + Error + StatusCode int `json:"status_code"` +} + +type SlimTextResponse struct { + Choices []TextResponseChoice `json:"choices"` + Usage `json:"usage"` + Error Error `json:"error"` +} + +type TextResponseChoice struct { + Index int `json:"index"` + Message `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type TextResponse struct { + Id string `json:"id"` + Model string `json:"model,omitempty"` + Object string `json:"object"` + Created int64 `json:"created"` + Choices []TextResponseChoice `json:"choices"` + Usage `json:"usage"` +} + +type EmbeddingResponseItem struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + +type EmbeddingResponse struct { + Object string `json:"object"` + Data []EmbeddingResponseItem `json:"data"` + Model string `json:"model"` + Usage `json:"usage"` +} + +type ImageResponse struct { + Created int `json:"created"` + Data []struct { + Url string `json:"url"` + } +} + +type ChatCompletionsStreamResponseChoice struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` +} + +type ChatCompletionsStreamResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionsStreamResponseChoice `json:"choices"` +} + +type CompletionsStreamResponse struct { + Choices []struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} diff --git a/controller/relay-utils.go b/relay/channel/openai/token.go similarity index 54% rename from controller/relay-utils.go rename to relay/channel/openai/token.go index a6a1f0f6..4b40b228 100644 --- a/controller/relay-utils.go +++ b/relay/channel/openai/token.go @@ -1,25 +1,15 @@ -package controller +package openai import ( - "context" - "encoding/json" "errors" "fmt" - "io" + "github.com/pkoukk/tiktoken-go" "math" - "net/http" "one-api/common" "one-api/common/image" - "one-api/model" - "strconv" "strings" - - "github.com/gin-gonic/gin" - "github.com/pkoukk/tiktoken-go" ) -var stopFinishReason = "stop" - // tokenEncoderMap won't grow after initialization var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} var defaultTokenEncoder *tiktoken.Tiktoken @@ -71,7 +61,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } -func countTokenMessages(messages []Message, model string) int { +func CountTokenMessages(messages []Message, model string) int { tokenEncoder := getTokenEncoder(model) // Reference: // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb @@ -195,191 +185,21 @@ func countImageTokens(url string, detail string) (_ int, err error) { } } -func countTokenInput(input any, model string) int { +func CountTokenInput(input any, model string) int { switch v := input.(type) { case string: - return countTokenText(v, model) + return CountTokenText(v, model) case []string: text := "" for _, s := range v { text += s } - return countTokenText(text, model) + return CountTokenText(text, model) } return 0 } -func countTokenText(text string, model string) int { +func CountTokenText(text string, model string) int { tokenEncoder := getTokenEncoder(model) return getTokenNum(tokenEncoder, text) } - -func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { - openAIError := OpenAIError{ - Message: err.Error(), - Type: "one_api_error", - Code: code, - } - return &OpenAIErrorWithStatusCode{ - OpenAIError: openAIError, - StatusCode: statusCode, - } -} - -func shouldDisableChannel(err *OpenAIError, statusCode int) bool { - if !common.AutomaticDisableChannelEnabled { - return false - } - if err == nil { - return false - } - if statusCode == http.StatusUnauthorized { - return true - } - if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { - return true - } - return false -} - -func shouldEnableChannel(err error, openAIErr *OpenAIError) bool { - if !common.AutomaticEnableChannelEnabled { - return false - } - if err != nil { - return false - } - if openAIErr != nil { - return false - } - return true -} - -func setEventStreamHeaders(c *gin.Context) { - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") -} - -type GeneralErrorResponse struct { - Error OpenAIError `json:"error"` - Message string `json:"message"` - Msg string `json:"msg"` - Err string `json:"err"` - ErrorMsg string `json:"error_msg"` - Header struct { - Message string `json:"message"` - } `json:"header"` - Response struct { - Error struct { - Message string `json:"message"` - } `json:"error"` - } `json:"response"` -} - -func (e GeneralErrorResponse) ToMessage() string { - if e.Error.Message != "" { - return e.Error.Message - } - if e.Message != "" { - return e.Message - } - if e.Msg != "" { - return e.Msg - } - if e.Err != "" { - return e.Err - } - if e.ErrorMsg != "" { - return e.ErrorMsg - } - if e.Header.Message != "" { - return e.Header.Message - } - if e.Response.Error.Message != "" { - return e.Response.Error.Message - } - return "" -} - -func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { - openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ - StatusCode: resp.StatusCode, - OpenAIError: OpenAIError{ - Message: "", - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, - } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return - } - err = resp.Body.Close() - if err != nil { - return - } - var errResponse GeneralErrorResponse - err = json.Unmarshal(responseBody, &errResponse) - if err != nil { - return - } - if errResponse.Error.Message != "" { - // OpenAI format error, so we override the default one - openAIErrorWithStatusCode.OpenAIError = errResponse.Error - } else { - openAIErrorWithStatusCode.OpenAIError.Message = errResponse.ToMessage() - } - if openAIErrorWithStatusCode.OpenAIError.Message == "" { - openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) - } - return -} - -func getFullRequestURL(baseURL string, requestURL string, channelType int) string { - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { - switch channelType { - case common.ChannelTypeOpenAI: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - case common.ChannelTypeAzure: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) - } - } - return fullRequestURL -} - -func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { - // quotaDelta is remaining quota to be consumed - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - // totalQuota is total quota consumed - if totalQuota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) - model.UpdateChannelUsedQuota(channelId, totalQuota) - } - if totalQuota <= 0 { - common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) - } -} - -func GetAPIVersion(c *gin.Context) string { - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") - } - return apiVersion -} diff --git a/relay/channel/openai/util.go b/relay/channel/openai/util.go new file mode 100644 index 00000000..69ece6b3 --- /dev/null +++ b/relay/channel/openai/util.go @@ -0,0 +1,13 @@ +package openai + +func ErrorWrapper(err error, code string, statusCode int) *ErrorWithStatusCode { + Error := Error{ + Message: err.Error(), + Type: "one_api_error", + Code: code, + } + return &ErrorWithStatusCode{ + Error: Error, + StatusCode: statusCode, + } +} diff --git a/controller/relay-tencent.go b/relay/channel/tencent/main.go similarity index 51% rename from controller/relay-tencent.go rename to relay/channel/tencent/main.go index 5930ae89..60e275a9 100644 --- a/controller/relay-tencent.go +++ b/relay/channel/tencent/main.go @@ -1,4 +1,4 @@ -package controller +package tencent import ( "bufio" @@ -12,6 +12,8 @@ import ( "io" "net/http" "one-api/common" + "one-api/relay/channel/openai" + "one-api/relay/constant" "sort" "strconv" "strings" @@ -19,80 +21,22 @@ import ( // https://cloud.tencent.com/document/product/1729/97732 -type TencentMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type TencentChatRequest struct { - AppId int64 `json:"app_id"` // 腾讯云账号的 APPID - SecretId string `json:"secret_id"` // 官网 SecretId - // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 - // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 - Timestamp int64 `json:"timestamp"` - // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, - // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 - Expired int64 `json:"expired"` - QueryID string `json:"query_id"` //请求 Id,用于问题排查 - // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 - // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 - // 建议该参数和 top_p 只设置1个,不要同时更改 top_p - Temperature float64 `json:"temperature"` - // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 - // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 - // 建议该参数和 temperature 只设置1个,不要同时更改 - TopP float64 `json:"top_p"` - // Stream 0:同步,1:流式 (默认,协议:SSE) - // 同步请求超时:60s,如果内容较长建议使用流式 - Stream int `json:"stream"` - // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 - // 输入 content 总数最大支持 3000 token。 - Messages []TencentMessage `json:"messages"` -} - -type TencentError struct { - Code int `json:"code"` - Message string `json:"message"` -} - -type TencentUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type TencentResponseChoices struct { - FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 - Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 - Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 -} - -type TencentChatResponse struct { - Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 - Created string `json:"created,omitempty"` // unix 时间戳的字符串 - Id string `json:"id,omitempty"` // 会话 id - Usage Usage `json:"usage,omitempty"` // token 数量 - Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 - Note string `json:"note,omitempty"` // 注释 - ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 -} - -func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { - messages := make([]TencentMessage, 0, len(request.Messages)) +func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { + messages := make([]Message, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] if message.Role == "system" { - messages = append(messages, TencentMessage{ + messages = append(messages, Message{ Role: "user", Content: message.StringContent(), }) - messages = append(messages, TencentMessage{ + messages = append(messages, Message{ Role: "assistant", Content: "Okay", }) continue } - messages = append(messages, TencentMessage{ + messages = append(messages, Message{ Content: message.StringContent(), Role: message.Role, }) @@ -101,7 +45,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { if request.Stream { stream = 1 } - return &TencentChatRequest{ + return &ChatRequest{ Timestamp: common.GetTimestamp(), Expired: common.GetTimestamp() + 24*60*60, QueryID: common.GetUUID(), @@ -112,16 +56,16 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { } } -func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ +func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { + fullTextResponse := openai.TextResponse{ Object: "chat.completion", Created: common.GetTimestamp(), Usage: response.Usage, } if len(response.Choices) > 0 { - choice := OpenAITextResponseChoice{ + choice := openai.TextResponseChoice{ Index: 0, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: response.Choices[0].Messages.Content, }, @@ -132,24 +76,24 @@ func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { return &fullTextResponse } -func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { - response := ChatCompletionsStreamResponse{ +func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + response := openai.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "tencent-hunyuan", } if len(TencentResponse.Choices) > 0 { - var choice ChatCompletionsStreamResponseChoice + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = TencentResponse.Choices[0].Delta.Content if TencentResponse.Choices[0].FinishReason == "stop" { - choice.FinishReason = &stopFinishReason + choice.FinishReason = &constant.StopFinishReason } response.Choices = append(response.Choices, choice) } return &response } -func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { var responseText string scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -180,11 +124,11 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith } stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - var TencentResponse TencentChatResponse + var TencentResponse ChatResponse err := json.Unmarshal([]byte(data), &TencentResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -208,28 +152,28 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } -func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var TencentResponse TencentChatResponse +func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var TencentResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &TencentResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if TencentResponse.Error.Code != 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: TencentResponse.Error.Message, Code: TencentResponse.Error.Code, }, @@ -240,7 +184,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatus fullTextResponse.Model = "hunyuan" jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) @@ -248,7 +192,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatus return nil, &fullTextResponse.Usage } -func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { +func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) { parts := strings.Split(config, "|") if len(parts) != 3 { err = errors.New("invalid tencent config") @@ -260,7 +204,7 @@ func parseTencentConfig(config string) (appId int64, secretId string, secretKey return } -func getTencentSign(req TencentChatRequest, secretKey string) string { +func GetSign(req ChatRequest, secretKey string) string { params := make([]string, 0) params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) params = append(params, "secret_id="+req.SecretId) diff --git a/relay/channel/tencent/model.go b/relay/channel/tencent/model.go new file mode 100644 index 00000000..511f3d97 --- /dev/null +++ b/relay/channel/tencent/model.go @@ -0,0 +1,63 @@ +package tencent + +import ( + "one-api/relay/channel/openai" +) + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ChatRequest struct { + AppId int64 `json:"app_id"` // 腾讯云账号的 APPID + SecretId string `json:"secret_id"` // 官网 SecretId + // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 + // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 + Timestamp int64 `json:"timestamp"` + // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, + // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 + Expired int64 `json:"expired"` + QueryID string `json:"query_id"` //请求 Id,用于问题排查 + // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 + // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 + // 建议该参数和 top_p 只设置1个,不要同时更改 top_p + Temperature float64 `json:"temperature"` + // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 + // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 + // 建议该参数和 temperature 只设置1个,不要同时更改 + TopP float64 `json:"top_p"` + // Stream 0:同步,1:流式 (默认,协议:SSE) + // 同步请求超时:60s,如果内容较长建议使用流式 + Stream int `json:"stream"` + // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 + // 输入 content 总数最大支持 3000 token。 + Messages []Message `json:"messages"` +} + +type Error struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type ResponseChoices struct { + FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 + Messages Message `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 + Delta Message `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 +} + +type ChatResponse struct { + Choices []ResponseChoices `json:"choices,omitempty"` // 结果 + Created string `json:"created,omitempty"` // unix 时间戳的字符串 + Id string `json:"id,omitempty"` // 会话 id + Usage openai.Usage `json:"usage,omitempty"` // token 数量 + Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 + Note string `json:"note,omitempty"` // 注释 + ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 +} diff --git a/controller/relay-xunfei.go b/relay/channel/xunfei/main.go similarity index 62% rename from controller/relay-xunfei.go rename to relay/channel/xunfei/main.go index 904e6d14..1cc0b664 100644 --- a/controller/relay-xunfei.go +++ b/relay/channel/xunfei/main.go @@ -1,4 +1,4 @@ -package controller +package xunfei import ( "crypto/hmac" @@ -12,6 +12,8 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/relay/channel/openai" + "one-api/relay/constant" "strings" "time" ) @@ -19,82 +21,26 @@ import ( // https://console.xfyun.cn/services/cbm // https://www.xfyun.cn/doc/spark/Web.html -type XunfeiMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type XunfeiChatRequest struct { - Header struct { - AppId string `json:"app_id"` - } `json:"header"` - Parameter struct { - Chat struct { - Domain string `json:"domain,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Auditing bool `json:"auditing,omitempty"` - } `json:"chat"` - } `json:"parameter"` - Payload struct { - Message struct { - Text []XunfeiMessage `json:"text"` - } `json:"message"` - } `json:"payload"` -} - -type XunfeiChatResponseTextItem struct { - Content string `json:"content"` - Role string `json:"role"` - Index int `json:"index"` -} - -type XunfeiChatResponse struct { - Header struct { - Code int `json:"code"` - Message string `json:"message"` - Sid string `json:"sid"` - Status int `json:"status"` - } `json:"header"` - Payload struct { - Choices struct { - Status int `json:"status"` - Seq int `json:"seq"` - Text []XunfeiChatResponseTextItem `json:"text"` - } `json:"choices"` - Usage struct { - //Text struct { - // QuestionTokens string `json:"question_tokens"` - // PromptTokens string `json:"prompt_tokens"` - // CompletionTokens string `json:"completion_tokens"` - // TotalTokens string `json:"total_tokens"` - //} `json:"text"` - Text Usage `json:"text"` - } `json:"usage"` - } `json:"payload"` -} - -func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { - messages := make([]XunfeiMessage, 0, len(request.Messages)) +func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { + messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { - messages = append(messages, XunfeiMessage{ + messages = append(messages, Message{ Role: "user", Content: message.StringContent(), }) - messages = append(messages, XunfeiMessage{ + messages = append(messages, Message{ Role: "assistant", Content: "Okay", }) } else { - messages = append(messages, XunfeiMessage{ + messages = append(messages, Message{ Role: message.Role, Content: message.StringContent(), }) } } - xunfeiRequest := XunfeiChatRequest{} + xunfeiRequest := ChatRequest{} xunfeiRequest.Header.AppId = xunfeiAppId xunfeiRequest.Parameter.Chat.Domain = domain xunfeiRequest.Parameter.Chat.Temperature = request.Temperature @@ -104,49 +50,49 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma return &xunfeiRequest } -func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { +func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { if len(response.Payload.Choices.Text) == 0 { - response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + response.Payload.Choices.Text = []ChatResponseTextItem{ { Content: "", }, } } - choice := OpenAITextResponseChoice{ + choice := openai.TextResponseChoice{ Index: 0, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: response.Payload.Choices.Text[0].Content, }, - FinishReason: stopFinishReason, + FinishReason: constant.StopFinishReason, } - fullTextResponse := OpenAITextResponse{ + fullTextResponse := openai.TextResponse{ Object: "chat.completion", Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, + Choices: []openai.TextResponseChoice{choice}, Usage: response.Payload.Usage.Text, } return &fullTextResponse } -func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { +func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { if len(xunfeiResponse.Payload.Choices.Text) == 0 { - xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ { Content: "", }, } } - var choice ChatCompletionsStreamResponseChoice + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content if xunfeiResponse.Payload.Choices.Status == 2 { - choice.FinishReason = &stopFinishReason + choice.FinishReason = &constant.StopFinishReason } - response := ChatCompletionsStreamResponse{ + response := openai.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "SparkDesk", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } return &response } @@ -177,14 +123,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { return callUrl } -func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { +func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { - return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } - setEventStreamHeaders(c) - var usage Usage + common.SetEventStreamHeaders(c) + var usage openai.Usage c.Stream(func(w io.Writer) bool { select { case xunfeiResponse := <-dataChan: @@ -207,15 +153,15 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId return nil, &usage } -func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { +func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { - return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } - var usage Usage + var usage openai.Usage var content string - var xunfeiResponse XunfeiChatResponse + var xunfeiResponse ChatResponse stop := false for !stop { select { @@ -231,7 +177,7 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin } } if len(xunfeiResponse.Payload.Choices.Text) == 0 { - xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ { Content: "", }, @@ -242,14 +188,14 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin response := responseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") _, _ = c.Writer.Write(jsonResponse) return nil, &usage } -func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { +func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) { d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } @@ -263,7 +209,7 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId return nil, nil, err } - dataChan := make(chan XunfeiChatResponse) + dataChan := make(chan ChatResponse) stopChan := make(chan bool) go func() { for { @@ -272,7 +218,7 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId common.SysError("error reading stream response: " + err.Error()) break } - var response XunfeiChatResponse + var response ChatResponse err = json.Unmarshal(msg, &response) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) diff --git a/relay/channel/xunfei/model.go b/relay/channel/xunfei/model.go new file mode 100644 index 00000000..0ca42818 --- /dev/null +++ b/relay/channel/xunfei/model.go @@ -0,0 +1,61 @@ +package xunfei + +import ( + "one-api/relay/channel/openai" +) + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ChatRequest struct { + Header struct { + AppId string `json:"app_id"` + } `json:"header"` + Parameter struct { + Chat struct { + Domain string `json:"domain,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Auditing bool `json:"auditing,omitempty"` + } `json:"chat"` + } `json:"parameter"` + Payload struct { + Message struct { + Text []Message `json:"text"` + } `json:"message"` + } `json:"payload"` +} + +type ChatResponseTextItem struct { + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` +} + +type ChatResponse struct { + Header struct { + Code int `json:"code"` + Message string `json:"message"` + Sid string `json:"sid"` + Status int `json:"status"` + } `json:"header"` + Payload struct { + Choices struct { + Status int `json:"status"` + Seq int `json:"seq"` + Text []ChatResponseTextItem `json:"text"` + } `json:"choices"` + Usage struct { + //Text struct { + // QuestionTokens string `json:"question_tokens"` + // PromptTokens string `json:"prompt_tokens"` + // CompletionTokens string `json:"completion_tokens"` + // TotalTokens string `json:"total_tokens"` + //} `json:"text"` + Text openai.Usage `json:"text"` + } `json:"usage"` + } `json:"payload"` +} diff --git a/controller/relay-zhipu.go b/relay/channel/zhipu/main.go similarity index 62% rename from controller/relay-zhipu.go rename to relay/channel/zhipu/main.go index cb5a78cf..3dc613a4 100644 --- a/controller/relay-zhipu.go +++ b/relay/channel/zhipu/main.go @@ -1,4 +1,4 @@ -package controller +package zhipu import ( "bufio" @@ -8,6 +8,8 @@ import ( "io" "net/http" "one-api/common" + "one-api/relay/channel/openai" + "one-api/relay/constant" "strings" "sync" "time" @@ -18,53 +20,13 @@ import ( // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke -type ZhipuMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type ZhipuRequest struct { - Prompt []ZhipuMessage `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - RequestId string `json:"request_id,omitempty"` - Incremental bool `json:"incremental,omitempty"` -} - -type ZhipuResponseData struct { - TaskId string `json:"task_id"` - RequestId string `json:"request_id"` - TaskStatus string `json:"task_status"` - Choices []ZhipuMessage `json:"choices"` - Usage `json:"usage"` -} - -type ZhipuResponse struct { - Code int `json:"code"` - Msg string `json:"msg"` - Success bool `json:"success"` - Data ZhipuResponseData `json:"data"` -} - -type ZhipuStreamMetaResponse struct { - RequestId string `json:"request_id"` - TaskId string `json:"task_id"` - TaskStatus string `json:"task_status"` - Usage `json:"usage"` -} - -type zhipuTokenData struct { - Token string - ExpiryTime time.Time -} - var zhipuTokens sync.Map var expSeconds int64 = 24 * 3600 -func getZhipuToken(apikey string) string { +func GetToken(apikey string) string { data, ok := zhipuTokens.Load(apikey) if ok { - tokenData := data.(zhipuTokenData) + tokenData := data.(tokenData) if time.Now().Before(tokenData.ExpiryTime) { return tokenData.Token } @@ -100,7 +62,7 @@ func getZhipuToken(apikey string) string { return "" } - zhipuTokens.Store(apikey, zhipuTokenData{ + zhipuTokens.Store(apikey, tokenData{ Token: tokenString, ExpiryTime: expiryTime, }) @@ -108,26 +70,26 @@ func getZhipuToken(apikey string) string { return tokenString } -func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { - messages := make([]ZhipuMessage, 0, len(request.Messages)) +func ConvertRequest(request openai.GeneralOpenAIRequest) *Request { + messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { - messages = append(messages, ZhipuMessage{ + messages = append(messages, Message{ Role: "system", Content: message.StringContent(), }) - messages = append(messages, ZhipuMessage{ + messages = append(messages, Message{ Role: "user", Content: "Okay", }) } else { - messages = append(messages, ZhipuMessage{ + messages = append(messages, Message{ Role: message.Role, Content: message.StringContent(), }) } } - return &ZhipuRequest{ + return &Request{ Prompt: messages, Temperature: request.Temperature, TopP: request.TopP, @@ -135,18 +97,18 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { } } -func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ +func responseZhipu2OpenAI(response *Response) *openai.TextResponse { + fullTextResponse := openai.TextResponse{ Id: response.Data.TaskId, Object: "chat.completion", Created: common.GetTimestamp(), - Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)), + Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)), Usage: response.Data.Usage, } for i, choice := range response.Data.Choices { - openaiChoice := OpenAITextResponseChoice{ + openaiChoice := openai.TextResponseChoice{ Index: i, - Message: Message{ + Message: openai.Message{ Role: choice.Role, Content: strings.Trim(choice.Content, "\""), }, @@ -160,34 +122,34 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { return &fullTextResponse } -func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = zhipuResponse - response := ChatCompletionsStreamResponse{ + response := openai.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "chatglm", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } return &response } -func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) { - var choice ChatCompletionsStreamResponseChoice +func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = "" - choice.FinishReason = &stopFinishReason - response := ChatCompletionsStreamResponse{ + choice.FinishReason = &constant.StopFinishReason + response := openai.ChatCompletionsStreamResponse{ Id: zhipuResponse.RequestId, Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "chatglm", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } return &response, &zhipuResponse.Usage } -func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage *Usage +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var usage *openai.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -224,7 +186,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt } stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -237,7 +199,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) return true case data := <-metaChan: - var zhipuResponse ZhipuStreamMetaResponse + var zhipuResponse StreamMetaResponse err := json.Unmarshal([]byte(data), &zhipuResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -259,28 +221,28 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, usage } -func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var zhipuResponse ZhipuResponse +func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var zhipuResponse Response responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &zhipuResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if !zhipuResponse.Success { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: zhipuResponse.Msg, Type: "zhipu_error", Param: "", @@ -293,7 +255,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo fullTextResponse.Model = "chatglm" jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/zhipu/model.go b/relay/channel/zhipu/model.go new file mode 100644 index 00000000..08a5ec5f --- /dev/null +++ b/relay/channel/zhipu/model.go @@ -0,0 +1,46 @@ +package zhipu + +import ( + "one-api/relay/channel/openai" + "time" +) + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type Request struct { + Prompt []Message `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + RequestId string `json:"request_id,omitempty"` + Incremental bool `json:"incremental,omitempty"` +} + +type ResponseData struct { + TaskId string `json:"task_id"` + RequestId string `json:"request_id"` + TaskStatus string `json:"task_status"` + Choices []Message `json:"choices"` + openai.Usage `json:"usage"` +} + +type Response struct { + Code int `json:"code"` + Msg string `json:"msg"` + Success bool `json:"success"` + Data ResponseData `json:"data"` +} + +type StreamMetaResponse struct { + RequestId string `json:"request_id"` + TaskId string `json:"task_id"` + TaskStatus string `json:"task_status"` + openai.Usage `json:"usage"` +} + +type tokenData struct { + Token string + ExpiryTime time.Time +} diff --git a/relay/constant/main.go b/relay/constant/main.go new file mode 100644 index 00000000..b3aeaaff --- /dev/null +++ b/relay/constant/main.go @@ -0,0 +1,16 @@ +package constant + +const ( + RelayModeUnknown = iota + RelayModeChatCompletions + RelayModeCompletions + RelayModeEmbeddings + RelayModeModerations + RelayModeImagesGenerations + RelayModeEdits + RelayModeAudioSpeech + RelayModeAudioTranscription + RelayModeAudioTranslation +) + +var StopFinishReason = "stop" diff --git a/controller/relay-audio.go b/relay/controller/audio.go similarity index 65% rename from controller/relay-audio.go rename to relay/controller/audio.go index 2247f4c7..08d9af2a 100644 --- a/controller/relay-audio.go +++ b/relay/controller/audio.go @@ -12,10 +12,13 @@ import ( "net/http" "one-api/common" "one-api/model" + "one-api/relay/channel/openai" + "one-api/relay/constant" + "one-api/relay/util" "strings" ) -func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { +func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { audioModel := "whisper-1" tokenId := c.GetInt("token_id") @@ -25,18 +28,18 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode group := c.GetString("group") tokenName := c.GetString("token_name") - var ttsRequest TextToSpeechRequest - if relayMode == RelayModeAudioSpeech { + var ttsRequest openai.TextToSpeechRequest + if relayMode == constant.RelayModeAudioSpeech { // Read JSON err := common.UnmarshalBodyReusable(c, &ttsRequest) // Check if JSON is valid if err != nil { - return errorWrapper(err, "invalid_json", http.StatusBadRequest) + return openai.ErrorWrapper(err, "invalid_json", http.StatusBadRequest) } audioModel = ttsRequest.Model // Check if text is too long 4096 if len(ttsRequest.Input) > 4096 { - return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) } } @@ -46,7 +49,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode var quota int var preConsumedQuota int switch relayMode { - case RelayModeAudioSpeech: + case constant.RelayModeAudioSpeech: preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota default: @@ -54,16 +57,16 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } userQuota, err := model.CacheGetUserQuota(userId) if err != nil { - return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } // Check if user quota is enough if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) } if userQuota > 100*preConsumedQuota { // in this case, we do not pre-consume quota @@ -73,7 +76,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } @@ -83,7 +86,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } if modelMap[audioModel] != "" { audioModel = modelMap[audioModel] @@ -96,27 +99,27 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL = c.GetString("base_url") } - fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) - if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) + if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiVersion := GetAPIVersion(c) + apiVersion := util.GetAPIVersion(c) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) } requestBody := &bytes.Buffer{} _, err = io.Copy(requestBody, c.Request.Body) if err != nil { - return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) } c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) responseFormat := c.DefaultPostForm("response_format", "json") req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") @@ -128,34 +131,34 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - resp, err := httpClient.Do(req) + resp, err := util.HTTPClient.Do(req) if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } err = req.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } err = c.Request.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode != RelayModeAudioSpeech { + if relayMode != constant.RelayModeAudioSpeech { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - var openAIErr TextResponse + var openAIErr openai.SlimTextResponse if err = json.Unmarshal(responseBody, &openAIErr); err == nil { if openAIErr.Error.Message != "" { - return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) + return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) } } @@ -172,12 +175,12 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode case "vtt": text, err = getTextFromVTT(responseBody) default: - return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) + return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) } if err != nil { - return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) } - quota = countTokenText(text, audioModel) + quota = openai.CountTokenText(text, audioModel) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { @@ -193,11 +196,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode }() }(c.Request.Context()) } - return relayErrorHandler(resp) + return util.RelayErrorHandler(resp) } quotaDelta := quota - preConsumedQuota defer func(ctx context.Context) { - go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) }(c.Request.Context()) for k, v := range resp.Header { @@ -207,11 +210,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } return nil } @@ -221,7 +224,7 @@ func getTextFromVTT(body []byte) (string, error) { } func getTextFromVerboseJSON(body []byte) (string, error) { - var whisperResponse WhisperVerboseJSONResponse + var whisperResponse openai.WhisperVerboseJSONResponse if err := json.Unmarshal(body, &whisperResponse); err != nil { return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) } @@ -254,7 +257,7 @@ func getTextFromText(body []byte) (string, error) { } func getTextFromJSON(body []byte) (string, error) { - var whisperResponse WhisperJSONResponse + var whisperResponse openai.WhisperJSONResponse if err := json.Unmarshal(body, &whisperResponse); err != nil { return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) } diff --git a/controller/relay-image.go b/relay/controller/image.go similarity index 71% rename from controller/relay-image.go rename to relay/controller/image.go index 14a2983b..be5fc3dd 100644 --- a/controller/relay-image.go +++ b/relay/controller/image.go @@ -10,6 +10,8 @@ import ( "net/http" "one-api/common" "one-api/model" + "one-api/relay/channel/openai" + "one-api/relay/util" "strings" "github.com/gin-gonic/gin" @@ -25,7 +27,7 @@ func isWithinRange(element string, value int) bool { return value >= min && value <= max } -func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { +func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { imageModel := "dall-e-2" imageSize := "1024x1024" @@ -35,10 +37,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode userId := c.GetInt("id") group := c.GetString("group") - var imageRequest ImageRequest + var imageRequest openai.ImageRequest err := common.UnmarshalBodyReusable(c, &imageRequest) if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } if imageRequest.N == 0 { @@ -67,24 +69,24 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } } } else { - return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) } // Prompt validation if imageRequest.Prompt == "" { - return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) } // Check prompt length if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { - return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) } // Number of generated images validation if isWithinRange(imageModel, imageRequest.N) == false { // channel not azure if channelType != common.ChannelTypeAzure { - return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) } } @@ -95,7 +97,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } if modelMap[imageModel] != "" { imageModel = modelMap[imageModel] @@ -107,10 +109,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) + fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) if channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api - apiVersion := GetAPIVersion(c) + apiVersion := util.GetAPIVersion(c) // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) } @@ -119,7 +121,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) } else { @@ -134,12 +136,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode quota := int(ratio*imageCostRatio*1000) * imageRequest.N if userQuota-quota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } token := c.Request.Header.Get("Authorization") if channelType == common.ChannelTypeAzure { // Azure authentication @@ -152,20 +154,20 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - resp, err := httpClient.Do(req) + resp, err := util.HTTPClient.Do(req) if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } err = req.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } err = c.Request.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - var textResponse ImageResponse + var textResponse openai.ImageResponse defer func(ctx context.Context) { if resp.StatusCode != http.StatusOK { @@ -192,15 +194,15 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } err = json.Unmarshal(responseBody, &textResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) @@ -212,11 +214,11 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } return nil } diff --git a/controller/relay-text.go b/relay/controller/text.go similarity index 68% rename from controller/relay-text.go rename to relay/controller/text.go index 64338545..b17ff950 100644 --- a/controller/relay-text.go +++ b/relay/controller/text.go @@ -6,15 +6,24 @@ import ( "encoding/json" "errors" "fmt" + "github.com/gin-gonic/gin" "io" "math" "net/http" "one-api/common" "one-api/model" + "one-api/relay/channel/aiproxy" + "one-api/relay/channel/ali" + "one-api/relay/channel/anthropic" + "one-api/relay/channel/baidu" + "one-api/relay/channel/google" + "one-api/relay/channel/openai" + "one-api/relay/channel/tencent" + "one-api/relay/channel/xunfei" + "one-api/relay/channel/zhipu" + "one-api/relay/constant" + "one-api/relay/util" "strings" - "time" - - "github.com/gin-gonic/gin" ) const ( @@ -30,64 +39,47 @@ const ( APITypeGemini ) -var httpClient *http.Client -var impatientHTTPClient *http.Client - -func init() { - if common.RelayTimeout == 0 { - httpClient = &http.Client{} - } else { - httpClient = &http.Client{ - Timeout: time.Duration(common.RelayTimeout) * time.Second, - } - } - - impatientHTTPClient = &http.Client{ - Timeout: 5 * time.Second, - } -} - -func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { +func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") tokenId := c.GetInt("token_id") userId := c.GetInt("id") group := c.GetString("group") - var textRequest GeneralOpenAIRequest + var textRequest openai.GeneralOpenAIRequest err := common.UnmarshalBodyReusable(c, &textRequest) if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { - return errorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest) } - if relayMode == RelayModeModerations && textRequest.Model == "" { + if relayMode == constant.RelayModeModerations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } - if relayMode == RelayModeEmbeddings && textRequest.Model == "" { + if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" { textRequest.Model = c.Param("model") } // request validation if textRequest.Model == "" { - return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) } switch relayMode { - case RelayModeCompletions: + case constant.RelayModeCompletions: if textRequest.Prompt == "" { - return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) } - case RelayModeChatCompletions: + case constant.RelayModeChatCompletions: if textRequest.Messages == nil || len(textRequest.Messages) == 0 { - return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) } - case RelayModeEmbeddings: - case RelayModeModerations: + case constant.RelayModeEmbeddings: + case constant.RelayModeModerations: if textRequest.Input == "" { - return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) } - case RelayModeEdits: + case constant.RelayModeEdits: if textRequest.Instruction == "" { - return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) } } // map model name @@ -97,7 +89,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } if modelMap[textRequest.Model] != "" { textRequest.Model = modelMap[textRequest.Model] @@ -130,12 +122,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) + fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) switch apiType { case APITypeOpenAI: if channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api - apiVersion := GetAPIVersion(c) + apiVersion := util.GetAPIVersion(c) requestURL := strings.Split(requestURL, "?")[0] requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) baseURL = c.GetString("base_url") @@ -148,7 +140,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { model_ = strings.TrimSuffix(model_, "-0613") requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) - fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) + fullRequestURL = util.GetFullRequestURL(baseURL, requestURL, channelType) } case APITypeClaude: fullRequestURL = "https://api.anthropic.com/v1/complete" @@ -171,8 +163,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") var err error - if apiKey, err = getBaiduAccessToken(apiKey); err != nil { - return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) + if apiKey, err = baidu.GetAccessToken(apiKey); err != nil { + return openai.ErrorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) } fullRequestURL += "?access_token=" + apiKey case APITypePaLM: @@ -202,7 +194,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) case APITypeAli: fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" - if relayMode == RelayModeEmbeddings { + if relayMode == constant.RelayModeEmbeddings { fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" } case APITypeTencent: @@ -213,12 +205,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { var promptTokens int var completionTokens int switch relayMode { - case RelayModeChatCompletions: - promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) - case RelayModeCompletions: - promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) - case RelayModeModerations: - promptTokens = countTokenInput(textRequest.Input, textRequest.Model) + case constant.RelayModeChatCompletions: + promptTokens = openai.CountTokenMessages(textRequest.Messages, textRequest.Model) + case constant.RelayModeCompletions: + promptTokens = openai.CountTokenInput(textRequest.Prompt, textRequest.Model) + case constant.RelayModeModerations: + promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model) } preConsumedTokens := common.PreConsumedQuota if textRequest.MaxTokens != 0 { @@ -230,14 +222,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { preConsumedQuota := int(float64(preConsumedTokens) * ratio) userQuota, err := model.CacheGetUserQuota(userId) if err != nil { - return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) } if userQuota > 100*preConsumedQuota { // in this case, we do not pre-consume quota @@ -248,14 +240,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } var requestBody io.Reader if isModelMapped { jsonStr, err := json.Marshal(textRequest) if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) } else { @@ -263,86 +255,86 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } switch apiType { case APITypeClaude: - claudeRequest := requestOpenAI2Claude(textRequest) + claudeRequest := anthropic.ConvertRequest(textRequest) jsonStr, err := json.Marshal(claudeRequest) if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) case APITypeBaidu: var jsonData []byte var err error switch relayMode { - case RelayModeEmbeddings: - baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) + case constant.RelayModeEmbeddings: + baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest) jsonData, err = json.Marshal(baiduEmbeddingRequest) default: - baiduRequest := requestOpenAI2Baidu(textRequest) + baiduRequest := baidu.ConvertRequest(textRequest) jsonData, err = json.Marshal(baiduRequest) } if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonData) case APITypePaLM: - palmRequest := requestOpenAI2PaLM(textRequest) + palmRequest := google.ConvertPaLMRequest(textRequest) jsonStr, err := json.Marshal(palmRequest) if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) case APITypeGemini: - geminiChatRequest := requestOpenAI2Gemini(textRequest) + geminiChatRequest := google.ConvertGeminiRequest(textRequest) jsonStr, err := json.Marshal(geminiChatRequest) if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) case APITypeZhipu: - zhipuRequest := requestOpenAI2Zhipu(textRequest) + zhipuRequest := zhipu.ConvertRequest(textRequest) jsonStr, err := json.Marshal(zhipuRequest) if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) case APITypeAli: var jsonStr []byte var err error switch relayMode { - case RelayModeEmbeddings: - aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) + case constant.RelayModeEmbeddings: + aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest) jsonStr, err = json.Marshal(aliEmbeddingRequest) default: - aliRequest := requestOpenAI2Ali(textRequest) + aliRequest := ali.ConvertRequest(textRequest) jsonStr, err = json.Marshal(aliRequest) } if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) case APITypeTencent: apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") - appId, secretId, secretKey, err := parseTencentConfig(apiKey) + appId, secretId, secretKey, err := tencent.ParseConfig(apiKey) if err != nil { - return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) } - tencentRequest := requestOpenAI2Tencent(textRequest) + tencentRequest := tencent.ConvertRequest(textRequest) tencentRequest.AppId = appId tencentRequest.SecretId = secretId jsonStr, err := json.Marshal(tencentRequest) if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } - sign := getTencentSign(*tencentRequest, secretKey) + sign := tencent.GetSign(*tencentRequest, secretKey) c.Request.Header.Set("Authorization", sign) requestBody = bytes.NewBuffer(jsonStr) case APITypeAIProxyLibrary: - aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) + aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest) aiProxyLibraryRequest.LibraryId = c.GetString("library_id") jsonStr, err := json.Marshal(aiProxyLibraryRequest) if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) } @@ -354,7 +346,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if apiType != APITypeXunfei { // cause xunfei use websocket req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") @@ -377,7 +369,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } req.Header.Set("anthropic-version", anthropicVersion) case APITypeZhipu: - token := getZhipuToken(apiKey) + token := zhipu.GetToken(apiKey) req.Header.Set("Authorization", token) case APITypeAli: req.Header.Set("Authorization", "Bearer "+apiKey) @@ -402,17 +394,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { req.Header.Set("Accept", "text/event-stream") } //req.Header.Set("Connection", c.Request.Header.Get("Connection")) - resp, err = httpClient.Do(req) + resp, err = util.HTTPClient.Do(req) if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } err = req.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } err = c.Request.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") @@ -426,11 +418,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } }(c.Request.Context()) } - return relayErrorHandler(resp) + return util.RelayErrorHandler(resp) } } - var textResponse TextResponse + var textResponse openai.SlimTextResponse tokenName := c.GetString("token_name") defer func(ctx context.Context) { @@ -471,15 +463,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { switch apiType { case APITypeOpenAI: if isStream { - err, responseText := openaiStreamHandler(c, resp, relayMode) + err, responseText := openai.StreamHandler(c, resp, relayMode) if err != nil { return err } textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) return nil } else { - err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model) + err, usage := openai.Handler(c, resp, promptTokens, textRequest.Model) if err != nil { return err } @@ -490,15 +482,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } case APITypeClaude: if isStream { - err, responseText := claudeStreamHandler(c, resp) + err, responseText := anthropic.StreamHandler(c, resp) if err != nil { return err } textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) return nil } else { - err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) + err, usage := anthropic.Handler(c, resp, promptTokens, textRequest.Model) if err != nil { return err } @@ -509,7 +501,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } case APITypeBaidu: if isStream { - err, usage := baiduStreamHandler(c, resp) + err, usage := baidu.StreamHandler(c, resp) if err != nil { return err } @@ -518,13 +510,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } else { - var err *OpenAIErrorWithStatusCode - var usage *Usage + var err *openai.ErrorWithStatusCode + var usage *openai.Usage switch relayMode { - case RelayModeEmbeddings: - err, usage = baiduEmbeddingHandler(c, resp) + case constant.RelayModeEmbeddings: + err, usage = baidu.EmbeddingHandler(c, resp) default: - err, usage = baiduHandler(c, resp) + err, usage = baidu.Handler(c, resp) } if err != nil { return err @@ -536,15 +528,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } case APITypePaLM: if textRequest.Stream { // PaLM2 API does not support stream - err, responseText := palmStreamHandler(c, resp) + err, responseText := google.PaLMStreamHandler(c, resp) if err != nil { return err } textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) return nil } else { - err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) + err, usage := google.PaLMHandler(c, resp, promptTokens, textRequest.Model) if err != nil { return err } @@ -555,15 +547,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } case APITypeGemini: if textRequest.Stream { - err, responseText := geminiChatStreamHandler(c, resp) + err, responseText := google.StreamHandler(c, resp) if err != nil { return err } textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) return nil } else { - err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) + err, usage := google.GeminiHandler(c, resp, promptTokens, textRequest.Model) if err != nil { return err } @@ -574,7 +566,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } case APITypeZhipu: if isStream { - err, usage := zhipuStreamHandler(c, resp) + err, usage := zhipu.StreamHandler(c, resp) if err != nil { return err } @@ -585,7 +577,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens return nil } else { - err, usage := zhipuHandler(c, resp) + err, usage := zhipu.Handler(c, resp) if err != nil { return err } @@ -598,7 +590,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } case APITypeAli: if isStream { - err, usage := aliStreamHandler(c, resp) + err, usage := ali.StreamHandler(c, resp) if err != nil { return err } @@ -607,13 +599,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } else { - var err *OpenAIErrorWithStatusCode - var usage *Usage + var err *openai.ErrorWithStatusCode + var usage *openai.Usage switch relayMode { - case RelayModeEmbeddings: - err, usage = aliEmbeddingHandler(c, resp) + case constant.RelayModeEmbeddings: + err, usage = ali.EmbeddingHandler(c, resp) default: - err, usage = aliHandler(c, resp) + err, usage = ali.Handler(c, resp) } if err != nil { return err @@ -628,14 +620,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { auth = strings.TrimPrefix(auth, "Bearer ") splits := strings.Split(auth, "|") if len(splits) != 3 { - return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) } - var err *OpenAIErrorWithStatusCode - var usage *Usage + var err *openai.ErrorWithStatusCode + var usage *openai.Usage if isStream { - err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) + err, usage = xunfei.StreamHandler(c, textRequest, splits[0], splits[1], splits[2]) } else { - err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) + err, usage = xunfei.Handler(c, textRequest, splits[0], splits[1], splits[2]) } if err != nil { return err @@ -646,7 +638,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return nil case APITypeAIProxyLibrary: if isStream { - err, usage := aiProxyLibraryStreamHandler(c, resp) + err, usage := aiproxy.StreamHandler(c, resp) if err != nil { return err } @@ -655,7 +647,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } else { - err, usage := aiProxyLibraryHandler(c, resp) + err, usage := aiproxy.Handler(c, resp) if err != nil { return err } @@ -666,15 +658,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } case APITypeTencent: if isStream { - err, responseText := tencentStreamHandler(c, resp) + err, responseText := tencent.StreamHandler(c, resp) if err != nil { return err } textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) return nil } else { - err, usage := tencentHandler(c, resp) + err, usage := tencent.Handler(c, resp) if err != nil { return err } @@ -684,6 +676,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return nil } default: - return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) + return openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) } } diff --git a/relay/util/common.go b/relay/util/common.go new file mode 100644 index 00000000..9d13b12e --- /dev/null +++ b/relay/util/common.go @@ -0,0 +1,166 @@ +package util + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/model" + "one-api/relay/channel/openai" + "strconv" + "strings" + + "github.com/gin-gonic/gin" +) + +func ShouldDisableChannel(err *openai.Error, statusCode int) bool { + if !common.AutomaticDisableChannelEnabled { + return false + } + if err == nil { + return false + } + if statusCode == http.StatusUnauthorized { + return true + } + if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + return true + } + return false +} + +func ShouldEnableChannel(err error, openAIErr *openai.Error) bool { + if !common.AutomaticEnableChannelEnabled { + return false + } + if err != nil { + return false + } + if openAIErr != nil { + return false + } + return true +} + +type GeneralErrorResponse struct { + Error openai.Error `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` + Header struct { + Message string `json:"message"` + } `json:"header"` + Response struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } `json:"response"` +} + +func (e GeneralErrorResponse) ToMessage() string { + if e.Error.Message != "" { + return e.Error.Message + } + if e.Message != "" { + return e.Message + } + if e.Msg != "" { + return e.Msg + } + if e.Err != "" { + return e.Err + } + if e.ErrorMsg != "" { + return e.ErrorMsg + } + if e.Header.Message != "" { + return e.Header.Message + } + if e.Response.Error.Message != "" { + return e.Response.Error.Message + } + return "" +} + +func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *openai.ErrorWithStatusCode) { + ErrorWithStatusCode = &openai.ErrorWithStatusCode{ + StatusCode: resp.StatusCode, + Error: openai.Error{ + Message: "", + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + err = resp.Body.Close() + if err != nil { + return + } + var errResponse GeneralErrorResponse + err = json.Unmarshal(responseBody, &errResponse) + if err != nil { + return + } + if errResponse.Error.Message != "" { + // OpenAI format error, so we override the default one + ErrorWithStatusCode.Error = errResponse.Error + } else { + ErrorWithStatusCode.Error.Message = errResponse.ToMessage() + } + if ErrorWithStatusCode.Error.Message == "" { + ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } + return +} + +func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + switch channelType { + case common.ChannelTypeOpenAI: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + case common.ChannelTypeAzure: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) + } + } + return fullRequestURL +} + +func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + // quotaDelta is remaining quota to be consumed + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + // totalQuota is total quota consumed + if totalQuota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) + model.UpdateChannelUsedQuota(channelId, totalQuota) + } + if totalQuota <= 0 { + common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) + } +} + +func GetAPIVersion(c *gin.Context) string { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + return apiVersion +} diff --git a/relay/util/init.go b/relay/util/init.go new file mode 100644 index 00000000..d308d900 --- /dev/null +++ b/relay/util/init.go @@ -0,0 +1,24 @@ +package util + +import ( + "net/http" + "one-api/common" + "time" +) + +var HTTPClient *http.Client +var ImpatientHTTPClient *http.Client + +func init() { + if common.RelayTimeout == 0 { + HTTPClient = &http.Client{} + } else { + HTTPClient = &http.Client{ + Timeout: time.Duration(common.RelayTimeout) * time.Second, + } + } + + ImpatientHTTPClient = &http.Client{ + Timeout: 5 * time.Second, + } +}