From f6eb4e56287a1cdb1e00bae299ea4f12e9b6477c Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 25 Jun 2023 10:25:33 +0800 Subject: [PATCH] perf: validate the request first before send to OpenAI's server --- controller/relay-text.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/controller/relay-text.go b/controller/relay-text.go index 778991bd..e14e0632 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/json" + "errors" "fmt" "github.com/gin-gonic/gin" "io" @@ -29,6 +30,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if relayMode == RelayModeModeration && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } + // request validation + if textRequest.Model == "" { + return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) + } + switch relayMode { + case RelayModeCompletions: + if textRequest.Prompt == "" { + return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) + } + case RelayModeChatCompletions: + if len(textRequest.Messages) == 0 { + return errorWrapper(errors.New("messages is required"), "required_field_missing", http.StatusBadRequest) + } + case RelayModeEmbeddings: + case RelayModeModeration: + if textRequest.Input == "" { + return errorWrapper(errors.New("input is required"), "required_field_missing", http.StatusBadRequest) + } + } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() if c.GetString("base_url") != "" {