feat: support custom http header
This commit is contained in:
parent
cf564f36fa
commit
66e02a4bcf
@ -146,6 +146,20 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
||||
req.Header.Set("X-Remote-Addr", ip)
|
||||
}
|
||||
|
||||
custom_http_headers := channel.CustomHttpHeaders
|
||||
if custom_http_headers != "" {
|
||||
var custom_http_headers_map map[string]string
|
||||
err := json.Unmarshal([]byte(custom_http_headers), &custom_http_headers_map)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for key, value := range custom_http_headers_map {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
|
@ -109,6 +109,20 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
|
||||
custom_http_headers := c.GetString("custom_http_headers")
|
||||
if custom_http_headers != "" {
|
||||
var custom_http_headers_map map[string]string
|
||||
err := json.Unmarshal([]byte(custom_http_headers), &custom_http_headers_map)
|
||||
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_custom_http_headers_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
for key, value := range custom_http_headers_map {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
|
@ -320,6 +320,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
req.Header.Set("X-Remote-Addr", ip)
|
||||
}
|
||||
|
||||
custom_http_headers := c.GetString("custom_http_headers")
|
||||
if custom_http_headers != "" {
|
||||
var custom_http_headers_map map[string]string
|
||||
err := json.Unmarshal([]byte(custom_http_headers), &custom_http_headers_map)
|
||||
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_custom_http_headers_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
for key, value := range custom_http_headers_map {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
|
@ -105,6 +105,7 @@ func Distribute() func(c *gin.Context) {
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
c.Set("model_mapping", channel.ModelMapping)
|
||||
c.Set("custom_http_headers", channel.CustomHttpHeaders)
|
||||
c.Set("enable_ip_randomization", channel.EnableIpRandomization)
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set("base_url", channel.BaseURL)
|
||||
|
@ -26,7 +26,8 @@ type Channel struct {
|
||||
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||
|
||||
// Additional fields, default value is false
|
||||
EnableIpRandomization bool `json:"enable_ip_randomization"`
|
||||
EnableIpRandomization bool `json:"enable_ip_randomization"`
|
||||
CustomHttpHeaders string `json:"custom_http_headers"`
|
||||
}
|
||||
|
||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||
|
@ -23,6 +23,10 @@ const MODEL_MAPPING_EXAMPLE = {
|
||||
'gpt-4-32k-0314': 'gpt-4-32k',
|
||||
};
|
||||
|
||||
const CUSTOM_HTTP_HEADERS_EXAMPLE = {
|
||||
'X-OpenAI-Organization': 'OpenAI',
|
||||
};
|
||||
|
||||
const EditChannel = () => {
|
||||
const params = useParams();
|
||||
const channelId = params.id;
|
||||
@ -35,6 +39,7 @@ const EditChannel = () => {
|
||||
base_url: '',
|
||||
other: '',
|
||||
model_mapping: '',
|
||||
custom_http_headers: '',
|
||||
models: [],
|
||||
groups: ['default'],
|
||||
enable_ip_randomization: false,
|
||||
@ -85,6 +90,13 @@ const EditChannel = () => {
|
||||
2,
|
||||
);
|
||||
}
|
||||
if (data.custom_http_headers !== '') {
|
||||
data.custom_http_headers = JSON.stringify(
|
||||
JSON.parse(data.custom_http_headers),
|
||||
null,
|
||||
2,
|
||||
);
|
||||
}
|
||||
setInputs(data);
|
||||
} else {
|
||||
showError(message);
|
||||
@ -153,6 +165,13 @@ const EditChannel = () => {
|
||||
showInfo('模型映射必须是合法的 JSON 格式!');
|
||||
return;
|
||||
}
|
||||
if (
|
||||
inputs.custom_http_headers !== '' &&
|
||||
!verifyJSON(inputs.custom_http_headers)
|
||||
) {
|
||||
showInfo('自定义 HTTP 头必须是合法的 JSON 格式!');
|
||||
return;
|
||||
}
|
||||
let localInputs = inputs;
|
||||
if (localInputs.base_url.endsWith('/')) {
|
||||
localInputs.base_url = localInputs.base_url.slice(
|
||||
@ -394,6 +413,21 @@ const EditChannel = () => {
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
</Form.Field>
|
||||
<Form.Field>
|
||||
<Form.TextArea
|
||||
label='自定义 HTTP 头'
|
||||
placeholder={`此项可选,为一个 JSON 文本,键为 HTTP 头名称,值为 HTTP 头内容,例如:\n${JSON.stringify(
|
||||
CUSTOM_HTTP_HEADERS_EXAMPLE,
|
||||
null,
|
||||
2,
|
||||
)}`}
|
||||
name='custom_http_headers'
|
||||
onChange={handleInputChange}
|
||||
value={inputs.custom_http_headers}
|
||||
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
</Form.Field>
|
||||
{batch ? (
|
||||
<Form.Field>
|
||||
<Form.TextArea
|
||||
|
Loading…
Reference in New Issue
Block a user