From b81808e83994369c3028f907b27a800957fa90c7 Mon Sep 17 00:00:00 2001 From: Buer <42402987+MartialBE@users.noreply.github.com> Date: Mon, 18 Mar 2024 16:00:35 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20support=20amazon=20bedrock?= =?UTF-8?q?=20anthropic=20(#114)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🚧 WIP: bedrock * ✨ feat: support amazon bedrock anthropic --- common/constants.go | 4 +- common/model-ratio.go | 10 +- go.mod | 9 +- go.sum | 14 +- providers/bedrock/base.go | 127 ++++++++ providers/bedrock/category/base.go | 54 ++++ providers/bedrock/category/claude.go | 63 ++++ providers/bedrock/chat.go | 81 +++++ providers/bedrock/sigv4/LICENSE | 27 ++ providers/bedrock/sigv4/const.go | 36 +++ providers/bedrock/sigv4/header.go | 91 ++++++ providers/bedrock/sigv4/helper.go | 302 +++++++++++++++++ providers/bedrock/sigv4/key_deriver.go | 125 ++++++++ providers/bedrock/sigv4/sign_time.go | 36 +++ providers/bedrock/sigv4/signer.go | 428 +++++++++++++++++++++++++ providers/bedrock/sigv4/util.go | 26 ++ providers/bedrock/stream_reader.go | 128 ++++++++ providers/bedrock/type.go | 11 + providers/claude/chat.go | 36 ++- providers/claude/type.go | 2 +- providers/providers.go | 2 + web/src/constants/ChannelConstants.js | 6 + web/src/views/Channel/type/Config.js | 28 +- 23 files changed, 1617 insertions(+), 29 deletions(-) create mode 100644 providers/bedrock/base.go create mode 100644 providers/bedrock/category/base.go create mode 100644 providers/bedrock/category/claude.go create mode 100644 providers/bedrock/chat.go create mode 100644 providers/bedrock/sigv4/LICENSE create mode 100644 providers/bedrock/sigv4/const.go create mode 100644 providers/bedrock/sigv4/header.go create mode 100644 providers/bedrock/sigv4/helper.go create mode 100644 providers/bedrock/sigv4/key_deriver.go create mode 100644 providers/bedrock/sigv4/sign_time.go create mode 100644 providers/bedrock/sigv4/signer.go create mode 100644 providers/bedrock/sigv4/util.go create mode 100644 providers/bedrock/stream_reader.go create mode 100644 providers/bedrock/type.go diff --git a/common/constants.go b/common/constants.go index 88b007d2..16ac3056 100644 --- a/common/constants.go +++ b/common/constants.go @@ -198,6 +198,7 @@ const ( ChannelTypeMoonshot = 29 ChannelTypeMistral = 30 ChannelTypeGroq = 31 + ChannelTypeBedrock = 32 ) var ChannelBaseURLs = []string{ @@ -232,7 +233,8 @@ var ChannelBaseURLs = []string{ "https://api.deepseek.com", //28 "https://api.moonshot.cn", //29 "https://api.mistral.ai", //30 - "https://api.groq.com/openai", //30 + "https://api.groq.com/openai", //31 + "", //32 } const ( diff --git a/common/model-ratio.go b/common/model-ratio.go index a84e9803..28886e31 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -89,10 +89,14 @@ func init() { // $0.80/million tokens $2.40/million tokens "claude-instant-1.2": {[]float64{0.4, 1.2}, ChannelTypeAnthropic}, // $8.00/million tokens $24.00/million tokens - "claude-2.0": {[]float64{4, 12}, ChannelTypeAnthropic}, - "claude-2.1": {[]float64{4, 12}, ChannelTypeAnthropic}, - "claude-3-opus-20240229": {[]float64{7.5, 22.5}, ChannelTypeAnthropic}, + "claude-2.0": {[]float64{4, 12}, ChannelTypeAnthropic}, + "claude-2.1": {[]float64{4, 12}, ChannelTypeAnthropic}, + // $15 / M $75 / M + "claude-3-opus-20240229": {[]float64{7.5, 22.5}, ChannelTypeAnthropic}, + // $3 / M $15 / M "claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, ChannelTypeAnthropic}, + // $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens + "claude-3-haiku-20240307": {[]float64{0.125, 0.625}, ChannelTypeAnthropic}, // ¥0.004 / 1k tokens ¥0.008 / 1k tokens "ERNIE-Speed": {[]float64{0.2857, 0.5714}, ChannelTypeBaidu}, diff --git a/go.mod b/go.mod index 8ffa0b6e..c7a1a5a0 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ module one-api go 1.18 require ( + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1 github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 @@ -24,8 +25,10 @@ require ( gorm.io/gorm v1.25.0 ) +require github.com/aws/smithy-go v1.20.1 // indirect + require ( - github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24 // indirect + github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24 github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect @@ -46,7 +49,7 @@ require ( github.com/jackc/pgx/v5 v5.3.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - github.com/joho/godotenv v1.5.1 // indirect + github.com/joho/godotenv v1.5.1 github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect @@ -59,7 +62,7 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect - golang.org/x/net v0.19.0 // indirect + golang.org/x/net v0.19.0 golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.30.0 // indirect diff --git a/go.sum b/go.sum index bd2e401e..d151bfe5 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,9 @@ github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24 h1:1T7RcpzlldaJ3qpZi0lNg/lBsfPCK+8n8Wc+R8EhAkU= github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24/go.mod h1:kL1v4iIjlalwm3gCYGvF4NLa3hs+aKEfRkNJvj4aoDU= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1 h1:gTK2uhtAPtFcdRRJilZPx8uJLL2J85xK11nKtWL0wfU= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1/go.mod h1:sxpLb+nZk7tIfCWChfd+h4QwHNUR57d8hA1cleTkjJo= +github.com/aws/smithy-go v1.20.1 h1:4SZlSlMr36UEqC7XOyRVb27XMeZubNcBNN+9IgEPIQw= +github.com/aws/smithy-go v1.20.1/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= @@ -49,8 +53,6 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= -github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= -github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= @@ -58,6 +60,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= +github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= @@ -111,6 +115,7 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -161,8 +166,6 @@ golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -203,14 +206,13 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco= gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04= -gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k= -gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= gorm.io/driver/mysql v1.4.7 h1:rY46lkCspzGHn7+IYsNpSfEv9tA+SU4SkkB+GFX125Y= gorm.io/driver/mysql v1.4.7/go.mod h1:SxzItlnT1cb6e1e4ZRpgJN2VYtcqJgqnHxWr4wsP8oc= gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= +gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0= gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= diff --git a/providers/bedrock/base.go b/providers/bedrock/base.go new file mode 100644 index 00000000..19e9584e --- /dev/null +++ b/providers/bedrock/base.go @@ -0,0 +1,127 @@ +package bedrock + +import ( + "bytes" + "crypto/sha256" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "one-api/common/requester" + "one-api/model" + "one-api/providers/base" + "one-api/types" + "strings" + "time" + + "one-api/providers/bedrock/category" + "one-api/providers/bedrock/sigv4" +) + +type BedrockProviderFactory struct{} + +// 创建 BedrockProvider +func (f BedrockProviderFactory) Create(channel *model.Channel) base.ProviderInterface { + + bedrockProvider := &BedrockProvider{ + BaseProvider: base.BaseProvider{ + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle), + }, + } + + getKeyConfig(bedrockProvider) + + return bedrockProvider +} + +type BedrockProvider struct { + base.BaseProvider + Region string + AccessKeyID string + SecretAccessKey string + SessionToken string + Category *category.Category +} + +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://bedrock-runtime.%s.amazonaws.com", + ChatCompletions: "/model/%s/invoke", + } +} + +// 请求错误处理 +func requestErrorHandle(resp *http.Response) *types.OpenAIError { + bedrockError := &BedrockError{} + err := json.NewDecoder(resp.Body).Decode(bedrockError) + if err != nil { + return nil + } + + return errorHandle(bedrockError) +} + +// 错误处理 +func errorHandle(bedrockError *BedrockError) *types.OpenAIError { + if bedrockError.Message == "" { + return nil + } + return &types.OpenAIError{ + Message: bedrockError.Message, + Type: "Bedrock Error", + } +} + +func (p *BedrockProvider) GetFullRequestURL(requestURL string, modelName string) string { + baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") + + return fmt.Sprintf(baseURL+requestURL, p.Region, modelName) +} + +func (p *BedrockProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + p.CommonRequestHeaders(headers) + headers["Accept"] = "*/*" + + return headers +} + +func getKeyConfig(bedrock *BedrockProvider) { + keys := strings.Split(bedrock.Channel.Key, "|") + if len(keys) < 3 { + return + } + + bedrock.Region = keys[0] + bedrock.AccessKeyID = keys[1] + bedrock.SecretAccessKey = keys[2] + if len(keys) == 4 && keys[3] != "" { + bedrock.SessionToken = keys[3] + } +} + +func (p *BedrockProvider) Sign(req *http.Request) error { + var body []byte + if req.Body == nil { + body = []byte("") + } else { + var err error + body, err = io.ReadAll(req.Body) + if err != nil { + return errors.New("error getting request body: " + err.Error()) + } + req.Body = io.NopCloser(bytes.NewReader(body)) + } + sig, err := sigv4.New(sigv4.WithCredential(p.AccessKeyID, p.SecretAccessKey, p.SessionToken), sigv4.WithRegionService(p.Region, awsService)) + if err != nil { + return err + } + + reqBodyHashHex := fmt.Sprintf("%x", sha256.Sum256(body)) + sig.Sign(req, reqBodyHashHex, sigv4.NewTime(time.Now())) + + return nil +} diff --git a/providers/bedrock/category/base.go b/providers/bedrock/category/base.go new file mode 100644 index 00000000..f16b7741 --- /dev/null +++ b/providers/bedrock/category/base.go @@ -0,0 +1,54 @@ +package category + +import ( + "errors" + "net/http" + "one-api/common/requester" + "one-api/providers/base" + "one-api/types" + "strings" +) + +var CategoryMap = map[string]Category{} + +type Category struct { + ModelName string + ChatComplete ChatCompletionConvert + ResponseChatComplete ChatCompletionResponse + ResponseChatCompleteStrem ChatCompletionStreamResponse +} + +func GetCategory(modelName string) (*Category, error) { + modelName = GetModelName(modelName) + + // 点分割 + provider := strings.Split(modelName, ".")[0] + + if category, exists := CategoryMap[provider]; exists { + category.ModelName = modelName + return &category, nil + } + + return nil, errors.New("category_not_found") +} + +func GetModelName(modelName string) string { + bedrockMap := map[string]string{ + "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", + "claude-2.1": "anthropic.claude-v2:1", + "claude-2.0": "anthropic.claude-v2", + "claude-instant-1.2": "anthropic.claude-instant-v1", + } + + if value, exists := bedrockMap[modelName]; exists { + modelName = value + } + + return modelName +} + +type ChatCompletionConvert func(*types.ChatCompletionRequest) (any, *types.OpenAIErrorWithStatusCode) +type ChatCompletionResponse func(base.ProviderInterface, *http.Response, *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) + +type ChatCompletionStreamResponse func(base.ProviderInterface, *types.ChatCompletionRequest) requester.HandlerPrefix[string] diff --git a/providers/bedrock/category/claude.go b/providers/bedrock/category/claude.go new file mode 100644 index 00000000..542f9b77 --- /dev/null +++ b/providers/bedrock/category/claude.go @@ -0,0 +1,63 @@ +package category + +import ( + "encoding/json" + "net/http" + "one-api/common" + "one-api/common/requester" + "one-api/providers/base" + "one-api/providers/claude" + "one-api/types" +) + +const anthropicVersion = "bedrock-2023-05-31" + +type ClaudeRequest struct { + *claude.ClaudeRequest + AnthropicVersion string `json:"anthropic_version"` +} + +func init() { + CategoryMap["anthropic"] = Category{ + ChatComplete: ConvertClaudeFromChatOpenai, + ResponseChatComplete: ConvertClaudeToChatOpenai, + ResponseChatCompleteStrem: ClaudeChatCompleteStrem, + } +} + +func ConvertClaudeFromChatOpenai(request *types.ChatCompletionRequest) (any, *types.OpenAIErrorWithStatusCode) { + rawRequest, err := claude.ConvertFromChatOpenai(request) + if err != nil { + return nil, err + } + + claudeRequest := &ClaudeRequest{} + claudeRequest.ClaudeRequest = rawRequest + claudeRequest.AnthropicVersion = anthropicVersion + + // 删除model字段 + claudeRequest.Model = "" + claudeRequest.Stream = false + + return claudeRequest, nil +} + +func ConvertClaudeToChatOpenai(provider base.ProviderInterface, response *http.Response, request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + claudeResponse := &claude.ClaudeResponse{} + err := json.NewDecoder(response.Body).Decode(claudeResponse) + if err != nil { + return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) + } + + return claude.ConvertToChatOpenai(provider, claudeResponse, request) +} + +func ClaudeChatCompleteStrem(provider base.ProviderInterface, request *types.ChatCompletionRequest) requester.HandlerPrefix[string] { + chatHandler := &claude.ClaudeStreamHandler{ + Usage: provider.GetUsage(), + Request: request, + Prefix: `{"type"`, + } + + return chatHandler.HandlerStream +} diff --git a/providers/bedrock/chat.go b/providers/bedrock/chat.go new file mode 100644 index 00000000..de24cdd2 --- /dev/null +++ b/providers/bedrock/chat.go @@ -0,0 +1,81 @@ +package bedrock + +import ( + "net/http" + "one-api/common" + "one-api/common/requester" + "one-api/providers/bedrock/category" + "one-api/types" +) + +func (p *BedrockProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + // 发送请求 + response, errWithCode := p.Send(request) + if errWithCode != nil { + return nil, errWithCode + } + + defer response.Body.Close() + + return p.Category.ResponseChatComplete(p, response, request) +} + +func (p *BedrockProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { + // 发送请求 + response, errWithCode := p.Send(request) + if errWithCode != nil { + return nil, errWithCode + } + + return RequestStream(response, p.Category.ResponseChatCompleteStrem(p, request)) +} + +func (p *BedrockProvider) Send(request *types.ChatCompletionRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + return p.Requester.SendRequestRaw(req) +} + +func (p *BedrockProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + var err error + p.Category, err = category.GetCategory(request.Model) + if err != nil || p.Category.ChatComplete == nil || p.Category.ResponseChatComplete == nil { + return nil, common.StringErrorWrapper("bedrock provider not found", "bedrock_err", http.StatusInternalServerError) + } + + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + if errWithCode != nil { + return nil, errWithCode + } + + if request.Stream { + url += "-with-response-stream" + } + + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, p.Category.ModelName) + if fullRequestURL == "" { + return nil, common.ErrorWrapper(nil, "invalid_claude_config", http.StatusInternalServerError) + } + + headers := p.GetRequestHeaders() + + bedrockRequest, errWithCode := p.Category.ChatComplete(request) + if errWithCode != nil { + return nil, errWithCode + } + + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(bedrockRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + p.Sign(req) + + return req, nil +} diff --git a/providers/bedrock/sigv4/LICENSE b/providers/bedrock/sigv4/LICENSE new file mode 100644 index 00000000..36e4f1fb --- /dev/null +++ b/providers/bedrock/sigv4/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2023 Macks C. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Macks C. nor the names of its contributors +may be used to endorse or promote products derived from this software +without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/providers/bedrock/sigv4/const.go b/providers/bedrock/sigv4/const.go new file mode 100644 index 00000000..f95ff017 --- /dev/null +++ b/providers/bedrock/sigv4/const.go @@ -0,0 +1,36 @@ +package sigv4 + +const authorizationHeader = "Authorization" + +// Signature Version 4 (SigV4) Constants +const ( + // SigningAlgorithm is the name of the algorithm used in this package. + SigningAlgorithm = "AWS4-HMAC-SHA256" + // EmptyStringSHA256 is the hex encoded sha256 value of an empty string. + EmptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855` + // UnsignedPayload indicates that the request payload body is unsigned. + UnsignedPayload = "UNSIGNED-PAYLOAD" + // AmzAlgorithmKey indicates the signing algorithm. + AmzAlgorithmKey = "X-Amz-Algorithm" + // AmzSecurityTokenKey indicates the security token to be used with temporary + // credentials. + AmzSecurityTokenKey = "X-Amz-Security-Token" + // AmzDateKey is the UTC timestamp for the request in the format YYYYMMDD'T'HHMMSS'Z'. + AmzDateKey = "X-Amz-Date" + // AmzCredentialKey is the access key ID and credential scope. + AmzCredentialKey = "X-Amz-Credential" + // AmzSignedHeadersKey is the set of headers signed for the request. + AmzSignedHeadersKey = "X-Amz-SignedHeaders" + // AmzSignatureKey is the query parameter to store the SigV4 signature. + AmzSignatureKey = "X-Amz-Signature" + // TimeFormat is the time format to be used in the X-Amz-Date header or query + // parameter. + TimeFormat = "20060102T150405Z" + // ShortTimeFormat is the shorten time format used in the credential scope. + ShortTimeFormat = "20060102" + // ContentSHAKey is the SHA256 of request body. + ContentSHAKey = "X-Amz-Content-Sha256" + // StreamingEventsPayload indicates that the request payload body is a signed + // event stream. + StreamingEventsPayload = "STREAMING-AWS4-HMAC-SHA256-EVENTS" +) diff --git a/providers/bedrock/sigv4/header.go b/providers/bedrock/sigv4/header.go new file mode 100644 index 00000000..8db400bd --- /dev/null +++ b/providers/bedrock/sigv4/header.go @@ -0,0 +1,91 @@ +package sigv4 + +// ignoredHeaders is a list of headers that are always ignored during signing. +var ignoreHeaders = map[string]struct{}{ + "Authorization": {}, + "User-Agent": {}, + "X-Amzn-Trace-Id": {}, + // also include lower case canonical versions + "authorization": {}, + "user-agent": {}, + "x-amzn-trace-id": {}, +} + +// requiredHeaderPrefix are header name prefixes that are mandatory for signing. +// If a header has one of these prefixes, it is a mandatory header. +var requiredHeaderPrefix = []string{"X-Amz-Object-Lock-", "X-Amz-Meta-"} + +// requiredHeaders is a list of headers that are mandatory for signing. +var requiredHeaders = map[string]struct{}{ + "Cache-Control": {}, + "Content-Disposition": {}, + "Content-Encoding": {}, + "Content-Language": {}, + "Content-Md5": {}, + "Content-Type": {}, + "Expires": {}, + "If-Match": {}, + "If-Modified-Since": {}, + "If-None-Match": {}, + "If-Unmodified-Since": {}, + "Range": {}, + "X-Amz-Acl": {}, + "X-Amz-Copy-Source": {}, + "X-Amz-Copy-Source-If-Match": {}, + "X-Amz-Copy-Source-If-Modified-Since": {}, + "X-Amz-Copy-Source-If-None-Match": {}, + "X-Amz-Copy-Source-If-Unmodified-Since": {}, + "X-Amz-Copy-Source-Range": {}, + "X-Amz-Copy-Source-Server-Side-Encryption-Customer-Algorithm": {}, + "X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key": {}, + "X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key-Md5": {}, + "X-Amz-Grant-Full-control": {}, + "X-Amz-Grant-Read": {}, + "X-Amz-Grant-Read-Acp": {}, + "X-Amz-Grant-Write": {}, + "X-Amz-Grant-Write-Acp": {}, + "X-Amz-Metadata-Directive": {}, + "X-Amz-Mfa": {}, + "X-Amz-Request-Payer": {}, + "X-Amz-Server-Side-Encryption": {}, + "X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": {}, + "X-Amz-Server-Side-Encryption-Customer-Algorithm": {}, + "X-Amz-Server-Side-Encryption-Customer-Key": {}, + "X-Amz-Server-Side-Encryption-Customer-Key-Md5": {}, + "X-Amz-Storage-Class": {}, + "X-Amz-Website-Redirect-Location": {}, + "X-Amz-Content-Sha256": {}, + "X-Amz-Tagging": {}, +} + +// headerPredicate is a function that evaluates whether a header is of the +// specific type. For example, whether a header should be ignored during signing. +type headerPredicate func(header string) bool + +// isIgnoredHeader returns true if header must be ignored during signing. +func isIgnoredHeader(header string) bool { + _, ok := ignoreHeaders[header] + return ok +} + +// isRequiredHeader returns true if header is mandatory for signing. +func isRequiredHeader(header string) bool { + _, ok := requiredHeaders[header] + if ok { + return true + } + for _, v := range requiredHeaderPrefix { + if hasPrefixFold(header, v) { + return true + } + } + return false +} + +// isAllowQueryHoisting is a allowed list for Build query headers. +func isAllowQueryHoisting(header string) bool { + if isRequiredHeader(header) { + return false + } + return hasPrefixFold(header, "X-Amz-") +} diff --git a/providers/bedrock/sigv4/helper.go b/providers/bedrock/sigv4/helper.go new file mode 100644 index 00000000..5ee90771 --- /dev/null +++ b/providers/bedrock/sigv4/helper.go @@ -0,0 +1,302 @@ +package sigv4 + +import ( + "bufio" + "crypto/hmac" + "crypto/sha256" + "fmt" + "net/http" + "net/url" + "sort" + "strings" + "time" +) + +var ( + awsURLNoEscTable [256]bool + awsURLEscTable [256][2]byte +) + +func init() { + for i := 0; i < len(awsURLNoEscTable); i++ { + // every char except these must be escaped + awsURLNoEscTable[i] = (i >= 'A' && i <= 'Z') || + (i >= 'a' && i <= 'z') || + (i >= '0' && i <= '9') || + i == '-' || + i == '.' || + i == '_' || + i == '~' + // % + encoded := fmt.Sprintf("%02X", i) + awsURLEscTable[i] = [2]byte{encoded[0], encoded[1]} + } +} + +// hmacsha256 computes a HMAC-SHA256 of data given the provided key. +func hmacsha256(key, data, buf []byte) []byte { + hash := hmac.New(sha256.New, key) + hash.Write(data) + return hash.Sum(buf) +} + +// hasPrefixFold tests whether the string s begins with prefix, interpreted as +// UTF-8 strings, under Unicode case-folding. +func hasPrefixFold(s, prefix string) bool { + return len(s) >= len(prefix) && + strings.EqualFold(s[0:len(prefix)], prefix) +} + +// isSameDay returns true if a and b are the same date (dd-mm-yyyy). +func isSameDay(a, b time.Time) bool { + xYear, xMonth, xDay := a.Date() + yYear, yMonth, yDay := b.Date() + + if xYear != yYear || xMonth != yMonth { + return false + } + return xDay == yDay +} + +// hostOrURLHost returns r.Host, or if empty, r.URL.Host. +func hostOrURLHost(r *http.Request) string { + if r.Host != "" { + return r.Host + } + return r.URL.Host +} + +// parsePort returns the port part of u.Host, without the leading colon. Returns +// an empty string if u.Host doesn't contain port. +// +// Adapted from the Go 1.8 standard library (net/url). +func parsePort(hostport string) string { + colon := strings.IndexByte(hostport, ':') + if colon == -1 || colon == len(hostport)-1 { + return "" + } + + // take care of ipv6 syntax: [a:b::]: + const ipv6Sep = "]:" + if i := strings.Index(hostport, ipv6Sep); i != -1 { + return hostport[i+len(ipv6Sep):] + } + if strings.Contains(hostport, "]") { + return "" + } + + return hostport[colon+1:] +} + +// stripPort returns Hostname portion of u.Host, i.e. without any port number. +// +// If hostport is an IPv6 literal with a port number, returns the IPv6 literal +// without the square brackets. IPv6 literals may include a zone identifier. +// +// Adapted from the Go 1.8 standard library (net/url). +func stripPort(hostport string) string { + colon := strings.IndexByte(hostport, ':') + if colon == -1 { + return hostport + } + // ipv6: remove the [] + if i := strings.IndexByte(hostport, ']'); i != -1 { + return strings.TrimPrefix(hostport[:i], "[") + } + return hostport[:colon] +} + +// isDefaultPort returns true if the specified URI is using the standard port +// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs). +func isDefaultPort(scheme, port string) bool { + switch strings.ToLower(scheme) { + case "http": + return port == "80" + case "https": + return port == "443" + default: + return false + } +} + +func cloneURL(u *url.URL) *url.URL { + if u == nil { + return nil + } + u2 := new(url.URL) + *u2 = *u + if u.User != nil { + u2.User = new(url.Userinfo) + *u2.User = *u.User + } + return u2 +} + +// writeAWSURIPath writes the escaped URI component from the specified URL (using +// AWS canonical URI specification) into w. URI component is path without query +// string. +func writeAWSURIPath(w *bufio.Writer, u *url.URL, encodeSep bool, isEscaped bool) { + const schemeSep, pathSep, queryStart = "//", "/", "?" + + var p string + if u.Opaque == "" { + p = u.EscapedPath() + } else { + opaque := u.Opaque + // discard query string if any + if i := strings.Index(opaque, queryStart); i != -1 { + opaque = opaque[:i] + } + // if has scheme separator as prefix, discard it + if strings.HasPrefix(opaque, schemeSep) { + opaque = opaque[len(schemeSep):] + } + + // everything after the first /, including the / + if i := strings.Index(opaque, pathSep); i != -1 { + p = opaque[i:] + } + } + + if p == "" { + w.WriteByte('/') + return + } + + if isEscaped { + w.WriteString(p) + return + } + + // Loop thru first like in https://cs.opensource.google/go/go/+/refs/tags/go1.20.2:/src/net/url/url.go. + // It may add ~800ns but we save on memory alloc and catches cases where there + // is no need to escape. + plen := len(p) + strlen := plen + for i := 0; i < plen; i++ { + c := p[i] + if awsURLNoEscTable[c] || (c == '/' && !encodeSep) { + continue + } + strlen += 2 + } + + // path already canonical, no need to escape + if plen == strlen { + w.WriteString(p) + return + } + + for i := 0; i < plen; i++ { + c := p[i] + if awsURLNoEscTable[c] || (c == '/' && !encodeSep) { + w.WriteByte(c) + continue + } + w.Write([]byte{'%', awsURLEscTable[c][0], awsURLEscTable[c][1]}) + } +} + +// writeCanonicalQueryParams builds the canonical form of query and writes to w. +// +// Side effect: query values are sorted after this function returns. +func writeCanonicalQueryParams(w *bufio.Writer, query url.Values) { + qlen := len(query) + if qlen == 0 { + return + } + + keys := make([]string, 0, qlen) + for k := range query { + keys = append(keys, k) + } + sort.Strings(keys) + + for i, k := range keys { + keyEscaped := strings.Replace(url.QueryEscape(k), "+", "%20", -1) + vs := query[k] + + if i != 0 { + w.WriteByte('&') + } + + if len(vs) == 0 { + w.WriteString(keyEscaped) + w.WriteByte('=') + continue + } + + sort.Strings(vs) + for j, v := range vs { + if j != 0 { + w.WriteByte('&') + } + w.WriteString(keyEscaped) + w.WriteByte('=') + if v != "" { + w.WriteString(strings.Replace(url.QueryEscape(v), "+", "%20", -1)) + } + } + } +} + +// writeCanonicalString removes leading and trailing whitespaces (as defined by Unicode) +// in s, replaces consecutive spaces (' ') in s with a single space, and then +// write the result to w. +func writeCanonicalString(w *bufio.Writer, s string) { + const dblSpace = " " + + s = strings.TrimSpace(s) + + // bail if str doesn't contain " " + j := strings.Index(s, dblSpace) + if j < 0 { + w.WriteString(s) + return + } + + w.WriteString(s[:j]) + + // replace all " " with " " in a performant way + var lastIsSpace bool + for i, l := j, len(s); i < l; i++ { + if s[i] == ' ' { + if !lastIsSpace { + w.WriteByte(' ') + lastIsSpace = true + } + continue + } + lastIsSpace = false + w.WriteByte(s[i]) + } +} + +type debugHasher struct { + buf []byte +} + +func (dh *debugHasher) Write(b []byte) (int, error) { + dh.buf = append(dh.buf, b...) + return len(b), nil +} + +func (dh *debugHasher) Sum(b []byte) []byte { + return nil +} + +func (dh *debugHasher) Reset() { + // do nothing +} + +func (dh *debugHasher) Size() int { + return 0 +} + +func (dh *debugHasher) BlockSize() int { + return sha256.BlockSize +} + +func (dh *debugHasher) Println() { + fmt.Printf("---%s---\n", dh.buf) +} diff --git a/providers/bedrock/sigv4/key_deriver.go b/providers/bedrock/sigv4/key_deriver.go new file mode 100644 index 00000000..b2774c85 --- /dev/null +++ b/providers/bedrock/sigv4/key_deriver.go @@ -0,0 +1,125 @@ +package sigv4 + +import ( + "crypto/sha256" + "strings" + "sync" + "time" + gotime "time" +) + +var credScopeSuffixBytes = []byte{'a', 'w', 's', '4', '_', 'r', 'e', 'q', 'u', 'e', 's', 't'} + +// deriveKey calculates the signing key. See https://docs.aws.amazon.com/general/latest/gr/create-signed-request.html. +func deriveKey(secret, service, region string, t Time) []byte { + // enc( + // enc( + // enc( + // enc(AWS4, ), + // ), + // ), + // "aws4_request") + + // https://en.wikipedia.org/wiki/HMAC + // HMAC_SHA256 produces 32 bytes output + + f1 := len(secret) + 4 + f2 := f1 + len(t.ShortTimeFormat()) + f3 := f2 + len(region) + f4 := f3 + len(service) + + qs := make([]byte, 0, f4) + qs = append(qs, "AWS4"...) + qs = append(qs, secret...) + qs = append(qs, t.ShortTimeFormat()...) + qs = append(qs, region...) + qs = append(qs, service...) + + buf := make([]byte, 0, sha256.BlockSize) + buf = hmacsha256(qs[:f1], qs[f1:f2], buf) + buf = hmacsha256(buf, qs[f2:f3], buf[:0]) + buf = hmacsha256(buf, qs[f3:], buf[:0]) + return hmacsha256(buf, credScopeSuffixBytes, buf[:0]) +} + +// keyDeriver returns a signing key based on parameters such as credentials. +type keyDeriver interface { + DeriveKey(accessKey, secret, service, region string, sigtime Time) []byte +} + +// signingKeyDeriver is the default implementation of keyDerivator. +type signingKeyDeriver struct { + cache derivedKeyCache +} + +// newKeyDeriver creates a keyDeriver using the default implementation. The +// signing key is cached per region/service, and updated when accessKey changes +// or signingTime is not on the same day for that region/service. +func newKeyDeriver() keyDeriver { + return &signingKeyDeriver{cache: newDerivedKeyCache()} +} + +// DeriveKey returns a derived signing key from the given credentials to be used +// with SigV4 signing. +func (k *signingKeyDeriver) DeriveKey(accessKey, secret, service, region string, sigtime Time) []byte { + return k.cache.Get(accessKey, secret, service, region, sigtime) +} + +type derivedKeyCache struct { + mutex sync.RWMutex + values map[string]derivedKey + nowFunc func() gotime.Time +} + +type derivedKey struct { + Date gotime.Time + Credential []byte +} + +func newDerivedKeyCache() derivedKeyCache { + return derivedKeyCache{ + values: make(map[string]derivedKey), + nowFunc: gotime.Now, + } +} + +// Get returns key from cache or creates a new one. +func (s *derivedKeyCache) Get(accessKey, secret, service, region string, sigtime Time) []byte { + // /// + key := strings.Join([]string{accessKey, sigtime.ShortTimeFormat(), region, service}, "/") + + s.mutex.RLock() + cred, status := s.getFromCache(key) + s.mutex.RUnlock() + if status == 0 { + return cred + } + + cred = deriveKey(secret, service, region, sigtime) + + s.mutex.Lock() + if status == -1 { + delete(s.values, key) + } + s.values[key] = derivedKey{ + Date: sigtime.Time, + Credential: cred, + } + s.mutex.Unlock() + + return cred +} + +// getFromCache returns s.values[key]. Second result is 1 if key was not found, +// or -1 if the cached value has expired. +func (s *derivedKeyCache) getFromCache(key string) ([]byte, int) { + v, ok := s.values[key] + if !ok { + return nil, 1 + } + // evict from cache if item is a day older than system time + if s.nowFunc().Sub(v.Date) > 24*time.Hour { + return nil, -1 + } + return v.Credential, 0 +} diff --git a/providers/bedrock/sigv4/sign_time.go b/providers/bedrock/sigv4/sign_time.go new file mode 100644 index 00000000..7bf1b855 --- /dev/null +++ b/providers/bedrock/sigv4/sign_time.go @@ -0,0 +1,36 @@ +package sigv4 + +import ( + gotime "time" +) + +// Time wraps time.Time to cache its string format result. +type Time struct { + gotime.Time + short string + long string +} + +// NewTime creates a new signingTime with the specified time.Time. +func NewTime(t gotime.Time) Time { + return Time{Time: t.UTC()} +} + +// TimeFormat provides a time formatted in the X-Amz-Date format. +func (m *Time) TimeFormat() string { + return m.readOrFormat(&m.long, TimeFormat) +} + +// ShortTimeFormat provides a time formatted in short time format. +func (m *Time) ShortTimeFormat() string { + return m.readOrFormat(&m.short, ShortTimeFormat) +} + +func (m *Time) readOrFormat(target *string, format string) string { + if len(*target) > 0 { + return *target + } + v := m.Time.Format(format) + *target = v + return v +} diff --git a/providers/bedrock/sigv4/signer.go b/providers/bedrock/sigv4/signer.go new file mode 100644 index 00000000..bdde513d --- /dev/null +++ b/providers/bedrock/sigv4/signer.go @@ -0,0 +1,428 @@ +package sigv4 + +import ( + "bufio" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "errors" + "hash" + "net/http" + "net/textproto" + "net/url" + "sort" + "strconv" + "strings" +) + +// HTTPSigner is an AWS SigV4 signer that can sign HTTP requests. +type HTTPSigner interface { + // Sign AWS v4 requests with the provided payload hash, service name, region + // the request is made to, and time the request is signed at. Set sigtime + // to the future to create a request that cannot be used until the future time. + // + // payloadHash is the hex encoded SHA-256 hash of the request payload, and must + // not be empty, even if the request has no payload (aka body). If the request + // has no payload, use the hex encoded SHA-256 of an empty string, or the constant + // EmptyStringSHA256. You can use the utility function ContentSHA256Sum to + // calculate the hash of a http.Request body. + // + // Some services such as Amazon S3 accept alternative values for the payload + // hash, such as "UNSIGNED-PAYLOAD" for requests where the body will not be + // protected by sigv4. See https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html. + // + // Sign differs from Presign in that it will sign the request using HTTP headers. + // The passed in request r will be modified in place: modified fields include + // r.Host and r.Header. + Sign(r *http.Request, payloadHash string, sigtime Time) error + // Presign is like Sign, but does not modify request r. It returns a copy of + // r.URL with additional query parameters that contains signing information. + // The URL can be used to recreate an authenticated request without specifying + // headers. It also returns http.Header as a second result, which must be + // included in the reconstructed request. + // + // Header hoisting: use WithHeaderHoisting option function to specify whether + // headers in request r should be added as query parameters. Some headers cannot + // be hoisted, and are returned as the second result. + // + // Presign will not set the expires time of the presigned request automatically. + // To specify the expire duration for a request, add the "X-Amz-Expires" query + // parameter on the request with the value as the duration in seconds the + // presigned URL should be considered valid for. This parameter is not used + // by all AWS services, and is most notable used by Amazon S3 APIs. + // + // expires := 20*time.Minute + // query := req.URL.Query() + // query.Set("X-Amz-Expires", strconv.FormatInt(int64(expires/time.Second), 10) + // req.URL.RawQuery = query.Encode() + Presign(r *http.Request, payloadHash string, sigtime Time) (*url.URL, http.Header, error) +} + +// HTTPSignerOption is an option parameter for HTTPSigner constructor function. +type HTTPSignerOption func(HTTPSigner) error + +// ErrInvalidOption means the option parameter is incompatible with the HTTPSigner. +var ErrInvalidOption = errors.New("cannot apply option to HTTPSigner") + +// httpV4Signer is the default implementation of HTTPSigner. +type httpV4Signer struct { + KeyDeriver keyDeriver + AccessKey string + Secret string + SessionToken string + Service string + Region string + HeaderHoisting bool + EscapeURLPath bool +} + +// WithCredential sets HTTPSigner credential fields. +func WithCredential(accessKey, secret, sessionToken string) HTTPSignerOption { + return func(signer HTTPSigner) error { + if sigv4, ok := signer.(*httpV4Signer); ok { + sigv4.AccessKey = accessKey + sigv4.Secret = secret + sigv4.SessionToken = sessionToken + return nil + } + return ErrInvalidOption + } +} + +// WithHeaderHoisting specifies whether HTTPSigner automatically hoist headers. +// Default is enabled. +func WithHeaderHoisting(enable bool) HTTPSignerOption { + return func(signer HTTPSigner) error { + if sigv4, ok := signer.(*httpV4Signer); ok { + sigv4.HeaderHoisting = enable + return nil + } + return ErrInvalidOption + } +} + +// WithEscapeURLPath specifies whether HTTPSigner automatically escapes URL paths. +// Default is enabled. +func WithEscapeURLPath(enable bool) HTTPSignerOption { + return func(signer HTTPSigner) error { + if sigv4, ok := signer.(*httpV4Signer); ok { + sigv4.EscapeURLPath = enable + return nil + } + return ErrInvalidOption + } +} + +// WithRegionService sets HTTPSigner region and service fields. +func WithRegionService(region, service string) HTTPSignerOption { + return func(signer HTTPSigner) error { + if sigv4, ok := signer.(*httpV4Signer); ok { + sigv4.Region = region + sigv4.Service = service + return nil + } + return ErrInvalidOption + } +} + +// New creates a HTTPSigner. +func New(opts ...HTTPSignerOption) (HTTPSigner, error) { + sigv4 := &httpV4Signer{ + KeyDeriver: newKeyDeriver(), + EscapeURLPath: true, + HeaderHoisting: true, + } + for _, o := range opts { + if o == nil { + continue + } + if err := o(sigv4); err != nil { + return nil, err + } + } + return sigv4, nil +} + +// Sign implements HTTPSigner. +func (s *httpV4Signer) Sign(r *http.Request, payloadHash string, sigtime Time) error { + if payloadHash == "" { + var err error + payloadHash, err = ContentSHA256Sum(r) + if err != nil { + return err + } + } + + // add mandatory headers to r.Header + setRequiredSigningHeaders(r.Header, sigtime, s.SessionToken) + // remove port in r.Host if any + r.Host = sanitizeHostForHeader(r) + + // parse URL query only once + query := r.URL.Query() + // sigBuf is used to act as a sha256 hash buffer + sigBuf := make([]byte, 0, sha256.Size) + + //hasher := &debugHasher{} + hasher := sha256.New() + reqhash, signedHeaderStr := canonicalRequestHash(hasher, r, r.Header, query, + r.Host, payloadHash, s.EscapeURLPath, false, sigBuf) + + credentialScope := strings.Join([]string{ + sigtime.ShortTimeFormat(), + s.Region, + s.Service, + "aws4_request", + }, "/") + + keyBytes := s.KeyDeriver.DeriveKey(s.AccessKey, s.Secret, s.Service, + s.Region, sigtime) + sigHasher := hmac.New(sha256.New, keyBytes) + signature := authorizationSignature(sigHasher, sigtime, credentialScope, reqhash, sigBuf) + + writeAuthorizationHeader(r.Header, s.AccessKey+"/"+credentialScope, + signedHeaderStr, signature) + + // done + return nil +} + +// Presign implements HTTPSigner. +func (s *httpV4Signer) Presign(r *http.Request, payloadHash string, sigtime Time) (*url.URL, http.Header, error) { + if payloadHash == "" { + var err error + payloadHash, err = ContentSHA256Sum(r) + if err != nil { + return nil, nil, err + } + } + + query := r.URL.Query() + setRequiredSigningQuery(query, sigtime, s.SessionToken) + // sort each query key's values + for key := range query { + sort.Strings(query[key]) + } + + credentialScope := strings.Join([]string{ + sigtime.ShortTimeFormat(), + s.Region, + s.Service, + "aws4_request", + }, "/") + credentialStr := s.AccessKey + "/" + credentialScope + query.Set(AmzCredentialKey, credentialStr) + + var headersLeft http.Header + if s.HeaderHoisting { + headersLeft = make(http.Header, len(r.Header)) + for k, v := range r.Header { + if isAllowQueryHoisting(k) { + query[k] = v + } else { + headersLeft[k] = v + } + } + } + + // sigBuf is used to act as a sha256 hash buffer + sigBuf := make([]byte, 0, sha256.Size) + + hasher := sha256.New() + reqhash, signedHeaderStr := canonicalRequestHash(hasher, r, headersLeft, + query, sanitizeHostForHeader(r), payloadHash, s.EscapeURLPath, true, sigBuf) + + keyBytes := s.KeyDeriver.DeriveKey(s.AccessKey, s.Secret, s.Service, + s.Region, sigtime) + sigHasher := hmac.New(sha256.New, keyBytes) + signature := authorizationSignature(sigHasher, sigtime, credentialScope, reqhash, sigBuf) + query.Set(AmzSignatureKey, signature) + + u := cloneURL(r.URL) + u.RawQuery = strings.Replace(query.Encode(), "+", "%20", -1) + + // For the signed headers we canonicalize the header keys in the returned map. + // This avoids situations where standard library can sometimes add double + // headers. For example, the standard library will set the Host header, + // even if it is present in lower-case form. + signedHeader := strings.Split(signedHeaderStr, ";") + canonHeader := make(http.Header, len(signedHeader)) + for _, k := range signedHeader { + canonKey := textproto.CanonicalMIMEHeaderKey(k) + switch k { + case "host": + canonHeader[canonKey] = []string{sanitizeHostForHeader(r)} + case "content-length": + canonHeader[canonKey] = []string{strconv.FormatInt(r.ContentLength, 10)} + default: + canonHeader[canonKey] = append(canonHeader[canonKey], headersLeft[http.CanonicalHeaderKey(k)]...) + } + } + return u, canonHeader, nil +} + +// authorizationSignature returns `sig` as documented in step 4 of algorithm +// documentation. key is hSig in step 4. It calculates the result of step 3 +// internally. +func authorizationSignature(hasher hash.Hash, sigtime Time, credScope, requestHash string, buf []byte) string { + w := bufio.NewWriterSize(hasher, sha256.BlockSize) + + w.WriteString(SigningAlgorithm) + w.WriteByte('\n') + w.WriteString(sigtime.TimeFormat()) + w.WriteByte('\n') + w.WriteString(credScope) + w.WriteByte('\n') + w.WriteString(requestHash) + + w.Flush() // VERY IMPORTANT! Don't forget to flush remaining buffer + //hasher.Println() + return hex.EncodeToString(hasher.Sum(buf[:0])) +} + +// canonicalRequestHash returns the hex-encoded sha256 sum of the canonical +// request string. Refer to step 2 of algorithm documentation. Expect hasher to +// be sha256.New. +func canonicalRequestHash( + hasher hash.Hash, r *http.Request, headers http.Header, query url.Values, + hostname, hashcode string, escapeURL, isPresign bool, buf []byte, +) (string, string) { + w := bufio.NewWriterSize(hasher, sha256.BlockSize) + + signedHeaders := make([]string, 0, len(headers)+2) + signedHeaders = append(signedHeaders, "host") + if r.ContentLength > 0 { + signedHeaders = append(signedHeaders, "content-length") + } + for k := range headers { + if strings.EqualFold(k, "content-length") || strings.EqualFold(k, "host") || isIgnoredHeader(k) { + continue + } + signedHeaders = append(signedHeaders, strings.ToLower(k)) + } + sort.Strings(signedHeaders) + signedHeaderStr := strings.Join(signedHeaders, ";") + + // for presigned requests, we need to add X-Amz-SignedHeaders to calculate the + // correct hash + if isPresign { + query.Set(AmzSignedHeadersKey, signedHeaderStr) + } + + // \n\n\n\n\n + + // HTTP_METHOD + w.WriteString(r.Method) + w.WriteByte('\n') + // CANONICAL_URI + writeAWSURIPath(w, r.URL, false, !escapeURL) + w.WriteByte('\n') + // CANONICAL_QUERY_PARAMS + writeCanonicalQueryParams(w, query) + w.WriteByte('\n') + // CANONICAL_HEADERS + for _, head := range signedHeaders { + switch head { + case "host": + w.WriteString(head) + w.WriteByte(':') + writeCanonicalString(w, hostname) + w.WriteByte('\n') + case "content-length": + w.WriteString(head) + w.WriteByte(':') + w.WriteString(strconv.FormatInt(r.ContentLength, 10)) + w.WriteByte('\n') + default: + w.WriteString(head) + w.WriteByte(':') + values := headers[http.CanonicalHeaderKey(head)] + for i, v := range values { + if i != 0 { + w.WriteByte(',') + } + writeCanonicalString(w, v) + } + w.WriteByte('\n') + } + } + w.WriteByte('\n') + // SIGNED_HEADERS + w.WriteString(signedHeaderStr) + w.WriteByte('\n') + // PAYLOAD_HASH + w.WriteString(hashcode) + + w.Flush() // VERY IMPORTANT! Don't forget to flush remaining buffer + //hasher.Println() + return hex.EncodeToString(hasher.Sum(buf[:0])), signedHeaderStr +} + +// writeAuthorizationHeader writes the Authorization header into header: +// +// AWS4-HMAC-SHA256 Credential=, SignedHeaders=, Signature= +func writeAuthorizationHeader(headers http.Header, credentialStr, signedHeaders, signature string) { + const credentialPrefix = "Credential=" + const signedHeadersPrefix = "SignedHeaders=" + const signaturePrefix = "Signature=" + const commaSpace = ", " + + var parts strings.Builder + parts.Grow(len(SigningAlgorithm) + 1 + + len(credentialPrefix) + len(credentialStr) + 2 + + len(signedHeadersPrefix) + len(signedHeaders) + 2 + + len(signaturePrefix) + len(signature)) + + parts.WriteString(SigningAlgorithm) + parts.WriteRune(' ') + parts.WriteString(credentialPrefix) + parts.WriteString(credentialStr) + parts.WriteString(commaSpace) + parts.WriteString(signedHeadersPrefix) + parts.WriteString(signedHeaders) + parts.WriteString(commaSpace) + parts.WriteString(signaturePrefix) + parts.WriteString(signature) + + headers[authorizationHeader] = append(headers[authorizationHeader][:0], + parts.String()) +} + +// helpers + +// sanitizeHostForHeader is like hostOrURLHost, but without port if port is the +// default port for the scheme. For example, it removes ":80" suffix if the scheme +// is "http". +func sanitizeHostForHeader(r *http.Request) string { + host := hostOrURLHost(r) + port := parsePort(host) + if port != "" && isDefaultPort(r.URL.Scheme, port) { + return stripPort(host) + } + return host +} + +// setRequiredSigningHeaders modifies headers: sets X-Amz-Date to sigtime, and +// if credToken is non-empty, set X-Amz-Security-Token to credToken. This function +// overwrites existing headers values with the same key. +func setRequiredSigningHeaders(headers http.Header, sigtime Time, sessionToken string) { + amzDate := sigtime.TimeFormat() + headers[AmzDateKey] = append(headers[AmzDateKey][:0], amzDate) + if sessionToken != "" { + headers[AmzSecurityTokenKey] = append(headers[AmzSecurityTokenKey][:0], + sessionToken) + } +} + +// setRequiredSigningQuery is like setRequiredSigningHeaders, but modifies +// query values. This is used for presign requests. +func setRequiredSigningQuery(query url.Values, sigtime Time, sessionToken string) { + query.Set(AmzAlgorithmKey, SigningAlgorithm) + + amzDate := sigtime.TimeFormat() + query.Set(AmzDateKey, amzDate) + + if sessionToken != "" { + query.Set(AmzSecurityTokenKey, sessionToken) + } +} diff --git a/providers/bedrock/sigv4/util.go b/providers/bedrock/sigv4/util.go new file mode 100644 index 00000000..b1a8e0d8 --- /dev/null +++ b/providers/bedrock/sigv4/util.go @@ -0,0 +1,26 @@ +package sigv4 + +import ( + "crypto/sha256" + "encoding/hex" + "io" + "net/http" +) + +// ContentSHA256Sum calculates the hex-encoded SHA256 checksum of r.Body. It returns +// EmptyStringSHA256 if r.Body is nil, r.Method is TRACE or r.ContentLength is zero. +// Returns non-nil error if r.Body cannot be read. +func ContentSHA256Sum(r *http.Request) (string, error) { + // We need to check r.Body is non-nil, because io.Copy(dst, nil) panics. + // This is not documented in https://pkg.go.dev/io#Copy. + if r.Method == http.MethodTrace || r.ContentLength == 0 || r.Body == nil { + return EmptyStringSHA256, nil + } + + h := sha256.New() + _, err := io.Copy(h, r.Body) + if err != nil { + return "", err + } + return hex.EncodeToString(h.Sum(nil)), nil +} diff --git a/providers/bedrock/stream_reader.go b/providers/bedrock/stream_reader.go new file mode 100644 index 00000000..835753f7 --- /dev/null +++ b/providers/bedrock/stream_reader.go @@ -0,0 +1,128 @@ +package bedrock + +import ( + "bufio" + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "one-api/common" + "one-api/common/requester" + "one-api/types" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi" + "github.com/aws/smithy-go" +) + +type streamReader[T any] struct { + reader *bufio.Reader + response *http.Response + + handlerPrefix requester.HandlerPrefix[T] + + DataChan chan T + ErrChan chan error +} + +func (stream *streamReader[T]) Recv() (<-chan T, <-chan error) { + go stream.processLines() + + return stream.DataChan, stream.ErrChan +} + +//nolint:gocognit +func (stream *streamReader[T]) processLines() { + decode := eventstream.NewDecoder() + payloadBuf := make([]byte, 0*1024) + for { + payloadBuf = payloadBuf[0:0] + messgae, readErr := decode.Decode(stream.reader, payloadBuf) + if readErr != nil { + stream.ErrChan <- readErr + return + } + + line, err := stream.deserializeEventMessage(&messgae) + if err != nil { + stream.ErrChan <- common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) + return + } + + stream.handlerPrefix(&line, stream.DataChan, stream.ErrChan) + + if line == nil { + continue + } + + if bytes.Equal(line, requester.StreamClosed) { + return + } + } +} + +func (stream *streamReader[T]) Close() { + stream.response.Body.Close() +} + +func (stream *streamReader[T]) deserializeEventMessage(msg *eventstream.Message) ([]byte, error) { + messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader) + if messageType == nil { + return nil, fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader) + } + + switch messageType.String() { + case eventstreamapi.EventMessageType: + var v BedrockResponseStream + if err := json.Unmarshal(msg.Payload, &v); err != nil { + return nil, err + } + buffer, err := base64.StdEncoding.DecodeString(v.Bytes) + if err != nil { + return nil, err + } + return buffer, nil + + case eventstreamapi.ExceptionMessageType: + exceptionType := msg.Headers.Get(eventstreamapi.ExceptionTypeHeader) + return nil, errors.New("Exception message :" + exceptionType.String()) + + case eventstreamapi.ErrorMessageType: + errorCode := "UnknownError" + errorMessage := errorCode + if header := msg.Headers.Get(eventstreamapi.ErrorCodeHeader); header != nil { + errorCode = header.String() + } + if header := msg.Headers.Get(eventstreamapi.ErrorMessageHeader); header != nil { + errorMessage = header.String() + } + return nil, &smithy.GenericAPIError{ + Code: errorCode, + Message: errorMessage, + } + + default: + return nil, errors.New("bedrock stream unknown error") + } +} + +func RequestStream[T any](resp *http.Response, handlerPrefix requester.HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) { + // 如果返回的头是json格式 说明有错误 + if strings.Contains(resp.Header.Get("Content-Type"), "application/json") { + return nil, requester.HandleErrorResp(resp, requestErrorHandle) + } + + stream := &streamReader[T]{ + reader: bufio.NewReader(resp.Body), + response: resp, + handlerPrefix: handlerPrefix, + + DataChan: make(chan T), + ErrChan: make(chan error), + } + + return stream, nil +} diff --git a/providers/bedrock/type.go b/providers/bedrock/type.go new file mode 100644 index 00000000..f49a07b4 --- /dev/null +++ b/providers/bedrock/type.go @@ -0,0 +1,11 @@ +package bedrock + +const awsService = "bedrock" + +type BedrockError struct { + Message string `json:"message"` +} + +type BedrockResponseStream struct { + Bytes string `json:"bytes"` +} diff --git a/providers/claude/chat.go b/providers/claude/chat.go index 29bee010..3434bfa3 100644 --- a/providers/claude/chat.go +++ b/providers/claude/chat.go @@ -8,13 +8,17 @@ import ( "one-api/common" "one-api/common/image" "one-api/common/requester" + "one-api/providers/base" "one-api/types" "strings" + + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" ) -type claudeStreamHandler struct { +type ClaudeStreamHandler struct { Usage *types.Usage Request *types.ChatCompletionRequest + Prefix string } func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { @@ -31,7 +35,7 @@ func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionReque return nil, errWithCode } - return p.convertToChatOpenai(claudeResponse, request) + return ConvertToChatOpenai(p, claudeResponse, request) } func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { @@ -47,12 +51,15 @@ func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletio return nil, errWithCode } - chatHandler := &claudeStreamHandler{ + chatHandler := &ClaudeStreamHandler{ Usage: p.Usage, Request: request, + Prefix: `data: {"type"`, } - return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream) + eventstream.NewDecoder() + + return requester.RequestStream(p.Requester, resp, chatHandler.HandlerStream) } func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { @@ -72,7 +79,7 @@ func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (* headers["Accept"] = "text/event-stream" } - claudeRequest, errWithCode := convertFromChatOpenai(request) + claudeRequest, errWithCode := ConvertFromChatOpenai(request) if errWithCode != nil { return nil, errWithCode } @@ -86,7 +93,7 @@ func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (* return req, nil } -func convertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest, *types.OpenAIErrorWithStatusCode) { +func ConvertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest, *types.OpenAIErrorWithStatusCode) { claudeRequest := ClaudeRequest{ Model: request.Model, Messages: []Message{}, @@ -142,7 +149,7 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest return &claudeRequest, nil } -func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { +func ConvertToChatOpenai(provider base.ProviderInterface, response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { error := errorHandle(&response.Error) if error != nil { errWithCode = &types.OpenAIErrorWithStatusCode{ @@ -182,21 +189,24 @@ func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request * openaiResponse.Usage.CompletionTokens = completionTokens openaiResponse.Usage.TotalTokens = promptTokens + completionTokens - *p.Usage = *openaiResponse.Usage + usage := provider.GetUsage() + *usage = *openaiResponse.Usage return openaiResponse, nil } // 转换为OpenAI聊天流式请求体 -func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { +func (h *ClaudeStreamHandler) HandlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { // 如果rawLine 前缀不为data:,则直接返回 - if !strings.HasPrefix(string(*rawLine), `data: {"type"`) { + if !strings.HasPrefix(string(*rawLine), h.Prefix) { *rawLine = nil return } - // 去除前缀 - *rawLine = (*rawLine)[6:] + if strings.HasPrefix(string(*rawLine), "data: ") { + // 去除前缀 + *rawLine = (*rawLine)[6:] + } var claudeResponse ClaudeStreamResponse err := json.Unmarshal(*rawLine, &claudeResponse) @@ -235,7 +245,7 @@ func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin } } -func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStreamResponse, dataChan chan string) { +func (h *ClaudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStreamResponse, dataChan chan string) { choice := types.ChatCompletionStreamChoice{ Index: claudeResponse.Index, } diff --git a/providers/claude/type.go b/providers/claude/type.go index 810641b7..06b38663 100644 --- a/providers/claude/type.go +++ b/providers/claude/type.go @@ -32,7 +32,7 @@ type Message struct { } type ClaudeRequest struct { - Model string `json:"model"` + Model string `json:"model,omitempty"` System string `json:"system,omitempty"` Messages []Message `json:"messages"` MaxTokens int `json:"max_tokens"` diff --git a/providers/providers.go b/providers/providers.go index 4c1c11fe..e12efc8d 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -13,6 +13,7 @@ import ( "one-api/providers/baichuan" "one-api/providers/baidu" "one-api/providers/base" + "one-api/providers/bedrock" "one-api/providers/claude" "one-api/providers/closeai" "one-api/providers/deepseek" @@ -62,6 +63,7 @@ func init() { providerFactories[common.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{} providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{} providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{} + providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{} } diff --git a/web/src/constants/ChannelConstants.js b/web/src/constants/ChannelConstants.js index f4f70298..b8145088 100644 --- a/web/src/constants/ChannelConstants.js +++ b/web/src/constants/ChannelConstants.js @@ -101,6 +101,12 @@ export const CHANNEL_OPTIONS = { value: 31, color: 'primary' }, + 32: { + key: 32, + text: 'Amazon Bedrock', + value: 32, + color: 'orange' + }, 24: { key: 24, text: 'Azure Speech', diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js index 813fcf3d..63bffe56 100644 --- a/web/src/views/Channel/type/Config.js +++ b/web/src/views/Channel/type/Config.js @@ -60,8 +60,15 @@ const typeConfig = { }, 14: { input: { - models: ['claude-instant-1.2', 'claude-2.0', 'claude-2.1', 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'], - test_model: 'claude-3-sonnet-20240229' + models: [ + 'claude-instant-1.2', + 'claude-2.0', + 'claude-2.1', + 'claude-3-opus-20240229', + 'claude-3-sonnet-20240229', + 'claude-3-haiku-20240307' + ], + test_model: 'claude-3-haiku-20240307' }, modelGroup: 'Anthropic' }, @@ -202,6 +209,23 @@ const typeConfig = { test_model: 'llama2-7b-2048' }, modelGroup: 'Groq' + }, + 32: { + input: { + models: [ + 'claude-instant-1.2', + 'claude-2.0', + 'claude-2.1', + 'claude-3-opus-20240229', + 'claude-3-sonnet-20240229', + 'claude-3-haiku-20240307' + ], + test_model: 'claude-3-haiku-20240307' + }, + prompt: { + key: '按照如下格式输入:Region|AccessKeyID|SecretAccessKey|SessionToken 其中SessionToken可不填空' + }, + modelGroup: 'Anthropic' } };