✨ feat: support amazon bedrock anthropic (#114)
* 🚧 WIP: bedrock * ✨ feat: support amazon bedrock anthropic
This commit is contained in:
parent
6f76007292
commit
b81808e839
@ -198,6 +198,7 @@ const (
|
|||||||
ChannelTypeMoonshot = 29
|
ChannelTypeMoonshot = 29
|
||||||
ChannelTypeMistral = 30
|
ChannelTypeMistral = 30
|
||||||
ChannelTypeGroq = 31
|
ChannelTypeGroq = 31
|
||||||
|
ChannelTypeBedrock = 32
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
@ -232,7 +233,8 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://api.deepseek.com", //28
|
"https://api.deepseek.com", //28
|
||||||
"https://api.moonshot.cn", //29
|
"https://api.moonshot.cn", //29
|
||||||
"https://api.mistral.ai", //30
|
"https://api.mistral.ai", //30
|
||||||
"https://api.groq.com/openai", //30
|
"https://api.groq.com/openai", //31
|
||||||
|
"", //32
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -89,10 +89,14 @@ func init() {
|
|||||||
// $0.80/million tokens $2.40/million tokens
|
// $0.80/million tokens $2.40/million tokens
|
||||||
"claude-instant-1.2": {[]float64{0.4, 1.2}, ChannelTypeAnthropic},
|
"claude-instant-1.2": {[]float64{0.4, 1.2}, ChannelTypeAnthropic},
|
||||||
// $8.00/million tokens $24.00/million tokens
|
// $8.00/million tokens $24.00/million tokens
|
||||||
"claude-2.0": {[]float64{4, 12}, ChannelTypeAnthropic},
|
"claude-2.0": {[]float64{4, 12}, ChannelTypeAnthropic},
|
||||||
"claude-2.1": {[]float64{4, 12}, ChannelTypeAnthropic},
|
"claude-2.1": {[]float64{4, 12}, ChannelTypeAnthropic},
|
||||||
"claude-3-opus-20240229": {[]float64{7.5, 22.5}, ChannelTypeAnthropic},
|
// $15 / M $75 / M
|
||||||
|
"claude-3-opus-20240229": {[]float64{7.5, 22.5}, ChannelTypeAnthropic},
|
||||||
|
// $3 / M $15 / M
|
||||||
"claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, ChannelTypeAnthropic},
|
"claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, ChannelTypeAnthropic},
|
||||||
|
// $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens
|
||||||
|
"claude-3-haiku-20240307": {[]float64{0.125, 0.625}, ChannelTypeAnthropic},
|
||||||
|
|
||||||
// ¥0.004 / 1k tokens ¥0.008 / 1k tokens
|
// ¥0.004 / 1k tokens ¥0.008 / 1k tokens
|
||||||
"ERNIE-Speed": {[]float64{0.2857, 0.5714}, ChannelTypeBaidu},
|
"ERNIE-Speed": {[]float64{0.2857, 0.5714}, ChannelTypeBaidu},
|
||||||
|
9
go.mod
9
go.mod
@ -4,6 +4,7 @@ module one-api
|
|||||||
go 1.18
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1
|
||||||
github.com/gin-contrib/cors v1.4.0
|
github.com/gin-contrib/cors v1.4.0
|
||||||
github.com/gin-contrib/gzip v0.0.6
|
github.com/gin-contrib/gzip v0.0.6
|
||||||
github.com/gin-contrib/sessions v0.0.5
|
github.com/gin-contrib/sessions v0.0.5
|
||||||
@ -24,8 +25,10 @@ require (
|
|||||||
gorm.io/gorm v1.25.0
|
gorm.io/gorm v1.25.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
require github.com/aws/smithy-go v1.20.1 // indirect
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24 // indirect
|
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24
|
||||||
github.com/bytedance/sonic v1.9.1 // indirect
|
github.com/bytedance/sonic v1.9.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||||
@ -46,7 +49,7 @@ require (
|
|||||||
github.com/jackc/pgx/v5 v5.3.1 // indirect
|
github.com/jackc/pgx/v5 v5.3.1 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/joho/godotenv v1.5.1 // indirect
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||||
github.com/leodido/go-urn v1.2.4 // indirect
|
github.com/leodido/go-urn v1.2.4 // indirect
|
||||||
@ -59,7 +62,7 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||||
golang.org/x/arch v0.3.0 // indirect
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
golang.org/x/net v0.19.0 // indirect
|
golang.org/x/net v0.19.0
|
||||||
golang.org/x/sys v0.15.0 // indirect
|
golang.org/x/sys v0.15.0 // indirect
|
||||||
golang.org/x/text v0.14.0 // indirect
|
golang.org/x/text v0.14.0 // indirect
|
||||||
google.golang.org/protobuf v1.30.0 // indirect
|
google.golang.org/protobuf v1.30.0 // indirect
|
||||||
|
14
go.sum
14
go.sum
@ -1,5 +1,9 @@
|
|||||||
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24 h1:1T7RcpzlldaJ3qpZi0lNg/lBsfPCK+8n8Wc+R8EhAkU=
|
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24 h1:1T7RcpzlldaJ3qpZi0lNg/lBsfPCK+8n8Wc+R8EhAkU=
|
||||||
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24/go.mod h1:kL1v4iIjlalwm3gCYGvF4NLa3hs+aKEfRkNJvj4aoDU=
|
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24/go.mod h1:kL1v4iIjlalwm3gCYGvF4NLa3hs+aKEfRkNJvj4aoDU=
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1 h1:gTK2uhtAPtFcdRRJilZPx8uJLL2J85xK11nKtWL0wfU=
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1/go.mod h1:sxpLb+nZk7tIfCWChfd+h4QwHNUR57d8hA1cleTkjJo=
|
||||||
|
github.com/aws/smithy-go v1.20.1 h1:4SZlSlMr36UEqC7XOyRVb27XMeZubNcBNN+9IgEPIQw=
|
||||||
|
github.com/aws/smithy-go v1.20.1/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
||||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||||
@ -49,8 +53,6 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg
|
|||||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||||
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
|
||||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
|
||||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||||
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
@ -58,6 +60,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
|||||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||||
|
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||||
|
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
||||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||||
@ -111,6 +115,7 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
|||||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||||
|
github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
@ -161,8 +166,6 @@ golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq
|
|||||||
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
|
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
|
||||||
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
|
||||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
|
||||||
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
|
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
|
||||||
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
|
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
|
||||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@ -203,14 +206,13 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco=
|
gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco=
|
||||||
gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04=
|
gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04=
|
||||||
gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k=
|
|
||||||
gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
|
|
||||||
gorm.io/driver/mysql v1.4.7 h1:rY46lkCspzGHn7+IYsNpSfEv9tA+SU4SkkB+GFX125Y=
|
gorm.io/driver/mysql v1.4.7 h1:rY46lkCspzGHn7+IYsNpSfEv9tA+SU4SkkB+GFX125Y=
|
||||||
gorm.io/driver/mysql v1.4.7/go.mod h1:SxzItlnT1cb6e1e4ZRpgJN2VYtcqJgqnHxWr4wsP8oc=
|
gorm.io/driver/mysql v1.4.7/go.mod h1:SxzItlnT1cb6e1e4ZRpgJN2VYtcqJgqnHxWr4wsP8oc=
|
||||||
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
|
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
|
||||||
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
|
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
|
||||||
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
|
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
|
||||||
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
|
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
|
||||||
|
gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0=
|
||||||
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
|
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
|
||||||
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
|
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
|
||||||
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
|
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
|
||||||
|
127
providers/bedrock/base.go
Normal file
127
providers/bedrock/base.go
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
package bedrock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common/requester"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"one-api/providers/bedrock/category"
|
||||||
|
"one-api/providers/bedrock/sigv4"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BedrockProviderFactory struct{}
|
||||||
|
|
||||||
|
// 创建 BedrockProvider
|
||||||
|
func (f BedrockProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||||
|
|
||||||
|
bedrockProvider := &BedrockProvider{
|
||||||
|
BaseProvider: base.BaseProvider{
|
||||||
|
Config: getConfig(),
|
||||||
|
Channel: channel,
|
||||||
|
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
getKeyConfig(bedrockProvider)
|
||||||
|
|
||||||
|
return bedrockProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
type BedrockProvider struct {
|
||||||
|
base.BaseProvider
|
||||||
|
Region string
|
||||||
|
AccessKeyID string
|
||||||
|
SecretAccessKey string
|
||||||
|
SessionToken string
|
||||||
|
Category *category.Category
|
||||||
|
}
|
||||||
|
|
||||||
|
func getConfig() base.ProviderConfig {
|
||||||
|
return base.ProviderConfig{
|
||||||
|
BaseURL: "https://bedrock-runtime.%s.amazonaws.com",
|
||||||
|
ChatCompletions: "/model/%s/invoke",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 请求错误处理
|
||||||
|
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||||
|
bedrockError := &BedrockError{}
|
||||||
|
err := json.NewDecoder(resp.Body).Decode(bedrockError)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return errorHandle(bedrockError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 错误处理
|
||||||
|
func errorHandle(bedrockError *BedrockError) *types.OpenAIError {
|
||||||
|
if bedrockError.Message == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &types.OpenAIError{
|
||||||
|
Message: bedrockError.Message,
|
||||||
|
Type: "Bedrock Error",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BedrockProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
|
||||||
|
return fmt.Sprintf(baseURL+requestURL, p.Region, modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BedrockProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
|
headers = make(map[string]string)
|
||||||
|
p.CommonRequestHeaders(headers)
|
||||||
|
headers["Accept"] = "*/*"
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func getKeyConfig(bedrock *BedrockProvider) {
|
||||||
|
keys := strings.Split(bedrock.Channel.Key, "|")
|
||||||
|
if len(keys) < 3 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
bedrock.Region = keys[0]
|
||||||
|
bedrock.AccessKeyID = keys[1]
|
||||||
|
bedrock.SecretAccessKey = keys[2]
|
||||||
|
if len(keys) == 4 && keys[3] != "" {
|
||||||
|
bedrock.SessionToken = keys[3]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BedrockProvider) Sign(req *http.Request) error {
|
||||||
|
var body []byte
|
||||||
|
if req.Body == nil {
|
||||||
|
body = []byte("")
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
body, err = io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("error getting request body: " + err.Error())
|
||||||
|
}
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
}
|
||||||
|
sig, err := sigv4.New(sigv4.WithCredential(p.AccessKeyID, p.SecretAccessKey, p.SessionToken), sigv4.WithRegionService(p.Region, awsService))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBodyHashHex := fmt.Sprintf("%x", sha256.Sum256(body))
|
||||||
|
sig.Sign(req, reqBodyHashHex, sigv4.NewTime(time.Now()))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
54
providers/bedrock/category/base.go
Normal file
54
providers/bedrock/category/base.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
package category
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common/requester"
|
||||||
|
"one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var CategoryMap = map[string]Category{}
|
||||||
|
|
||||||
|
type Category struct {
|
||||||
|
ModelName string
|
||||||
|
ChatComplete ChatCompletionConvert
|
||||||
|
ResponseChatComplete ChatCompletionResponse
|
||||||
|
ResponseChatCompleteStrem ChatCompletionStreamResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCategory(modelName string) (*Category, error) {
|
||||||
|
modelName = GetModelName(modelName)
|
||||||
|
|
||||||
|
// 点分割
|
||||||
|
provider := strings.Split(modelName, ".")[0]
|
||||||
|
|
||||||
|
if category, exists := CategoryMap[provider]; exists {
|
||||||
|
category.ModelName = modelName
|
||||||
|
return &category, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("category_not_found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetModelName(modelName string) string {
|
||||||
|
bedrockMap := map[string]string{
|
||||||
|
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"claude-2.1": "anthropic.claude-v2:1",
|
||||||
|
"claude-2.0": "anthropic.claude-v2",
|
||||||
|
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, exists := bedrockMap[modelName]; exists {
|
||||||
|
modelName = value
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelName
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionConvert func(*types.ChatCompletionRequest) (any, *types.OpenAIErrorWithStatusCode)
|
||||||
|
type ChatCompletionResponse func(base.ProviderInterface, *http.Response, *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode)
|
||||||
|
|
||||||
|
type ChatCompletionStreamResponse func(base.ProviderInterface, *types.ChatCompletionRequest) requester.HandlerPrefix[string]
|
63
providers/bedrock/category/claude.go
Normal file
63
providers/bedrock/category/claude.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package category
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/common/requester"
|
||||||
|
"one-api/providers/base"
|
||||||
|
"one-api/providers/claude"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const anthropicVersion = "bedrock-2023-05-31"
|
||||||
|
|
||||||
|
type ClaudeRequest struct {
|
||||||
|
*claude.ClaudeRequest
|
||||||
|
AnthropicVersion string `json:"anthropic_version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
CategoryMap["anthropic"] = Category{
|
||||||
|
ChatComplete: ConvertClaudeFromChatOpenai,
|
||||||
|
ResponseChatComplete: ConvertClaudeToChatOpenai,
|
||||||
|
ResponseChatCompleteStrem: ClaudeChatCompleteStrem,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertClaudeFromChatOpenai(request *types.ChatCompletionRequest) (any, *types.OpenAIErrorWithStatusCode) {
|
||||||
|
rawRequest, err := claude.ConvertFromChatOpenai(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeRequest := &ClaudeRequest{}
|
||||||
|
claudeRequest.ClaudeRequest = rawRequest
|
||||||
|
claudeRequest.AnthropicVersion = anthropicVersion
|
||||||
|
|
||||||
|
// 删除model字段
|
||||||
|
claudeRequest.Model = ""
|
||||||
|
claudeRequest.Stream = false
|
||||||
|
|
||||||
|
return claudeRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertClaudeToChatOpenai(provider base.ProviderInterface, response *http.Response, request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||||
|
claudeResponse := &claude.ClaudeResponse{}
|
||||||
|
err := json.NewDecoder(response.Body).Decode(claudeResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
return claude.ConvertToChatOpenai(provider, claudeResponse, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClaudeChatCompleteStrem(provider base.ProviderInterface, request *types.ChatCompletionRequest) requester.HandlerPrefix[string] {
|
||||||
|
chatHandler := &claude.ClaudeStreamHandler{
|
||||||
|
Usage: provider.GetUsage(),
|
||||||
|
Request: request,
|
||||||
|
Prefix: `{"type"`,
|
||||||
|
}
|
||||||
|
|
||||||
|
return chatHandler.HandlerStream
|
||||||
|
}
|
81
providers/bedrock/chat.go
Normal file
81
providers/bedrock/chat.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
package bedrock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/common/requester"
|
||||||
|
"one-api/providers/bedrock/category"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *BedrockProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||||
|
// 发送请求
|
||||||
|
response, errWithCode := p.Send(request)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return nil, errWithCode
|
||||||
|
}
|
||||||
|
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
return p.Category.ResponseChatComplete(p, response, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BedrockProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||||
|
// 发送请求
|
||||||
|
response, errWithCode := p.Send(request)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return nil, errWithCode
|
||||||
|
}
|
||||||
|
|
||||||
|
return RequestStream(response, p.Category.ResponseChatCompleteStrem(p, request))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BedrockProvider) Send(request *types.ChatCompletionRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||||
|
req, errWithCode := p.getChatRequest(request)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return nil, errWithCode
|
||||||
|
}
|
||||||
|
defer req.Body.Close()
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
return p.Requester.SendRequestRaw(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BedrockProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||||
|
var err error
|
||||||
|
p.Category, err = category.GetCategory(request.Model)
|
||||||
|
if err != nil || p.Category.ChatComplete == nil || p.Category.ResponseChatComplete == nil {
|
||||||
|
return nil, common.StringErrorWrapper("bedrock provider not found", "bedrock_err", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return nil, errWithCode
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Stream {
|
||||||
|
url += "-with-response-stream"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求地址
|
||||||
|
fullRequestURL := p.GetFullRequestURL(url, p.Category.ModelName)
|
||||||
|
if fullRequestURL == "" {
|
||||||
|
return nil, common.ErrorWrapper(nil, "invalid_claude_config", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
bedrockRequest, errWithCode := p.Category.ChatComplete(request)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return nil, errWithCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建请求
|
||||||
|
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(bedrockRequest), p.Requester.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
p.Sign(req)
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
27
providers/bedrock/sigv4/LICENSE
Normal file
27
providers/bedrock/sigv4/LICENSE
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
Copyright (c) 2023 Macks C. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Macks C. nor the names of its contributors
|
||||||
|
may be used to endorse or promote products derived from this software
|
||||||
|
without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
36
providers/bedrock/sigv4/const.go
Normal file
36
providers/bedrock/sigv4/const.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package sigv4
|
||||||
|
|
||||||
|
const authorizationHeader = "Authorization"
|
||||||
|
|
||||||
|
// Signature Version 4 (SigV4) Constants
|
||||||
|
const (
|
||||||
|
// SigningAlgorithm is the name of the algorithm used in this package.
|
||||||
|
SigningAlgorithm = "AWS4-HMAC-SHA256"
|
||||||
|
// EmptyStringSHA256 is the hex encoded sha256 value of an empty string.
|
||||||
|
EmptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
|
||||||
|
// UnsignedPayload indicates that the request payload body is unsigned.
|
||||||
|
UnsignedPayload = "UNSIGNED-PAYLOAD"
|
||||||
|
// AmzAlgorithmKey indicates the signing algorithm.
|
||||||
|
AmzAlgorithmKey = "X-Amz-Algorithm"
|
||||||
|
// AmzSecurityTokenKey indicates the security token to be used with temporary
|
||||||
|
// credentials.
|
||||||
|
AmzSecurityTokenKey = "X-Amz-Security-Token"
|
||||||
|
// AmzDateKey is the UTC timestamp for the request in the format YYYYMMDD'T'HHMMSS'Z'.
|
||||||
|
AmzDateKey = "X-Amz-Date"
|
||||||
|
// AmzCredentialKey is the access key ID and credential scope.
|
||||||
|
AmzCredentialKey = "X-Amz-Credential"
|
||||||
|
// AmzSignedHeadersKey is the set of headers signed for the request.
|
||||||
|
AmzSignedHeadersKey = "X-Amz-SignedHeaders"
|
||||||
|
// AmzSignatureKey is the query parameter to store the SigV4 signature.
|
||||||
|
AmzSignatureKey = "X-Amz-Signature"
|
||||||
|
// TimeFormat is the time format to be used in the X-Amz-Date header or query
|
||||||
|
// parameter.
|
||||||
|
TimeFormat = "20060102T150405Z"
|
||||||
|
// ShortTimeFormat is the shorten time format used in the credential scope.
|
||||||
|
ShortTimeFormat = "20060102"
|
||||||
|
// ContentSHAKey is the SHA256 of request body.
|
||||||
|
ContentSHAKey = "X-Amz-Content-Sha256"
|
||||||
|
// StreamingEventsPayload indicates that the request payload body is a signed
|
||||||
|
// event stream.
|
||||||
|
StreamingEventsPayload = "STREAMING-AWS4-HMAC-SHA256-EVENTS"
|
||||||
|
)
|
91
providers/bedrock/sigv4/header.go
Normal file
91
providers/bedrock/sigv4/header.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package sigv4
|
||||||
|
|
||||||
|
// ignoredHeaders is a list of headers that are always ignored during signing.
|
||||||
|
var ignoreHeaders = map[string]struct{}{
|
||||||
|
"Authorization": {},
|
||||||
|
"User-Agent": {},
|
||||||
|
"X-Amzn-Trace-Id": {},
|
||||||
|
// also include lower case canonical versions
|
||||||
|
"authorization": {},
|
||||||
|
"user-agent": {},
|
||||||
|
"x-amzn-trace-id": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// requiredHeaderPrefix are header name prefixes that are mandatory for signing.
|
||||||
|
// If a header has one of these prefixes, it is a mandatory header.
|
||||||
|
var requiredHeaderPrefix = []string{"X-Amz-Object-Lock-", "X-Amz-Meta-"}
|
||||||
|
|
||||||
|
// requiredHeaders is a list of headers that are mandatory for signing.
|
||||||
|
var requiredHeaders = map[string]struct{}{
|
||||||
|
"Cache-Control": {},
|
||||||
|
"Content-Disposition": {},
|
||||||
|
"Content-Encoding": {},
|
||||||
|
"Content-Language": {},
|
||||||
|
"Content-Md5": {},
|
||||||
|
"Content-Type": {},
|
||||||
|
"Expires": {},
|
||||||
|
"If-Match": {},
|
||||||
|
"If-Modified-Since": {},
|
||||||
|
"If-None-Match": {},
|
||||||
|
"If-Unmodified-Since": {},
|
||||||
|
"Range": {},
|
||||||
|
"X-Amz-Acl": {},
|
||||||
|
"X-Amz-Copy-Source": {},
|
||||||
|
"X-Amz-Copy-Source-If-Match": {},
|
||||||
|
"X-Amz-Copy-Source-If-Modified-Since": {},
|
||||||
|
"X-Amz-Copy-Source-If-None-Match": {},
|
||||||
|
"X-Amz-Copy-Source-If-Unmodified-Since": {},
|
||||||
|
"X-Amz-Copy-Source-Range": {},
|
||||||
|
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Algorithm": {},
|
||||||
|
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key": {},
|
||||||
|
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key-Md5": {},
|
||||||
|
"X-Amz-Grant-Full-control": {},
|
||||||
|
"X-Amz-Grant-Read": {},
|
||||||
|
"X-Amz-Grant-Read-Acp": {},
|
||||||
|
"X-Amz-Grant-Write": {},
|
||||||
|
"X-Amz-Grant-Write-Acp": {},
|
||||||
|
"X-Amz-Metadata-Directive": {},
|
||||||
|
"X-Amz-Mfa": {},
|
||||||
|
"X-Amz-Request-Payer": {},
|
||||||
|
"X-Amz-Server-Side-Encryption": {},
|
||||||
|
"X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": {},
|
||||||
|
"X-Amz-Server-Side-Encryption-Customer-Algorithm": {},
|
||||||
|
"X-Amz-Server-Side-Encryption-Customer-Key": {},
|
||||||
|
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": {},
|
||||||
|
"X-Amz-Storage-Class": {},
|
||||||
|
"X-Amz-Website-Redirect-Location": {},
|
||||||
|
"X-Amz-Content-Sha256": {},
|
||||||
|
"X-Amz-Tagging": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// headerPredicate is a function that evaluates whether a header is of the
|
||||||
|
// specific type. For example, whether a header should be ignored during signing.
|
||||||
|
type headerPredicate func(header string) bool
|
||||||
|
|
||||||
|
// isIgnoredHeader returns true if header must be ignored during signing.
|
||||||
|
func isIgnoredHeader(header string) bool {
|
||||||
|
_, ok := ignoreHeaders[header]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// isRequiredHeader returns true if header is mandatory for signing.
|
||||||
|
func isRequiredHeader(header string) bool {
|
||||||
|
_, ok := requiredHeaders[header]
|
||||||
|
if ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, v := range requiredHeaderPrefix {
|
||||||
|
if hasPrefixFold(header, v) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAllowQueryHoisting is a allowed list for Build query headers.
|
||||||
|
func isAllowQueryHoisting(header string) bool {
|
||||||
|
if isRequiredHeader(header) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return hasPrefixFold(header, "X-Amz-")
|
||||||
|
}
|
302
providers/bedrock/sigv4/helper.go
Normal file
302
providers/bedrock/sigv4/helper.go
Normal file
@ -0,0 +1,302 @@
|
|||||||
|
package sigv4
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
awsURLNoEscTable [256]bool
|
||||||
|
awsURLEscTable [256][2]byte
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
for i := 0; i < len(awsURLNoEscTable); i++ {
|
||||||
|
// every char except these must be escaped
|
||||||
|
awsURLNoEscTable[i] = (i >= 'A' && i <= 'Z') ||
|
||||||
|
(i >= 'a' && i <= 'z') ||
|
||||||
|
(i >= '0' && i <= '9') ||
|
||||||
|
i == '-' ||
|
||||||
|
i == '.' ||
|
||||||
|
i == '_' ||
|
||||||
|
i == '~'
|
||||||
|
// %<hex><hex>
|
||||||
|
encoded := fmt.Sprintf("%02X", i)
|
||||||
|
awsURLEscTable[i] = [2]byte{encoded[0], encoded[1]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hmacsha256 computes a HMAC-SHA256 of data given the provided key.
|
||||||
|
func hmacsha256(key, data, buf []byte) []byte {
|
||||||
|
hash := hmac.New(sha256.New, key)
|
||||||
|
hash.Write(data)
|
||||||
|
return hash.Sum(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasPrefixFold tests whether the string s begins with prefix, interpreted as
|
||||||
|
// UTF-8 strings, under Unicode case-folding.
|
||||||
|
func hasPrefixFold(s, prefix string) bool {
|
||||||
|
return len(s) >= len(prefix) &&
|
||||||
|
strings.EqualFold(s[0:len(prefix)], prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isSameDay returns true if a and b are the same date (dd-mm-yyyy).
|
||||||
|
func isSameDay(a, b time.Time) bool {
|
||||||
|
xYear, xMonth, xDay := a.Date()
|
||||||
|
yYear, yMonth, yDay := b.Date()
|
||||||
|
|
||||||
|
if xYear != yYear || xMonth != yMonth {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return xDay == yDay
|
||||||
|
}
|
||||||
|
|
||||||
|
// hostOrURLHost returns r.Host, or if empty, r.URL.Host.
|
||||||
|
func hostOrURLHost(r *http.Request) string {
|
||||||
|
if r.Host != "" {
|
||||||
|
return r.Host
|
||||||
|
}
|
||||||
|
return r.URL.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePort returns the port part of u.Host, without the leading colon. Returns
|
||||||
|
// an empty string if u.Host doesn't contain port.
|
||||||
|
//
|
||||||
|
// Adapted from the Go 1.8 standard library (net/url).
|
||||||
|
func parsePort(hostport string) string {
|
||||||
|
colon := strings.IndexByte(hostport, ':')
|
||||||
|
if colon == -1 || colon == len(hostport)-1 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// take care of ipv6 syntax: [a:b::]:<port>
|
||||||
|
const ipv6Sep = "]:"
|
||||||
|
if i := strings.Index(hostport, ipv6Sep); i != -1 {
|
||||||
|
return hostport[i+len(ipv6Sep):]
|
||||||
|
}
|
||||||
|
if strings.Contains(hostport, "]") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return hostport[colon+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripPort returns Hostname portion of u.Host, i.e. without any port number.
|
||||||
|
//
|
||||||
|
// If hostport is an IPv6 literal with a port number, returns the IPv6 literal
|
||||||
|
// without the square brackets. IPv6 literals may include a zone identifier.
|
||||||
|
//
|
||||||
|
// Adapted from the Go 1.8 standard library (net/url).
|
||||||
|
func stripPort(hostport string) string {
|
||||||
|
colon := strings.IndexByte(hostport, ':')
|
||||||
|
if colon == -1 {
|
||||||
|
return hostport
|
||||||
|
}
|
||||||
|
// ipv6: remove the []
|
||||||
|
if i := strings.IndexByte(hostport, ']'); i != -1 {
|
||||||
|
return strings.TrimPrefix(hostport[:i], "[")
|
||||||
|
}
|
||||||
|
return hostport[:colon]
|
||||||
|
}
|
||||||
|
|
||||||
|
// isDefaultPort returns true if the specified URI is using the standard port
|
||||||
|
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs).
|
||||||
|
func isDefaultPort(scheme, port string) bool {
|
||||||
|
switch strings.ToLower(scheme) {
|
||||||
|
case "http":
|
||||||
|
return port == "80"
|
||||||
|
case "https":
|
||||||
|
return port == "443"
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneURL(u *url.URL) *url.URL {
|
||||||
|
if u == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
u2 := new(url.URL)
|
||||||
|
*u2 = *u
|
||||||
|
if u.User != nil {
|
||||||
|
u2.User = new(url.Userinfo)
|
||||||
|
*u2.User = *u.User
|
||||||
|
}
|
||||||
|
return u2
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeAWSURIPath writes the escaped URI component from the specified URL (using
|
||||||
|
// AWS canonical URI specification) into w. URI component is path without query
|
||||||
|
// string.
|
||||||
|
func writeAWSURIPath(w *bufio.Writer, u *url.URL, encodeSep bool, isEscaped bool) {
|
||||||
|
const schemeSep, pathSep, queryStart = "//", "/", "?"
|
||||||
|
|
||||||
|
var p string
|
||||||
|
if u.Opaque == "" {
|
||||||
|
p = u.EscapedPath()
|
||||||
|
} else {
|
||||||
|
opaque := u.Opaque
|
||||||
|
// discard query string if any
|
||||||
|
if i := strings.Index(opaque, queryStart); i != -1 {
|
||||||
|
opaque = opaque[:i]
|
||||||
|
}
|
||||||
|
// if has scheme separator as prefix, discard it
|
||||||
|
if strings.HasPrefix(opaque, schemeSep) {
|
||||||
|
opaque = opaque[len(schemeSep):]
|
||||||
|
}
|
||||||
|
|
||||||
|
// everything after the first /, including the /
|
||||||
|
if i := strings.Index(opaque, pathSep); i != -1 {
|
||||||
|
p = opaque[i:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p == "" {
|
||||||
|
w.WriteByte('/')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if isEscaped {
|
||||||
|
w.WriteString(p)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop thru first like in https://cs.opensource.google/go/go/+/refs/tags/go1.20.2:/src/net/url/url.go.
|
||||||
|
// It may add ~800ns but we save on memory alloc and catches cases where there
|
||||||
|
// is no need to escape.
|
||||||
|
plen := len(p)
|
||||||
|
strlen := plen
|
||||||
|
for i := 0; i < plen; i++ {
|
||||||
|
c := p[i]
|
||||||
|
if awsURLNoEscTable[c] || (c == '/' && !encodeSep) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
strlen += 2
|
||||||
|
}
|
||||||
|
|
||||||
|
// path already canonical, no need to escape
|
||||||
|
if plen == strlen {
|
||||||
|
w.WriteString(p)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < plen; i++ {
|
||||||
|
c := p[i]
|
||||||
|
if awsURLNoEscTable[c] || (c == '/' && !encodeSep) {
|
||||||
|
w.WriteByte(c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
w.Write([]byte{'%', awsURLEscTable[c][0], awsURLEscTable[c][1]})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeCanonicalQueryParams builds the canonical form of query and writes to w.
|
||||||
|
//
|
||||||
|
// Side effect: query values are sorted after this function returns.
|
||||||
|
func writeCanonicalQueryParams(w *bufio.Writer, query url.Values) {
|
||||||
|
qlen := len(query)
|
||||||
|
if qlen == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make([]string, 0, qlen)
|
||||||
|
for k := range query {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
for i, k := range keys {
|
||||||
|
keyEscaped := strings.Replace(url.QueryEscape(k), "+", "%20", -1)
|
||||||
|
vs := query[k]
|
||||||
|
|
||||||
|
if i != 0 {
|
||||||
|
w.WriteByte('&')
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vs) == 0 {
|
||||||
|
w.WriteString(keyEscaped)
|
||||||
|
w.WriteByte('=')
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(vs)
|
||||||
|
for j, v := range vs {
|
||||||
|
if j != 0 {
|
||||||
|
w.WriteByte('&')
|
||||||
|
}
|
||||||
|
w.WriteString(keyEscaped)
|
||||||
|
w.WriteByte('=')
|
||||||
|
if v != "" {
|
||||||
|
w.WriteString(strings.Replace(url.QueryEscape(v), "+", "%20", -1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeCanonicalString removes leading and trailing whitespaces (as defined by Unicode)
|
||||||
|
// in s, replaces consecutive spaces (' ') in s with a single space, and then
|
||||||
|
// write the result to w.
|
||||||
|
func writeCanonicalString(w *bufio.Writer, s string) {
|
||||||
|
const dblSpace = " "
|
||||||
|
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
|
||||||
|
// bail if str doesn't contain " "
|
||||||
|
j := strings.Index(s, dblSpace)
|
||||||
|
if j < 0 {
|
||||||
|
w.WriteString(s)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteString(s[:j])
|
||||||
|
|
||||||
|
// replace all " " with " " in a performant way
|
||||||
|
var lastIsSpace bool
|
||||||
|
for i, l := j, len(s); i < l; i++ {
|
||||||
|
if s[i] == ' ' {
|
||||||
|
if !lastIsSpace {
|
||||||
|
w.WriteByte(' ')
|
||||||
|
lastIsSpace = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lastIsSpace = false
|
||||||
|
w.WriteByte(s[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type debugHasher struct {
|
||||||
|
buf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dh *debugHasher) Write(b []byte) (int, error) {
|
||||||
|
dh.buf = append(dh.buf, b...)
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dh *debugHasher) Sum(b []byte) []byte {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dh *debugHasher) Reset() {
|
||||||
|
// do nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dh *debugHasher) Size() int {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dh *debugHasher) BlockSize() int {
|
||||||
|
return sha256.BlockSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dh *debugHasher) Println() {
|
||||||
|
fmt.Printf("---%s---\n", dh.buf)
|
||||||
|
}
|
125
providers/bedrock/sigv4/key_deriver.go
Normal file
125
providers/bedrock/sigv4/key_deriver.go
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
package sigv4
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
gotime "time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var credScopeSuffixBytes = []byte{'a', 'w', 's', '4', '_', 'r', 'e', 'q', 'u', 'e', 's', 't'}
|
||||||
|
|
||||||
|
// deriveKey calculates the signing key. See https://docs.aws.amazon.com/general/latest/gr/create-signed-request.html.
|
||||||
|
func deriveKey(secret, service, region string, t Time) []byte {
|
||||||
|
// enc(
|
||||||
|
// enc(
|
||||||
|
// enc(
|
||||||
|
// enc(AWS4<secret>, <short_time>),
|
||||||
|
// <region>),
|
||||||
|
// <service>),
|
||||||
|
// "aws4_request")
|
||||||
|
|
||||||
|
// https://en.wikipedia.org/wiki/HMAC
|
||||||
|
// HMAC_SHA256 produces 32 bytes output
|
||||||
|
|
||||||
|
f1 := len(secret) + 4
|
||||||
|
f2 := f1 + len(t.ShortTimeFormat())
|
||||||
|
f3 := f2 + len(region)
|
||||||
|
f4 := f3 + len(service)
|
||||||
|
|
||||||
|
qs := make([]byte, 0, f4)
|
||||||
|
qs = append(qs, "AWS4"...)
|
||||||
|
qs = append(qs, secret...)
|
||||||
|
qs = append(qs, t.ShortTimeFormat()...)
|
||||||
|
qs = append(qs, region...)
|
||||||
|
qs = append(qs, service...)
|
||||||
|
|
||||||
|
buf := make([]byte, 0, sha256.BlockSize)
|
||||||
|
buf = hmacsha256(qs[:f1], qs[f1:f2], buf)
|
||||||
|
buf = hmacsha256(buf, qs[f2:f3], buf[:0])
|
||||||
|
buf = hmacsha256(buf, qs[f3:], buf[:0])
|
||||||
|
return hmacsha256(buf, credScopeSuffixBytes, buf[:0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// keyDeriver returns a signing key based on parameters such as credentials.
|
||||||
|
type keyDeriver interface {
|
||||||
|
DeriveKey(accessKey, secret, service, region string, sigtime Time) []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// signingKeyDeriver is the default implementation of keyDerivator.
|
||||||
|
type signingKeyDeriver struct {
|
||||||
|
cache derivedKeyCache
|
||||||
|
}
|
||||||
|
|
||||||
|
// newKeyDeriver creates a keyDeriver using the default implementation. The
|
||||||
|
// signing key is cached per region/service, and updated when accessKey changes
|
||||||
|
// or signingTime is not on the same day for that region/service.
|
||||||
|
func newKeyDeriver() keyDeriver {
|
||||||
|
return &signingKeyDeriver{cache: newDerivedKeyCache()}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveKey returns a derived signing key from the given credentials to be used
|
||||||
|
// with SigV4 signing.
|
||||||
|
func (k *signingKeyDeriver) DeriveKey(accessKey, secret, service, region string, sigtime Time) []byte {
|
||||||
|
return k.cache.Get(accessKey, secret, service, region, sigtime)
|
||||||
|
}
|
||||||
|
|
||||||
|
type derivedKeyCache struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
values map[string]derivedKey
|
||||||
|
nowFunc func() gotime.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type derivedKey struct {
|
||||||
|
Date gotime.Time
|
||||||
|
Credential []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDerivedKeyCache() derivedKeyCache {
|
||||||
|
return derivedKeyCache{
|
||||||
|
values: make(map[string]derivedKey),
|
||||||
|
nowFunc: gotime.Now,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns key from cache or creates a new one.
|
||||||
|
func (s *derivedKeyCache) Get(accessKey, secret, service, region string, sigtime Time) []byte {
|
||||||
|
// <accessKey>/<YYYYMMDD>/<region>/<service>
|
||||||
|
key := strings.Join([]string{accessKey, sigtime.ShortTimeFormat(), region, service}, "/")
|
||||||
|
|
||||||
|
s.mutex.RLock()
|
||||||
|
cred, status := s.getFromCache(key)
|
||||||
|
s.mutex.RUnlock()
|
||||||
|
if status == 0 {
|
||||||
|
return cred
|
||||||
|
}
|
||||||
|
|
||||||
|
cred = deriveKey(secret, service, region, sigtime)
|
||||||
|
|
||||||
|
s.mutex.Lock()
|
||||||
|
if status == -1 {
|
||||||
|
delete(s.values, key)
|
||||||
|
}
|
||||||
|
s.values[key] = derivedKey{
|
||||||
|
Date: sigtime.Time,
|
||||||
|
Credential: cred,
|
||||||
|
}
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
return cred
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFromCache returns s.values[key]. Second result is 1 if key was not found,
|
||||||
|
// or -1 if the cached value has expired.
|
||||||
|
func (s *derivedKeyCache) getFromCache(key string) ([]byte, int) {
|
||||||
|
v, ok := s.values[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, 1
|
||||||
|
}
|
||||||
|
// evict from cache if item is a day older than system time
|
||||||
|
if s.nowFunc().Sub(v.Date) > 24*time.Hour {
|
||||||
|
return nil, -1
|
||||||
|
}
|
||||||
|
return v.Credential, 0
|
||||||
|
}
|
36
providers/bedrock/sigv4/sign_time.go
Normal file
36
providers/bedrock/sigv4/sign_time.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package sigv4
|
||||||
|
|
||||||
|
import (
|
||||||
|
gotime "time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Time wraps time.Time to cache its string format result.
|
||||||
|
type Time struct {
|
||||||
|
gotime.Time
|
||||||
|
short string
|
||||||
|
long string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTime creates a new signingTime with the specified time.Time.
|
||||||
|
func NewTime(t gotime.Time) Time {
|
||||||
|
return Time{Time: t.UTC()}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeFormat provides a time formatted in the X-Amz-Date format.
|
||||||
|
func (m *Time) TimeFormat() string {
|
||||||
|
return m.readOrFormat(&m.long, TimeFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShortTimeFormat provides a time formatted in short time format.
|
||||||
|
func (m *Time) ShortTimeFormat() string {
|
||||||
|
return m.readOrFormat(&m.short, ShortTimeFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Time) readOrFormat(target *string, format string) string {
|
||||||
|
if len(*target) > 0 {
|
||||||
|
return *target
|
||||||
|
}
|
||||||
|
v := m.Time.Format(format)
|
||||||
|
*target = v
|
||||||
|
return v
|
||||||
|
}
|
428
providers/bedrock/sigv4/signer.go
Normal file
428
providers/bedrock/sigv4/signer.go
Normal file
@ -0,0 +1,428 @@
|
|||||||
|
package sigv4
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"hash"
|
||||||
|
"net/http"
|
||||||
|
"net/textproto"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPSigner is an AWS SigV4 signer that can sign HTTP requests.
|
||||||
|
type HTTPSigner interface {
|
||||||
|
// Sign AWS v4 requests with the provided payload hash, service name, region
|
||||||
|
// the request is made to, and time the request is signed at. Set sigtime
|
||||||
|
// to the future to create a request that cannot be used until the future time.
|
||||||
|
//
|
||||||
|
// payloadHash is the hex encoded SHA-256 hash of the request payload, and must
|
||||||
|
// not be empty, even if the request has no payload (aka body). If the request
|
||||||
|
// has no payload, use the hex encoded SHA-256 of an empty string, or the constant
|
||||||
|
// EmptyStringSHA256. You can use the utility function ContentSHA256Sum to
|
||||||
|
// calculate the hash of a http.Request body.
|
||||||
|
//
|
||||||
|
// Some services such as Amazon S3 accept alternative values for the payload
|
||||||
|
// hash, such as "UNSIGNED-PAYLOAD" for requests where the body will not be
|
||||||
|
// protected by sigv4. See https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html.
|
||||||
|
//
|
||||||
|
// Sign differs from Presign in that it will sign the request using HTTP headers.
|
||||||
|
// The passed in request r will be modified in place: modified fields include
|
||||||
|
// r.Host and r.Header.
|
||||||
|
Sign(r *http.Request, payloadHash string, sigtime Time) error
|
||||||
|
// Presign is like Sign, but does not modify request r. It returns a copy of
|
||||||
|
// r.URL with additional query parameters that contains signing information.
|
||||||
|
// The URL can be used to recreate an authenticated request without specifying
|
||||||
|
// headers. It also returns http.Header as a second result, which must be
|
||||||
|
// included in the reconstructed request.
|
||||||
|
//
|
||||||
|
// Header hoisting: use WithHeaderHoisting option function to specify whether
|
||||||
|
// headers in request r should be added as query parameters. Some headers cannot
|
||||||
|
// be hoisted, and are returned as the second result.
|
||||||
|
//
|
||||||
|
// Presign will not set the expires time of the presigned request automatically.
|
||||||
|
// To specify the expire duration for a request, add the "X-Amz-Expires" query
|
||||||
|
// parameter on the request with the value as the duration in seconds the
|
||||||
|
// presigned URL should be considered valid for. This parameter is not used
|
||||||
|
// by all AWS services, and is most notable used by Amazon S3 APIs.
|
||||||
|
//
|
||||||
|
// expires := 20*time.Minute
|
||||||
|
// query := req.URL.Query()
|
||||||
|
// query.Set("X-Amz-Expires", strconv.FormatInt(int64(expires/time.Second), 10)
|
||||||
|
// req.URL.RawQuery = query.Encode()
|
||||||
|
Presign(r *http.Request, payloadHash string, sigtime Time) (*url.URL, http.Header, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPSignerOption is an option parameter for HTTPSigner constructor function.
|
||||||
|
type HTTPSignerOption func(HTTPSigner) error
|
||||||
|
|
||||||
|
// ErrInvalidOption means the option parameter is incompatible with the HTTPSigner.
|
||||||
|
var ErrInvalidOption = errors.New("cannot apply option to HTTPSigner")
|
||||||
|
|
||||||
|
// httpV4Signer is the default implementation of HTTPSigner.
|
||||||
|
type httpV4Signer struct {
|
||||||
|
KeyDeriver keyDeriver
|
||||||
|
AccessKey string
|
||||||
|
Secret string
|
||||||
|
SessionToken string
|
||||||
|
Service string
|
||||||
|
Region string
|
||||||
|
HeaderHoisting bool
|
||||||
|
EscapeURLPath bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCredential sets HTTPSigner credential fields.
|
||||||
|
func WithCredential(accessKey, secret, sessionToken string) HTTPSignerOption {
|
||||||
|
return func(signer HTTPSigner) error {
|
||||||
|
if sigv4, ok := signer.(*httpV4Signer); ok {
|
||||||
|
sigv4.AccessKey = accessKey
|
||||||
|
sigv4.Secret = secret
|
||||||
|
sigv4.SessionToken = sessionToken
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ErrInvalidOption
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHeaderHoisting specifies whether HTTPSigner automatically hoist headers.
|
||||||
|
// Default is enabled.
|
||||||
|
func WithHeaderHoisting(enable bool) HTTPSignerOption {
|
||||||
|
return func(signer HTTPSigner) error {
|
||||||
|
if sigv4, ok := signer.(*httpV4Signer); ok {
|
||||||
|
sigv4.HeaderHoisting = enable
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ErrInvalidOption
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithEscapeURLPath specifies whether HTTPSigner automatically escapes URL paths.
|
||||||
|
// Default is enabled.
|
||||||
|
func WithEscapeURLPath(enable bool) HTTPSignerOption {
|
||||||
|
return func(signer HTTPSigner) error {
|
||||||
|
if sigv4, ok := signer.(*httpV4Signer); ok {
|
||||||
|
sigv4.EscapeURLPath = enable
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ErrInvalidOption
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRegionService sets HTTPSigner region and service fields.
|
||||||
|
func WithRegionService(region, service string) HTTPSignerOption {
|
||||||
|
return func(signer HTTPSigner) error {
|
||||||
|
if sigv4, ok := signer.(*httpV4Signer); ok {
|
||||||
|
sigv4.Region = region
|
||||||
|
sigv4.Service = service
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ErrInvalidOption
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a HTTPSigner.
|
||||||
|
func New(opts ...HTTPSignerOption) (HTTPSigner, error) {
|
||||||
|
sigv4 := &httpV4Signer{
|
||||||
|
KeyDeriver: newKeyDeriver(),
|
||||||
|
EscapeURLPath: true,
|
||||||
|
HeaderHoisting: true,
|
||||||
|
}
|
||||||
|
for _, o := range opts {
|
||||||
|
if o == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := o(sigv4); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sigv4, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign implements HTTPSigner.
|
||||||
|
func (s *httpV4Signer) Sign(r *http.Request, payloadHash string, sigtime Time) error {
|
||||||
|
if payloadHash == "" {
|
||||||
|
var err error
|
||||||
|
payloadHash, err = ContentSHA256Sum(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add mandatory headers to r.Header
|
||||||
|
setRequiredSigningHeaders(r.Header, sigtime, s.SessionToken)
|
||||||
|
// remove port in r.Host if any
|
||||||
|
r.Host = sanitizeHostForHeader(r)
|
||||||
|
|
||||||
|
// parse URL query only once
|
||||||
|
query := r.URL.Query()
|
||||||
|
// sigBuf is used to act as a sha256 hash buffer
|
||||||
|
sigBuf := make([]byte, 0, sha256.Size)
|
||||||
|
|
||||||
|
//hasher := &debugHasher{}
|
||||||
|
hasher := sha256.New()
|
||||||
|
reqhash, signedHeaderStr := canonicalRequestHash(hasher, r, r.Header, query,
|
||||||
|
r.Host, payloadHash, s.EscapeURLPath, false, sigBuf)
|
||||||
|
|
||||||
|
credentialScope := strings.Join([]string{
|
||||||
|
sigtime.ShortTimeFormat(),
|
||||||
|
s.Region,
|
||||||
|
s.Service,
|
||||||
|
"aws4_request",
|
||||||
|
}, "/")
|
||||||
|
|
||||||
|
keyBytes := s.KeyDeriver.DeriveKey(s.AccessKey, s.Secret, s.Service,
|
||||||
|
s.Region, sigtime)
|
||||||
|
sigHasher := hmac.New(sha256.New, keyBytes)
|
||||||
|
signature := authorizationSignature(sigHasher, sigtime, credentialScope, reqhash, sigBuf)
|
||||||
|
|
||||||
|
writeAuthorizationHeader(r.Header, s.AccessKey+"/"+credentialScope,
|
||||||
|
signedHeaderStr, signature)
|
||||||
|
|
||||||
|
// done
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Presign implements HTTPSigner.
|
||||||
|
func (s *httpV4Signer) Presign(r *http.Request, payloadHash string, sigtime Time) (*url.URL, http.Header, error) {
|
||||||
|
if payloadHash == "" {
|
||||||
|
var err error
|
||||||
|
payloadHash, err = ContentSHA256Sum(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
query := r.URL.Query()
|
||||||
|
setRequiredSigningQuery(query, sigtime, s.SessionToken)
|
||||||
|
// sort each query key's values
|
||||||
|
for key := range query {
|
||||||
|
sort.Strings(query[key])
|
||||||
|
}
|
||||||
|
|
||||||
|
credentialScope := strings.Join([]string{
|
||||||
|
sigtime.ShortTimeFormat(),
|
||||||
|
s.Region,
|
||||||
|
s.Service,
|
||||||
|
"aws4_request",
|
||||||
|
}, "/")
|
||||||
|
credentialStr := s.AccessKey + "/" + credentialScope
|
||||||
|
query.Set(AmzCredentialKey, credentialStr)
|
||||||
|
|
||||||
|
var headersLeft http.Header
|
||||||
|
if s.HeaderHoisting {
|
||||||
|
headersLeft = make(http.Header, len(r.Header))
|
||||||
|
for k, v := range r.Header {
|
||||||
|
if isAllowQueryHoisting(k) {
|
||||||
|
query[k] = v
|
||||||
|
} else {
|
||||||
|
headersLeft[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sigBuf is used to act as a sha256 hash buffer
|
||||||
|
sigBuf := make([]byte, 0, sha256.Size)
|
||||||
|
|
||||||
|
hasher := sha256.New()
|
||||||
|
reqhash, signedHeaderStr := canonicalRequestHash(hasher, r, headersLeft,
|
||||||
|
query, sanitizeHostForHeader(r), payloadHash, s.EscapeURLPath, true, sigBuf)
|
||||||
|
|
||||||
|
keyBytes := s.KeyDeriver.DeriveKey(s.AccessKey, s.Secret, s.Service,
|
||||||
|
s.Region, sigtime)
|
||||||
|
sigHasher := hmac.New(sha256.New, keyBytes)
|
||||||
|
signature := authorizationSignature(sigHasher, sigtime, credentialScope, reqhash, sigBuf)
|
||||||
|
query.Set(AmzSignatureKey, signature)
|
||||||
|
|
||||||
|
u := cloneURL(r.URL)
|
||||||
|
u.RawQuery = strings.Replace(query.Encode(), "+", "%20", -1)
|
||||||
|
|
||||||
|
// For the signed headers we canonicalize the header keys in the returned map.
|
||||||
|
// This avoids situations where standard library can sometimes add double
|
||||||
|
// headers. For example, the standard library will set the Host header,
|
||||||
|
// even if it is present in lower-case form.
|
||||||
|
signedHeader := strings.Split(signedHeaderStr, ";")
|
||||||
|
canonHeader := make(http.Header, len(signedHeader))
|
||||||
|
for _, k := range signedHeader {
|
||||||
|
canonKey := textproto.CanonicalMIMEHeaderKey(k)
|
||||||
|
switch k {
|
||||||
|
case "host":
|
||||||
|
canonHeader[canonKey] = []string{sanitizeHostForHeader(r)}
|
||||||
|
case "content-length":
|
||||||
|
canonHeader[canonKey] = []string{strconv.FormatInt(r.ContentLength, 10)}
|
||||||
|
default:
|
||||||
|
canonHeader[canonKey] = append(canonHeader[canonKey], headersLeft[http.CanonicalHeaderKey(k)]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return u, canonHeader, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// authorizationSignature returns `sig` as documented in step 4 of algorithm
|
||||||
|
// documentation. key is hSig in step 4. It calculates the result of step 3
|
||||||
|
// internally.
|
||||||
|
func authorizationSignature(hasher hash.Hash, sigtime Time, credScope, requestHash string, buf []byte) string {
|
||||||
|
w := bufio.NewWriterSize(hasher, sha256.BlockSize)
|
||||||
|
|
||||||
|
w.WriteString(SigningAlgorithm)
|
||||||
|
w.WriteByte('\n')
|
||||||
|
w.WriteString(sigtime.TimeFormat())
|
||||||
|
w.WriteByte('\n')
|
||||||
|
w.WriteString(credScope)
|
||||||
|
w.WriteByte('\n')
|
||||||
|
w.WriteString(requestHash)
|
||||||
|
|
||||||
|
w.Flush() // VERY IMPORTANT! Don't forget to flush remaining buffer
|
||||||
|
//hasher.Println()
|
||||||
|
return hex.EncodeToString(hasher.Sum(buf[:0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// canonicalRequestHash returns the hex-encoded sha256 sum of the canonical
|
||||||
|
// request string. Refer to step 2 of algorithm documentation. Expect hasher to
|
||||||
|
// be sha256.New.
|
||||||
|
func canonicalRequestHash(
|
||||||
|
hasher hash.Hash, r *http.Request, headers http.Header, query url.Values,
|
||||||
|
hostname, hashcode string, escapeURL, isPresign bool, buf []byte,
|
||||||
|
) (string, string) {
|
||||||
|
w := bufio.NewWriterSize(hasher, sha256.BlockSize)
|
||||||
|
|
||||||
|
signedHeaders := make([]string, 0, len(headers)+2)
|
||||||
|
signedHeaders = append(signedHeaders, "host")
|
||||||
|
if r.ContentLength > 0 {
|
||||||
|
signedHeaders = append(signedHeaders, "content-length")
|
||||||
|
}
|
||||||
|
for k := range headers {
|
||||||
|
if strings.EqualFold(k, "content-length") || strings.EqualFold(k, "host") || isIgnoredHeader(k) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
signedHeaders = append(signedHeaders, strings.ToLower(k))
|
||||||
|
}
|
||||||
|
sort.Strings(signedHeaders)
|
||||||
|
signedHeaderStr := strings.Join(signedHeaders, ";")
|
||||||
|
|
||||||
|
// for presigned requests, we need to add X-Amz-SignedHeaders to calculate the
|
||||||
|
// correct hash
|
||||||
|
if isPresign {
|
||||||
|
query.Set(AmzSignedHeadersKey, signedHeaderStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// <METHOD>\n<URI>\n<QUERY>\n<HEADERS>\n<SIGNED_HEADERS>\n<PAYLOAD_HASH>
|
||||||
|
|
||||||
|
// HTTP_METHOD
|
||||||
|
w.WriteString(r.Method)
|
||||||
|
w.WriteByte('\n')
|
||||||
|
// CANONICAL_URI
|
||||||
|
writeAWSURIPath(w, r.URL, false, !escapeURL)
|
||||||
|
w.WriteByte('\n')
|
||||||
|
// CANONICAL_QUERY_PARAMS
|
||||||
|
writeCanonicalQueryParams(w, query)
|
||||||
|
w.WriteByte('\n')
|
||||||
|
// CANONICAL_HEADERS
|
||||||
|
for _, head := range signedHeaders {
|
||||||
|
switch head {
|
||||||
|
case "host":
|
||||||
|
w.WriteString(head)
|
||||||
|
w.WriteByte(':')
|
||||||
|
writeCanonicalString(w, hostname)
|
||||||
|
w.WriteByte('\n')
|
||||||
|
case "content-length":
|
||||||
|
w.WriteString(head)
|
||||||
|
w.WriteByte(':')
|
||||||
|
w.WriteString(strconv.FormatInt(r.ContentLength, 10))
|
||||||
|
w.WriteByte('\n')
|
||||||
|
default:
|
||||||
|
w.WriteString(head)
|
||||||
|
w.WriteByte(':')
|
||||||
|
values := headers[http.CanonicalHeaderKey(head)]
|
||||||
|
for i, v := range values {
|
||||||
|
if i != 0 {
|
||||||
|
w.WriteByte(',')
|
||||||
|
}
|
||||||
|
writeCanonicalString(w, v)
|
||||||
|
}
|
||||||
|
w.WriteByte('\n')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteByte('\n')
|
||||||
|
// SIGNED_HEADERS
|
||||||
|
w.WriteString(signedHeaderStr)
|
||||||
|
w.WriteByte('\n')
|
||||||
|
// PAYLOAD_HASH
|
||||||
|
w.WriteString(hashcode)
|
||||||
|
|
||||||
|
w.Flush() // VERY IMPORTANT! Don't forget to flush remaining buffer
|
||||||
|
//hasher.Println()
|
||||||
|
return hex.EncodeToString(hasher.Sum(buf[:0])), signedHeaderStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeAuthorizationHeader writes the Authorization header into header:
|
||||||
|
//
|
||||||
|
// AWS4-HMAC-SHA256 Credential=<cred>, SignedHeaders=<signed_headers>, Signature=<sig>
|
||||||
|
func writeAuthorizationHeader(headers http.Header, credentialStr, signedHeaders, signature string) {
|
||||||
|
const credentialPrefix = "Credential="
|
||||||
|
const signedHeadersPrefix = "SignedHeaders="
|
||||||
|
const signaturePrefix = "Signature="
|
||||||
|
const commaSpace = ", "
|
||||||
|
|
||||||
|
var parts strings.Builder
|
||||||
|
parts.Grow(len(SigningAlgorithm) + 1 +
|
||||||
|
len(credentialPrefix) + len(credentialStr) + 2 +
|
||||||
|
len(signedHeadersPrefix) + len(signedHeaders) + 2 +
|
||||||
|
len(signaturePrefix) + len(signature))
|
||||||
|
|
||||||
|
parts.WriteString(SigningAlgorithm)
|
||||||
|
parts.WriteRune(' ')
|
||||||
|
parts.WriteString(credentialPrefix)
|
||||||
|
parts.WriteString(credentialStr)
|
||||||
|
parts.WriteString(commaSpace)
|
||||||
|
parts.WriteString(signedHeadersPrefix)
|
||||||
|
parts.WriteString(signedHeaders)
|
||||||
|
parts.WriteString(commaSpace)
|
||||||
|
parts.WriteString(signaturePrefix)
|
||||||
|
parts.WriteString(signature)
|
||||||
|
|
||||||
|
headers[authorizationHeader] = append(headers[authorizationHeader][:0],
|
||||||
|
parts.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// helpers
|
||||||
|
|
||||||
|
// sanitizeHostForHeader is like hostOrURLHost, but without port if port is the
|
||||||
|
// default port for the scheme. For example, it removes ":80" suffix if the scheme
|
||||||
|
// is "http".
|
||||||
|
func sanitizeHostForHeader(r *http.Request) string {
|
||||||
|
host := hostOrURLHost(r)
|
||||||
|
port := parsePort(host)
|
||||||
|
if port != "" && isDefaultPort(r.URL.Scheme, port) {
|
||||||
|
return stripPort(host)
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
// setRequiredSigningHeaders modifies headers: sets X-Amz-Date to sigtime, and
|
||||||
|
// if credToken is non-empty, set X-Amz-Security-Token to credToken. This function
|
||||||
|
// overwrites existing headers values with the same key.
|
||||||
|
func setRequiredSigningHeaders(headers http.Header, sigtime Time, sessionToken string) {
|
||||||
|
amzDate := sigtime.TimeFormat()
|
||||||
|
headers[AmzDateKey] = append(headers[AmzDateKey][:0], amzDate)
|
||||||
|
if sessionToken != "" {
|
||||||
|
headers[AmzSecurityTokenKey] = append(headers[AmzSecurityTokenKey][:0],
|
||||||
|
sessionToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setRequiredSigningQuery is like setRequiredSigningHeaders, but modifies
|
||||||
|
// query values. This is used for presign requests.
|
||||||
|
func setRequiredSigningQuery(query url.Values, sigtime Time, sessionToken string) {
|
||||||
|
query.Set(AmzAlgorithmKey, SigningAlgorithm)
|
||||||
|
|
||||||
|
amzDate := sigtime.TimeFormat()
|
||||||
|
query.Set(AmzDateKey, amzDate)
|
||||||
|
|
||||||
|
if sessionToken != "" {
|
||||||
|
query.Set(AmzSecurityTokenKey, sessionToken)
|
||||||
|
}
|
||||||
|
}
|
26
providers/bedrock/sigv4/util.go
Normal file
26
providers/bedrock/sigv4/util.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package sigv4
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ContentSHA256Sum calculates the hex-encoded SHA256 checksum of r.Body. It returns
|
||||||
|
// EmptyStringSHA256 if r.Body is nil, r.Method is TRACE or r.ContentLength is zero.
|
||||||
|
// Returns non-nil error if r.Body cannot be read.
|
||||||
|
func ContentSHA256Sum(r *http.Request) (string, error) {
|
||||||
|
// We need to check r.Body is non-nil, because io.Copy(dst, nil) panics.
|
||||||
|
// This is not documented in https://pkg.go.dev/io#Copy.
|
||||||
|
if r.Method == http.MethodTrace || r.ContentLength == 0 || r.Body == nil {
|
||||||
|
return EmptyStringSHA256, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
h := sha256.New()
|
||||||
|
_, err := io.Copy(h, r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(h.Sum(nil)), nil
|
||||||
|
}
|
128
providers/bedrock/stream_reader.go
Normal file
128
providers/bedrock/stream_reader.go
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
package bedrock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/common/requester"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi"
|
||||||
|
"github.com/aws/smithy-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
type streamReader[T any] struct {
|
||||||
|
reader *bufio.Reader
|
||||||
|
response *http.Response
|
||||||
|
|
||||||
|
handlerPrefix requester.HandlerPrefix[T]
|
||||||
|
|
||||||
|
DataChan chan T
|
||||||
|
ErrChan chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stream *streamReader[T]) Recv() (<-chan T, <-chan error) {
|
||||||
|
go stream.processLines()
|
||||||
|
|
||||||
|
return stream.DataChan, stream.ErrChan
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:gocognit
|
||||||
|
func (stream *streamReader[T]) processLines() {
|
||||||
|
decode := eventstream.NewDecoder()
|
||||||
|
payloadBuf := make([]byte, 0*1024)
|
||||||
|
for {
|
||||||
|
payloadBuf = payloadBuf[0:0]
|
||||||
|
messgae, readErr := decode.Decode(stream.reader, payloadBuf)
|
||||||
|
if readErr != nil {
|
||||||
|
stream.ErrChan <- readErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
line, err := stream.deserializeEventMessage(&messgae)
|
||||||
|
if err != nil {
|
||||||
|
stream.ErrChan <- common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.handlerPrefix(&line, stream.DataChan, stream.ErrChan)
|
||||||
|
|
||||||
|
if line == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(line, requester.StreamClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stream *streamReader[T]) Close() {
|
||||||
|
stream.response.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stream *streamReader[T]) deserializeEventMessage(msg *eventstream.Message) ([]byte, error) {
|
||||||
|
messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader)
|
||||||
|
if messageType == nil {
|
||||||
|
return nil, fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch messageType.String() {
|
||||||
|
case eventstreamapi.EventMessageType:
|
||||||
|
var v BedrockResponseStream
|
||||||
|
if err := json.Unmarshal(msg.Payload, &v); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
buffer, err := base64.StdEncoding.DecodeString(v.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buffer, nil
|
||||||
|
|
||||||
|
case eventstreamapi.ExceptionMessageType:
|
||||||
|
exceptionType := msg.Headers.Get(eventstreamapi.ExceptionTypeHeader)
|
||||||
|
return nil, errors.New("Exception message :" + exceptionType.String())
|
||||||
|
|
||||||
|
case eventstreamapi.ErrorMessageType:
|
||||||
|
errorCode := "UnknownError"
|
||||||
|
errorMessage := errorCode
|
||||||
|
if header := msg.Headers.Get(eventstreamapi.ErrorCodeHeader); header != nil {
|
||||||
|
errorCode = header.String()
|
||||||
|
}
|
||||||
|
if header := msg.Headers.Get(eventstreamapi.ErrorMessageHeader); header != nil {
|
||||||
|
errorMessage = header.String()
|
||||||
|
}
|
||||||
|
return nil, &smithy.GenericAPIError{
|
||||||
|
Code: errorCode,
|
||||||
|
Message: errorMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, errors.New("bedrock stream unknown error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RequestStream[T any](resp *http.Response, handlerPrefix requester.HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) {
|
||||||
|
// 如果返回的头是json格式 说明有错误
|
||||||
|
if strings.Contains(resp.Header.Get("Content-Type"), "application/json") {
|
||||||
|
return nil, requester.HandleErrorResp(resp, requestErrorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := &streamReader[T]{
|
||||||
|
reader: bufio.NewReader(resp.Body),
|
||||||
|
response: resp,
|
||||||
|
handlerPrefix: handlerPrefix,
|
||||||
|
|
||||||
|
DataChan: make(chan T),
|
||||||
|
ErrChan: make(chan error),
|
||||||
|
}
|
||||||
|
|
||||||
|
return stream, nil
|
||||||
|
}
|
11
providers/bedrock/type.go
Normal file
11
providers/bedrock/type.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package bedrock
|
||||||
|
|
||||||
|
const awsService = "bedrock"
|
||||||
|
|
||||||
|
type BedrockError struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BedrockResponseStream struct {
|
||||||
|
Bytes string `json:"bytes"`
|
||||||
|
}
|
@ -8,13 +8,17 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/image"
|
"one-api/common/image"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/providers/base"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
|
||||||
)
|
)
|
||||||
|
|
||||||
type claudeStreamHandler struct {
|
type ClaudeStreamHandler struct {
|
||||||
Usage *types.Usage
|
Usage *types.Usage
|
||||||
Request *types.ChatCompletionRequest
|
Request *types.ChatCompletionRequest
|
||||||
|
Prefix string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -31,7 +35,7 @@ func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionReque
|
|||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.convertToChatOpenai(claudeResponse, request)
|
return ConvertToChatOpenai(p, claudeResponse, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -47,12 +51,15 @@ func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletio
|
|||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
}
|
}
|
||||||
|
|
||||||
chatHandler := &claudeStreamHandler{
|
chatHandler := &ClaudeStreamHandler{
|
||||||
Usage: p.Usage,
|
Usage: p.Usage,
|
||||||
Request: request,
|
Request: request,
|
||||||
|
Prefix: `data: {"type"`,
|
||||||
}
|
}
|
||||||
|
|
||||||
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
eventstream.NewDecoder()
|
||||||
|
|
||||||
|
return requester.RequestStream(p.Requester, resp, chatHandler.HandlerStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -72,7 +79,7 @@ func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*
|
|||||||
headers["Accept"] = "text/event-stream"
|
headers["Accept"] = "text/event-stream"
|
||||||
}
|
}
|
||||||
|
|
||||||
claudeRequest, errWithCode := convertFromChatOpenai(request)
|
claudeRequest, errWithCode := ConvertFromChatOpenai(request)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
}
|
}
|
||||||
@ -86,7 +93,7 @@ func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest, *types.OpenAIErrorWithStatusCode) {
|
func ConvertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest, *types.OpenAIErrorWithStatusCode) {
|
||||||
claudeRequest := ClaudeRequest{
|
claudeRequest := ClaudeRequest{
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Messages: []Message{},
|
Messages: []Message{},
|
||||||
@ -142,7 +149,7 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest
|
|||||||
return &claudeRequest, nil
|
return &claudeRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
func ConvertToChatOpenai(provider base.ProviderInterface, response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
error := errorHandle(&response.Error)
|
error := errorHandle(&response.Error)
|
||||||
if error != nil {
|
if error != nil {
|
||||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||||
@ -182,21 +189,24 @@ func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *
|
|||||||
openaiResponse.Usage.CompletionTokens = completionTokens
|
openaiResponse.Usage.CompletionTokens = completionTokens
|
||||||
openaiResponse.Usage.TotalTokens = promptTokens + completionTokens
|
openaiResponse.Usage.TotalTokens = promptTokens + completionTokens
|
||||||
|
|
||||||
*p.Usage = *openaiResponse.Usage
|
usage := provider.GetUsage()
|
||||||
|
*usage = *openaiResponse.Usage
|
||||||
|
|
||||||
return openaiResponse, nil
|
return openaiResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 转换为OpenAI聊天流式请求体
|
// 转换为OpenAI聊天流式请求体
|
||||||
func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
func (h *ClaudeStreamHandler) HandlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||||
// 如果rawLine 前缀不为data:,则直接返回
|
// 如果rawLine 前缀不为data:,则直接返回
|
||||||
if !strings.HasPrefix(string(*rawLine), `data: {"type"`) {
|
if !strings.HasPrefix(string(*rawLine), h.Prefix) {
|
||||||
*rawLine = nil
|
*rawLine = nil
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 去除前缀
|
if strings.HasPrefix(string(*rawLine), "data: ") {
|
||||||
*rawLine = (*rawLine)[6:]
|
// 去除前缀
|
||||||
|
*rawLine = (*rawLine)[6:]
|
||||||
|
}
|
||||||
|
|
||||||
var claudeResponse ClaudeStreamResponse
|
var claudeResponse ClaudeStreamResponse
|
||||||
err := json.Unmarshal(*rawLine, &claudeResponse)
|
err := json.Unmarshal(*rawLine, &claudeResponse)
|
||||||
@ -235,7 +245,7 @@ func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStreamResponse, dataChan chan string) {
|
func (h *ClaudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStreamResponse, dataChan chan string) {
|
||||||
choice := types.ChatCompletionStreamChoice{
|
choice := types.ChatCompletionStreamChoice{
|
||||||
Index: claudeResponse.Index,
|
Index: claudeResponse.Index,
|
||||||
}
|
}
|
||||||
|
@ -32,7 +32,7 @@ type Message struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeRequest struct {
|
type ClaudeRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model,omitempty"`
|
||||||
System string `json:"system,omitempty"`
|
System string `json:"system,omitempty"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
MaxTokens int `json:"max_tokens"`
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"one-api/providers/baichuan"
|
"one-api/providers/baichuan"
|
||||||
"one-api/providers/baidu"
|
"one-api/providers/baidu"
|
||||||
"one-api/providers/base"
|
"one-api/providers/base"
|
||||||
|
"one-api/providers/bedrock"
|
||||||
"one-api/providers/claude"
|
"one-api/providers/claude"
|
||||||
"one-api/providers/closeai"
|
"one-api/providers/closeai"
|
||||||
"one-api/providers/deepseek"
|
"one-api/providers/deepseek"
|
||||||
@ -62,6 +63,7 @@ func init() {
|
|||||||
providerFactories[common.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{}
|
providerFactories[common.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{}
|
||||||
providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{}
|
providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{}
|
||||||
providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{}
|
providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{}
|
||||||
|
providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,6 +101,12 @@ export const CHANNEL_OPTIONS = {
|
|||||||
value: 31,
|
value: 31,
|
||||||
color: 'primary'
|
color: 'primary'
|
||||||
},
|
},
|
||||||
|
32: {
|
||||||
|
key: 32,
|
||||||
|
text: 'Amazon Bedrock',
|
||||||
|
value: 32,
|
||||||
|
color: 'orange'
|
||||||
|
},
|
||||||
24: {
|
24: {
|
||||||
key: 24,
|
key: 24,
|
||||||
text: 'Azure Speech',
|
text: 'Azure Speech',
|
||||||
|
@ -60,8 +60,15 @@ const typeConfig = {
|
|||||||
},
|
},
|
||||||
14: {
|
14: {
|
||||||
input: {
|
input: {
|
||||||
models: ['claude-instant-1.2', 'claude-2.0', 'claude-2.1', 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'],
|
models: [
|
||||||
test_model: 'claude-3-sonnet-20240229'
|
'claude-instant-1.2',
|
||||||
|
'claude-2.0',
|
||||||
|
'claude-2.1',
|
||||||
|
'claude-3-opus-20240229',
|
||||||
|
'claude-3-sonnet-20240229',
|
||||||
|
'claude-3-haiku-20240307'
|
||||||
|
],
|
||||||
|
test_model: 'claude-3-haiku-20240307'
|
||||||
},
|
},
|
||||||
modelGroup: 'Anthropic'
|
modelGroup: 'Anthropic'
|
||||||
},
|
},
|
||||||
@ -202,6 +209,23 @@ const typeConfig = {
|
|||||||
test_model: 'llama2-7b-2048'
|
test_model: 'llama2-7b-2048'
|
||||||
},
|
},
|
||||||
modelGroup: 'Groq'
|
modelGroup: 'Groq'
|
||||||
|
},
|
||||||
|
32: {
|
||||||
|
input: {
|
||||||
|
models: [
|
||||||
|
'claude-instant-1.2',
|
||||||
|
'claude-2.0',
|
||||||
|
'claude-2.1',
|
||||||
|
'claude-3-opus-20240229',
|
||||||
|
'claude-3-sonnet-20240229',
|
||||||
|
'claude-3-haiku-20240307'
|
||||||
|
],
|
||||||
|
test_model: 'claude-3-haiku-20240307'
|
||||||
|
},
|
||||||
|
prompt: {
|
||||||
|
key: '按照如下格式输入:Region|AccessKeyID|SecretAccessKey|SessionToken 其中SessionToken可不填空'
|
||||||
|
},
|
||||||
|
modelGroup: 'Anthropic'
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user