From 2114bc1982da7145a6c9590737893d7182eded39 Mon Sep 17 00:00:00 2001 From: Martial BE Date: Fri, 1 Dec 2023 18:36:30 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A5=20=E5=88=A0=E9=99=A4=E6=97=A0?= =?UTF-8?q?=E6=95=88=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/billing.go | 20 ++- controller/group.go | 3 +- controller/relay.go | 228 +------------------------------- providers/openai/image_edits.go | 12 +- 4 files changed, 27 insertions(+), 236 deletions(-) diff --git a/controller/billing.go b/controller/billing.go index 42e86aea..3de7c847 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -1,9 +1,11 @@ package controller import ( - "github.com/gin-gonic/gin" "one-api/common" "one-api/model" + "one-api/types" + + "github.com/gin-gonic/gin" ) func GetSubscription(c *gin.Context) { @@ -21,13 +23,23 @@ func GetSubscription(c *gin.Context) { } else { userId := c.GetInt("id") remainQuota, err = model.GetUserQuota(userId) + if err != nil { + openAIError := types.OpenAIError{ + Message: err.Error(), + Type: "upstream_error", + } + c.JSON(200, gin.H{ + "error": openAIError, + }) + return + } usedQuota, err = model.GetUserUsedQuota(userId) } if expiredTime <= 0 { expiredTime = 0 } if err != nil { - openAIError := OpenAIError{ + openAIError := types.OpenAIError{ Message: err.Error(), Type: "upstream_error", } @@ -53,7 +65,6 @@ func GetSubscription(c *gin.Context) { AccessUntil: expiredTime, } c.JSON(200, subscription) - return } func GetUsage(c *gin.Context) { @@ -69,7 +80,7 @@ func GetUsage(c *gin.Context) { quota, err = model.GetUserUsedQuota(userId) } if err != nil { - openAIError := OpenAIError{ + openAIError := types.OpenAIError{ Message: err.Error(), Type: "one_api_error", } @@ -87,5 +98,4 @@ func GetUsage(c *gin.Context) { TotalUsage: amount * 100, } c.JSON(200, usage) - return } diff --git a/controller/group.go b/controller/group.go index 2b2f6006..109e2bce 100644 --- a/controller/group.go +++ b/controller/group.go @@ -1,9 +1,10 @@ package controller import ( - "github.com/gin-gonic/gin" "net/http" "one-api/common" + + "github.com/gin-gonic/gin" ) func GetGroups(c *gin.Context) { diff --git a/controller/relay.go b/controller/relay.go index f73e7593..fb0a1582 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -11,216 +11,6 @@ import ( "github.com/gin-gonic/gin" ) -type Message struct { - Role string `json:"role"` - Content any `json:"content"` - Name *string `json:"name,omitempty"` -} - -type ImageURL struct { - Url string `json:"url,omitempty"` - Detail string `json:"detail,omitempty"` -} - -type TextContent struct { - Type string `json:"type,omitempty"` - Text string `json:"text,omitempty"` -} - -type ImageContent struct { - Type string `json:"type,omitempty"` - ImageURL *ImageURL `json:"image_url,omitempty"` -} - -func (m Message) StringContent() string { - content, ok := m.Content.(string) - if ok { - return content - } - contentList, ok := m.Content.([]any) - if ok { - var contentStr string - for _, contentItem := range contentList { - contentMap, ok := contentItem.(map[string]any) - if !ok { - continue - } - if contentMap["type"] == "text" { - if subStr, ok := contentMap["text"].(string); ok { - contentStr += subStr - } - } - } - return contentStr - } - return "" -} - -// https://platform.openai.com/docs/api-reference/chat - -type ResponseFormat struct { - Type string `json:"type,omitempty"` -} - -type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - ResponseFormat *ResponseFormat `json:"response_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` -} - -func (r GeneralOpenAIRequest) ParseInput() []string { - if r.Input == nil { - return nil - } - var input []string - switch r.Input.(type) { - case string: - input = []string{r.Input.(string)} - case []any: - input = make([]string, 0, len(r.Input.([]any))) - for _, item := range r.Input.([]any) { - if str, ok := item.(string); ok { - input = append(input, str) - } - } - } - return input -} - -type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - MaxTokens int `json:"max_tokens"` -} - -type TextRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Prompt string `json:"prompt"` - MaxTokens int `json:"max_tokens"` - //Stream bool `json:"stream"` -} - -// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create -type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt" binding:"required"` - N int `json:"n"` - Size string `json:"size"` - Quality string `json:"quality"` - ResponseFormat string `json:"response_format"` - Style string `json:"style"` - User string `json:"user"` -} - -type WhisperResponse struct { - Text string `json:"text,omitempty"` -} - -type TextToSpeechRequest struct { - Model string `json:"model" binding:"required"` - Input string `json:"input" binding:"required"` - Voice string `json:"voice" binding:"required"` - Speed float64 `json:"speed"` - ResponseFormat string `json:"response_format"` -} - -type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type OpenAIError struct { - Message string `json:"message"` - Type string `json:"type"` - Param string `json:"param"` - Code any `json:"code"` -} - -type OpenAIErrorWithStatusCode struct { - OpenAIError - StatusCode int `json:"status_code"` -} - -type TextResponse struct { - Choices []OpenAITextResponseChoice `json:"choices"` - Usage `json:"usage"` - Error OpenAIError `json:"error"` -} - -type OpenAITextResponseChoice struct { - Index int `json:"index"` - Message `json:"message"` - FinishReason string `json:"finish_reason"` -} - -type OpenAITextResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Choices []OpenAITextResponseChoice `json:"choices"` - Usage `json:"usage"` -} - -type OpenAIEmbeddingResponseItem struct { - Object string `json:"object"` - Index int `json:"index"` - Embedding []float64 `json:"embedding"` -} - -type OpenAIEmbeddingResponse struct { - Object string `json:"object"` - Data []OpenAIEmbeddingResponseItem `json:"data"` - Model string `json:"model"` - Usage `json:"usage"` -} - -type ImageResponse struct { - Created int `json:"created"` - Data []struct { - Url string `json:"url"` - } -} - -type ChatCompletionsStreamResponseChoice struct { - Delta struct { - Content string `json:"content"` - } `json:"delta"` - FinishReason *string `json:"finish_reason"` -} - -type ChatCompletionsStreamResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionsStreamResponseChoice `json:"choices"` -} - -type CompletionsStreamResponse struct { - Choices []struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` -} - func Relay(c *gin.Context) { defer c.Request.Body.Close() var err *types.OpenAIErrorWithStatusCode @@ -252,18 +42,8 @@ func Relay(c *gin.Context) { // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { // relayMode = RelayModeEdits - switch relayMode { - // case RelayModeImagesGenerations: - // err = relayImageHelper(c, relayMode) - // case RelayModeAudioSpeech: - // fallthrough - // case RelayModeAudioTranslation: - // fallthrough - // case RelayModeAudioTranscription: - // err = relayAudioHelper(c, relayMode) - default: - err = relayHelper(c, relayMode) - } + err = relayHelper(c, relayMode) + if err != nil { requestId := c.GetString(common.RequestIdKey) retryTimesStr := c.Query("retry") @@ -294,7 +74,7 @@ func Relay(c *gin.Context) { } func RelayNotImplemented(c *gin.Context) { - err := OpenAIError{ + err := types.OpenAIError{ Message: "API not implemented", Type: "one_api_error", Param: "", @@ -306,7 +86,7 @@ func RelayNotImplemented(c *gin.Context) { } func RelayNotFound(c *gin.Context) { - err := OpenAIError{ + err := types.OpenAIError{ Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), Type: "invalid_request_error", Param: "", diff --git a/providers/openai/image_edits.go b/providers/openai/image_edits.go index dce7c4ff..45bfbf32 100644 --- a/providers/openai/image_edits.go +++ b/providers/openai/image_edits.go @@ -52,7 +52,7 @@ func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isMod func imagesEditsMultipartForm(request *types.ImageEditRequest, b common.FormBuilder) error { err := b.CreateFormFile("image", request.Image) if err != nil { - return fmt.Errorf("creating form file: %w", err) + return fmt.Errorf("creating form image: %w", err) } err = b.WriteField("prompt", request.Prompt) @@ -68,7 +68,7 @@ func imagesEditsMultipartForm(request *types.ImageEditRequest, b common.FormBuil if request.Mask != nil { err = b.CreateFormFile("mask", request.Mask) if err != nil { - return fmt.Errorf("writing format: %w", err) + return fmt.Errorf("writing mask: %w", err) } } @@ -80,23 +80,23 @@ func imagesEditsMultipartForm(request *types.ImageEditRequest, b common.FormBuil } if request.N != 0 { - err = b.WriteField("n", fmt.Sprintf("%.2f", request.N)) + err = b.WriteField("n", fmt.Sprintf("%d", request.N)) if err != nil { - return fmt.Errorf("writing temperature: %w", err) + return fmt.Errorf("writing n: %w", err) } } if request.Size != "" { err = b.WriteField("size", request.Size) if err != nil { - return fmt.Errorf("writing language: %w", err) + return fmt.Errorf("writing size: %w", err) } } if request.User != "" { err = b.WriteField("user", request.User) if err != nil { - return fmt.Errorf("writing language: %w", err) + return fmt.Errorf("writing user: %w", err) } }