diff --git a/README.md b/README.md index e39b8d46..fdca3a31 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用 + [x] [OpenAI-SB](https://openai-sb.com) + [x] [OpenAI Max](https://openaimax.com) + [x] [OhMyGPT](https://www.ohmygpt.com) + + [x] 自定义渠道 2. 支持通过负载均衡的方式访问多个渠道。 3. 支持单个访问渠道设置多个 API Key,利用起来你的多个 API Key。 4. 支持 HTTP SSE。 diff --git a/common/constants.go b/common/constants.go index ef1f7031..987f6b08 100644 --- a/common/constants.go +++ b/common/constants.go @@ -106,6 +106,7 @@ const ( ChannelTypeOpenAISB = 5 ChannelTypeOpenAIMax = 6 ChannelTypeOhMyGPT = 7 + ChannelTypeCustom = 8 ) var ChannelBaseURLs = []string{ @@ -117,4 +118,5 @@ var ChannelBaseURLs = []string{ "https://api.openai-sb.com", // 5 "https://api.openaimax.com", // 6 "https://api.ohmygpt.com", // 7 + "", // 8 } diff --git a/controller/relay.go b/controller/relay.go index 04e10e35..22645e2d 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -11,6 +11,9 @@ import ( func Relay(c *gin.Context) { channelType := c.GetInt("channel") baseURL := common.ChannelBaseURLs[channelType] + if channelType == common.ChannelTypeCustom { + baseURL = c.GetString("base_url") + } req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, c.Request.URL.String()), c.Request.Body) if err != nil { c.JSON(http.StatusOK, gin.H{ diff --git a/middleware/distributor.go b/middleware/distributor.go index 08193f8f..18df9d48 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -63,6 +63,9 @@ func Distribute() func(c *gin.Context) { } c.Set("channel", channel.Type) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + if channel.Type == common.ChannelTypeCustom { + c.Set("base_url", channel.BaseURL) + } c.Next() } } diff --git a/model/channel.go b/model/channel.go index ceaf2710..ef76bf54 100644 --- a/model/channel.go +++ b/model/channel.go @@ -14,6 +14,7 @@ type Channel struct { Weight int `json:"weight"` CreatedTime int64 `json:"created_time" gorm:"bigint"` AccessedTime int64 `json:"accessed_time" gorm:"bigint"` + BaseURL string `json:"base_url" gorm:"column:base_url"` } func GetAllChannels(startIdx int, num int) ([]*Channel, error) { diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 5cd3d1e3..6cec5885 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -1,9 +1,10 @@ export const CHANNEL_OPTIONS = [ - { key: 1, text: 'OpenAI', value: 1, color: 'green' }, - { key: 2, text: 'API2D', value: 2, color: 'blue' }, - { key: 3, text: 'Azure', value: 3, color: 'olive' }, - { key: 4, text: 'CloseAI', value: 4, color: 'teal' }, - { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' }, - { key: 6, text: 'OpenAI Max', value: 6, color: 'violet' }, - { key: 7, text: 'OhMyGPT', value: 7, color: 'purple' } + { key: 1, text: 'OpenAI', value: 1, color: 'green' }, + { key: 2, text: 'API2D', value: 2, color: 'blue' }, + { key: 3, text: 'Azure', value: 3, color: 'olive' }, + { key: 4, text: 'CloseAI', value: 4, color: 'teal' }, + { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' }, + { key: 6, text: 'OpenAI Max', value: 6, color: 'violet' }, + { key: 7, text: 'OhMyGPT', value: 7, color: 'purple' }, + { key: 8, text: '自定义', value: 8, color: 'pink' } ]; diff --git a/web/src/pages/Channel/AddChannel.js b/web/src/pages/Channel/AddChannel.js index 8547f6b0..332c3e31 100644 --- a/web/src/pages/Channel/AddChannel.js +++ b/web/src/pages/Channel/AddChannel.js @@ -7,7 +7,8 @@ const AddChannel = () => { const originInputs = { name: '', type: 1, - key: '' + key: '', + base_url: '', }; const [inputs, setInputs] = useState(originInputs); const { name, type, key } = inputs; @@ -18,6 +19,9 @@ const AddChannel = () => { const submit = async () => { if (inputs.name === '' || inputs.key === '') return; + if (inputs.base_url.endsWith('/')) { + inputs.base_url = inputs.base_url.slice(0, inputs.base_url.length - 1); + } const res = await API.post(`/api/channel/`, inputs); const { success, message } = res.data; if (success) { @@ -42,6 +46,20 @@ const AddChannel = () => { onChange={handleInputChange} /> + { + type === 8 && ( + + + + ) + } { name: '', key: '', type: 1, + base_url: '', }); const handleInputChange = (e, { name, value }) => { setInputs((inputs) => ({ ...inputs, [name]: value })); @@ -33,6 +34,9 @@ const EditChannel = () => { }, []); const submit = async () => { + if (inputs.base_url.endsWith('/')) { + inputs.base_url = inputs.base_url.slice(0, inputs.base_url.length - 1); + } let res = await API.put(`/api/channel/`, { ...inputs, id: parseInt(channelId) }); const { success, message } = res.data; if (success) { @@ -56,6 +60,20 @@ const EditChannel = () => { onChange={handleInputChange} /> + { + inputs.type === 8 && ( + + + + ) + }