perf: update config related code

This commit is contained in:
JustSong 2024-04-20 00:23:31 +08:00
parent 00fa86c000
commit e0dc6e29b2
5 changed files with 32 additions and 37 deletions

View File

@ -6,4 +6,7 @@ const (
KeyAPIVersion = KeyPrefix + "api_version"
KeyLibraryID = KeyPrefix + "library_id"
KeyPlugin = KeyPrefix + "plugin"
KeySK = KeyPrefix + "sk"
KeyAK = KeyPrefix + "ak"
KeyRegion = KeyPrefix + "region"
)

View File

@ -1,7 +1,6 @@
package ctxkey
var (
Channel = "channel"
RequestModel = "request_model"
ConvertedRequest = "converted_request"
OriginalModel = "original_model"

View File

@ -59,7 +59,6 @@ func Distribute() func(c *gin.Context) {
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
c.Set(ctxkey.Channel, channel)
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)

View File

@ -5,10 +5,10 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"io"
"net/http"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
@ -20,20 +20,16 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func newAwsClient(channel *model.Channel) (*bedrockruntime.Client, error) {
ks := strings.Split(channel.Key, "\n")
if len(ks) != 2 {
return nil, errors.New("invalid key")
}
ak, sk := ks[0], ks[1]
func newAwsClient(c *gin.Context) (*bedrockruntime.Client, error) {
ak := c.GetString(config.KeyAK)
sk := c.GetString(config.KeySK)
region := c.GetString(config.KeyRegion)
client := bedrockruntime.New(bedrockruntime.Options{
Region: *channel.BaseURL,
Region: region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
})
@ -68,14 +64,7 @@ func awsModelID(requestModel string) (string, error) {
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
var channel *model.Channel
if channel_, ok := c.Get(ctxkey.Channel); !ok {
return wrapErr(errors.New("channel not found")), nil
} else {
channel = channel_.(*model.Channel)
}
awsCli, err := newAwsClient(channel)
awsCli, err := newAwsClient(c)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
@ -134,15 +123,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp()
var channel *model.Channel
if channel_, ok := c.Get(ctxkey.Channel); !ok {
return wrapErr(errors.New("channel not found")), nil
} else {
channel = channel_.(*model.Channel)
}
awsCli, err := newAwsClient(channel)
awsCli, err := newAwsClient(c)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}

View File

@ -54,6 +54,11 @@ const EditChannel = () => {
const [basicModels, setBasicModels] = useState([]);
const [fullModels, setFullModels] = useState([]);
const [customModel, setCustomModel] = useState('');
const [config, setConfig] = useState({
region: '',
sk: '',
ak: ''
});
const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
if (name === 'type') {
@ -65,6 +70,10 @@ const EditChannel = () => {
}
};
const handleConfigChange = (e, { name, value }) => {
setConfig((inputs) => ({ ...inputs, [name]: value }));
};
const loadChannel = async () => {
let res = await API.get(`/api/channel/${channelId}`);
const { success, message, data } = res.data;
@ -83,6 +92,7 @@ const EditChannel = () => {
data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2);
}
setInputs(data);
setConfig(JSON.parse(data.config));
setBasicModels(getChannelModels(data.type));
} else {
showError(message);
@ -176,6 +186,7 @@ const EditChannel = () => {
let res;
localInputs.models = localInputs.models.join(',');
localInputs.group = localInputs.groups.join(',');
localInputs.config = JSON.stringify(config);
if (isEdit) {
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
} else {
@ -352,7 +363,9 @@ const EditChannel = () => {
fluid
multiple
search
onLabelClick={(e, { value }) => {copy(value).then()}}
onLabelClick={(e, { value }) => {
copy(value).then();
}}
selection
onChange={handleInputChange}
value={inputs.models}
@ -403,11 +416,11 @@ const EditChannel = () => {
<Form.Field>
<Form.Input
label='Region'
name='base_url'
name='region'
required
placeholder={'regione.g. us-west-2'}
onChange={handleInputChange}
value={inputs.base_url}
onChange={handleConfigChange}
value={config.region}
autoComplete=''
/>
<Form.Input
@ -415,8 +428,8 @@ const EditChannel = () => {
name='ak'
required
placeholder={'AWS IAM Access Key'}
onChange={handleInputChange}
value={inputs.ak}
onChange={handleConfigChange}
value={config.ak}
autoComplete=''
/>
<Form.Input
@ -424,8 +437,8 @@ const EditChannel = () => {
name='sk'
required
placeholder={'AWS IAM Secret Key'}
onChange={handleInputChange}
value={inputs.sk}
onChange={handleConfigChange}
value={config.sk}
autoComplete=''
/>
</Form.Field>