perf: update config related code

This commit is contained in:
JustSong 2024-04-20 00:23:31 +08:00
parent 00fa86c000
commit e0dc6e29b2
5 changed files with 32 additions and 37 deletions

View File

@ -6,4 +6,7 @@ const (
KeyAPIVersion = KeyPrefix + "api_version" KeyAPIVersion = KeyPrefix + "api_version"
KeyLibraryID = KeyPrefix + "library_id" KeyLibraryID = KeyPrefix + "library_id"
KeyPlugin = KeyPrefix + "plugin" KeyPlugin = KeyPrefix + "plugin"
KeySK = KeyPrefix + "sk"
KeyAK = KeyPrefix + "ak"
KeyRegion = KeyPrefix + "region"
) )

View File

@ -1,7 +1,6 @@
package ctxkey package ctxkey
var ( var (
Channel = "channel"
RequestModel = "request_model" RequestModel = "request_model"
ConvertedRequest = "converted_request" ConvertedRequest = "converted_request"
OriginalModel = "original_model" OriginalModel = "original_model"

View File

@ -59,7 +59,6 @@ func Distribute() func(c *gin.Context) {
} }
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
c.Set(ctxkey.Channel, channel)
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id) c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name) c.Set("channel_name", channel.Name)

View File

@ -5,10 +5,10 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"io" "io"
"net/http" "net/http"
"strings"
"github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials"
@ -20,20 +20,16 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic" "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
) )
func newAwsClient(channel *model.Channel) (*bedrockruntime.Client, error) { func newAwsClient(c *gin.Context) (*bedrockruntime.Client, error) {
ks := strings.Split(channel.Key, "\n") ak := c.GetString(config.KeyAK)
if len(ks) != 2 { sk := c.GetString(config.KeySK)
return nil, errors.New("invalid key") region := c.GetString(config.KeyRegion)
}
ak, sk := ks[0], ks[1]
client := bedrockruntime.New(bedrockruntime.Options{ client := bedrockruntime.New(bedrockruntime.Options{
Region: *channel.BaseURL, Region: region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")), Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
}) })
@ -68,14 +64,7 @@ func awsModelID(requestModel string) (string, error) {
} }
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
var channel *model.Channel awsCli, err := newAwsClient(c)
if channel_, ok := c.Get(ctxkey.Channel); !ok {
return wrapErr(errors.New("channel not found")), nil
} else {
channel = channel_.(*model.Channel)
}
awsCli, err := newAwsClient(channel)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil return wrapErr(errors.Wrap(err, "newAwsClient")), nil
} }
@ -134,15 +123,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
awsCli, err := newAwsClient(c)
var channel *model.Channel
if channel_, ok := c.Get(ctxkey.Channel); !ok {
return wrapErr(errors.New("channel not found")), nil
} else {
channel = channel_.(*model.Channel)
}
awsCli, err := newAwsClient(channel)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil return wrapErr(errors.Wrap(err, "newAwsClient")), nil
} }

View File

@ -54,6 +54,11 @@ const EditChannel = () => {
const [basicModels, setBasicModels] = useState([]); const [basicModels, setBasicModels] = useState([]);
const [fullModels, setFullModels] = useState([]); const [fullModels, setFullModels] = useState([]);
const [customModel, setCustomModel] = useState(''); const [customModel, setCustomModel] = useState('');
const [config, setConfig] = useState({
region: '',
sk: '',
ak: ''
});
const handleInputChange = (e, { name, value }) => { const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
if (name === 'type') { if (name === 'type') {
@ -65,6 +70,10 @@ const EditChannel = () => {
} }
}; };
const handleConfigChange = (e, { name, value }) => {
setConfig((inputs) => ({ ...inputs, [name]: value }));
};
const loadChannel = async () => { const loadChannel = async () => {
let res = await API.get(`/api/channel/${channelId}`); let res = await API.get(`/api/channel/${channelId}`);
const { success, message, data } = res.data; const { success, message, data } = res.data;
@ -83,6 +92,7 @@ const EditChannel = () => {
data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2); data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2);
} }
setInputs(data); setInputs(data);
setConfig(JSON.parse(data.config));
setBasicModels(getChannelModels(data.type)); setBasicModels(getChannelModels(data.type));
} else { } else {
showError(message); showError(message);
@ -176,6 +186,7 @@ const EditChannel = () => {
let res; let res;
localInputs.models = localInputs.models.join(','); localInputs.models = localInputs.models.join(',');
localInputs.group = localInputs.groups.join(','); localInputs.group = localInputs.groups.join(',');
localInputs.config = JSON.stringify(config);
if (isEdit) { if (isEdit) {
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) }); res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
} else { } else {
@ -352,7 +363,9 @@ const EditChannel = () => {
fluid fluid
multiple multiple
search search
onLabelClick={(e, { value }) => {copy(value).then()}} onLabelClick={(e, { value }) => {
copy(value).then();
}}
selection selection
onChange={handleInputChange} onChange={handleInputChange}
value={inputs.models} value={inputs.models}
@ -403,11 +416,11 @@ const EditChannel = () => {
<Form.Field> <Form.Field>
<Form.Input <Form.Input
label='Region' label='Region'
name='base_url' name='region'
required required
placeholder={'regione.g. us-west-2'} placeholder={'regione.g. us-west-2'}
onChange={handleInputChange} onChange={handleConfigChange}
value={inputs.base_url} value={config.region}
autoComplete='' autoComplete=''
/> />
<Form.Input <Form.Input
@ -415,8 +428,8 @@ const EditChannel = () => {
name='ak' name='ak'
required required
placeholder={'AWS IAM Access Key'} placeholder={'AWS IAM Access Key'}
onChange={handleInputChange} onChange={handleConfigChange}
value={inputs.ak} value={config.ak}
autoComplete='' autoComplete=''
/> />
<Form.Input <Form.Input
@ -424,8 +437,8 @@ const EditChannel = () => {
name='sk' name='sk'
required required
placeholder={'AWS IAM Secret Key'} placeholder={'AWS IAM Secret Key'}
onChange={handleInputChange} onChange={handleConfigChange}
value={inputs.sk} value={config.sk}
autoComplete='' autoComplete=''
/> />
</Form.Field> </Form.Field>