feat: support custom http header
This commit is contained in:
parent
cf564f36fa
commit
66e02a4bcf
@ -146,6 +146,20 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
|||||||
req.Header.Set("X-Remote-Addr", ip)
|
req.Header.Set("X-Remote-Addr", ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
custom_http_headers := channel.CustomHttpHeaders
|
||||||
|
if custom_http_headers != "" {
|
||||||
|
var custom_http_headers_map map[string]string
|
||||||
|
err := json.Unmarshal([]byte(custom_http_headers), &custom_http_headers_map)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range custom_http_headers_map {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -109,6 +109,20 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
|
|
||||||
|
custom_http_headers := c.GetString("custom_http_headers")
|
||||||
|
if custom_http_headers != "" {
|
||||||
|
var custom_http_headers_map map[string]string
|
||||||
|
err := json.Unmarshal([]byte(custom_http_headers), &custom_http_headers_map)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_custom_http_headers_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range custom_http_headers_map {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -320,6 +320,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
req.Header.Set("X-Remote-Addr", ip)
|
req.Header.Set("X-Remote-Addr", ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
custom_http_headers := c.GetString("custom_http_headers")
|
||||||
|
if custom_http_headers != "" {
|
||||||
|
var custom_http_headers_map map[string]string
|
||||||
|
err := json.Unmarshal([]byte(custom_http_headers), &custom_http_headers_map)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_custom_http_headers_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range custom_http_headers_map {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -105,6 +105,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
c.Set("channel_id", channel.Id)
|
c.Set("channel_id", channel.Id)
|
||||||
c.Set("channel_name", channel.Name)
|
c.Set("channel_name", channel.Name)
|
||||||
c.Set("model_mapping", channel.ModelMapping)
|
c.Set("model_mapping", channel.ModelMapping)
|
||||||
|
c.Set("custom_http_headers", channel.CustomHttpHeaders)
|
||||||
c.Set("enable_ip_randomization", channel.EnableIpRandomization)
|
c.Set("enable_ip_randomization", channel.EnableIpRandomization)
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
c.Set("base_url", channel.BaseURL)
|
c.Set("base_url", channel.BaseURL)
|
||||||
|
@ -26,7 +26,8 @@ type Channel struct {
|
|||||||
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||||
|
|
||||||
// Additional fields, default value is false
|
// Additional fields, default value is false
|
||||||
EnableIpRandomization bool `json:"enable_ip_randomization"`
|
EnableIpRandomization bool `json:"enable_ip_randomization"`
|
||||||
|
CustomHttpHeaders string `json:"custom_http_headers"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||||
|
@ -23,6 +23,10 @@ const MODEL_MAPPING_EXAMPLE = {
|
|||||||
'gpt-4-32k-0314': 'gpt-4-32k',
|
'gpt-4-32k-0314': 'gpt-4-32k',
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const CUSTOM_HTTP_HEADERS_EXAMPLE = {
|
||||||
|
'X-OpenAI-Organization': 'OpenAI',
|
||||||
|
};
|
||||||
|
|
||||||
const EditChannel = () => {
|
const EditChannel = () => {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
const channelId = params.id;
|
const channelId = params.id;
|
||||||
@ -35,6 +39,7 @@ const EditChannel = () => {
|
|||||||
base_url: '',
|
base_url: '',
|
||||||
other: '',
|
other: '',
|
||||||
model_mapping: '',
|
model_mapping: '',
|
||||||
|
custom_http_headers: '',
|
||||||
models: [],
|
models: [],
|
||||||
groups: ['default'],
|
groups: ['default'],
|
||||||
enable_ip_randomization: false,
|
enable_ip_randomization: false,
|
||||||
@ -85,6 +90,13 @@ const EditChannel = () => {
|
|||||||
2,
|
2,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
if (data.custom_http_headers !== '') {
|
||||||
|
data.custom_http_headers = JSON.stringify(
|
||||||
|
JSON.parse(data.custom_http_headers),
|
||||||
|
null,
|
||||||
|
2,
|
||||||
|
);
|
||||||
|
}
|
||||||
setInputs(data);
|
setInputs(data);
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
@ -153,6 +165,13 @@ const EditChannel = () => {
|
|||||||
showInfo('模型映射必须是合法的 JSON 格式!');
|
showInfo('模型映射必须是合法的 JSON 格式!');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (
|
||||||
|
inputs.custom_http_headers !== '' &&
|
||||||
|
!verifyJSON(inputs.custom_http_headers)
|
||||||
|
) {
|
||||||
|
showInfo('自定义 HTTP 头必须是合法的 JSON 格式!');
|
||||||
|
return;
|
||||||
|
}
|
||||||
let localInputs = inputs;
|
let localInputs = inputs;
|
||||||
if (localInputs.base_url.endsWith('/')) {
|
if (localInputs.base_url.endsWith('/')) {
|
||||||
localInputs.base_url = localInputs.base_url.slice(
|
localInputs.base_url = localInputs.base_url.slice(
|
||||||
@ -394,6 +413,21 @@ const EditChannel = () => {
|
|||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
|
<Form.Field>
|
||||||
|
<Form.TextArea
|
||||||
|
label='自定义 HTTP 头'
|
||||||
|
placeholder={`此项可选,为一个 JSON 文本,键为 HTTP 头名称,值为 HTTP 头内容,例如:\n${JSON.stringify(
|
||||||
|
CUSTOM_HTTP_HEADERS_EXAMPLE,
|
||||||
|
null,
|
||||||
|
2,
|
||||||
|
)}`}
|
||||||
|
name='custom_http_headers'
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.custom_http_headers}
|
||||||
|
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
{batch ? (
|
{batch ? (
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.TextArea
|
<Form.TextArea
|
||||||
|
Loading…
Reference in New Issue
Block a user