✨ feat: support amazon bedrock anthropic (#114)
* 🚧 WIP: bedrock * ✨ feat: support amazon bedrock anthropic
This commit is contained in:
parent
6f76007292
commit
b81808e839
@ -198,6 +198,7 @@ const (
|
||||
ChannelTypeMoonshot = 29
|
||||
ChannelTypeMistral = 30
|
||||
ChannelTypeGroq = 31
|
||||
ChannelTypeBedrock = 32
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
@ -232,7 +233,8 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.deepseek.com", //28
|
||||
"https://api.moonshot.cn", //29
|
||||
"https://api.mistral.ai", //30
|
||||
"https://api.groq.com/openai", //30
|
||||
"https://api.groq.com/openai", //31
|
||||
"", //32
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -89,10 +89,14 @@ func init() {
|
||||
// $0.80/million tokens $2.40/million tokens
|
||||
"claude-instant-1.2": {[]float64{0.4, 1.2}, ChannelTypeAnthropic},
|
||||
// $8.00/million tokens $24.00/million tokens
|
||||
"claude-2.0": {[]float64{4, 12}, ChannelTypeAnthropic},
|
||||
"claude-2.1": {[]float64{4, 12}, ChannelTypeAnthropic},
|
||||
"claude-3-opus-20240229": {[]float64{7.5, 22.5}, ChannelTypeAnthropic},
|
||||
"claude-2.0": {[]float64{4, 12}, ChannelTypeAnthropic},
|
||||
"claude-2.1": {[]float64{4, 12}, ChannelTypeAnthropic},
|
||||
// $15 / M $75 / M
|
||||
"claude-3-opus-20240229": {[]float64{7.5, 22.5}, ChannelTypeAnthropic},
|
||||
// $3 / M $15 / M
|
||||
"claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, ChannelTypeAnthropic},
|
||||
// $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens
|
||||
"claude-3-haiku-20240307": {[]float64{0.125, 0.625}, ChannelTypeAnthropic},
|
||||
|
||||
// ¥0.004 / 1k tokens ¥0.008 / 1k tokens
|
||||
"ERNIE-Speed": {[]float64{0.2857, 0.5714}, ChannelTypeBaidu},
|
||||
|
9
go.mod
9
go.mod
@ -4,6 +4,7 @@ module one-api
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1
|
||||
github.com/gin-contrib/cors v1.4.0
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
@ -24,8 +25,10 @@ require (
|
||||
gorm.io/gorm v1.25.0
|
||||
)
|
||||
|
||||
require github.com/aws/smithy-go v1.20.1 // indirect
|
||||
|
||||
require (
|
||||
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24 // indirect
|
||||
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
@ -46,7 +49,7 @@ require (
|
||||
github.com/jackc/pgx/v5 v5.3.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/joho/godotenv v1.5.1 // indirect
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
@ -59,7 +62,7 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/net v0.19.0 // indirect
|
||||
golang.org/x/net v0.19.0
|
||||
golang.org/x/sys v0.15.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
|
14
go.sum
14
go.sum
@ -1,5 +1,9 @@
|
||||
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24 h1:1T7RcpzlldaJ3qpZi0lNg/lBsfPCK+8n8Wc+R8EhAkU=
|
||||
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24/go.mod h1:kL1v4iIjlalwm3gCYGvF4NLa3hs+aKEfRkNJvj4aoDU=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1 h1:gTK2uhtAPtFcdRRJilZPx8uJLL2J85xK11nKtWL0wfU=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1/go.mod h1:sxpLb+nZk7tIfCWChfd+h4QwHNUR57d8hA1cleTkjJo=
|
||||
github.com/aws/smithy-go v1.20.1 h1:4SZlSlMr36UEqC7XOyRVb27XMeZubNcBNN+9IgEPIQw=
|
||||
github.com/aws/smithy-go v1.20.1/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
@ -49,8 +53,6 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
@ -58,6 +60,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
@ -111,6 +115,7 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@ -161,8 +166,6 @@ golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq
|
||||
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
|
||||
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
|
||||
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@ -203,14 +206,13 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco=
|
||||
gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04=
|
||||
gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k=
|
||||
gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
|
||||
gorm.io/driver/mysql v1.4.7 h1:rY46lkCspzGHn7+IYsNpSfEv9tA+SU4SkkB+GFX125Y=
|
||||
gorm.io/driver/mysql v1.4.7/go.mod h1:SxzItlnT1cb6e1e4ZRpgJN2VYtcqJgqnHxWr4wsP8oc=
|
||||
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
|
||||
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
|
||||
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
|
||||
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
|
||||
gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0=
|
||||
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
|
||||
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
|
||||
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
|
||||
|
127
providers/bedrock/base.go
Normal file
127
providers/bedrock/base.go
Normal file
@ -0,0 +1,127 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"one-api/providers/bedrock/category"
|
||||
"one-api/providers/bedrock/sigv4"
|
||||
)
|
||||
|
||||
type BedrockProviderFactory struct{}
|
||||
|
||||
// 创建 BedrockProvider
|
||||
func (f BedrockProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
|
||||
bedrockProvider := &BedrockProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
|
||||
},
|
||||
}
|
||||
|
||||
getKeyConfig(bedrockProvider)
|
||||
|
||||
return bedrockProvider
|
||||
}
|
||||
|
||||
type BedrockProvider struct {
|
||||
base.BaseProvider
|
||||
Region string
|
||||
AccessKeyID string
|
||||
SecretAccessKey string
|
||||
SessionToken string
|
||||
Category *category.Category
|
||||
}
|
||||
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://bedrock-runtime.%s.amazonaws.com",
|
||||
ChatCompletions: "/model/%s/invoke",
|
||||
}
|
||||
}
|
||||
|
||||
// 请求错误处理
|
||||
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
bedrockError := &BedrockError{}
|
||||
err := json.NewDecoder(resp.Body).Decode(bedrockError)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorHandle(bedrockError)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func errorHandle(bedrockError *BedrockError) *types.OpenAIError {
|
||||
if bedrockError.Message == "" {
|
||||
return nil
|
||||
}
|
||||
return &types.OpenAIError{
|
||||
Message: bedrockError.Message,
|
||||
Type: "Bedrock Error",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BedrockProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
return fmt.Sprintf(baseURL+requestURL, p.Region, modelName)
|
||||
}
|
||||
|
||||
func (p *BedrockProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
headers["Accept"] = "*/*"
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
func getKeyConfig(bedrock *BedrockProvider) {
|
||||
keys := strings.Split(bedrock.Channel.Key, "|")
|
||||
if len(keys) < 3 {
|
||||
return
|
||||
}
|
||||
|
||||
bedrock.Region = keys[0]
|
||||
bedrock.AccessKeyID = keys[1]
|
||||
bedrock.SecretAccessKey = keys[2]
|
||||
if len(keys) == 4 && keys[3] != "" {
|
||||
bedrock.SessionToken = keys[3]
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BedrockProvider) Sign(req *http.Request) error {
|
||||
var body []byte
|
||||
if req.Body == nil {
|
||||
body = []byte("")
|
||||
} else {
|
||||
var err error
|
||||
body, err = io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return errors.New("error getting request body: " + err.Error())
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||
}
|
||||
sig, err := sigv4.New(sigv4.WithCredential(p.AccessKeyID, p.SecretAccessKey, p.SessionToken), sigv4.WithRegionService(p.Region, awsService))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqBodyHashHex := fmt.Sprintf("%x", sha256.Sum256(body))
|
||||
sig.Sign(req, reqBodyHashHex, sigv4.NewTime(time.Now()))
|
||||
|
||||
return nil
|
||||
}
|
54
providers/bedrock/category/base.go
Normal file
54
providers/bedrock/category/base.go
Normal file
@ -0,0 +1,54 @@
|
||||
package category
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common/requester"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var CategoryMap = map[string]Category{}
|
||||
|
||||
type Category struct {
|
||||
ModelName string
|
||||
ChatComplete ChatCompletionConvert
|
||||
ResponseChatComplete ChatCompletionResponse
|
||||
ResponseChatCompleteStrem ChatCompletionStreamResponse
|
||||
}
|
||||
|
||||
func GetCategory(modelName string) (*Category, error) {
|
||||
modelName = GetModelName(modelName)
|
||||
|
||||
// 点分割
|
||||
provider := strings.Split(modelName, ".")[0]
|
||||
|
||||
if category, exists := CategoryMap[provider]; exists {
|
||||
category.ModelName = modelName
|
||||
return &category, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("category_not_found")
|
||||
}
|
||||
|
||||
func GetModelName(modelName string) string {
|
||||
bedrockMap := map[string]string{
|
||||
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"claude-2.1": "anthropic.claude-v2:1",
|
||||
"claude-2.0": "anthropic.claude-v2",
|
||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||
}
|
||||
|
||||
if value, exists := bedrockMap[modelName]; exists {
|
||||
modelName = value
|
||||
}
|
||||
|
||||
return modelName
|
||||
}
|
||||
|
||||
type ChatCompletionConvert func(*types.ChatCompletionRequest) (any, *types.OpenAIErrorWithStatusCode)
|
||||
type ChatCompletionResponse func(base.ProviderInterface, *http.Response, *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode)
|
||||
|
||||
type ChatCompletionStreamResponse func(base.ProviderInterface, *types.ChatCompletionRequest) requester.HandlerPrefix[string]
|
63
providers/bedrock/category/claude.go
Normal file
63
providers/bedrock/category/claude.go
Normal file
@ -0,0 +1,63 @@
|
||||
package category
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/claude"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
const anthropicVersion = "bedrock-2023-05-31"
|
||||
|
||||
type ClaudeRequest struct {
|
||||
*claude.ClaudeRequest
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
CategoryMap["anthropic"] = Category{
|
||||
ChatComplete: ConvertClaudeFromChatOpenai,
|
||||
ResponseChatComplete: ConvertClaudeToChatOpenai,
|
||||
ResponseChatCompleteStrem: ClaudeChatCompleteStrem,
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertClaudeFromChatOpenai(request *types.ChatCompletionRequest) (any, *types.OpenAIErrorWithStatusCode) {
|
||||
rawRequest, err := claude.ConvertFromChatOpenai(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claudeRequest := &ClaudeRequest{}
|
||||
claudeRequest.ClaudeRequest = rawRequest
|
||||
claudeRequest.AnthropicVersion = anthropicVersion
|
||||
|
||||
// 删除model字段
|
||||
claudeRequest.Model = ""
|
||||
claudeRequest.Stream = false
|
||||
|
||||
return claudeRequest, nil
|
||||
}
|
||||
|
||||
func ConvertClaudeToChatOpenai(provider base.ProviderInterface, response *http.Response, request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
claudeResponse := &claude.ClaudeResponse{}
|
||||
err := json.NewDecoder(response.Body).Decode(claudeResponse)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return claude.ConvertToChatOpenai(provider, claudeResponse, request)
|
||||
}
|
||||
|
||||
func ClaudeChatCompleteStrem(provider base.ProviderInterface, request *types.ChatCompletionRequest) requester.HandlerPrefix[string] {
|
||||
chatHandler := &claude.ClaudeStreamHandler{
|
||||
Usage: provider.GetUsage(),
|
||||
Request: request,
|
||||
Prefix: `{"type"`,
|
||||
}
|
||||
|
||||
return chatHandler.HandlerStream
|
||||
}
|
81
providers/bedrock/chat.go
Normal file
81
providers/bedrock/chat.go
Normal file
@ -0,0 +1,81 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/providers/bedrock/category"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *BedrockProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
// 发送请求
|
||||
response, errWithCode := p.Send(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
|
||||
return p.Category.ResponseChatComplete(p, response, request)
|
||||
}
|
||||
|
||||
func (p *BedrockProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
// 发送请求
|
||||
response, errWithCode := p.Send(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return RequestStream(response, p.Category.ResponseChatCompleteStrem(p, request))
|
||||
}
|
||||
|
||||
func (p *BedrockProvider) Send(request *types.ChatCompletionRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
return p.Requester.SendRequestRaw(req)
|
||||
}
|
||||
|
||||
func (p *BedrockProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
var err error
|
||||
p.Category, err = category.GetCategory(request.Model)
|
||||
if err != nil || p.Category.ChatComplete == nil || p.Category.ResponseChatComplete == nil {
|
||||
return nil, common.StringErrorWrapper("bedrock provider not found", "bedrock_err", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
url += "-with-response-stream"
|
||||
}
|
||||
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, p.Category.ModelName)
|
||||
if fullRequestURL == "" {
|
||||
return nil, common.ErrorWrapper(nil, "invalid_claude_config", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
bedrockRequest, errWithCode := p.Category.ChatComplete(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(bedrockRequest), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
p.Sign(req)
|
||||
|
||||
return req, nil
|
||||
}
|
27
providers/bedrock/sigv4/LICENSE
Normal file
27
providers/bedrock/sigv4/LICENSE
Normal file
@ -0,0 +1,27 @@
|
||||
Copyright (c) 2023 Macks C. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Macks C. nor the names of its contributors
|
||||
may be used to endorse or promote products derived from this software
|
||||
without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
36
providers/bedrock/sigv4/const.go
Normal file
36
providers/bedrock/sigv4/const.go
Normal file
@ -0,0 +1,36 @@
|
||||
package sigv4
|
||||
|
||||
const authorizationHeader = "Authorization"
|
||||
|
||||
// Signature Version 4 (SigV4) Constants
|
||||
const (
|
||||
// SigningAlgorithm is the name of the algorithm used in this package.
|
||||
SigningAlgorithm = "AWS4-HMAC-SHA256"
|
||||
// EmptyStringSHA256 is the hex encoded sha256 value of an empty string.
|
||||
EmptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
|
||||
// UnsignedPayload indicates that the request payload body is unsigned.
|
||||
UnsignedPayload = "UNSIGNED-PAYLOAD"
|
||||
// AmzAlgorithmKey indicates the signing algorithm.
|
||||
AmzAlgorithmKey = "X-Amz-Algorithm"
|
||||
// AmzSecurityTokenKey indicates the security token to be used with temporary
|
||||
// credentials.
|
||||
AmzSecurityTokenKey = "X-Amz-Security-Token"
|
||||
// AmzDateKey is the UTC timestamp for the request in the format YYYYMMDD'T'HHMMSS'Z'.
|
||||
AmzDateKey = "X-Amz-Date"
|
||||
// AmzCredentialKey is the access key ID and credential scope.
|
||||
AmzCredentialKey = "X-Amz-Credential"
|
||||
// AmzSignedHeadersKey is the set of headers signed for the request.
|
||||
AmzSignedHeadersKey = "X-Amz-SignedHeaders"
|
||||
// AmzSignatureKey is the query parameter to store the SigV4 signature.
|
||||
AmzSignatureKey = "X-Amz-Signature"
|
||||
// TimeFormat is the time format to be used in the X-Amz-Date header or query
|
||||
// parameter.
|
||||
TimeFormat = "20060102T150405Z"
|
||||
// ShortTimeFormat is the shorten time format used in the credential scope.
|
||||
ShortTimeFormat = "20060102"
|
||||
// ContentSHAKey is the SHA256 of request body.
|
||||
ContentSHAKey = "X-Amz-Content-Sha256"
|
||||
// StreamingEventsPayload indicates that the request payload body is a signed
|
||||
// event stream.
|
||||
StreamingEventsPayload = "STREAMING-AWS4-HMAC-SHA256-EVENTS"
|
||||
)
|
91
providers/bedrock/sigv4/header.go
Normal file
91
providers/bedrock/sigv4/header.go
Normal file
@ -0,0 +1,91 @@
|
||||
package sigv4
|
||||
|
||||
// ignoredHeaders is a list of headers that are always ignored during signing.
|
||||
var ignoreHeaders = map[string]struct{}{
|
||||
"Authorization": {},
|
||||
"User-Agent": {},
|
||||
"X-Amzn-Trace-Id": {},
|
||||
// also include lower case canonical versions
|
||||
"authorization": {},
|
||||
"user-agent": {},
|
||||
"x-amzn-trace-id": {},
|
||||
}
|
||||
|
||||
// requiredHeaderPrefix are header name prefixes that are mandatory for signing.
|
||||
// If a header has one of these prefixes, it is a mandatory header.
|
||||
var requiredHeaderPrefix = []string{"X-Amz-Object-Lock-", "X-Amz-Meta-"}
|
||||
|
||||
// requiredHeaders is a list of headers that are mandatory for signing.
|
||||
var requiredHeaders = map[string]struct{}{
|
||||
"Cache-Control": {},
|
||||
"Content-Disposition": {},
|
||||
"Content-Encoding": {},
|
||||
"Content-Language": {},
|
||||
"Content-Md5": {},
|
||||
"Content-Type": {},
|
||||
"Expires": {},
|
||||
"If-Match": {},
|
||||
"If-Modified-Since": {},
|
||||
"If-None-Match": {},
|
||||
"If-Unmodified-Since": {},
|
||||
"Range": {},
|
||||
"X-Amz-Acl": {},
|
||||
"X-Amz-Copy-Source": {},
|
||||
"X-Amz-Copy-Source-If-Match": {},
|
||||
"X-Amz-Copy-Source-If-Modified-Since": {},
|
||||
"X-Amz-Copy-Source-If-None-Match": {},
|
||||
"X-Amz-Copy-Source-If-Unmodified-Since": {},
|
||||
"X-Amz-Copy-Source-Range": {},
|
||||
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Algorithm": {},
|
||||
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key": {},
|
||||
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key-Md5": {},
|
||||
"X-Amz-Grant-Full-control": {},
|
||||
"X-Amz-Grant-Read": {},
|
||||
"X-Amz-Grant-Read-Acp": {},
|
||||
"X-Amz-Grant-Write": {},
|
||||
"X-Amz-Grant-Write-Acp": {},
|
||||
"X-Amz-Metadata-Directive": {},
|
||||
"X-Amz-Mfa": {},
|
||||
"X-Amz-Request-Payer": {},
|
||||
"X-Amz-Server-Side-Encryption": {},
|
||||
"X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": {},
|
||||
"X-Amz-Server-Side-Encryption-Customer-Algorithm": {},
|
||||
"X-Amz-Server-Side-Encryption-Customer-Key": {},
|
||||
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": {},
|
||||
"X-Amz-Storage-Class": {},
|
||||
"X-Amz-Website-Redirect-Location": {},
|
||||
"X-Amz-Content-Sha256": {},
|
||||
"X-Amz-Tagging": {},
|
||||
}
|
||||
|
||||
// headerPredicate is a function that evaluates whether a header is of the
|
||||
// specific type. For example, whether a header should be ignored during signing.
|
||||
type headerPredicate func(header string) bool
|
||||
|
||||
// isIgnoredHeader returns true if header must be ignored during signing.
|
||||
func isIgnoredHeader(header string) bool {
|
||||
_, ok := ignoreHeaders[header]
|
||||
return ok
|
||||
}
|
||||
|
||||
// isRequiredHeader returns true if header is mandatory for signing.
|
||||
func isRequiredHeader(header string) bool {
|
||||
_, ok := requiredHeaders[header]
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
for _, v := range requiredHeaderPrefix {
|
||||
if hasPrefixFold(header, v) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isAllowQueryHoisting is a allowed list for Build query headers.
|
||||
func isAllowQueryHoisting(header string) bool {
|
||||
if isRequiredHeader(header) {
|
||||
return false
|
||||
}
|
||||
return hasPrefixFold(header, "X-Amz-")
|
||||
}
|
302
providers/bedrock/sigv4/helper.go
Normal file
302
providers/bedrock/sigv4/helper.go
Normal file
@ -0,0 +1,302 @@
|
||||
package sigv4
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
awsURLNoEscTable [256]bool
|
||||
awsURLEscTable [256][2]byte
|
||||
)
|
||||
|
||||
func init() {
|
||||
for i := 0; i < len(awsURLNoEscTable); i++ {
|
||||
// every char except these must be escaped
|
||||
awsURLNoEscTable[i] = (i >= 'A' && i <= 'Z') ||
|
||||
(i >= 'a' && i <= 'z') ||
|
||||
(i >= '0' && i <= '9') ||
|
||||
i == '-' ||
|
||||
i == '.' ||
|
||||
i == '_' ||
|
||||
i == '~'
|
||||
// %<hex><hex>
|
||||
encoded := fmt.Sprintf("%02X", i)
|
||||
awsURLEscTable[i] = [2]byte{encoded[0], encoded[1]}
|
||||
}
|
||||
}
|
||||
|
||||
// hmacsha256 computes a HMAC-SHA256 of data given the provided key.
|
||||
func hmacsha256(key, data, buf []byte) []byte {
|
||||
hash := hmac.New(sha256.New, key)
|
||||
hash.Write(data)
|
||||
return hash.Sum(buf)
|
||||
}
|
||||
|
||||
// hasPrefixFold tests whether the string s begins with prefix, interpreted as
|
||||
// UTF-8 strings, under Unicode case-folding.
|
||||
func hasPrefixFold(s, prefix string) bool {
|
||||
return len(s) >= len(prefix) &&
|
||||
strings.EqualFold(s[0:len(prefix)], prefix)
|
||||
}
|
||||
|
||||
// isSameDay returns true if a and b are the same date (dd-mm-yyyy).
|
||||
func isSameDay(a, b time.Time) bool {
|
||||
xYear, xMonth, xDay := a.Date()
|
||||
yYear, yMonth, yDay := b.Date()
|
||||
|
||||
if xYear != yYear || xMonth != yMonth {
|
||||
return false
|
||||
}
|
||||
return xDay == yDay
|
||||
}
|
||||
|
||||
// hostOrURLHost returns r.Host, or if empty, r.URL.Host.
|
||||
func hostOrURLHost(r *http.Request) string {
|
||||
if r.Host != "" {
|
||||
return r.Host
|
||||
}
|
||||
return r.URL.Host
|
||||
}
|
||||
|
||||
// parsePort returns the port part of u.Host, without the leading colon. Returns
|
||||
// an empty string if u.Host doesn't contain port.
|
||||
//
|
||||
// Adapted from the Go 1.8 standard library (net/url).
|
||||
func parsePort(hostport string) string {
|
||||
colon := strings.IndexByte(hostport, ':')
|
||||
if colon == -1 || colon == len(hostport)-1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// take care of ipv6 syntax: [a:b::]:<port>
|
||||
const ipv6Sep = "]:"
|
||||
if i := strings.Index(hostport, ipv6Sep); i != -1 {
|
||||
return hostport[i+len(ipv6Sep):]
|
||||
}
|
||||
if strings.Contains(hostport, "]") {
|
||||
return ""
|
||||
}
|
||||
|
||||
return hostport[colon+1:]
|
||||
}
|
||||
|
||||
// stripPort returns Hostname portion of u.Host, i.e. without any port number.
|
||||
//
|
||||
// If hostport is an IPv6 literal with a port number, returns the IPv6 literal
|
||||
// without the square brackets. IPv6 literals may include a zone identifier.
|
||||
//
|
||||
// Adapted from the Go 1.8 standard library (net/url).
|
||||
func stripPort(hostport string) string {
|
||||
colon := strings.IndexByte(hostport, ':')
|
||||
if colon == -1 {
|
||||
return hostport
|
||||
}
|
||||
// ipv6: remove the []
|
||||
if i := strings.IndexByte(hostport, ']'); i != -1 {
|
||||
return strings.TrimPrefix(hostport[:i], "[")
|
||||
}
|
||||
return hostport[:colon]
|
||||
}
|
||||
|
||||
// isDefaultPort returns true if the specified URI is using the standard port
|
||||
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs).
|
||||
func isDefaultPort(scheme, port string) bool {
|
||||
switch strings.ToLower(scheme) {
|
||||
case "http":
|
||||
return port == "80"
|
||||
case "https":
|
||||
return port == "443"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func cloneURL(u *url.URL) *url.URL {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u2 := new(url.URL)
|
||||
*u2 = *u
|
||||
if u.User != nil {
|
||||
u2.User = new(url.Userinfo)
|
||||
*u2.User = *u.User
|
||||
}
|
||||
return u2
|
||||
}
|
||||
|
||||
// writeAWSURIPath writes the escaped URI component from the specified URL (using
|
||||
// AWS canonical URI specification) into w. URI component is path without query
|
||||
// string.
|
||||
func writeAWSURIPath(w *bufio.Writer, u *url.URL, encodeSep bool, isEscaped bool) {
|
||||
const schemeSep, pathSep, queryStart = "//", "/", "?"
|
||||
|
||||
var p string
|
||||
if u.Opaque == "" {
|
||||
p = u.EscapedPath()
|
||||
} else {
|
||||
opaque := u.Opaque
|
||||
// discard query string if any
|
||||
if i := strings.Index(opaque, queryStart); i != -1 {
|
||||
opaque = opaque[:i]
|
||||
}
|
||||
// if has scheme separator as prefix, discard it
|
||||
if strings.HasPrefix(opaque, schemeSep) {
|
||||
opaque = opaque[len(schemeSep):]
|
||||
}
|
||||
|
||||
// everything after the first /, including the /
|
||||
if i := strings.Index(opaque, pathSep); i != -1 {
|
||||
p = opaque[i:]
|
||||
}
|
||||
}
|
||||
|
||||
if p == "" {
|
||||
w.WriteByte('/')
|
||||
return
|
||||
}
|
||||
|
||||
if isEscaped {
|
||||
w.WriteString(p)
|
||||
return
|
||||
}
|
||||
|
||||
// Loop thru first like in https://cs.opensource.google/go/go/+/refs/tags/go1.20.2:/src/net/url/url.go.
|
||||
// It may add ~800ns but we save on memory alloc and catches cases where there
|
||||
// is no need to escape.
|
||||
plen := len(p)
|
||||
strlen := plen
|
||||
for i := 0; i < plen; i++ {
|
||||
c := p[i]
|
||||
if awsURLNoEscTable[c] || (c == '/' && !encodeSep) {
|
||||
continue
|
||||
}
|
||||
strlen += 2
|
||||
}
|
||||
|
||||
// path already canonical, no need to escape
|
||||
if plen == strlen {
|
||||
w.WriteString(p)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < plen; i++ {
|
||||
c := p[i]
|
||||
if awsURLNoEscTable[c] || (c == '/' && !encodeSep) {
|
||||
w.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
w.Write([]byte{'%', awsURLEscTable[c][0], awsURLEscTable[c][1]})
|
||||
}
|
||||
}
|
||||
|
||||
// writeCanonicalQueryParams builds the canonical form of query and writes to w.
|
||||
//
|
||||
// Side effect: query values are sorted after this function returns.
|
||||
func writeCanonicalQueryParams(w *bufio.Writer, query url.Values) {
|
||||
qlen := len(query)
|
||||
if qlen == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
keys := make([]string, 0, qlen)
|
||||
for k := range query {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for i, k := range keys {
|
||||
keyEscaped := strings.Replace(url.QueryEscape(k), "+", "%20", -1)
|
||||
vs := query[k]
|
||||
|
||||
if i != 0 {
|
||||
w.WriteByte('&')
|
||||
}
|
||||
|
||||
if len(vs) == 0 {
|
||||
w.WriteString(keyEscaped)
|
||||
w.WriteByte('=')
|
||||
continue
|
||||
}
|
||||
|
||||
sort.Strings(vs)
|
||||
for j, v := range vs {
|
||||
if j != 0 {
|
||||
w.WriteByte('&')
|
||||
}
|
||||
w.WriteString(keyEscaped)
|
||||
w.WriteByte('=')
|
||||
if v != "" {
|
||||
w.WriteString(strings.Replace(url.QueryEscape(v), "+", "%20", -1))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeCanonicalString removes leading and trailing whitespaces (as defined by Unicode)
|
||||
// in s, replaces consecutive spaces (' ') in s with a single space, and then
|
||||
// write the result to w.
|
||||
func writeCanonicalString(w *bufio.Writer, s string) {
|
||||
const dblSpace = " "
|
||||
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
// bail if str doesn't contain " "
|
||||
j := strings.Index(s, dblSpace)
|
||||
if j < 0 {
|
||||
w.WriteString(s)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteString(s[:j])
|
||||
|
||||
// replace all " " with " " in a performant way
|
||||
var lastIsSpace bool
|
||||
for i, l := j, len(s); i < l; i++ {
|
||||
if s[i] == ' ' {
|
||||
if !lastIsSpace {
|
||||
w.WriteByte(' ')
|
||||
lastIsSpace = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
lastIsSpace = false
|
||||
w.WriteByte(s[i])
|
||||
}
|
||||
}
|
||||
|
||||
type debugHasher struct {
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func (dh *debugHasher) Write(b []byte) (int, error) {
|
||||
dh.buf = append(dh.buf, b...)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (dh *debugHasher) Sum(b []byte) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dh *debugHasher) Reset() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func (dh *debugHasher) Size() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (dh *debugHasher) BlockSize() int {
|
||||
return sha256.BlockSize
|
||||
}
|
||||
|
||||
func (dh *debugHasher) Println() {
|
||||
fmt.Printf("---%s---\n", dh.buf)
|
||||
}
|
125
providers/bedrock/sigv4/key_deriver.go
Normal file
125
providers/bedrock/sigv4/key_deriver.go
Normal file
@ -0,0 +1,125 @@
|
||||
package sigv4
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
gotime "time"
|
||||
)
|
||||
|
||||
var credScopeSuffixBytes = []byte{'a', 'w', 's', '4', '_', 'r', 'e', 'q', 'u', 'e', 's', 't'}
|
||||
|
||||
// deriveKey calculates the signing key. See https://docs.aws.amazon.com/general/latest/gr/create-signed-request.html.
|
||||
func deriveKey(secret, service, region string, t Time) []byte {
|
||||
// enc(
|
||||
// enc(
|
||||
// enc(
|
||||
// enc(AWS4<secret>, <short_time>),
|
||||
// <region>),
|
||||
// <service>),
|
||||
// "aws4_request")
|
||||
|
||||
// https://en.wikipedia.org/wiki/HMAC
|
||||
// HMAC_SHA256 produces 32 bytes output
|
||||
|
||||
f1 := len(secret) + 4
|
||||
f2 := f1 + len(t.ShortTimeFormat())
|
||||
f3 := f2 + len(region)
|
||||
f4 := f3 + len(service)
|
||||
|
||||
qs := make([]byte, 0, f4)
|
||||
qs = append(qs, "AWS4"...)
|
||||
qs = append(qs, secret...)
|
||||
qs = append(qs, t.ShortTimeFormat()...)
|
||||
qs = append(qs, region...)
|
||||
qs = append(qs, service...)
|
||||
|
||||
buf := make([]byte, 0, sha256.BlockSize)
|
||||
buf = hmacsha256(qs[:f1], qs[f1:f2], buf)
|
||||
buf = hmacsha256(buf, qs[f2:f3], buf[:0])
|
||||
buf = hmacsha256(buf, qs[f3:], buf[:0])
|
||||
return hmacsha256(buf, credScopeSuffixBytes, buf[:0])
|
||||
}
|
||||
|
||||
// keyDeriver returns a signing key based on parameters such as credentials.
|
||||
type keyDeriver interface {
|
||||
DeriveKey(accessKey, secret, service, region string, sigtime Time) []byte
|
||||
}
|
||||
|
||||
// signingKeyDeriver is the default implementation of keyDerivator.
|
||||
type signingKeyDeriver struct {
|
||||
cache derivedKeyCache
|
||||
}
|
||||
|
||||
// newKeyDeriver creates a keyDeriver using the default implementation. The
|
||||
// signing key is cached per region/service, and updated when accessKey changes
|
||||
// or signingTime is not on the same day for that region/service.
|
||||
func newKeyDeriver() keyDeriver {
|
||||
return &signingKeyDeriver{cache: newDerivedKeyCache()}
|
||||
}
|
||||
|
||||
// DeriveKey returns a derived signing key from the given credentials to be used
|
||||
// with SigV4 signing.
|
||||
func (k *signingKeyDeriver) DeriveKey(accessKey, secret, service, region string, sigtime Time) []byte {
|
||||
return k.cache.Get(accessKey, secret, service, region, sigtime)
|
||||
}
|
||||
|
||||
type derivedKeyCache struct {
|
||||
mutex sync.RWMutex
|
||||
values map[string]derivedKey
|
||||
nowFunc func() gotime.Time
|
||||
}
|
||||
|
||||
type derivedKey struct {
|
||||
Date gotime.Time
|
||||
Credential []byte
|
||||
}
|
||||
|
||||
func newDerivedKeyCache() derivedKeyCache {
|
||||
return derivedKeyCache{
|
||||
values: make(map[string]derivedKey),
|
||||
nowFunc: gotime.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns key from cache or creates a new one.
|
||||
func (s *derivedKeyCache) Get(accessKey, secret, service, region string, sigtime Time) []byte {
|
||||
// <accessKey>/<YYYYMMDD>/<region>/<service>
|
||||
key := strings.Join([]string{accessKey, sigtime.ShortTimeFormat(), region, service}, "/")
|
||||
|
||||
s.mutex.RLock()
|
||||
cred, status := s.getFromCache(key)
|
||||
s.mutex.RUnlock()
|
||||
if status == 0 {
|
||||
return cred
|
||||
}
|
||||
|
||||
cred = deriveKey(secret, service, region, sigtime)
|
||||
|
||||
s.mutex.Lock()
|
||||
if status == -1 {
|
||||
delete(s.values, key)
|
||||
}
|
||||
s.values[key] = derivedKey{
|
||||
Date: sigtime.Time,
|
||||
Credential: cred,
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
return cred
|
||||
}
|
||||
|
||||
// getFromCache returns s.values[key]. Second result is 1 if key was not found,
|
||||
// or -1 if the cached value has expired.
|
||||
func (s *derivedKeyCache) getFromCache(key string) ([]byte, int) {
|
||||
v, ok := s.values[key]
|
||||
if !ok {
|
||||
return nil, 1
|
||||
}
|
||||
// evict from cache if item is a day older than system time
|
||||
if s.nowFunc().Sub(v.Date) > 24*time.Hour {
|
||||
return nil, -1
|
||||
}
|
||||
return v.Credential, 0
|
||||
}
|
36
providers/bedrock/sigv4/sign_time.go
Normal file
36
providers/bedrock/sigv4/sign_time.go
Normal file
@ -0,0 +1,36 @@
|
||||
package sigv4
|
||||
|
||||
import (
|
||||
gotime "time"
|
||||
)
|
||||
|
||||
// Time wraps time.Time to cache its string format result.
|
||||
type Time struct {
|
||||
gotime.Time
|
||||
short string
|
||||
long string
|
||||
}
|
||||
|
||||
// NewTime creates a new signingTime with the specified time.Time.
|
||||
func NewTime(t gotime.Time) Time {
|
||||
return Time{Time: t.UTC()}
|
||||
}
|
||||
|
||||
// TimeFormat provides a time formatted in the X-Amz-Date format.
|
||||
func (m *Time) TimeFormat() string {
|
||||
return m.readOrFormat(&m.long, TimeFormat)
|
||||
}
|
||||
|
||||
// ShortTimeFormat provides a time formatted in short time format.
|
||||
func (m *Time) ShortTimeFormat() string {
|
||||
return m.readOrFormat(&m.short, ShortTimeFormat)
|
||||
}
|
||||
|
||||
func (m *Time) readOrFormat(target *string, format string) string {
|
||||
if len(*target) > 0 {
|
||||
return *target
|
||||
}
|
||||
v := m.Time.Format(format)
|
||||
*target = v
|
||||
return v
|
||||
}
|
428
providers/bedrock/sigv4/signer.go
Normal file
428
providers/bedrock/sigv4/signer.go
Normal file
@ -0,0 +1,428 @@
|
||||
package sigv4
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"hash"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// HTTPSigner is an AWS SigV4 signer that can sign HTTP requests.
|
||||
type HTTPSigner interface {
|
||||
// Sign AWS v4 requests with the provided payload hash, service name, region
|
||||
// the request is made to, and time the request is signed at. Set sigtime
|
||||
// to the future to create a request that cannot be used until the future time.
|
||||
//
|
||||
// payloadHash is the hex encoded SHA-256 hash of the request payload, and must
|
||||
// not be empty, even if the request has no payload (aka body). If the request
|
||||
// has no payload, use the hex encoded SHA-256 of an empty string, or the constant
|
||||
// EmptyStringSHA256. You can use the utility function ContentSHA256Sum to
|
||||
// calculate the hash of a http.Request body.
|
||||
//
|
||||
// Some services such as Amazon S3 accept alternative values for the payload
|
||||
// hash, such as "UNSIGNED-PAYLOAD" for requests where the body will not be
|
||||
// protected by sigv4. See https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html.
|
||||
//
|
||||
// Sign differs from Presign in that it will sign the request using HTTP headers.
|
||||
// The passed in request r will be modified in place: modified fields include
|
||||
// r.Host and r.Header.
|
||||
Sign(r *http.Request, payloadHash string, sigtime Time) error
|
||||
// Presign is like Sign, but does not modify request r. It returns a copy of
|
||||
// r.URL with additional query parameters that contains signing information.
|
||||
// The URL can be used to recreate an authenticated request without specifying
|
||||
// headers. It also returns http.Header as a second result, which must be
|
||||
// included in the reconstructed request.
|
||||
//
|
||||
// Header hoisting: use WithHeaderHoisting option function to specify whether
|
||||
// headers in request r should be added as query parameters. Some headers cannot
|
||||
// be hoisted, and are returned as the second result.
|
||||
//
|
||||
// Presign will not set the expires time of the presigned request automatically.
|
||||
// To specify the expire duration for a request, add the "X-Amz-Expires" query
|
||||
// parameter on the request with the value as the duration in seconds the
|
||||
// presigned URL should be considered valid for. This parameter is not used
|
||||
// by all AWS services, and is most notable used by Amazon S3 APIs.
|
||||
//
|
||||
// expires := 20*time.Minute
|
||||
// query := req.URL.Query()
|
||||
// query.Set("X-Amz-Expires", strconv.FormatInt(int64(expires/time.Second), 10)
|
||||
// req.URL.RawQuery = query.Encode()
|
||||
Presign(r *http.Request, payloadHash string, sigtime Time) (*url.URL, http.Header, error)
|
||||
}
|
||||
|
||||
// HTTPSignerOption is an option parameter for HTTPSigner constructor function.
|
||||
type HTTPSignerOption func(HTTPSigner) error
|
||||
|
||||
// ErrInvalidOption means the option parameter is incompatible with the HTTPSigner.
|
||||
var ErrInvalidOption = errors.New("cannot apply option to HTTPSigner")
|
||||
|
||||
// httpV4Signer is the default implementation of HTTPSigner.
|
||||
type httpV4Signer struct {
|
||||
KeyDeriver keyDeriver
|
||||
AccessKey string
|
||||
Secret string
|
||||
SessionToken string
|
||||
Service string
|
||||
Region string
|
||||
HeaderHoisting bool
|
||||
EscapeURLPath bool
|
||||
}
|
||||
|
||||
// WithCredential sets HTTPSigner credential fields.
|
||||
func WithCredential(accessKey, secret, sessionToken string) HTTPSignerOption {
|
||||
return func(signer HTTPSigner) error {
|
||||
if sigv4, ok := signer.(*httpV4Signer); ok {
|
||||
sigv4.AccessKey = accessKey
|
||||
sigv4.Secret = secret
|
||||
sigv4.SessionToken = sessionToken
|
||||
return nil
|
||||
}
|
||||
return ErrInvalidOption
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeaderHoisting specifies whether HTTPSigner automatically hoist headers.
|
||||
// Default is enabled.
|
||||
func WithHeaderHoisting(enable bool) HTTPSignerOption {
|
||||
return func(signer HTTPSigner) error {
|
||||
if sigv4, ok := signer.(*httpV4Signer); ok {
|
||||
sigv4.HeaderHoisting = enable
|
||||
return nil
|
||||
}
|
||||
return ErrInvalidOption
|
||||
}
|
||||
}
|
||||
|
||||
// WithEscapeURLPath specifies whether HTTPSigner automatically escapes URL paths.
|
||||
// Default is enabled.
|
||||
func WithEscapeURLPath(enable bool) HTTPSignerOption {
|
||||
return func(signer HTTPSigner) error {
|
||||
if sigv4, ok := signer.(*httpV4Signer); ok {
|
||||
sigv4.EscapeURLPath = enable
|
||||
return nil
|
||||
}
|
||||
return ErrInvalidOption
|
||||
}
|
||||
}
|
||||
|
||||
// WithRegionService sets HTTPSigner region and service fields.
|
||||
func WithRegionService(region, service string) HTTPSignerOption {
|
||||
return func(signer HTTPSigner) error {
|
||||
if sigv4, ok := signer.(*httpV4Signer); ok {
|
||||
sigv4.Region = region
|
||||
sigv4.Service = service
|
||||
return nil
|
||||
}
|
||||
return ErrInvalidOption
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a HTTPSigner.
|
||||
func New(opts ...HTTPSignerOption) (HTTPSigner, error) {
|
||||
sigv4 := &httpV4Signer{
|
||||
KeyDeriver: newKeyDeriver(),
|
||||
EscapeURLPath: true,
|
||||
HeaderHoisting: true,
|
||||
}
|
||||
for _, o := range opts {
|
||||
if o == nil {
|
||||
continue
|
||||
}
|
||||
if err := o(sigv4); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return sigv4, nil
|
||||
}
|
||||
|
||||
// Sign implements HTTPSigner.
|
||||
func (s *httpV4Signer) Sign(r *http.Request, payloadHash string, sigtime Time) error {
|
||||
if payloadHash == "" {
|
||||
var err error
|
||||
payloadHash, err = ContentSHA256Sum(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// add mandatory headers to r.Header
|
||||
setRequiredSigningHeaders(r.Header, sigtime, s.SessionToken)
|
||||
// remove port in r.Host if any
|
||||
r.Host = sanitizeHostForHeader(r)
|
||||
|
||||
// parse URL query only once
|
||||
query := r.URL.Query()
|
||||
// sigBuf is used to act as a sha256 hash buffer
|
||||
sigBuf := make([]byte, 0, sha256.Size)
|
||||
|
||||
//hasher := &debugHasher{}
|
||||
hasher := sha256.New()
|
||||
reqhash, signedHeaderStr := canonicalRequestHash(hasher, r, r.Header, query,
|
||||
r.Host, payloadHash, s.EscapeURLPath, false, sigBuf)
|
||||
|
||||
credentialScope := strings.Join([]string{
|
||||
sigtime.ShortTimeFormat(),
|
||||
s.Region,
|
||||
s.Service,
|
||||
"aws4_request",
|
||||
}, "/")
|
||||
|
||||
keyBytes := s.KeyDeriver.DeriveKey(s.AccessKey, s.Secret, s.Service,
|
||||
s.Region, sigtime)
|
||||
sigHasher := hmac.New(sha256.New, keyBytes)
|
||||
signature := authorizationSignature(sigHasher, sigtime, credentialScope, reqhash, sigBuf)
|
||||
|
||||
writeAuthorizationHeader(r.Header, s.AccessKey+"/"+credentialScope,
|
||||
signedHeaderStr, signature)
|
||||
|
||||
// done
|
||||
return nil
|
||||
}
|
||||
|
||||
// Presign implements HTTPSigner.
|
||||
func (s *httpV4Signer) Presign(r *http.Request, payloadHash string, sigtime Time) (*url.URL, http.Header, error) {
|
||||
if payloadHash == "" {
|
||||
var err error
|
||||
payloadHash, err = ContentSHA256Sum(r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
query := r.URL.Query()
|
||||
setRequiredSigningQuery(query, sigtime, s.SessionToken)
|
||||
// sort each query key's values
|
||||
for key := range query {
|
||||
sort.Strings(query[key])
|
||||
}
|
||||
|
||||
credentialScope := strings.Join([]string{
|
||||
sigtime.ShortTimeFormat(),
|
||||
s.Region,
|
||||
s.Service,
|
||||
"aws4_request",
|
||||
}, "/")
|
||||
credentialStr := s.AccessKey + "/" + credentialScope
|
||||
query.Set(AmzCredentialKey, credentialStr)
|
||||
|
||||
var headersLeft http.Header
|
||||
if s.HeaderHoisting {
|
||||
headersLeft = make(http.Header, len(r.Header))
|
||||
for k, v := range r.Header {
|
||||
if isAllowQueryHoisting(k) {
|
||||
query[k] = v
|
||||
} else {
|
||||
headersLeft[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sigBuf is used to act as a sha256 hash buffer
|
||||
sigBuf := make([]byte, 0, sha256.Size)
|
||||
|
||||
hasher := sha256.New()
|
||||
reqhash, signedHeaderStr := canonicalRequestHash(hasher, r, headersLeft,
|
||||
query, sanitizeHostForHeader(r), payloadHash, s.EscapeURLPath, true, sigBuf)
|
||||
|
||||
keyBytes := s.KeyDeriver.DeriveKey(s.AccessKey, s.Secret, s.Service,
|
||||
s.Region, sigtime)
|
||||
sigHasher := hmac.New(sha256.New, keyBytes)
|
||||
signature := authorizationSignature(sigHasher, sigtime, credentialScope, reqhash, sigBuf)
|
||||
query.Set(AmzSignatureKey, signature)
|
||||
|
||||
u := cloneURL(r.URL)
|
||||
u.RawQuery = strings.Replace(query.Encode(), "+", "%20", -1)
|
||||
|
||||
// For the signed headers we canonicalize the header keys in the returned map.
|
||||
// This avoids situations where standard library can sometimes add double
|
||||
// headers. For example, the standard library will set the Host header,
|
||||
// even if it is present in lower-case form.
|
||||
signedHeader := strings.Split(signedHeaderStr, ";")
|
||||
canonHeader := make(http.Header, len(signedHeader))
|
||||
for _, k := range signedHeader {
|
||||
canonKey := textproto.CanonicalMIMEHeaderKey(k)
|
||||
switch k {
|
||||
case "host":
|
||||
canonHeader[canonKey] = []string{sanitizeHostForHeader(r)}
|
||||
case "content-length":
|
||||
canonHeader[canonKey] = []string{strconv.FormatInt(r.ContentLength, 10)}
|
||||
default:
|
||||
canonHeader[canonKey] = append(canonHeader[canonKey], headersLeft[http.CanonicalHeaderKey(k)]...)
|
||||
}
|
||||
}
|
||||
return u, canonHeader, nil
|
||||
}
|
||||
|
||||
// authorizationSignature returns `sig` as documented in step 4 of algorithm
|
||||
// documentation. key is hSig in step 4. It calculates the result of step 3
|
||||
// internally.
|
||||
func authorizationSignature(hasher hash.Hash, sigtime Time, credScope, requestHash string, buf []byte) string {
|
||||
w := bufio.NewWriterSize(hasher, sha256.BlockSize)
|
||||
|
||||
w.WriteString(SigningAlgorithm)
|
||||
w.WriteByte('\n')
|
||||
w.WriteString(sigtime.TimeFormat())
|
||||
w.WriteByte('\n')
|
||||
w.WriteString(credScope)
|
||||
w.WriteByte('\n')
|
||||
w.WriteString(requestHash)
|
||||
|
||||
w.Flush() // VERY IMPORTANT! Don't forget to flush remaining buffer
|
||||
//hasher.Println()
|
||||
return hex.EncodeToString(hasher.Sum(buf[:0]))
|
||||
}
|
||||
|
||||
// canonicalRequestHash returns the hex-encoded sha256 sum of the canonical
|
||||
// request string. Refer to step 2 of algorithm documentation. Expect hasher to
|
||||
// be sha256.New.
|
||||
func canonicalRequestHash(
|
||||
hasher hash.Hash, r *http.Request, headers http.Header, query url.Values,
|
||||
hostname, hashcode string, escapeURL, isPresign bool, buf []byte,
|
||||
) (string, string) {
|
||||
w := bufio.NewWriterSize(hasher, sha256.BlockSize)
|
||||
|
||||
signedHeaders := make([]string, 0, len(headers)+2)
|
||||
signedHeaders = append(signedHeaders, "host")
|
||||
if r.ContentLength > 0 {
|
||||
signedHeaders = append(signedHeaders, "content-length")
|
||||
}
|
||||
for k := range headers {
|
||||
if strings.EqualFold(k, "content-length") || strings.EqualFold(k, "host") || isIgnoredHeader(k) {
|
||||
continue
|
||||
}
|
||||
signedHeaders = append(signedHeaders, strings.ToLower(k))
|
||||
}
|
||||
sort.Strings(signedHeaders)
|
||||
signedHeaderStr := strings.Join(signedHeaders, ";")
|
||||
|
||||
// for presigned requests, we need to add X-Amz-SignedHeaders to calculate the
|
||||
// correct hash
|
||||
if isPresign {
|
||||
query.Set(AmzSignedHeadersKey, signedHeaderStr)
|
||||
}
|
||||
|
||||
// <METHOD>\n<URI>\n<QUERY>\n<HEADERS>\n<SIGNED_HEADERS>\n<PAYLOAD_HASH>
|
||||
|
||||
// HTTP_METHOD
|
||||
w.WriteString(r.Method)
|
||||
w.WriteByte('\n')
|
||||
// CANONICAL_URI
|
||||
writeAWSURIPath(w, r.URL, false, !escapeURL)
|
||||
w.WriteByte('\n')
|
||||
// CANONICAL_QUERY_PARAMS
|
||||
writeCanonicalQueryParams(w, query)
|
||||
w.WriteByte('\n')
|
||||
// CANONICAL_HEADERS
|
||||
for _, head := range signedHeaders {
|
||||
switch head {
|
||||
case "host":
|
||||
w.WriteString(head)
|
||||
w.WriteByte(':')
|
||||
writeCanonicalString(w, hostname)
|
||||
w.WriteByte('\n')
|
||||
case "content-length":
|
||||
w.WriteString(head)
|
||||
w.WriteByte(':')
|
||||
w.WriteString(strconv.FormatInt(r.ContentLength, 10))
|
||||
w.WriteByte('\n')
|
||||
default:
|
||||
w.WriteString(head)
|
||||
w.WriteByte(':')
|
||||
values := headers[http.CanonicalHeaderKey(head)]
|
||||
for i, v := range values {
|
||||
if i != 0 {
|
||||
w.WriteByte(',')
|
||||
}
|
||||
writeCanonicalString(w, v)
|
||||
}
|
||||
w.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
w.WriteByte('\n')
|
||||
// SIGNED_HEADERS
|
||||
w.WriteString(signedHeaderStr)
|
||||
w.WriteByte('\n')
|
||||
// PAYLOAD_HASH
|
||||
w.WriteString(hashcode)
|
||||
|
||||
w.Flush() // VERY IMPORTANT! Don't forget to flush remaining buffer
|
||||
//hasher.Println()
|
||||
return hex.EncodeToString(hasher.Sum(buf[:0])), signedHeaderStr
|
||||
}
|
||||
|
||||
// writeAuthorizationHeader writes the Authorization header into header:
|
||||
//
|
||||
// AWS4-HMAC-SHA256 Credential=<cred>, SignedHeaders=<signed_headers>, Signature=<sig>
|
||||
func writeAuthorizationHeader(headers http.Header, credentialStr, signedHeaders, signature string) {
|
||||
const credentialPrefix = "Credential="
|
||||
const signedHeadersPrefix = "SignedHeaders="
|
||||
const signaturePrefix = "Signature="
|
||||
const commaSpace = ", "
|
||||
|
||||
var parts strings.Builder
|
||||
parts.Grow(len(SigningAlgorithm) + 1 +
|
||||
len(credentialPrefix) + len(credentialStr) + 2 +
|
||||
len(signedHeadersPrefix) + len(signedHeaders) + 2 +
|
||||
len(signaturePrefix) + len(signature))
|
||||
|
||||
parts.WriteString(SigningAlgorithm)
|
||||
parts.WriteRune(' ')
|
||||
parts.WriteString(credentialPrefix)
|
||||
parts.WriteString(credentialStr)
|
||||
parts.WriteString(commaSpace)
|
||||
parts.WriteString(signedHeadersPrefix)
|
||||
parts.WriteString(signedHeaders)
|
||||
parts.WriteString(commaSpace)
|
||||
parts.WriteString(signaturePrefix)
|
||||
parts.WriteString(signature)
|
||||
|
||||
headers[authorizationHeader] = append(headers[authorizationHeader][:0],
|
||||
parts.String())
|
||||
}
|
||||
|
||||
// helpers
|
||||
|
||||
// sanitizeHostForHeader is like hostOrURLHost, but without port if port is the
|
||||
// default port for the scheme. For example, it removes ":80" suffix if the scheme
|
||||
// is "http".
|
||||
func sanitizeHostForHeader(r *http.Request) string {
|
||||
host := hostOrURLHost(r)
|
||||
port := parsePort(host)
|
||||
if port != "" && isDefaultPort(r.URL.Scheme, port) {
|
||||
return stripPort(host)
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// setRequiredSigningHeaders modifies headers: sets X-Amz-Date to sigtime, and
|
||||
// if credToken is non-empty, set X-Amz-Security-Token to credToken. This function
|
||||
// overwrites existing headers values with the same key.
|
||||
func setRequiredSigningHeaders(headers http.Header, sigtime Time, sessionToken string) {
|
||||
amzDate := sigtime.TimeFormat()
|
||||
headers[AmzDateKey] = append(headers[AmzDateKey][:0], amzDate)
|
||||
if sessionToken != "" {
|
||||
headers[AmzSecurityTokenKey] = append(headers[AmzSecurityTokenKey][:0],
|
||||
sessionToken)
|
||||
}
|
||||
}
|
||||
|
||||
// setRequiredSigningQuery is like setRequiredSigningHeaders, but modifies
|
||||
// query values. This is used for presign requests.
|
||||
func setRequiredSigningQuery(query url.Values, sigtime Time, sessionToken string) {
|
||||
query.Set(AmzAlgorithmKey, SigningAlgorithm)
|
||||
|
||||
amzDate := sigtime.TimeFormat()
|
||||
query.Set(AmzDateKey, amzDate)
|
||||
|
||||
if sessionToken != "" {
|
||||
query.Set(AmzSecurityTokenKey, sessionToken)
|
||||
}
|
||||
}
|
26
providers/bedrock/sigv4/util.go
Normal file
26
providers/bedrock/sigv4/util.go
Normal file
@ -0,0 +1,26 @@
|
||||
package sigv4
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ContentSHA256Sum calculates the hex-encoded SHA256 checksum of r.Body. It returns
|
||||
// EmptyStringSHA256 if r.Body is nil, r.Method is TRACE or r.ContentLength is zero.
|
||||
// Returns non-nil error if r.Body cannot be read.
|
||||
func ContentSHA256Sum(r *http.Request) (string, error) {
|
||||
// We need to check r.Body is non-nil, because io.Copy(dst, nil) panics.
|
||||
// This is not documented in https://pkg.go.dev/io#Copy.
|
||||
if r.Method == http.MethodTrace || r.ContentLength == 0 || r.Body == nil {
|
||||
return EmptyStringSHA256, nil
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
_, err := io.Copy(h, r.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(h.Sum(nil)), nil
|
||||
}
|
128
providers/bedrock/stream_reader.go
Normal file
128
providers/bedrock/stream_reader.go
Normal file
@ -0,0 +1,128 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
|
||||
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi"
|
||||
"github.com/aws/smithy-go"
|
||||
)
|
||||
|
||||
type streamReader[T any] struct {
|
||||
reader *bufio.Reader
|
||||
response *http.Response
|
||||
|
||||
handlerPrefix requester.HandlerPrefix[T]
|
||||
|
||||
DataChan chan T
|
||||
ErrChan chan error
|
||||
}
|
||||
|
||||
func (stream *streamReader[T]) Recv() (<-chan T, <-chan error) {
|
||||
go stream.processLines()
|
||||
|
||||
return stream.DataChan, stream.ErrChan
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (stream *streamReader[T]) processLines() {
|
||||
decode := eventstream.NewDecoder()
|
||||
payloadBuf := make([]byte, 0*1024)
|
||||
for {
|
||||
payloadBuf = payloadBuf[0:0]
|
||||
messgae, readErr := decode.Decode(stream.reader, payloadBuf)
|
||||
if readErr != nil {
|
||||
stream.ErrChan <- readErr
|
||||
return
|
||||
}
|
||||
|
||||
line, err := stream.deserializeEventMessage(&messgae)
|
||||
if err != nil {
|
||||
stream.ErrChan <- common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
stream.handlerPrefix(&line, stream.DataChan, stream.ErrChan)
|
||||
|
||||
if line == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(line, requester.StreamClosed) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *streamReader[T]) Close() {
|
||||
stream.response.Body.Close()
|
||||
}
|
||||
|
||||
func (stream *streamReader[T]) deserializeEventMessage(msg *eventstream.Message) ([]byte, error) {
|
||||
messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader)
|
||||
if messageType == nil {
|
||||
return nil, fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader)
|
||||
}
|
||||
|
||||
switch messageType.String() {
|
||||
case eventstreamapi.EventMessageType:
|
||||
var v BedrockResponseStream
|
||||
if err := json.Unmarshal(msg.Payload, &v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer, err := base64.StdEncoding.DecodeString(v.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buffer, nil
|
||||
|
||||
case eventstreamapi.ExceptionMessageType:
|
||||
exceptionType := msg.Headers.Get(eventstreamapi.ExceptionTypeHeader)
|
||||
return nil, errors.New("Exception message :" + exceptionType.String())
|
||||
|
||||
case eventstreamapi.ErrorMessageType:
|
||||
errorCode := "UnknownError"
|
||||
errorMessage := errorCode
|
||||
if header := msg.Headers.Get(eventstreamapi.ErrorCodeHeader); header != nil {
|
||||
errorCode = header.String()
|
||||
}
|
||||
if header := msg.Headers.Get(eventstreamapi.ErrorMessageHeader); header != nil {
|
||||
errorMessage = header.String()
|
||||
}
|
||||
return nil, &smithy.GenericAPIError{
|
||||
Code: errorCode,
|
||||
Message: errorMessage,
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, errors.New("bedrock stream unknown error")
|
||||
}
|
||||
}
|
||||
|
||||
func RequestStream[T any](resp *http.Response, handlerPrefix requester.HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) {
|
||||
// 如果返回的头是json格式 说明有错误
|
||||
if strings.Contains(resp.Header.Get("Content-Type"), "application/json") {
|
||||
return nil, requester.HandleErrorResp(resp, requestErrorHandle)
|
||||
}
|
||||
|
||||
stream := &streamReader[T]{
|
||||
reader: bufio.NewReader(resp.Body),
|
||||
response: resp,
|
||||
handlerPrefix: handlerPrefix,
|
||||
|
||||
DataChan: make(chan T),
|
||||
ErrChan: make(chan error),
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
11
providers/bedrock/type.go
Normal file
11
providers/bedrock/type.go
Normal file
@ -0,0 +1,11 @@
|
||||
package bedrock
|
||||
|
||||
const awsService = "bedrock"
|
||||
|
||||
type BedrockError struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type BedrockResponseStream struct {
|
||||
Bytes string `json:"bytes"`
|
||||
}
|
@ -8,13 +8,17 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/common/image"
|
||||
"one-api/common/requester"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
|
||||
)
|
||||
|
||||
type claudeStreamHandler struct {
|
||||
type ClaudeStreamHandler struct {
|
||||
Usage *types.Usage
|
||||
Request *types.ChatCompletionRequest
|
||||
Prefix string
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
@ -31,7 +35,7 @@ func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionReque
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return p.convertToChatOpenai(claudeResponse, request)
|
||||
return ConvertToChatOpenai(p, claudeResponse, request)
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
@ -47,12 +51,15 @@ func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletio
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := &claudeStreamHandler{
|
||||
chatHandler := &ClaudeStreamHandler{
|
||||
Usage: p.Usage,
|
||||
Request: request,
|
||||
Prefix: `data: {"type"`,
|
||||
}
|
||||
|
||||
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||
eventstream.NewDecoder()
|
||||
|
||||
return requester.RequestStream(p.Requester, resp, chatHandler.HandlerStream)
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
@ -72,7 +79,7 @@ func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
claudeRequest, errWithCode := convertFromChatOpenai(request)
|
||||
claudeRequest, errWithCode := ConvertFromChatOpenai(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
@ -86,7 +93,7 @@ func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func convertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest, *types.OpenAIErrorWithStatusCode) {
|
||||
func ConvertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest, *types.OpenAIErrorWithStatusCode) {
|
||||
claudeRequest := ClaudeRequest{
|
||||
Model: request.Model,
|
||||
Messages: []Message{},
|
||||
@ -142,7 +149,7 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest
|
||||
return &claudeRequest, nil
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
func ConvertToChatOpenai(provider base.ProviderInterface, response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
error := errorHandle(&response.Error)
|
||||
if error != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
@ -182,21 +189,24 @@ func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *
|
||||
openaiResponse.Usage.CompletionTokens = completionTokens
|
||||
openaiResponse.Usage.TotalTokens = promptTokens + completionTokens
|
||||
|
||||
*p.Usage = *openaiResponse.Usage
|
||||
usage := provider.GetUsage()
|
||||
*usage = *openaiResponse.Usage
|
||||
|
||||
return openaiResponse, nil
|
||||
}
|
||||
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
func (h *ClaudeStreamHandler) HandlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), `data: {"type"`) {
|
||||
if !strings.HasPrefix(string(*rawLine), h.Prefix) {
|
||||
*rawLine = nil
|
||||
return
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
*rawLine = (*rawLine)[6:]
|
||||
if strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
// 去除前缀
|
||||
*rawLine = (*rawLine)[6:]
|
||||
}
|
||||
|
||||
var claudeResponse ClaudeStreamResponse
|
||||
err := json.Unmarshal(*rawLine, &claudeResponse)
|
||||
@ -235,7 +245,7 @@ func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin
|
||||
}
|
||||
}
|
||||
|
||||
func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStreamResponse, dataChan chan string) {
|
||||
func (h *ClaudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStreamResponse, dataChan chan string) {
|
||||
choice := types.ChatCompletionStreamChoice{
|
||||
Index: claudeResponse.Index,
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ type Message struct {
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Model string `json:"model,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"one-api/providers/baichuan"
|
||||
"one-api/providers/baidu"
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/bedrock"
|
||||
"one-api/providers/claude"
|
||||
"one-api/providers/closeai"
|
||||
"one-api/providers/deepseek"
|
||||
@ -62,6 +63,7 @@ func init() {
|
||||
providerFactories[common.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{}
|
||||
providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{}
|
||||
providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{}
|
||||
providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
|
||||
|
||||
}
|
||||
|
||||
|
@ -101,6 +101,12 @@ export const CHANNEL_OPTIONS = {
|
||||
value: 31,
|
||||
color: 'primary'
|
||||
},
|
||||
32: {
|
||||
key: 32,
|
||||
text: 'Amazon Bedrock',
|
||||
value: 32,
|
||||
color: 'orange'
|
||||
},
|
||||
24: {
|
||||
key: 24,
|
||||
text: 'Azure Speech',
|
||||
|
@ -60,8 +60,15 @@ const typeConfig = {
|
||||
},
|
||||
14: {
|
||||
input: {
|
||||
models: ['claude-instant-1.2', 'claude-2.0', 'claude-2.1', 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'],
|
||||
test_model: 'claude-3-sonnet-20240229'
|
||||
models: [
|
||||
'claude-instant-1.2',
|
||||
'claude-2.0',
|
||||
'claude-2.1',
|
||||
'claude-3-opus-20240229',
|
||||
'claude-3-sonnet-20240229',
|
||||
'claude-3-haiku-20240307'
|
||||
],
|
||||
test_model: 'claude-3-haiku-20240307'
|
||||
},
|
||||
modelGroup: 'Anthropic'
|
||||
},
|
||||
@ -202,6 +209,23 @@ const typeConfig = {
|
||||
test_model: 'llama2-7b-2048'
|
||||
},
|
||||
modelGroup: 'Groq'
|
||||
},
|
||||
32: {
|
||||
input: {
|
||||
models: [
|
||||
'claude-instant-1.2',
|
||||
'claude-2.0',
|
||||
'claude-2.1',
|
||||
'claude-3-opus-20240229',
|
||||
'claude-3-sonnet-20240229',
|
||||
'claude-3-haiku-20240307'
|
||||
],
|
||||
test_model: 'claude-3-haiku-20240307'
|
||||
},
|
||||
prompt: {
|
||||
key: '按照如下格式输入:Region|AccessKeyID|SecretAccessKey|SessionToken 其中SessionToken可不填空'
|
||||
},
|
||||
modelGroup: 'Anthropic'
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user