add channel proxy

This commit is contained in:
Martial BE 2023-12-26 18:42:39 +08:00
parent eeb867da10
commit fb24d024a7
No known key found for this signature in database
GPG Key ID: D06C32DF0EDB9084
27 changed files with 181 additions and 33 deletions

View File

@ -6,23 +6,61 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"one-api/types"
"strconv"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/net/proxy"
)
var HttpClient *http.Client
var clientPool = &sync.Pool{
New: func() interface{} {
return &http.Client{}
},
}
func init() {
if RelayTimeout == 0 {
HttpClient = &http.Client{}
} else {
HttpClient = &http.Client{
Timeout: time.Duration(RelayTimeout) * time.Second,
func GetHttpClient(proxyAddr string) *http.Client {
client := clientPool.Get().(*http.Client)
if RelayTimeout > 0 {
client.Timeout = time.Duration(RelayTimeout) * time.Second
}
if proxyAddr != "" {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
SysError("Error parsing proxy address: " + err.Error())
return client
}
switch proxyURL.Scheme {
case "http", "https":
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
case "socks5":
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
if err != nil {
SysError("Error creating SOCKS5 dialer: " + err.Error())
return client
}
client.Transport = &http.Transport{
Dial: dialer.Dial,
}
default:
SysError("Unsupported proxy scheme: " + proxyURL.Scheme)
}
}
return client
}
func PutHttpClient(c *http.Client) {
clientPool.Put(c)
}
type Client struct {
@ -92,12 +130,14 @@ func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http
return req, nil
}
func SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
func SendRequest(req *http.Request, response any, outputResp bool, proxyAddr string) (*http.Response, *types.OpenAIErrorWithStatusCode) {
// 发送请求
resp, err := HttpClient.Do(req)
client := GetHttpClient(proxyAddr)
resp, err := client.Do(req)
if err != nil {
return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
PutHttpClient(client)
if !outputResp {
defer resp.Body.Close()
@ -210,8 +250,10 @@ func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.Open
return
}
func (c *Client) SendRequestRaw(req *http.Request) (body io.ReadCloser, err error) {
resp, err := HttpClient.Do(req)
func (c *Client) SendRequestRaw(req *http.Request, proxyAddr string) (body io.ReadCloser, err error) {
client := GetHttpClient(proxyAddr)
resp, err := client.Do(req)
PutHttpClient(client)
if err != nil {
return
}

2
go.mod
View File

@ -58,7 +58,7 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/net v0.19.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect

2
go.sum
View File

@ -159,6 +159,8 @@ golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@ -25,6 +25,7 @@ type Channel struct {
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Proxy string `json:"proxy" gorm:"type:varchar(255);default:''"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {

View File

@ -19,7 +19,7 @@ func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) {
// 发送请求
var response base.BalanceResponse
_, errWithCode := common.SendRequest(req, &response, false)
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}

View File

@ -20,7 +20,7 @@ func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) {
// 发送请求
var response AIProxyUserOverviewResponse
_, errWithCode := common.SendRequest(req, &response, false)
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}

View File

@ -157,10 +157,12 @@ func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage
usage = &types.Usage{}
// 发送请求
resp, err := common.HttpClient.Do(req)
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return nil, common.HandleErrorResp(resp)

View File

@ -19,7 +19,7 @@ func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
// 发送请求
var response base.BalanceResponse
_, errWithCode := common.SendRequest(req, &response, false)
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}

View File

@ -19,7 +19,7 @@ func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) {
// 发送请求
var response base.BalanceResponse
_, errWithCode := common.SendRequest(req, &response, false)
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}

View File

@ -38,7 +38,7 @@ func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIRespons
for i := 0; i < 3; i++ {
// 休眠 2 秒
time.Sleep(2 * time.Second)
_, errWithCode = common.SendRequest(req, &getImageAzureResponse, false)
_, errWithCode = common.SendRequest(req, &getImageAzureResponse, false, c.Proxy)
fmt.Println("getImageAzureResponse", getImageAzureResponse)
if errWithCode != nil {
return
@ -81,6 +81,7 @@ func (p *AzureProvider) ImageGenerationsAction(request *types.ImageRequest, isMo
if request.Model == "dall-e-2" {
imageAzureResponse := &ImageAzureResponse{
Header: headers,
Proxy: p.Channel.Proxy,
}
errWithCode = p.SendRequest(req, imageAzureResponse, false)
} else {

View File

@ -10,6 +10,7 @@ type ImageAzureResponse struct {
Status string `json:"status,omitempty"`
Error ImageAzureError `json:"error,omitempty"`
Header map[string]string `json:"header,omitempty"`
Proxy string `json:"proxy,omitempty"`
}
type ImageAzureError struct {

View File

@ -105,10 +105,12 @@ func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessTo
return nil, err
}
resp, err := common.HttpClient.Do(req)
httpClient := common.GetHttpClient(p.Channel.Proxy)
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
common.PutHttpClient(httpClient)
defer resp.Body.Close()

View File

@ -130,10 +130,12 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usag
usage = &types.Usage{}
// 发送请求
resp, err := common.HttpClient.Do(req)
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return nil, common.HandleErrorResp(resp)

View File

@ -65,7 +65,7 @@ func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) {
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close()
resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true)
resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true, p.Channel.Proxy)
if openAIErrorWithStatusCode != nil {
return
}
@ -108,10 +108,12 @@ func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusC
defer req.Body.Close()
// 发送请求
resp, err := common.HttpClient.Do(req)
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
common.PutHttpClient(client)
defer resp.Body.Close()

View File

@ -142,10 +142,12 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
defer req.Body.Close()
// 发送请求
resp, err := common.HttpClient.Do(req)
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""

View File

@ -18,7 +18,7 @@ func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error)
// 发送请求
var response OpenAICreditGrants
_, errWithCode := common.SendRequest(req, &response, false)
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}

View File

@ -217,10 +217,12 @@ func (p *GeminiProvider) sendStreamRequest(req *http.Request, model string) (*ty
defer req.Body.Close()
// 发送请求
resp, err := common.HttpClient.Do(req)
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""

View File

@ -20,7 +20,7 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
// 发送请求
var subscription OpenAISubscriptionResponse
_, errWithCode := common.SendRequest(req, &subscription, false)
_, errWithCode := common.SendRequest(req, &subscription, false, p.Channel.Proxy)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
@ -38,7 +38,7 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
return 0, err
}
usage := OpenAIUsageResponse{}
_, errWithCode = common.SendRequest(req, &usage, false)
_, errWithCode = common.SendRequest(req, &usage, false, p.Channel.Proxy)
balance := subscription.HardLimitUSD - usage.TotalUsage/100
channel.UpdateBalance(balance)

View File

@ -111,10 +111,12 @@ func (p *OpenAIProvider) GetRequestBody(request any, isModelMapped bool) (reques
func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
defer req.Body.Close()
resp, err := common.HttpClient.Do(req)
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""

View File

@ -21,7 +21,7 @@ func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
// 发送请求
var response OpenAISBUsageResponse
_, errWithCode := common.SendRequest(req, &response, false)
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
if err != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}

View File

@ -134,10 +134,12 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW
defer req.Body.Close()
// 发送请求
resp, err := common.HttpClient.Do(req)
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""

View File

@ -147,10 +147,12 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC
func (p *TencentProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) {
defer req.Body.Close()
// 发送请求
resp, err := common.HttpClient.Do(req)
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""

View File

@ -145,10 +145,12 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request, model string) (*typ
defer req.Body.Close()
// 发送请求
resp, err := common.HttpClient.Do(req)
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), nil

View File

@ -35,6 +35,7 @@ const validationSchema = Yup.object().shape({
type: Yup.number().required('渠道 不能为空'),
key: Yup.string().when('is_edit', { is: false, then: Yup.string().required('密钥 不能为空') }),
other: Yup.string(),
proxy: Yup.string(),
models: Yup.array().min(1, '模型 不能为空'),
groups: Yup.array().min(1, '用户组 不能为空'),
base_url: Yup.string().when('type', {
@ -442,6 +443,27 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
<FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.proxy && errors.proxy)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-proxy-label">{inputLabel.proxy}</InputLabel>
<OutlinedInput
id="channel-proxy-label"
label={inputLabel.proxy}
type="text"
value={values.proxy}
name="proxy"
onBlur={handleBlur}
onChange={handleChange}
inputProps={{}}
aria-describedby="helper-text-channel-proxy-label"
/>
{touched.proxy && errors.proxy ? (
<FormHelperText error id="helper-tex-channel-proxy-label">
{errors.proxy}
</FormHelperText>
) : (
<FormHelperText id="helper-tex-channel-proxy-label"> {inputPrompt.proxy} </FormHelperText>
)}
</FormControl>
<DialogActions>
<Button onClick={onCancel}>取消</Button>
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">

View File

@ -0,0 +1,53 @@
import PropTypes from 'prop-types';
import { Tooltip, Stack, Container } from '@mui/material';
import Label from 'ui-component/Label';
import { styled } from '@mui/material/styles';
import { showSuccess } from 'utils/common';
const TooltipContainer = styled(Container)({
maxHeight: '250px',
overflow: 'auto',
'&::-webkit-scrollbar': {
width: '0px' // Set the width to 0 to hide the scrollbar
}
});
const NameLabel = ({ name, models }) => {
let modelMap = [];
modelMap = models.split(',');
modelMap.sort();
return (
<Tooltip
title={
<TooltipContainer>
<Stack spacing={1}>
{modelMap.map((item, index) => {
return (
<Label
variant="ghost"
key={index}
onClick={() => {
navigator.clipboard.writeText(item);
showSuccess('复制模型名称成功!');
}}
>
{item}
</Label>
);
})}
</Stack>
</TooltipContainer>
}
placement="top"
>
{name}
</Tooltip>
);
};
NameLabel.propTypes = {
group: PropTypes.string
};
export default NameLabel;

View File

@ -29,6 +29,7 @@ import TableSwitch from 'ui-component/Switch';
import ResponseTimeLabel from './ResponseTimeLabel';
import GroupLabel from './GroupLabel';
import NameLabel from './NameLabel';
import { IconDotsVertical, IconEdit, IconTrash, IconPencil } from '@tabler/icons-react';
@ -102,7 +103,9 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal,
<TableRow tabIndex={item.id}>
<TableCell>{item.id}</TableCell>
<TableCell>{item.name}</TableCell>
<TableCell>
<NameLabel name={item.name} models={item.models} />
</TableCell>
<TableCell>
<GroupLabel group={item.group} />

View File

@ -5,6 +5,7 @@ const defaultConfig = {
key: '',
base_url: '',
other: '',
proxy: '',
model_mapping: '',
models: [],
groups: ['default']
@ -15,6 +16,7 @@ const defaultConfig = {
base_url: '渠道API地址',
key: '密钥',
other: '其他参数',
proxy: '代理地址',
models: '模型',
model_mapping: '模型映射关系',
groups: '用户组'
@ -25,6 +27,7 @@ const defaultConfig = {
base_url: '可空请输入中转API地址例如通过cloudflare中转',
key: '请输入渠道对应的鉴权密钥',
other: '',
proxy: '单独设置代理地址支持http和socks5例如http://127.0.0.1:1080',
models: '请选择该渠道所支持的模型',
model_mapping:
'请输入要修改的模型映射关系格式为api请求模型ID:实际转发给渠道的模型ID使用JSON数组表示例如{"gpt-3.5": "gpt-35"}',