diff --git a/controller/relay.go b/controller/relay.go index c41f4d8a..15c6a6bb 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -68,12 +68,8 @@ func relayHelper(c *gin.Context) error { channelType := c.GetInt("channel") tokenId := c.GetInt("token_id") consumeQuota := c.GetBool("consume_quota") - baseURL := common.ChannelBaseURLs[channelType] - if channelType == common.ChannelTypeCustom { - baseURL = c.GetString("base_url") - } var textRequest TextRequest - if consumeQuota { + if consumeQuota || channelType == common.ChannelTypeAzure { requestBody, err := io.ReadAll(c.Request.Body) if err != nil { return err @@ -89,12 +85,30 @@ func relayHelper(c *gin.Context) error { // Reset request body c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) } + baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() - req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, requestURL), c.Request.Body) + if channelType == common.ChannelTypeCustom { + baseURL = c.GetString("base_url") + } + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + if channelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api + baseURL = c.GetString("base_url") + task := strings.TrimPrefix(requestURL, "/v1/") + model_ := textRequest.Model + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) if err != nil { return err } - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + if channelType == common.ChannelTypeAzure { + key := c.Request.Header.Get("Authorization") + key = strings.TrimPrefix(key, "Bearer ") + req.Header.Set("api-key", key) + } else { + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Connection", c.Request.Header.Get("Connection")) diff --git a/middleware/distributor.go b/middleware/distributor.go index 18df9d48..04e9f84d 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -63,7 +63,7 @@ func Distribute() func(c *gin.Context) { } c.Set("channel", channel.Type) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - if channel.Type == common.ChannelTypeCustom { + if channel.Type == common.ChannelTypeCustom || channel.Type == common.ChannelTypeAzure { c.Set("base_url", channel.BaseURL) } c.Next() diff --git a/web/src/App.js b/web/src/App.js index 0c060662..ec2d91f9 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -19,7 +19,6 @@ import Channel from './pages/Channel'; import Token from './pages/Token'; import EditToken from './pages/Token/EditToken'; import EditChannel from './pages/Channel/EditChannel'; -import AddChannel from './pages/Channel/AddChannel'; import Redemption from './pages/Redemption'; import EditRedemption from './pages/Redemption/EditRedemption'; @@ -93,7 +92,7 @@ function App() { path='/channel/add' element={ }> - + } /> diff --git a/web/src/pages/Channel/AddChannel.js b/web/src/pages/Channel/AddChannel.js deleted file mode 100644 index 57c16eba..00000000 --- a/web/src/pages/Channel/AddChannel.js +++ /dev/null @@ -1,95 +0,0 @@ -import React, { useState } from 'react'; -import { Button, Form, Header, Segment } from 'semantic-ui-react'; -import { API, showError, showSuccess } from '../../helpers'; -import { CHANNEL_OPTIONS } from '../../constants'; - -const AddChannel = () => { - const originInputs = { - name: '', - type: 1, - key: '', - base_url: '', - }; - const [inputs, setInputs] = useState(originInputs); - const { name, type, key } = inputs; - - const handleInputChange = (e, { name, value }) => { - setInputs((inputs) => ({ ...inputs, [name]: value })); - }; - - const submit = async () => { - if (inputs.name === '' || inputs.key === '') return; - if (inputs.base_url.endsWith('/')) { - inputs.base_url = inputs.base_url.slice(0, inputs.base_url.length - 1); - } - const res = await API.post(`/api/channel/`, inputs); - const { success, message } = res.data; - if (success) { - showSuccess('渠道创建成功!'); - setInputs(originInputs); - } else { - showError(message); - } - }; - - return ( - <> - -
创建新的渠道
-
- - - - { - type === 8 && ( - - - - ) - } - - - - - - - -
-
- - ); -}; - -export default AddChannel; diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 7758323f..ed5384e2 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -1,5 +1,5 @@ import React, { useEffect, useState } from 'react'; -import { Button, Form, Header, Segment } from 'semantic-ui-react'; +import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; import { useParams } from 'react-router-dom'; import { API, showError, showSuccess } from '../../helpers'; import { CHANNEL_OPTIONS } from '../../constants'; @@ -7,13 +7,15 @@ import { CHANNEL_OPTIONS } from '../../constants'; const EditChannel = () => { const params = useParams(); const channelId = params.id; - const [loading, setLoading] = useState(true); - const [inputs, setInputs] = useState({ + const isEdit = channelId !== undefined; + const [loading, setLoading] = useState(isEdit); + const originInputs = { name: '', - key: '', type: 1, - base_url: '', - }); + key: '', + base_url: '' + }; + const [inputs, setInputs] = useState(originInputs); const handleInputChange = (e, { name, value }) => { setInputs((inputs) => ({ ...inputs, [name]: value })); }; @@ -30,17 +32,31 @@ const EditChannel = () => { setLoading(false); }; useEffect(() => { - loadChannel().then(); + if (isEdit) { + loadChannel().then(); + } }, []); const submit = async () => { - if (inputs.base_url.endsWith('/')) { - inputs.base_url = inputs.base_url.slice(0, inputs.base_url.length - 1); + if (!isEdit && (inputs.name === '' || inputs.key === '')) return; + let localInputs = inputs; + if (localInputs.base_url.endsWith('/')) { + localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); + } + let res; + if (isEdit) { + res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) }); + } else { + res = await API.post(`/api/channel/`, localInputs); } - let res = await API.put(`/api/channel/`, { ...inputs, id: parseInt(channelId) }); const { success, message } = res.data; if (success) { - showSuccess('渠道更新成功!'); + if (isEdit) { + showSuccess('渠道更新成功!'); + } else { + showSuccess('渠道创建成功!'); + setInputs(originInputs); + } } else { showError(message); } @@ -49,7 +65,7 @@ const EditChannel = () => { return ( <> -
更新渠道信息
+
{isEdit ? '更新渠道信息' : '创建新的渠道'}
{ onChange={handleInputChange} /> + { + inputs.type === 3 && ( + <> + + 注意,创建资源时,部署名称必须和 OpenAI 官方的模型名称保持一致,因为 One API 会把请求体中的 model 参数替换为你的部署名称。 + + + + + + ) + } { inputs.type === 8 && ( { {