From f9d914873ffa3281de58f88d8a7c02c6026765fd Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 6 Apr 2024 00:44:33 +0800 Subject: [PATCH] chore: reorganize constant related package --- common/constants.go | 38 ----------------------- controller/channel-billing.go | 19 ++++++------ controller/channel-test.go | 7 +++-- controller/model.go | 14 ++++----- controller/relay.go | 12 ++++---- main.go | 2 +- middleware/distributor.go | 11 +++---- relay/apitype/define.go | 17 +++++++++++ relay/channel/ali/adaptor.go | 14 ++++----- relay/channel/baidu/adaptor.go | 6 ++-- relay/channel/minimax/main.go | 6 ++-- relay/channel/ollama/adaptor.go | 8 ++--- relay/channel/openai/adaptor.go | 16 +++++----- relay/channel/openai/compatible.go | 38 +++++++++++------------ relay/channel/openai/main.go | 6 ++-- relay/channel/zhipu/adaptor.go | 14 ++++----- relay/channeltype/define.go | 39 ++++++++++++++++++++++++ relay/channeltype/helper.go | 30 +++++++++++++++++++ relay/constant/api_type.go | 48 ------------------------------ relay/constant/relay_mode.go | 41 ------------------------- relay/controller/audio.go | 17 ++++++----- relay/controller/helper.go | 14 +++++---- relay/controller/image.go | 9 +++--- relay/controller/text.go | 7 +++-- relay/helper/main.go | 24 +++++++-------- relay/relaymode/define.go | 14 +++++++++ relay/relaymode/helper.go | 29 ++++++++++++++++++ relay/util/common.go | 5 ++-- relay/util/relay_meta.go | 9 +++--- relay/util/validation.go | 12 ++++---- 30 files changed, 269 insertions(+), 257 deletions(-) create mode 100644 relay/apitype/define.go create mode 100644 relay/channeltype/define.go create mode 100644 relay/channeltype/helper.go delete mode 100644 relay/constant/api_type.go create mode 100644 relay/relaymode/define.go create mode 100644 relay/relaymode/helper.go diff --git a/common/constants.go b/common/constants.go index 04a56649..95b29683 100644 --- a/common/constants.go +++ b/common/constants.go @@ -38,44 +38,6 @@ const ( ChannelStatusAutoDisabled = 3 ) -const ( - ChannelTypeUnknown = iota - ChannelTypeOpenAI - ChannelTypeAPI2D - ChannelTypeAzure - ChannelTypeCloseAI - ChannelTypeOpenAISB - ChannelTypeOpenAIMax - ChannelTypeOhMyGPT - ChannelTypeCustom - ChannelTypeAILS - ChannelTypeAIProxy - ChannelTypePaLM - ChannelTypeAPI2GPT - ChannelTypeAIGC2D - ChannelTypeAnthropic - ChannelTypeBaidu - ChannelTypeZhipu - ChannelTypeAli - ChannelTypeXunfei - ChannelType360 - ChannelTypeOpenRouter - ChannelTypeAIProxyLibrary - ChannelTypeFastGPT - ChannelTypeTencent - ChannelTypeGemini - ChannelTypeMoonshot - ChannelTypeBaichuan - ChannelTypeMinimax - ChannelTypeMistral - ChannelTypeGroq - ChannelTypeOllama - ChannelTypeLingYiWanWu - ChannelTypeStepFun - - ChannelTypeDummy -) - var ChannelBaseURLs = []string{ "", // 0 "https://api.openai.com", // 1 diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 03c97349..b31850a3 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -9,6 +9,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/monitor" + "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" @@ -209,23 +210,23 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { channel.BaseURL = &baseURL } switch channel.Type { - case common.ChannelTypeOpenAI: + case channeltype.OpenAI: if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } - case common.ChannelTypeAzure: + case channeltype.Azure: return 0, errors.New("尚未实现") - case common.ChannelTypeCustom: + case channeltype.Custom: baseURL = channel.GetBaseURL() - case common.ChannelTypeCloseAI: + case channeltype.CloseAI: return updateChannelCloseAIBalance(channel) - case common.ChannelTypeOpenAISB: + case channeltype.OpenAISB: return updateChannelOpenAISBBalance(channel) - case common.ChannelTypeAIProxy: + case channeltype.AIProxy: return updateChannelAIProxyBalance(channel) - case common.ChannelTypeAPI2GPT: + case channeltype.API2GPT: return updateChannelAPI2GPTBalance(channel) - case common.ChannelTypeAIGC2D: + case channeltype.AIGC2D: return updateChannelAIGC2DBalance(channel) default: return 0, errors.New("尚未实现") @@ -305,7 +306,7 @@ func updateAllChannelsBalance() error { continue } // TODO: support Azure - if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { + if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom { continue } balance, err := updateChannelBalance(channel) diff --git a/controller/channel-test.go b/controller/channel-test.go index 95f4d769..8f7cb17c 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -12,9 +12,10 @@ import ( "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/monitor" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/helper" relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" @@ -57,7 +58,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error c.Set("base_url", channel.GetBaseURL()) middleware.SetupContextForSelectedChannel(c, channel, "") meta := util.GetRelayMeta(c) - apiType := constant.ChannelType2APIType(channel.Type) + apiType := channeltype.ToAPIType(channel.Type) adaptor := helper.GetAdaptor(apiType) if adaptor == nil { return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil @@ -73,7 +74,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error request := buildTestRequest() request.Model = modelName meta.OriginModelName, meta.ActualModelName = modelName, modelName - convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) + convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) if err != nil { return err, nil } diff --git a/controller/model.go b/controller/model.go index a03c96cb..b8002e6f 100644 --- a/controller/model.go +++ b/controller/model.go @@ -3,10 +3,10 @@ package controller import ( "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/apitype" "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/helper" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" @@ -62,8 +62,8 @@ func init() { IsBlocking: false, }) // https://platform.openai.com/docs/models/model-endpoint-compatibility - for i := 0; i < constant.APITypeDummy; i++ { - if i == constant.APITypeAIProxyLibrary { + for i := 0; i < apitype.Dummy; i++ { + if i == apitype.AIProxyLibrary { continue } adaptor := helper.GetAdaptor(i) @@ -82,7 +82,7 @@ func init() { } } for _, channelType := range openai.CompatibleChannels { - if channelType == common.ChannelTypeAzure { + if channelType == channeltype.Azure { continue } channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) @@ -103,8 +103,8 @@ func init() { modelsMap[model.Id] = model } channelId2Models = make(map[int][]string) - for i := 1; i < common.ChannelTypeDummy; i++ { - adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i)) + for i := 1; i < channeltype.Dummy; i++ { + adaptor := helper.GetAdaptor(channeltype.ToAPIType(i)) meta := &util.RelayMeta{ ChannelType: i, } diff --git a/controller/relay.go b/controller/relay.go index b34768df..36e96651 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -12,9 +12,9 @@ import ( "github.com/songquanpeng/one-api/middleware" dbmodel "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/monitor" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" @@ -25,13 +25,13 @@ import ( func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { var err *model.ErrorWithStatusCode switch relayMode { - case constant.RelayModeImagesGenerations: + case relaymode.ImagesGenerations: err = controller.RelayImageHelper(c, relayMode) - case constant.RelayModeAudioSpeech: + case relaymode.AudioSpeech: fallthrough - case constant.RelayModeAudioTranslation: + case relaymode.AudioTranslation: fallthrough - case constant.RelayModeAudioTranscription: + case relaymode.AudioTranscription: err = controller.RelayAudioHelper(c, relayMode) default: err = controller.RelayTextHelper(c) @@ -41,7 +41,7 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { func Relay(c *gin.Context) { ctx := c.Request.Context() - relayMode := constant.Path2RelayMode(c.Request.URL.Path) + relayMode := relaymode.GetByPath(c.Request.URL.Path) if config.DebugEnabled { requestBody, _ := common.GetRequestBody(c) logger.Debugf(ctx, "request body: %s", string(requestBody)) diff --git a/main.go b/main.go index b20c6daf..92668408 100644 --- a/main.go +++ b/main.go @@ -28,7 +28,7 @@ func main() { gin.SetMode(gin.ReleaseMode) } if config.DebugEnabled { - logger.SysLog("running in debug mode") + logger.SysLog("running in debug relaymode") } var err error // Initialize SQL Database diff --git a/middleware/distributor.go b/middleware/distributor.go index 04489a2b..29a1d5b3 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -6,6 +6,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/channeltype" "net/http" "strconv" ) @@ -66,15 +67,15 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("base_url", channel.GetBaseURL()) // this is for backward compatibility switch channel.Type { - case common.ChannelTypeAzure: + case channeltype.Azure: c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeXunfei: + case channeltype.Xunfei: c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeGemini: + case channeltype.Gemini: c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeAIProxyLibrary: + case channeltype.AIProxyLibrary: c.Set(common.ConfigKeyLibraryID, channel.Other) - case common.ChannelTypeAli: + case channeltype.Ali: c.Set(common.ConfigKeyPlugin, channel.Other) } cfg, _ := channel.LoadConfig() diff --git a/relay/apitype/define.go b/relay/apitype/define.go new file mode 100644 index 00000000..82d32a50 --- /dev/null +++ b/relay/apitype/define.go @@ -0,0 +1,17 @@ +package apitype + +const ( + OpenAI = iota + Anthropic + PaLM + Baidu + Zhipu + Ali + Xunfei + AIProxyLibrary + Tencent + Gemini + Ollama + + Dummy // this one is only for count, do not add any channel after this +) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index f04bf56f..d46c082f 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -6,8 +6,8 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" @@ -25,9 +25,9 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { fullRequestURL := "" switch meta.Mode { - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL) - case constant.RelayModeImagesGenerations: + case relaymode.ImagesGenerations: fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL) default: fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) @@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut } req.Header.Set("Authorization", "Bearer "+meta.APIKey) - if meta.Mode == constant.RelayModeImagesGenerations { + if meta.Mode == relaymode.ImagesGenerations { req.Header.Set("X-DashScope-Async", "enable") } if c.GetString(common.ConfigKeyPlugin) != "" { @@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } switch relayMode { - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: aliEmbeddingRequest := ConvertEmbeddingRequest(*request) return aliEmbeddingRequest, nil default: @@ -85,9 +85,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel err, usage = StreamHandler(c, resp) } else { switch meta.Mode { - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: err, usage = EmbeddingHandler(c, resp) - case constant.RelayModeImagesGenerations: + case relaymode.ImagesGenerations: err, usage = ImageHandler(c, resp) default: err, usage = Handler(c, resp) diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 6096eb31..c2388dc1 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -3,13 +3,13 @@ package baidu import ( "errors" "fmt" + "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" ) @@ -100,7 +100,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } switch relayMode { - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) return baiduEmbeddingRequest, nil default: @@ -125,7 +125,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel err, usage = StreamHandler(c, resp) } else { switch meta.Mode { - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: err, usage = EmbeddingHandler(c, resp) default: err, usage = Handler(c, resp) diff --git a/relay/channel/minimax/main.go b/relay/channel/minimax/main.go index a01821c2..4a0c9e0f 100644 --- a/relay/channel/minimax/main.go +++ b/relay/channel/minimax/main.go @@ -2,13 +2,13 @@ package minimax import ( "fmt" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/util" ) func GetRequestURL(meta *util.RelayMeta) (string, error) { - if meta.Mode == constant.RelayModeChatCompletions { + if meta.Mode == relaymode.ChatCompletions { return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil } - return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode) + return "", fmt.Errorf("unsupported relay relaymode %d for minimax", meta.Mode) } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 33635c5c..50fb3ca3 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -3,12 +3,12 @@ package ollama import ( "errors" "fmt" + "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" ) @@ -23,7 +23,7 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { // https://github.com/ollama/ollama/blob/main/docs/api.md fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) - if meta.Mode == constant.RelayModeEmbeddings { + if meta.Mode == relaymode.Embeddings { fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) } return fullRequestURL, nil @@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } switch relayMode { - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request) return ollamaEmbeddingRequest, nil default: @@ -64,7 +64,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel err, usage = StreamHandler(c, resp) } else { switch meta.Mode { - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: err, usage = EmbeddingHandler(c, resp) default: err, usage = Handler(c, resp) diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 3212d8f8..9c2e8408 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -4,11 +4,11 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/minimax" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" @@ -25,8 +25,8 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { switch meta.ChannelType { - case common.ChannelTypeAzure: - if meta.Mode == constant.RelayModeImagesGenerations { + case channeltype.Azure: + if meta.Mode == relaymode.ImagesGenerations { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion) @@ -43,7 +43,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { // {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version} requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil - case common.ChannelTypeMinimax: + case channeltype.Minimax: return minimax.GetRequestURL(meta) default: return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil @@ -52,12 +52,12 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { channel.SetupCommonRequestHeader(c, req, meta) - if meta.ChannelType == common.ChannelTypeAzure { + if meta.ChannelType == channeltype.Azure { req.Header.Set("api-key", meta.APIKey) return nil } req.Header.Set("Authorization", "Bearer "+meta.APIKey) - if meta.ChannelType == common.ChannelTypeOpenRouter { + if meta.ChannelType == channeltype.OpenRouter { req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") req.Header.Set("X-Title", "One API") } @@ -91,7 +91,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel } } else { switch meta.Mode { - case constant.RelayModeImagesGenerations: + case relaymode.ImagesGenerations: err, _ = ImageHandler(c, resp) default: err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) diff --git a/relay/channel/openai/compatible.go b/relay/channel/openai/compatible.go index 2a1447ab..6698e941 100644 --- a/relay/channel/openai/compatible.go +++ b/relay/channel/openai/compatible.go @@ -1,7 +1,6 @@ package openai import ( - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/groq" @@ -10,39 +9,40 @@ import ( "github.com/songquanpeng/one-api/relay/channel/mistral" "github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/channel/stepfun" + "github.com/songquanpeng/one-api/relay/channeltype" ) var CompatibleChannels = []int{ - common.ChannelTypeAzure, - common.ChannelType360, - common.ChannelTypeMoonshot, - common.ChannelTypeBaichuan, - common.ChannelTypeMinimax, - common.ChannelTypeMistral, - common.ChannelTypeGroq, - common.ChannelTypeLingYiWanWu, - common.ChannelTypeStepFun, + channeltype.Azure, + channeltype.AI360, + channeltype.Moonshot, + channeltype.Baichuan, + channeltype.Minimax, + channeltype.Mistral, + channeltype.Groq, + channeltype.LingYiWanWu, + channeltype.StepFun, } func GetCompatibleChannelMeta(channelType int) (string, []string) { switch channelType { - case common.ChannelTypeAzure: + case channeltype.Azure: return "azure", ModelList - case common.ChannelType360: + case channeltype.AI360: return "360", ai360.ModelList - case common.ChannelTypeMoonshot: + case channeltype.Moonshot: return "moonshot", moonshot.ModelList - case common.ChannelTypeBaichuan: + case channeltype.Baichuan: return "baichuan", baichuan.ModelList - case common.ChannelTypeMinimax: + case channeltype.Minimax: return "minimax", minimax.ModelList - case common.ChannelTypeMistral: + case channeltype.Mistral: return "mistralai", mistral.ModelList - case common.ChannelTypeGroq: + case channeltype.Groq: return "groq", groq.ModelList - case common.ChannelTypeLingYiWanWu: + case channeltype.LingYiWanWu: return "lingyiwanwu", lingyiwanwu.ModelList - case common.ChannelTypeStepFun: + case channeltype.StepFun: return "stepfun", stepfun.ModelList default: return "openai", ModelList diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go index 63cb9ae8..68d8f48f 100644 --- a/relay/channel/openai/main.go +++ b/relay/channel/openai/main.go @@ -8,8 +8,8 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "strings" @@ -46,7 +46,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E data = data[6:] if !strings.HasPrefix(data, "[DONE]") { switch relayMode { - case constant.RelayModeChatCompletions: + case relaymode.ChatCompletions: var streamResponse ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) if err != nil { @@ -59,7 +59,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E if streamResponse.Usage != nil { usage = streamResponse.Usage } - case constant.RelayModeCompletions: + case relaymode.Completions: var streamResponse CompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) if err != nil { diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 14c581dd..61c40e14 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -6,8 +6,8 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/util" "io" "math" @@ -33,9 +33,9 @@ func (a *Adaptor) SetVersionByModeName(modelName string) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { switch meta.Mode { - case constant.RelayModeImagesGenerations: + case relaymode.ImagesGenerations: return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil } a.SetVersionByModeName(meta.ActualModelName) @@ -61,7 +61,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } switch relayMode { - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) return baiduEmbeddingRequest, nil default: @@ -107,10 +107,10 @@ func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.R func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { switch meta.Mode { - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: err, usage = EmbeddingsHandler(c, resp) return - case constant.RelayModeImagesGenerations: + case relaymode.ImagesGenerations: err, usage = openai.ImageHandler(c, resp) return } @@ -120,7 +120,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel if meta.IsStream { err, usage = StreamHandler(c, resp) } else { - if meta.Mode == constant.RelayModeEmbeddings { + if meta.Mode == relaymode.Embeddings { err, usage = EmbeddingsHandler(c, resp) } else { err, usage = Handler(c, resp) diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go new file mode 100644 index 00000000..80027a80 --- /dev/null +++ b/relay/channeltype/define.go @@ -0,0 +1,39 @@ +package channeltype + +const ( + Unknown = iota + OpenAI + API2D + Azure + CloseAI + OpenAISB + OpenAIMax + OhMyGPT + Custom + Ails + AIProxy + PaLM + API2GPT + AIGC2D + Anthropic + Baidu + Zhipu + Ali + Xunfei + AI360 + OpenRouter + AIProxyLibrary + FastGPT + Tencent + Gemini + Moonshot + Baichuan + Minimax + Mistral + Groq + Ollama + LingYiWanWu + StepFun + + Dummy +) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go new file mode 100644 index 00000000..01c2918c --- /dev/null +++ b/relay/channeltype/helper.go @@ -0,0 +1,30 @@ +package channeltype + +import "github.com/songquanpeng/one-api/relay/apitype" + +func ToAPIType(channelType int) int { + apiType := apitype.OpenAI + switch channelType { + case Anthropic: + apiType = apitype.Anthropic + case Baidu: + apiType = apitype.Baidu + case PaLM: + apiType = apitype.PaLM + case Zhipu: + apiType = apitype.Zhipu + case Ali: + apiType = apitype.Ali + case Xunfei: + apiType = apitype.Xunfei + case AIProxyLibrary: + apiType = apitype.AIProxyLibrary + case Tencent: + apiType = apitype.Tencent + case Gemini: + apiType = apitype.Gemini + case Ollama: + apiType = apitype.Ollama + } + return apiType +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go deleted file mode 100644 index b249f6a2..00000000 --- a/relay/constant/api_type.go +++ /dev/null @@ -1,48 +0,0 @@ -package constant - -import ( - "github.com/songquanpeng/one-api/common" -) - -const ( - APITypeOpenAI = iota - APITypeAnthropic - APITypePaLM - APITypeBaidu - APITypeZhipu - APITypeAli - APITypeXunfei - APITypeAIProxyLibrary - APITypeTencent - APITypeGemini - APITypeOllama - - APITypeDummy // this one is only for count, do not add any channel after this -) - -func ChannelType2APIType(channelType int) int { - apiType := APITypeOpenAI - switch channelType { - case common.ChannelTypeAnthropic: - apiType = APITypeAnthropic - case common.ChannelTypeBaidu: - apiType = APITypeBaidu - case common.ChannelTypePaLM: - apiType = APITypePaLM - case common.ChannelTypeZhipu: - apiType = APITypeZhipu - case common.ChannelTypeAli: - apiType = APITypeAli - case common.ChannelTypeXunfei: - apiType = APITypeXunfei - case common.ChannelTypeAIProxyLibrary: - apiType = APITypeAIProxyLibrary - case common.ChannelTypeTencent: - apiType = APITypeTencent - case common.ChannelTypeGemini: - apiType = APITypeGemini - case common.ChannelTypeOllama: - apiType = APITypeOllama - } - return apiType -} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 5e2fe574..3f2495e1 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -1,42 +1 @@ package constant - -import "strings" - -const ( - RelayModeUnknown = iota - RelayModeChatCompletions - RelayModeCompletions - RelayModeEmbeddings - RelayModeModerations - RelayModeImagesGenerations - RelayModeEdits - RelayModeAudioSpeech - RelayModeAudioTranscription - RelayModeAudioTranslation -) - -func Path2RelayMode(path string) int { - relayMode := RelayModeUnknown - if strings.HasPrefix(path, "/v1/chat/completions") { - relayMode = RelayModeChatCompletions - } else if strings.HasPrefix(path, "/v1/completions") { - relayMode = RelayModeCompletions - } else if strings.HasPrefix(path, "/v1/embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasSuffix(path, "embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasPrefix(path, "/v1/moderations") { - relayMode = RelayModeModerations - } else if strings.HasPrefix(path, "/v1/images/generations") { - relayMode = RelayModeImagesGenerations - } else if strings.HasPrefix(path, "/v1/edits") { - relayMode = RelayModeEdits - } else if strings.HasPrefix(path, "/v1/audio/speech") { - relayMode = RelayModeAudioSpeech - } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { - relayMode = RelayModeAudioTranscription - } else if strings.HasPrefix(path, "/v1/audio/translations") { - relayMode = RelayModeAudioTranslation - } - return relayMode -} diff --git a/relay/controller/audio.go b/relay/controller/audio.go index cd118985..4fc0071b 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -13,8 +13,9 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/channeltype" relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" @@ -33,7 +34,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus tokenName := c.GetString("token_name") var ttsRequest openai.TextToSpeechRequest - if relayMode == constant.RelayModeAudioSpeech { + if relayMode == relaymode.AudioSpeech { // Read JSON err := common.UnmarshalBodyReusable(c, &ttsRequest) // Check if JSON is valid @@ -53,7 +54,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus var quota int64 var preConsumedQuota int64 switch relayMode { - case constant.RelayModeAudioSpeech: + case relaymode.AudioSpeech: preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota default: @@ -122,12 +123,12 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure { + if channelType == channeltype.Azure { apiVersion := util.GetAzureAPIVersion(c) - if relayMode == constant.RelayModeAudioTranscription { + if relayMode == relaymode.AudioTranscription { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) - } else if relayMode == constant.RelayModeAudioSpeech { + } else if relayMode == relaymode.AudioSpeech { // https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion) } @@ -146,7 +147,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - if (relayMode == constant.RelayModeAudioTranscription || relayMode == constant.RelayModeAudioSpeech) && channelType == common.ChannelTypeAzure { + if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") @@ -172,7 +173,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode != constant.RelayModeAudioSpeech { + if relayMode != relaymode.AudioSpeech { responseBody, err := io.ReadAll(resp.Body) if err != nil { return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) diff --git a/relay/controller/helper.go b/relay/controller/helper.go index d591984e..81bc57d9 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -10,8 +10,10 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/constant" relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/util" "math" "net/http" @@ -23,10 +25,10 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener if err != nil { return nil, err } - if relayMode == constant.RelayModeModerations && textRequest.Model == "" { + if relayMode == relaymode.Moderations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } - if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" { + if relayMode == relaymode.Embeddings && textRequest.Model == "" { textRequest.Model = c.Param("model") } err = util.ValidateTextRequest(textRequest, relayMode) @@ -86,7 +88,7 @@ func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.Rela // Number of generated images validation if !isWithinRange(imageRequest.Model, imageRequest.N) { // channel not azure - if meta.ChannelType != common.ChannelTypeAzure { + if meta.ChannelType != channeltype.Azure { return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) } } @@ -110,11 +112,11 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { - case constant.RelayModeChatCompletions: + case relaymode.ChatCompletions: return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) - case constant.RelayModeCompletions: + case relaymode.Completions: return openai.CountTokenInput(textRequest.Prompt, textRequest.Model) - case constant.RelayModeModerations: + case relaymode.Moderations: return openai.CountTokenInput(textRequest.Input, textRequest.Model) } return 0 diff --git a/relay/controller/image.go b/relay/controller/image.go index ee0c4495..b2af115a 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -11,6 +11,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" relaymodel "github.com/songquanpeng/one-api/relay/model" @@ -55,7 +56,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } var requestBody io.Reader - if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body + if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) @@ -71,11 +72,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } switch meta.ChannelType { - case common.ChannelTypeAli: + case channeltype.Ali: fallthrough - case common.ChannelTypeBaidu: + case channeltype.Baidu: fallthrough - case common.ChannelTypeZhipu: + case channeltype.Zhipu: finalRequest, err := adaptor.ConvertImageRequest(imageRequest) if err != nil { return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) diff --git a/relay/controller/text.go b/relay/controller/text.go index ba008713..6ec64614 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -7,8 +7,9 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/apitype" "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" @@ -53,9 +54,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { // get request body var requestBody io.Reader - if meta.APIType == constant.APITypeOpenAI { + if meta.APIType == apitype.OpenAI { // no need to convert request for openai - shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan + shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan if shouldResetRequestBody { jsonStr, err := json.Marshal(textRequest) if err != nil { diff --git a/relay/helper/main.go b/relay/helper/main.go index e7342329..c84392a3 100644 --- a/relay/helper/main.go +++ b/relay/helper/main.go @@ -1,6 +1,7 @@ package helper import ( + "github.com/songquanpeng/one-api/relay/apitype" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/aiproxy" "github.com/songquanpeng/one-api/relay/channel/ali" @@ -13,32 +14,31 @@ import ( "github.com/songquanpeng/one-api/relay/channel/tencent" "github.com/songquanpeng/one-api/relay/channel/xunfei" "github.com/songquanpeng/one-api/relay/channel/zhipu" - "github.com/songquanpeng/one-api/relay/constant" ) func GetAdaptor(apiType int) channel.Adaptor { switch apiType { - case constant.APITypeAIProxyLibrary: + case apitype.AIProxyLibrary: return &aiproxy.Adaptor{} - case constant.APITypeAli: + case apitype.Ali: return &ali.Adaptor{} - case constant.APITypeAnthropic: + case apitype.Anthropic: return &anthropic.Adaptor{} - case constant.APITypeBaidu: + case apitype.Baidu: return &baidu.Adaptor{} - case constant.APITypeGemini: + case apitype.Gemini: return &gemini.Adaptor{} - case constant.APITypeOpenAI: + case apitype.OpenAI: return &openai.Adaptor{} - case constant.APITypePaLM: + case apitype.PaLM: return &palm.Adaptor{} - case constant.APITypeTencent: + case apitype.Tencent: return &tencent.Adaptor{} - case constant.APITypeXunfei: + case apitype.Xunfei: return &xunfei.Adaptor{} - case constant.APITypeZhipu: + case apitype.Zhipu: return &zhipu.Adaptor{} - case constant.APITypeOllama: + case apitype.Ollama: return &ollama.Adaptor{} } return nil diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go new file mode 100644 index 00000000..96d09438 --- /dev/null +++ b/relay/relaymode/define.go @@ -0,0 +1,14 @@ +package relaymode + +const ( + Unknown = iota + ChatCompletions + Completions + Embeddings + Moderations + ImagesGenerations + Edits + AudioSpeech + AudioTranscription + AudioTranslation +) diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go new file mode 100644 index 00000000..926dd42e --- /dev/null +++ b/relay/relaymode/helper.go @@ -0,0 +1,29 @@ +package relaymode + +import "strings" + +func GetByPath(path string) int { + relayMode := Unknown + if strings.HasPrefix(path, "/v1/chat/completions") { + relayMode = ChatCompletions + } else if strings.HasPrefix(path, "/v1/completions") { + relayMode = Completions + } else if strings.HasPrefix(path, "/v1/embeddings") { + relayMode = Embeddings + } else if strings.HasSuffix(path, "embeddings") { + relayMode = Embeddings + } else if strings.HasPrefix(path, "/v1/moderations") { + relayMode = Moderations + } else if strings.HasPrefix(path, "/v1/images/generations") { + relayMode = ImagesGenerations + } else if strings.HasPrefix(path, "/v1/edits") { + relayMode = Edits + } else if strings.HasPrefix(path, "/v1/audio/speech") { + relayMode = AudioSpeech + } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { + relayMode = AudioTranscription + } else if strings.HasPrefix(path, "/v1/audio/translations") { + relayMode = AudioTranslation + } + return relayMode +} diff --git a/relay/util/common.go b/relay/util/common.go index 5d787204..3826e67f 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -8,6 +8,7 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/channeltype" relaymodel "github.com/songquanpeng/one-api/relay/model" "io" "net/http" @@ -155,9 +156,9 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { switch channelType { - case common.ChannelTypeOpenAI: + case channeltype.OpenAI: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - case common.ChannelTypeAzure: + case channeltype.Azure: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) } } diff --git a/relay/util/relay_meta.go b/relay/util/relay_meta.go index 31b9d2b4..481e33dc 100644 --- a/relay/util/relay_meta.go +++ b/relay/util/relay_meta.go @@ -3,7 +3,8 @@ package util import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/relaymode" "strings" ) @@ -30,7 +31,7 @@ type RelayMeta struct { func GetRelayMeta(c *gin.Context) *RelayMeta { meta := RelayMeta{ - Mode: constant.Path2RelayMode(c.Request.URL.Path), + Mode: relaymode.GetByPath(c.Request.URL.Path), ChannelType: c.GetInt("channel"), ChannelId: c.GetInt("channel_id"), TokenId: c.GetInt("token_id"), @@ -44,12 +45,12 @@ func GetRelayMeta(c *gin.Context) *RelayMeta { Config: nil, RequestURLPath: c.Request.URL.String(), } - if meta.ChannelType == common.ChannelTypeAzure { + if meta.ChannelType == channeltype.Azure { meta.APIVersion = GetAzureAPIVersion(c) } if meta.BaseURL == "" { meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType] } - meta.APIType = constant.ChannelType2APIType(meta.ChannelType) + meta.APIType = channeltype.ToAPIType(meta.ChannelType) return &meta } diff --git a/relay/util/validation.go b/relay/util/validation.go index ef8d840c..92ff55bc 100644 --- a/relay/util/validation.go +++ b/relay/util/validation.go @@ -2,8 +2,8 @@ package util import ( "errors" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "math" ) @@ -15,20 +15,20 @@ func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) return errors.New("model is required") } switch relayMode { - case constant.RelayModeCompletions: + case relaymode.Completions: if textRequest.Prompt == "" { return errors.New("field prompt is required") } - case constant.RelayModeChatCompletions: + case relaymode.ChatCompletions: if textRequest.Messages == nil || len(textRequest.Messages) == 0 { return errors.New("field messages is required") } - case constant.RelayModeEmbeddings: - case constant.RelayModeModerations: + case relaymode.Embeddings: + case relaymode.Moderations: if textRequest.Input == "" { return errors.New("field input is required") } - case constant.RelayModeEdits: + case relaymode.Edits: if textRequest.Instruction == "" { return errors.New("field instruction is required") }