diff --git a/controller/channel-test.go b/controller/channel-test.go index e9f0dc9c..52ede600 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -146,6 +146,20 @@ func testChannel(channel *model.Channel, request ChatRequest) error { req.Header.Set("X-Remote-Addr", ip) } + custom_http_headers := channel.CustomHttpHeaders + if custom_http_headers != "" { + var custom_http_headers_map map[string]string + err := json.Unmarshal([]byte(custom_http_headers), &custom_http_headers_map) + + if err != nil { + return err + } + + for key, value := range custom_http_headers_map { + req.Header.Set(key, value) + } + } + client := &http.Client{} resp, err := client.Do(req) if err != nil { diff --git a/controller/relay-image.go b/controller/relay-image.go index 7a37be80..3afa864a 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -109,6 +109,20 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) + custom_http_headers := c.GetString("custom_http_headers") + if custom_http_headers != "" { + var custom_http_headers_map map[string]string + err := json.Unmarshal([]byte(custom_http_headers), &custom_http_headers_map) + + if err != nil { + return errorWrapper(err, "unmarshal_custom_http_headers_failed", http.StatusInternalServerError) + } + + for key, value := range custom_http_headers_map { + req.Header.Set(key, value) + } + } + client := &http.Client{} resp, err := client.Do(req) if err != nil { diff --git a/controller/relay-text.go b/controller/relay-text.go index 4d404109..ef15d323 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -320,6 +320,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { req.Header.Set("X-Remote-Addr", ip) } + custom_http_headers := c.GetString("custom_http_headers") + if custom_http_headers != "" { + var custom_http_headers_map map[string]string + err := json.Unmarshal([]byte(custom_http_headers), &custom_http_headers_map) + + if err != nil { + return errorWrapper(err, "unmarshal_custom_http_headers_failed", http.StatusInternalServerError) + } + + for key, value := range custom_http_headers_map { + req.Header.Set(key, value) + } + } + client := &http.Client{} resp, err := client.Do(req) if err != nil { diff --git a/middleware/distributor.go b/middleware/distributor.go index 2b6ccbf9..a0582817 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -105,6 +105,7 @@ func Distribute() func(c *gin.Context) { c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("model_mapping", channel.ModelMapping) + c.Set("custom_http_headers", channel.CustomHttpHeaders) c.Set("enable_ip_randomization", channel.EnableIpRandomization) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.BaseURL) diff --git a/model/channel.go b/model/channel.go index fee4e964..bded0c78 100644 --- a/model/channel.go +++ b/model/channel.go @@ -26,7 +26,8 @@ type Channel struct { ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` // Additional fields, default value is false - EnableIpRandomization bool `json:"enable_ip_randomization"` + EnableIpRandomization bool `json:"enable_ip_randomization"` + CustomHttpHeaders string `json:"custom_http_headers"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { diff --git a/web/src/pages/Channel/EditChannel.jsx b/web/src/pages/Channel/EditChannel.jsx index 8918cf45..68681038 100644 --- a/web/src/pages/Channel/EditChannel.jsx +++ b/web/src/pages/Channel/EditChannel.jsx @@ -23,6 +23,10 @@ const MODEL_MAPPING_EXAMPLE = { 'gpt-4-32k-0314': 'gpt-4-32k', }; +const CUSTOM_HTTP_HEADERS_EXAMPLE = { + 'X-OpenAI-Organization': 'OpenAI', +}; + const EditChannel = () => { const params = useParams(); const channelId = params.id; @@ -35,6 +39,7 @@ const EditChannel = () => { base_url: '', other: '', model_mapping: '', + custom_http_headers: '', models: [], groups: ['default'], enable_ip_randomization: false, @@ -85,6 +90,13 @@ const EditChannel = () => { 2, ); } + if (data.custom_http_headers !== '') { + data.custom_http_headers = JSON.stringify( + JSON.parse(data.custom_http_headers), + null, + 2, + ); + } setInputs(data); } else { showError(message); @@ -153,6 +165,13 @@ const EditChannel = () => { showInfo('模型映射必须是合法的 JSON 格式!'); return; } + if ( + inputs.custom_http_headers !== '' && + !verifyJSON(inputs.custom_http_headers) + ) { + showInfo('自定义 HTTP 头必须是合法的 JSON 格式!'); + return; + } let localInputs = inputs; if (localInputs.base_url.endsWith('/')) { localInputs.base_url = localInputs.base_url.slice( @@ -394,6 +413,21 @@ const EditChannel = () => { autoComplete='new-password' /> + + + {batch ? (