From f70506eac1afe8a8cc55ba6fbc8ece6f9247cca0 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 6 Apr 2024 01:31:44 +0800 Subject: [PATCH] chore: reorganize relay related package --- controller/channel-test.go | 3 ++- controller/model.go | 4 ++-- relay/channel/aiproxy/adaptor.go | 12 ++++++------ relay/channel/ali/adaptor.go | 12 ++++++------ relay/channel/anthropic/adaptor.go | 12 ++++++------ relay/channel/azure/helper.go | 15 +++++++++++++++ relay/channel/baidu/adaptor.go | 12 ++++++------ relay/channel/common.go | 5 +++-- relay/channel/gemini/adaptor.go | 12 ++++++------ relay/channel/interface.go | 12 ++++++------ relay/channel/minimax/main.go | 4 ++-- relay/channel/ollama/adaptor.go | 12 ++++++------ relay/channel/openai/adaptor.go | 11 ++++++----- relay/channel/palm/adaptor.go | 12 ++++++------ relay/channel/tencent/adaptor.go | 12 ++++++------ relay/channel/xunfei/adaptor.go | 12 ++++++------ relay/channel/zhipu/adaptor.go | 14 +++++++------- relay/controller/audio.go | 3 ++- relay/controller/helper.go | 7 ++++--- relay/controller/image.go | 3 ++- relay/controller/text.go | 3 ++- relay/{util => meta}/relay_meta.go | 11 ++++++----- relay/util/common.go | 12 ------------ 23 files changed, 113 insertions(+), 102 deletions(-) create mode 100644 relay/channel/azure/helper.go rename relay/{util => meta}/relay_meta.go (88%) diff --git a/controller/channel-test.go b/controller/channel-test.go index 8f7cb17c..d9d6a73c 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -14,6 +14,7 @@ import ( "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/helper" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/util" @@ -57,7 +58,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error c.Set("channel", channel.Type) c.Set("base_url", channel.GetBaseURL()) middleware.SetupContextForSelectedChannel(c, channel, "") - meta := util.GetRelayMeta(c) + meta := meta.GetByContext(c) apiType := channeltype.ToAPIType(channel.Type) adaptor := helper.GetAdaptor(apiType) if adaptor == nil { diff --git a/controller/model.go b/controller/model.go index b8002e6f..dadafb5d 100644 --- a/controller/model.go +++ b/controller/model.go @@ -8,8 +8,8 @@ import ( "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/helper" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "net/http" "strings" ) @@ -105,7 +105,7 @@ func init() { channelId2Models = make(map[int][]string) for i := 1; i < channeltype.Dummy; i++ { adaptor := helper.GetAdaptor(channeltype.ToAPIType(i)) - meta := &util.RelayMeta{ + meta := &meta.Meta{ ChannelType: i, } adaptor.Init(meta) diff --git a/relay/channel/aiproxy/adaptor.go b/relay/channel/aiproxy/adaptor.go index 386c5ba4..9132fe60 100644 --- a/relay/channel/aiproxy/adaptor.go +++ b/relay/channel/aiproxy/adaptor.go @@ -6,8 +6,8 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) @@ -15,15 +15,15 @@ import ( type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channel.SetupCommonRequestHeader(c, req, meta) req.Header.Set("Authorization", "Bearer "+meta.APIKey) return nil @@ -45,11 +45,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = StreamHandler(c, resp) } else { diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index d46c082f..f69f9561 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -6,9 +6,9 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) @@ -18,11 +18,11 @@ import ( type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { fullRequestURL := "" switch meta.Mode { case relaymode.Embeddings: @@ -36,7 +36,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { return fullRequestURL, nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channel.SetupCommonRequestHeader(c, req, meta) if meta.IsStream { req.Header.Set("Accept", "text/event-stream") @@ -76,11 +76,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return aliRequest, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = StreamHandler(c, resp) } else { diff --git a/relay/channel/anthropic/adaptor.go b/relay/channel/anthropic/adaptor.go index 12e01c52..6b1bc0f1 100644 --- a/relay/channel/anthropic/adaptor.go +++ b/relay/channel/anthropic/adaptor.go @@ -5,8 +5,8 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) @@ -14,15 +14,15 @@ import ( type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channel.SetupCommonRequestHeader(c, req, meta) req.Header.Set("x-api-key", meta.APIKey) anthropicVersion := c.Request.Header.Get("anthropic-version") @@ -48,11 +48,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = StreamHandler(c, resp) } else { diff --git a/relay/channel/azure/helper.go b/relay/channel/azure/helper.go new file mode 100644 index 00000000..29004d27 --- /dev/null +++ b/relay/channel/azure/helper.go @@ -0,0 +1,15 @@ +package azure + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" +) + +func GetAPIVersion(c *gin.Context) string { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString(common.ConfigKeyAPIVersion) + } + return apiVersion +} diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index c2388dc1..fde63340 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -3,6 +3,7 @@ package baidu import ( "errors" "fmt" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" @@ -11,17 +12,16 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" ) type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t suffix := "chat/" if strings.HasPrefix(meta.ActualModelName, "Embedding") { @@ -89,7 +89,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { return fullRequestURL, nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channel.SetupCommonRequestHeader(c, req, meta) req.Header.Set("Authorization", "Bearer "+meta.APIKey) return nil @@ -116,11 +116,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = StreamHandler(c, resp) } else { diff --git a/relay/channel/common.go b/relay/channel/common.go index c6e1abf2..794bd985 100644 --- a/relay/channel/common.go +++ b/relay/channel/common.go @@ -4,12 +4,13 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) -func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) { +func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) { req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) if meta.IsStream && c.Request.Header.Get("Accept") == "" { @@ -17,7 +18,7 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.Rela } } -func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { fullRequestURL, err := a.GetRequestURL(meta) if err != nil { return nil, fmt.Errorf("get request url failed: %w", err) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 30002c53..685c7e28 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -7,8 +7,8 @@ import ( "github.com/songquanpeng/one-api/common/helper" channelhelper "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) @@ -16,11 +16,11 @@ import ( type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { version := helper.AssignOrDefault(meta.APIVersion, "v1") action := "generateContent" if meta.IsStream { @@ -29,7 +29,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channelhelper.SetupCommonRequestHeader(c, req, meta) req.Header.Set("x-goog-api-key", meta.APIKey) return nil @@ -49,11 +49,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channelhelper.DoRequestHelper(a, c, meta, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string err, responseText = StreamHandler(c, resp) diff --git a/relay/channel/interface.go b/relay/channel/interface.go index 78d6ace1..5331c8cf 100644 --- a/relay/channel/interface.go +++ b/relay/channel/interface.go @@ -2,20 +2,20 @@ package channel import ( "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) type Adaptor interface { - Init(meta *util.RelayMeta) - GetRequestURL(meta *util.RelayMeta) (string, error) - SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error + Init(meta *meta.Meta) + GetRequestURL(meta *meta.Meta) (string, error) + SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) ConvertImageRequest(request *model.ImageRequest) (any, error) - DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) - DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) + DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) + DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) GetModelList() []string GetChannelName() string } diff --git a/relay/channel/minimax/main.go b/relay/channel/minimax/main.go index 4a0c9e0f..68741890 100644 --- a/relay/channel/minimax/main.go +++ b/relay/channel/minimax/main.go @@ -2,11 +2,11 @@ package minimax import ( "fmt" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/relaymode" - "github.com/songquanpeng/one-api/relay/util" ) -func GetRequestURL(meta *util.RelayMeta) (string, error) { +func GetRequestURL(meta *meta.Meta) (string, error) { if meta.Mode == relaymode.ChatCompletions { return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 50fb3ca3..b8790fd0 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -3,6 +3,7 @@ package ollama import ( "errors" "fmt" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" @@ -10,17 +11,16 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" ) type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { // https://github.com/ollama/ollama/blob/main/docs/api.md fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) if meta.Mode == relaymode.Embeddings { @@ -29,7 +29,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { return fullRequestURL, nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channel.SetupCommonRequestHeader(c, req, meta) req.Header.Set("Authorization", "Bearer "+meta.APIKey) return nil @@ -55,11 +55,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = StreamHandler(c, resp) } else { diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 9c2e8408..e8dd59fd 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -7,6 +7,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/minimax" "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/util" @@ -19,11 +20,11 @@ type Adaptor struct { ChannelType int } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { a.ChannelType = meta.ChannelType } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { switch meta.ChannelType { case channeltype.Azure: if meta.Mode == relaymode.ImagesGenerations { @@ -50,7 +51,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { } } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channel.SetupCommonRequestHeader(c, req, meta) if meta.ChannelType == channeltype.Azure { req.Header.Set("api-key", meta.APIKey) @@ -78,11 +79,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string err, responseText, usage = StreamHandler(c, resp, meta.Mode) diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 8d3c9b0b..d0904ce2 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -6,8 +6,8 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) @@ -15,15 +15,15 @@ import ( type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channel.SetupCommonRequestHeader(c, req, meta) req.Header.Set("x-goog-api-key", meta.APIKey) return nil @@ -43,11 +43,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string err, responseText = StreamHandler(c, resp) diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 84ed8527..65aedb1f 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -6,8 +6,8 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "strings" @@ -19,15 +19,15 @@ type Adaptor struct { Sign string } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channel.SetupCommonRequestHeader(c, req, meta) req.Header.Set("Authorization", a.Sign) req.Header.Set("X-TC-Action", meta.ActualModelName) @@ -59,11 +59,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string err, responseText = StreamHandler(c, resp) diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 8d83ff1b..0d51d7a8 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -5,8 +5,8 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "strings" @@ -16,15 +16,15 @@ type Adaptor struct { request *model.GeneralOpenAIRequest } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return "", nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channel.SetupCommonRequestHeader(c, req, meta) // check DoResponse for auth part return nil @@ -45,14 +45,14 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} dummyResp.StatusCode = http.StatusOK return dummyResp, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { splits := strings.Split(meta.APIKey, "|") if len(splits) != 3 { return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 61c40e14..774a98ef 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -6,9 +6,9 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "github.com/songquanpeng/one-api/relay/util" "io" "math" "net/http" @@ -19,7 +19,7 @@ type Adaptor struct { APIVersion string } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } @@ -31,7 +31,7 @@ func (a *Adaptor) SetVersionByModeName(modelName string) { } } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { switch meta.Mode { case relaymode.ImagesGenerations: return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil @@ -49,7 +49,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channel.SetupCommonRequestHeader(c, req, meta) token := GetToken(meta.APIKey) req.Header.Set("Authorization", token) @@ -92,11 +92,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return newRequest, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } -func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, _, usage = openai.StreamHandler(c, resp, meta.Mode) } else { @@ -105,7 +105,7 @@ func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.R return } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { switch meta.Mode { case relaymode.Embeddings: err, usage = EmbeddingsHandler(c, resp) diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 912f7a3e..9b02d34a 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -14,6 +14,7 @@ import ( "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channel/azure" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channeltype" relaymodel "github.com/songquanpeng/one-api/relay/model" @@ -126,7 +127,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) if channelType == channeltype.Azure { - apiVersion := util.GetAzureAPIVersion(c) + apiVersion := azure.GetAPIVersion(c) if relayMode == relaymode.AudioTranscription { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 4444bd51..f224051f 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -12,6 +12,7 @@ import ( billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/util" @@ -72,7 +73,7 @@ func getImageSizeRatio(model string, size string) float64 { return ratio } -func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode { +func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { // model validation hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size) if !hasValidSize { @@ -130,7 +131,7 @@ func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTok return int64(float64(preConsumedTokens) * ratio) } -func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) { +func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *meta.Meta) (int64, *relaymodel.ErrorWithStatusCode) { preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) @@ -159,7 +160,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR return preConsumedQuota, nil } -func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return diff --git a/relay/controller/image.go b/relay/controller/image.go index 18e864f5..0aa96308 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -13,6 +13,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/helper" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -30,7 +31,7 @@ func isWithinRange(element string, value int) bool { func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() - meta := util.GetRelayMeta(c) + meta := meta.GetByContext(c) imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) diff --git a/relay/controller/text.go b/relay/controller/text.go index 068cef8d..dfa7aed7 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -12,6 +12,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/helper" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -21,7 +22,7 @@ import ( func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { ctx := c.Request.Context() - meta := util.GetRelayMeta(c) + meta := meta.GetByContext(c) // get & validate textRequest textRequest, err := getAndValidateTextRequest(c, meta.Mode) if err != nil { diff --git a/relay/util/relay_meta.go b/relay/meta/relay_meta.go similarity index 88% rename from relay/util/relay_meta.go rename to relay/meta/relay_meta.go index 481e33dc..f42b9d4a 100644 --- a/relay/util/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -1,14 +1,15 @@ -package util +package meta import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/relay/channel/azure" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/relaymode" "strings" ) -type RelayMeta struct { +type Meta struct { Mode int ChannelType int ChannelId int @@ -29,8 +30,8 @@ type RelayMeta struct { PromptTokens int // only for DoResponse } -func GetRelayMeta(c *gin.Context) *RelayMeta { - meta := RelayMeta{ +func GetByContext(c *gin.Context) *Meta { + meta := Meta{ Mode: relaymode.GetByPath(c.Request.URL.Path), ChannelType: c.GetInt("channel"), ChannelId: c.GetInt("channel_id"), @@ -46,7 +47,7 @@ func GetRelayMeta(c *gin.Context) *RelayMeta { RequestURLPath: c.Request.URL.String(), } if meta.ChannelType == channeltype.Azure { - meta.APIVersion = GetAzureAPIVersion(c) + meta.APIVersion = azure.GetAPIVersion(c) } if meta.BaseURL == "" { meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType] diff --git a/relay/util/common.go b/relay/util/common.go index 315f1253..b2f162be 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -3,7 +3,6 @@ package util import ( "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channeltype" @@ -12,8 +11,6 @@ import ( "net/http" "strconv" "strings" - - "github.com/gin-gonic/gin" ) func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { @@ -162,12 +159,3 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin } return fullRequestURL } - -func GetAzureAPIVersion(c *gin.Context) string { - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString(common.ConfigKeyAPIVersion) - } - return apiVersion -}