feat: support amazon bedrock anthropic (#114)

* 🚧 WIP: bedrock

*  feat: support amazon bedrock anthropic
This commit is contained in:
Buer 2024-03-18 16:00:35 +08:00 committed by GitHub
parent 6f76007292
commit b81808e839
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1617 additions and 29 deletions

View File

@ -198,6 +198,7 @@ const (
ChannelTypeMoonshot = 29
ChannelTypeMistral = 30
ChannelTypeGroq = 31
ChannelTypeBedrock = 32
)
var ChannelBaseURLs = []string{
@ -232,7 +233,8 @@ var ChannelBaseURLs = []string{
"https://api.deepseek.com", //28
"https://api.moonshot.cn", //29
"https://api.mistral.ai", //30
"https://api.groq.com/openai", //30
"https://api.groq.com/openai", //31
"", //32
}
const (

View File

@ -89,10 +89,14 @@ func init() {
// $0.80/million tokens $2.40/million tokens
"claude-instant-1.2": {[]float64{0.4, 1.2}, ChannelTypeAnthropic},
// $8.00/million tokens $24.00/million tokens
"claude-2.0": {[]float64{4, 12}, ChannelTypeAnthropic},
"claude-2.1": {[]float64{4, 12}, ChannelTypeAnthropic},
"claude-3-opus-20240229": {[]float64{7.5, 22.5}, ChannelTypeAnthropic},
"claude-2.0": {[]float64{4, 12}, ChannelTypeAnthropic},
"claude-2.1": {[]float64{4, 12}, ChannelTypeAnthropic},
// $15 / M $75 / M
"claude-3-opus-20240229": {[]float64{7.5, 22.5}, ChannelTypeAnthropic},
// $3 / M $15 / M
"claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, ChannelTypeAnthropic},
// $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens
"claude-3-haiku-20240307": {[]float64{0.125, 0.625}, ChannelTypeAnthropic},
// ¥0.004 / 1k tokens ¥0.008 / 1k tokens
"ERNIE-Speed": {[]float64{0.2857, 0.5714}, ChannelTypeBaidu},

9
go.mod
View File

@ -4,6 +4,7 @@ module one-api
go 1.18
require (
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1
github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
@ -24,8 +25,10 @@ require (
gorm.io/gorm v1.25.0
)
require github.com/aws/smithy-go v1.20.1 // indirect
require (
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24 // indirect
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
@ -46,7 +49,7 @@ require (
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/joho/godotenv v1.5.1 // indirect
github.com/joho/godotenv v1.5.1
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
@ -59,7 +62,7 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.19.0 // indirect
golang.org/x/net v0.19.0
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect

14
go.sum
View File

@ -1,5 +1,9 @@
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24 h1:1T7RcpzlldaJ3qpZi0lNg/lBsfPCK+8n8Wc+R8EhAkU=
github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.24/go.mod h1:kL1v4iIjlalwm3gCYGvF4NLa3hs+aKEfRkNJvj4aoDU=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1 h1:gTK2uhtAPtFcdRRJilZPx8uJLL2J85xK11nKtWL0wfU=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1/go.mod h1:sxpLb+nZk7tIfCWChfd+h4QwHNUR57d8hA1cleTkjJo=
github.com/aws/smithy-go v1.20.1 h1:4SZlSlMr36UEqC7XOyRVb27XMeZubNcBNN+9IgEPIQw=
github.com/aws/smithy-go v1.20.1/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@ -49,8 +53,6 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
@ -58,6 +60,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
@ -111,6 +115,7 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -161,8 +166,6 @@ golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -203,14 +206,13 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco=
gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04=
gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k=
gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
gorm.io/driver/mysql v1.4.7 h1:rY46lkCspzGHn7+IYsNpSfEv9tA+SU4SkkB+GFX125Y=
gorm.io/driver/mysql v1.4.7/go.mod h1:SxzItlnT1cb6e1e4ZRpgJN2VYtcqJgqnHxWr4wsP8oc=
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0=
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=

127
providers/bedrock/base.go Normal file
View File

@ -0,0 +1,127 @@
package bedrock
import (
"bytes"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/types"
"strings"
"time"
"one-api/providers/bedrock/category"
"one-api/providers/bedrock/sigv4"
)
type BedrockProviderFactory struct{}
// 创建 BedrockProvider
func (f BedrockProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
bedrockProvider := &BedrockProvider{
BaseProvider: base.BaseProvider{
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
},
}
getKeyConfig(bedrockProvider)
return bedrockProvider
}
type BedrockProvider struct {
base.BaseProvider
Region string
AccessKeyID string
SecretAccessKey string
SessionToken string
Category *category.Category
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://bedrock-runtime.%s.amazonaws.com",
ChatCompletions: "/model/%s/invoke",
}
}
// 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
bedrockError := &BedrockError{}
err := json.NewDecoder(resp.Body).Decode(bedrockError)
if err != nil {
return nil
}
return errorHandle(bedrockError)
}
// 错误处理
func errorHandle(bedrockError *BedrockError) *types.OpenAIError {
if bedrockError.Message == "" {
return nil
}
return &types.OpenAIError{
Message: bedrockError.Message,
Type: "Bedrock Error",
}
}
func (p *BedrockProvider) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
return fmt.Sprintf(baseURL+requestURL, p.Region, modelName)
}
func (p *BedrockProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["Accept"] = "*/*"
return headers
}
func getKeyConfig(bedrock *BedrockProvider) {
keys := strings.Split(bedrock.Channel.Key, "|")
if len(keys) < 3 {
return
}
bedrock.Region = keys[0]
bedrock.AccessKeyID = keys[1]
bedrock.SecretAccessKey = keys[2]
if len(keys) == 4 && keys[3] != "" {
bedrock.SessionToken = keys[3]
}
}
func (p *BedrockProvider) Sign(req *http.Request) error {
var body []byte
if req.Body == nil {
body = []byte("")
} else {
var err error
body, err = io.ReadAll(req.Body)
if err != nil {
return errors.New("error getting request body: " + err.Error())
}
req.Body = io.NopCloser(bytes.NewReader(body))
}
sig, err := sigv4.New(sigv4.WithCredential(p.AccessKeyID, p.SecretAccessKey, p.SessionToken), sigv4.WithRegionService(p.Region, awsService))
if err != nil {
return err
}
reqBodyHashHex := fmt.Sprintf("%x", sha256.Sum256(body))
sig.Sign(req, reqBodyHashHex, sigv4.NewTime(time.Now()))
return nil
}

View File

@ -0,0 +1,54 @@
package category
import (
"errors"
"net/http"
"one-api/common/requester"
"one-api/providers/base"
"one-api/types"
"strings"
)
var CategoryMap = map[string]Category{}
type Category struct {
ModelName string
ChatComplete ChatCompletionConvert
ResponseChatComplete ChatCompletionResponse
ResponseChatCompleteStrem ChatCompletionStreamResponse
}
func GetCategory(modelName string) (*Category, error) {
modelName = GetModelName(modelName)
// 点分割
provider := strings.Split(modelName, ".")[0]
if category, exists := CategoryMap[provider]; exists {
category.ModelName = modelName
return &category, nil
}
return nil, errors.New("category_not_found")
}
func GetModelName(modelName string) string {
bedrockMap := map[string]string{
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-2.1": "anthropic.claude-v2:1",
"claude-2.0": "anthropic.claude-v2",
"claude-instant-1.2": "anthropic.claude-instant-v1",
}
if value, exists := bedrockMap[modelName]; exists {
modelName = value
}
return modelName
}
type ChatCompletionConvert func(*types.ChatCompletionRequest) (any, *types.OpenAIErrorWithStatusCode)
type ChatCompletionResponse func(base.ProviderInterface, *http.Response, *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode)
type ChatCompletionStreamResponse func(base.ProviderInterface, *types.ChatCompletionRequest) requester.HandlerPrefix[string]

View File

@ -0,0 +1,63 @@
package category
import (
"encoding/json"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/providers/base"
"one-api/providers/claude"
"one-api/types"
)
const anthropicVersion = "bedrock-2023-05-31"
type ClaudeRequest struct {
*claude.ClaudeRequest
AnthropicVersion string `json:"anthropic_version"`
}
func init() {
CategoryMap["anthropic"] = Category{
ChatComplete: ConvertClaudeFromChatOpenai,
ResponseChatComplete: ConvertClaudeToChatOpenai,
ResponseChatCompleteStrem: ClaudeChatCompleteStrem,
}
}
func ConvertClaudeFromChatOpenai(request *types.ChatCompletionRequest) (any, *types.OpenAIErrorWithStatusCode) {
rawRequest, err := claude.ConvertFromChatOpenai(request)
if err != nil {
return nil, err
}
claudeRequest := &ClaudeRequest{}
claudeRequest.ClaudeRequest = rawRequest
claudeRequest.AnthropicVersion = anthropicVersion
// 删除model字段
claudeRequest.Model = ""
claudeRequest.Stream = false
return claudeRequest, nil
}
func ConvertClaudeToChatOpenai(provider base.ProviderInterface, response *http.Response, request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
claudeResponse := &claude.ClaudeResponse{}
err := json.NewDecoder(response.Body).Decode(claudeResponse)
if err != nil {
return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
return claude.ConvertToChatOpenai(provider, claudeResponse, request)
}
func ClaudeChatCompleteStrem(provider base.ProviderInterface, request *types.ChatCompletionRequest) requester.HandlerPrefix[string] {
chatHandler := &claude.ClaudeStreamHandler{
Usage: provider.GetUsage(),
Request: request,
Prefix: `{"type"`,
}
return chatHandler.HandlerStream
}

81
providers/bedrock/chat.go Normal file
View File

@ -0,0 +1,81 @@
package bedrock
import (
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/providers/bedrock/category"
"one-api/types"
)
func (p *BedrockProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
// 发送请求
response, errWithCode := p.Send(request)
if errWithCode != nil {
return nil, errWithCode
}
defer response.Body.Close()
return p.Category.ResponseChatComplete(p, response, request)
}
func (p *BedrockProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
// 发送请求
response, errWithCode := p.Send(request)
if errWithCode != nil {
return nil, errWithCode
}
return RequestStream(response, p.Category.ResponseChatCompleteStrem(p, request))
}
func (p *BedrockProvider) Send(request *types.ChatCompletionRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
return p.Requester.SendRequestRaw(req)
}
func (p *BedrockProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
var err error
p.Category, err = category.GetCategory(request.Model)
if err != nil || p.Category.ChatComplete == nil || p.Category.ResponseChatComplete == nil {
return nil, common.StringErrorWrapper("bedrock provider not found", "bedrock_err", http.StatusInternalServerError)
}
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
if request.Stream {
url += "-with-response-stream"
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, p.Category.ModelName)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_claude_config", http.StatusInternalServerError)
}
headers := p.GetRequestHeaders()
bedrockRequest, errWithCode := p.Category.ChatComplete(request)
if errWithCode != nil {
return nil, errWithCode
}
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(bedrockRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
p.Sign(req)
return req, nil
}

View File

@ -0,0 +1,27 @@
Copyright (c) 2023 Macks C. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Macks C. nor the names of its contributors
may be used to endorse or promote products derived from this software
without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,36 @@
package sigv4
const authorizationHeader = "Authorization"
// Signature Version 4 (SigV4) Constants
const (
// SigningAlgorithm is the name of the algorithm used in this package.
SigningAlgorithm = "AWS4-HMAC-SHA256"
// EmptyStringSHA256 is the hex encoded sha256 value of an empty string.
EmptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
// UnsignedPayload indicates that the request payload body is unsigned.
UnsignedPayload = "UNSIGNED-PAYLOAD"
// AmzAlgorithmKey indicates the signing algorithm.
AmzAlgorithmKey = "X-Amz-Algorithm"
// AmzSecurityTokenKey indicates the security token to be used with temporary
// credentials.
AmzSecurityTokenKey = "X-Amz-Security-Token"
// AmzDateKey is the UTC timestamp for the request in the format YYYYMMDD'T'HHMMSS'Z'.
AmzDateKey = "X-Amz-Date"
// AmzCredentialKey is the access key ID and credential scope.
AmzCredentialKey = "X-Amz-Credential"
// AmzSignedHeadersKey is the set of headers signed for the request.
AmzSignedHeadersKey = "X-Amz-SignedHeaders"
// AmzSignatureKey is the query parameter to store the SigV4 signature.
AmzSignatureKey = "X-Amz-Signature"
// TimeFormat is the time format to be used in the X-Amz-Date header or query
// parameter.
TimeFormat = "20060102T150405Z"
// ShortTimeFormat is the shorten time format used in the credential scope.
ShortTimeFormat = "20060102"
// ContentSHAKey is the SHA256 of request body.
ContentSHAKey = "X-Amz-Content-Sha256"
// StreamingEventsPayload indicates that the request payload body is a signed
// event stream.
StreamingEventsPayload = "STREAMING-AWS4-HMAC-SHA256-EVENTS"
)

View File

@ -0,0 +1,91 @@
package sigv4
// ignoredHeaders is a list of headers that are always ignored during signing.
var ignoreHeaders = map[string]struct{}{
"Authorization": {},
"User-Agent": {},
"X-Amzn-Trace-Id": {},
// also include lower case canonical versions
"authorization": {},
"user-agent": {},
"x-amzn-trace-id": {},
}
// requiredHeaderPrefix are header name prefixes that are mandatory for signing.
// If a header has one of these prefixes, it is a mandatory header.
var requiredHeaderPrefix = []string{"X-Amz-Object-Lock-", "X-Amz-Meta-"}
// requiredHeaders is a list of headers that are mandatory for signing.
var requiredHeaders = map[string]struct{}{
"Cache-Control": {},
"Content-Disposition": {},
"Content-Encoding": {},
"Content-Language": {},
"Content-Md5": {},
"Content-Type": {},
"Expires": {},
"If-Match": {},
"If-Modified-Since": {},
"If-None-Match": {},
"If-Unmodified-Since": {},
"Range": {},
"X-Amz-Acl": {},
"X-Amz-Copy-Source": {},
"X-Amz-Copy-Source-If-Match": {},
"X-Amz-Copy-Source-If-Modified-Since": {},
"X-Amz-Copy-Source-If-None-Match": {},
"X-Amz-Copy-Source-If-Unmodified-Since": {},
"X-Amz-Copy-Source-Range": {},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Algorithm": {},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key": {},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key-Md5": {},
"X-Amz-Grant-Full-control": {},
"X-Amz-Grant-Read": {},
"X-Amz-Grant-Read-Acp": {},
"X-Amz-Grant-Write": {},
"X-Amz-Grant-Write-Acp": {},
"X-Amz-Metadata-Directive": {},
"X-Amz-Mfa": {},
"X-Amz-Request-Payer": {},
"X-Amz-Server-Side-Encryption": {},
"X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": {},
"X-Amz-Server-Side-Encryption-Customer-Algorithm": {},
"X-Amz-Server-Side-Encryption-Customer-Key": {},
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": {},
"X-Amz-Storage-Class": {},
"X-Amz-Website-Redirect-Location": {},
"X-Amz-Content-Sha256": {},
"X-Amz-Tagging": {},
}
// headerPredicate is a function that evaluates whether a header is of the
// specific type. For example, whether a header should be ignored during signing.
type headerPredicate func(header string) bool
// isIgnoredHeader returns true if header must be ignored during signing.
func isIgnoredHeader(header string) bool {
_, ok := ignoreHeaders[header]
return ok
}
// isRequiredHeader returns true if header is mandatory for signing.
func isRequiredHeader(header string) bool {
_, ok := requiredHeaders[header]
if ok {
return true
}
for _, v := range requiredHeaderPrefix {
if hasPrefixFold(header, v) {
return true
}
}
return false
}
// isAllowQueryHoisting is a allowed list for Build query headers.
func isAllowQueryHoisting(header string) bool {
if isRequiredHeader(header) {
return false
}
return hasPrefixFold(header, "X-Amz-")
}

View File

@ -0,0 +1,302 @@
package sigv4
import (
"bufio"
"crypto/hmac"
"crypto/sha256"
"fmt"
"net/http"
"net/url"
"sort"
"strings"
"time"
)
var (
awsURLNoEscTable [256]bool
awsURLEscTable [256][2]byte
)
func init() {
for i := 0; i < len(awsURLNoEscTable); i++ {
// every char except these must be escaped
awsURLNoEscTable[i] = (i >= 'A' && i <= 'Z') ||
(i >= 'a' && i <= 'z') ||
(i >= '0' && i <= '9') ||
i == '-' ||
i == '.' ||
i == '_' ||
i == '~'
// %<hex><hex>
encoded := fmt.Sprintf("%02X", i)
awsURLEscTable[i] = [2]byte{encoded[0], encoded[1]}
}
}
// hmacsha256 computes a HMAC-SHA256 of data given the provided key.
func hmacsha256(key, data, buf []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(buf)
}
// hasPrefixFold tests whether the string s begins with prefix, interpreted as
// UTF-8 strings, under Unicode case-folding.
func hasPrefixFold(s, prefix string) bool {
return len(s) >= len(prefix) &&
strings.EqualFold(s[0:len(prefix)], prefix)
}
// isSameDay returns true if a and b are the same date (dd-mm-yyyy).
func isSameDay(a, b time.Time) bool {
xYear, xMonth, xDay := a.Date()
yYear, yMonth, yDay := b.Date()
if xYear != yYear || xMonth != yMonth {
return false
}
return xDay == yDay
}
// hostOrURLHost returns r.Host, or if empty, r.URL.Host.
func hostOrURLHost(r *http.Request) string {
if r.Host != "" {
return r.Host
}
return r.URL.Host
}
// parsePort returns the port part of u.Host, without the leading colon. Returns
// an empty string if u.Host doesn't contain port.
//
// Adapted from the Go 1.8 standard library (net/url).
func parsePort(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 || colon == len(hostport)-1 {
return ""
}
// take care of ipv6 syntax: [a:b::]:<port>
const ipv6Sep = "]:"
if i := strings.Index(hostport, ipv6Sep); i != -1 {
return hostport[i+len(ipv6Sep):]
}
if strings.Contains(hostport, "]") {
return ""
}
return hostport[colon+1:]
}
// stripPort returns Hostname portion of u.Host, i.e. without any port number.
//
// If hostport is an IPv6 literal with a port number, returns the IPv6 literal
// without the square brackets. IPv6 literals may include a zone identifier.
//
// Adapted from the Go 1.8 standard library (net/url).
func stripPort(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return hostport
}
// ipv6: remove the []
if i := strings.IndexByte(hostport, ']'); i != -1 {
return strings.TrimPrefix(hostport[:i], "[")
}
return hostport[:colon]
}
// isDefaultPort returns true if the specified URI is using the standard port
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs).
func isDefaultPort(scheme, port string) bool {
switch strings.ToLower(scheme) {
case "http":
return port == "80"
case "https":
return port == "443"
default:
return false
}
}
func cloneURL(u *url.URL) *url.URL {
if u == nil {
return nil
}
u2 := new(url.URL)
*u2 = *u
if u.User != nil {
u2.User = new(url.Userinfo)
*u2.User = *u.User
}
return u2
}
// writeAWSURIPath writes the escaped URI component from the specified URL (using
// AWS canonical URI specification) into w. URI component is path without query
// string.
func writeAWSURIPath(w *bufio.Writer, u *url.URL, encodeSep bool, isEscaped bool) {
const schemeSep, pathSep, queryStart = "//", "/", "?"
var p string
if u.Opaque == "" {
p = u.EscapedPath()
} else {
opaque := u.Opaque
// discard query string if any
if i := strings.Index(opaque, queryStart); i != -1 {
opaque = opaque[:i]
}
// if has scheme separator as prefix, discard it
if strings.HasPrefix(opaque, schemeSep) {
opaque = opaque[len(schemeSep):]
}
// everything after the first /, including the /
if i := strings.Index(opaque, pathSep); i != -1 {
p = opaque[i:]
}
}
if p == "" {
w.WriteByte('/')
return
}
if isEscaped {
w.WriteString(p)
return
}
// Loop thru first like in https://cs.opensource.google/go/go/+/refs/tags/go1.20.2:/src/net/url/url.go.
// It may add ~800ns but we save on memory alloc and catches cases where there
// is no need to escape.
plen := len(p)
strlen := plen
for i := 0; i < plen; i++ {
c := p[i]
if awsURLNoEscTable[c] || (c == '/' && !encodeSep) {
continue
}
strlen += 2
}
// path already canonical, no need to escape
if plen == strlen {
w.WriteString(p)
return
}
for i := 0; i < plen; i++ {
c := p[i]
if awsURLNoEscTable[c] || (c == '/' && !encodeSep) {
w.WriteByte(c)
continue
}
w.Write([]byte{'%', awsURLEscTable[c][0], awsURLEscTable[c][1]})
}
}
// writeCanonicalQueryParams builds the canonical form of query and writes to w.
//
// Side effect: query values are sorted after this function returns.
func writeCanonicalQueryParams(w *bufio.Writer, query url.Values) {
qlen := len(query)
if qlen == 0 {
return
}
keys := make([]string, 0, qlen)
for k := range query {
keys = append(keys, k)
}
sort.Strings(keys)
for i, k := range keys {
keyEscaped := strings.Replace(url.QueryEscape(k), "+", "%20", -1)
vs := query[k]
if i != 0 {
w.WriteByte('&')
}
if len(vs) == 0 {
w.WriteString(keyEscaped)
w.WriteByte('=')
continue
}
sort.Strings(vs)
for j, v := range vs {
if j != 0 {
w.WriteByte('&')
}
w.WriteString(keyEscaped)
w.WriteByte('=')
if v != "" {
w.WriteString(strings.Replace(url.QueryEscape(v), "+", "%20", -1))
}
}
}
}
// writeCanonicalString removes leading and trailing whitespaces (as defined by Unicode)
// in s, replaces consecutive spaces (' ') in s with a single space, and then
// write the result to w.
func writeCanonicalString(w *bufio.Writer, s string) {
const dblSpace = " "
s = strings.TrimSpace(s)
// bail if str doesn't contain " "
j := strings.Index(s, dblSpace)
if j < 0 {
w.WriteString(s)
return
}
w.WriteString(s[:j])
// replace all " " with " " in a performant way
var lastIsSpace bool
for i, l := j, len(s); i < l; i++ {
if s[i] == ' ' {
if !lastIsSpace {
w.WriteByte(' ')
lastIsSpace = true
}
continue
}
lastIsSpace = false
w.WriteByte(s[i])
}
}
type debugHasher struct {
buf []byte
}
func (dh *debugHasher) Write(b []byte) (int, error) {
dh.buf = append(dh.buf, b...)
return len(b), nil
}
func (dh *debugHasher) Sum(b []byte) []byte {
return nil
}
func (dh *debugHasher) Reset() {
// do nothing
}
func (dh *debugHasher) Size() int {
return 0
}
func (dh *debugHasher) BlockSize() int {
return sha256.BlockSize
}
func (dh *debugHasher) Println() {
fmt.Printf("---%s---\n", dh.buf)
}

View File

@ -0,0 +1,125 @@
package sigv4
import (
"crypto/sha256"
"strings"
"sync"
"time"
gotime "time"
)
var credScopeSuffixBytes = []byte{'a', 'w', 's', '4', '_', 'r', 'e', 'q', 'u', 'e', 's', 't'}
// deriveKey calculates the signing key. See https://docs.aws.amazon.com/general/latest/gr/create-signed-request.html.
func deriveKey(secret, service, region string, t Time) []byte {
// enc(
// enc(
// enc(
// enc(AWS4<secret>, <short_time>),
// <region>),
// <service>),
// "aws4_request")
// https://en.wikipedia.org/wiki/HMAC
// HMAC_SHA256 produces 32 bytes output
f1 := len(secret) + 4
f2 := f1 + len(t.ShortTimeFormat())
f3 := f2 + len(region)
f4 := f3 + len(service)
qs := make([]byte, 0, f4)
qs = append(qs, "AWS4"...)
qs = append(qs, secret...)
qs = append(qs, t.ShortTimeFormat()...)
qs = append(qs, region...)
qs = append(qs, service...)
buf := make([]byte, 0, sha256.BlockSize)
buf = hmacsha256(qs[:f1], qs[f1:f2], buf)
buf = hmacsha256(buf, qs[f2:f3], buf[:0])
buf = hmacsha256(buf, qs[f3:], buf[:0])
return hmacsha256(buf, credScopeSuffixBytes, buf[:0])
}
// keyDeriver returns a signing key based on parameters such as credentials.
type keyDeriver interface {
DeriveKey(accessKey, secret, service, region string, sigtime Time) []byte
}
// signingKeyDeriver is the default implementation of keyDerivator.
type signingKeyDeriver struct {
cache derivedKeyCache
}
// newKeyDeriver creates a keyDeriver using the default implementation. The
// signing key is cached per region/service, and updated when accessKey changes
// or signingTime is not on the same day for that region/service.
func newKeyDeriver() keyDeriver {
return &signingKeyDeriver{cache: newDerivedKeyCache()}
}
// DeriveKey returns a derived signing key from the given credentials to be used
// with SigV4 signing.
func (k *signingKeyDeriver) DeriveKey(accessKey, secret, service, region string, sigtime Time) []byte {
return k.cache.Get(accessKey, secret, service, region, sigtime)
}
type derivedKeyCache struct {
mutex sync.RWMutex
values map[string]derivedKey
nowFunc func() gotime.Time
}
type derivedKey struct {
Date gotime.Time
Credential []byte
}
func newDerivedKeyCache() derivedKeyCache {
return derivedKeyCache{
values: make(map[string]derivedKey),
nowFunc: gotime.Now,
}
}
// Get returns key from cache or creates a new one.
func (s *derivedKeyCache) Get(accessKey, secret, service, region string, sigtime Time) []byte {
// <accessKey>/<YYYYMMDD>/<region>/<service>
key := strings.Join([]string{accessKey, sigtime.ShortTimeFormat(), region, service}, "/")
s.mutex.RLock()
cred, status := s.getFromCache(key)
s.mutex.RUnlock()
if status == 0 {
return cred
}
cred = deriveKey(secret, service, region, sigtime)
s.mutex.Lock()
if status == -1 {
delete(s.values, key)
}
s.values[key] = derivedKey{
Date: sigtime.Time,
Credential: cred,
}
s.mutex.Unlock()
return cred
}
// getFromCache returns s.values[key]. Second result is 1 if key was not found,
// or -1 if the cached value has expired.
func (s *derivedKeyCache) getFromCache(key string) ([]byte, int) {
v, ok := s.values[key]
if !ok {
return nil, 1
}
// evict from cache if item is a day older than system time
if s.nowFunc().Sub(v.Date) > 24*time.Hour {
return nil, -1
}
return v.Credential, 0
}

View File

@ -0,0 +1,36 @@
package sigv4
import (
gotime "time"
)
// Time wraps time.Time to cache its string format result.
type Time struct {
gotime.Time
short string
long string
}
// NewTime creates a new signingTime with the specified time.Time.
func NewTime(t gotime.Time) Time {
return Time{Time: t.UTC()}
}
// TimeFormat provides a time formatted in the X-Amz-Date format.
func (m *Time) TimeFormat() string {
return m.readOrFormat(&m.long, TimeFormat)
}
// ShortTimeFormat provides a time formatted in short time format.
func (m *Time) ShortTimeFormat() string {
return m.readOrFormat(&m.short, ShortTimeFormat)
}
func (m *Time) readOrFormat(target *string, format string) string {
if len(*target) > 0 {
return *target
}
v := m.Time.Format(format)
*target = v
return v
}

View File

@ -0,0 +1,428 @@
package sigv4
import (
"bufio"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"errors"
"hash"
"net/http"
"net/textproto"
"net/url"
"sort"
"strconv"
"strings"
)
// HTTPSigner is an AWS SigV4 signer that can sign HTTP requests.
type HTTPSigner interface {
// Sign AWS v4 requests with the provided payload hash, service name, region
// the request is made to, and time the request is signed at. Set sigtime
// to the future to create a request that cannot be used until the future time.
//
// payloadHash is the hex encoded SHA-256 hash of the request payload, and must
// not be empty, even if the request has no payload (aka body). If the request
// has no payload, use the hex encoded SHA-256 of an empty string, or the constant
// EmptyStringSHA256. You can use the utility function ContentSHA256Sum to
// calculate the hash of a http.Request body.
//
// Some services such as Amazon S3 accept alternative values for the payload
// hash, such as "UNSIGNED-PAYLOAD" for requests where the body will not be
// protected by sigv4. See https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html.
//
// Sign differs from Presign in that it will sign the request using HTTP headers.
// The passed in request r will be modified in place: modified fields include
// r.Host and r.Header.
Sign(r *http.Request, payloadHash string, sigtime Time) error
// Presign is like Sign, but does not modify request r. It returns a copy of
// r.URL with additional query parameters that contains signing information.
// The URL can be used to recreate an authenticated request without specifying
// headers. It also returns http.Header as a second result, which must be
// included in the reconstructed request.
//
// Header hoisting: use WithHeaderHoisting option function to specify whether
// headers in request r should be added as query parameters. Some headers cannot
// be hoisted, and are returned as the second result.
//
// Presign will not set the expires time of the presigned request automatically.
// To specify the expire duration for a request, add the "X-Amz-Expires" query
// parameter on the request with the value as the duration in seconds the
// presigned URL should be considered valid for. This parameter is not used
// by all AWS services, and is most notable used by Amazon S3 APIs.
//
// expires := 20*time.Minute
// query := req.URL.Query()
// query.Set("X-Amz-Expires", strconv.FormatInt(int64(expires/time.Second), 10)
// req.URL.RawQuery = query.Encode()
Presign(r *http.Request, payloadHash string, sigtime Time) (*url.URL, http.Header, error)
}
// HTTPSignerOption is an option parameter for HTTPSigner constructor function.
type HTTPSignerOption func(HTTPSigner) error
// ErrInvalidOption means the option parameter is incompatible with the HTTPSigner.
var ErrInvalidOption = errors.New("cannot apply option to HTTPSigner")
// httpV4Signer is the default implementation of HTTPSigner.
type httpV4Signer struct {
KeyDeriver keyDeriver
AccessKey string
Secret string
SessionToken string
Service string
Region string
HeaderHoisting bool
EscapeURLPath bool
}
// WithCredential sets HTTPSigner credential fields.
func WithCredential(accessKey, secret, sessionToken string) HTTPSignerOption {
return func(signer HTTPSigner) error {
if sigv4, ok := signer.(*httpV4Signer); ok {
sigv4.AccessKey = accessKey
sigv4.Secret = secret
sigv4.SessionToken = sessionToken
return nil
}
return ErrInvalidOption
}
}
// WithHeaderHoisting specifies whether HTTPSigner automatically hoist headers.
// Default is enabled.
func WithHeaderHoisting(enable bool) HTTPSignerOption {
return func(signer HTTPSigner) error {
if sigv4, ok := signer.(*httpV4Signer); ok {
sigv4.HeaderHoisting = enable
return nil
}
return ErrInvalidOption
}
}
// WithEscapeURLPath specifies whether HTTPSigner automatically escapes URL paths.
// Default is enabled.
func WithEscapeURLPath(enable bool) HTTPSignerOption {
return func(signer HTTPSigner) error {
if sigv4, ok := signer.(*httpV4Signer); ok {
sigv4.EscapeURLPath = enable
return nil
}
return ErrInvalidOption
}
}
// WithRegionService sets HTTPSigner region and service fields.
func WithRegionService(region, service string) HTTPSignerOption {
return func(signer HTTPSigner) error {
if sigv4, ok := signer.(*httpV4Signer); ok {
sigv4.Region = region
sigv4.Service = service
return nil
}
return ErrInvalidOption
}
}
// New creates a HTTPSigner.
func New(opts ...HTTPSignerOption) (HTTPSigner, error) {
sigv4 := &httpV4Signer{
KeyDeriver: newKeyDeriver(),
EscapeURLPath: true,
HeaderHoisting: true,
}
for _, o := range opts {
if o == nil {
continue
}
if err := o(sigv4); err != nil {
return nil, err
}
}
return sigv4, nil
}
// Sign implements HTTPSigner.
func (s *httpV4Signer) Sign(r *http.Request, payloadHash string, sigtime Time) error {
if payloadHash == "" {
var err error
payloadHash, err = ContentSHA256Sum(r)
if err != nil {
return err
}
}
// add mandatory headers to r.Header
setRequiredSigningHeaders(r.Header, sigtime, s.SessionToken)
// remove port in r.Host if any
r.Host = sanitizeHostForHeader(r)
// parse URL query only once
query := r.URL.Query()
// sigBuf is used to act as a sha256 hash buffer
sigBuf := make([]byte, 0, sha256.Size)
//hasher := &debugHasher{}
hasher := sha256.New()
reqhash, signedHeaderStr := canonicalRequestHash(hasher, r, r.Header, query,
r.Host, payloadHash, s.EscapeURLPath, false, sigBuf)
credentialScope := strings.Join([]string{
sigtime.ShortTimeFormat(),
s.Region,
s.Service,
"aws4_request",
}, "/")
keyBytes := s.KeyDeriver.DeriveKey(s.AccessKey, s.Secret, s.Service,
s.Region, sigtime)
sigHasher := hmac.New(sha256.New, keyBytes)
signature := authorizationSignature(sigHasher, sigtime, credentialScope, reqhash, sigBuf)
writeAuthorizationHeader(r.Header, s.AccessKey+"/"+credentialScope,
signedHeaderStr, signature)
// done
return nil
}
// Presign implements HTTPSigner.
func (s *httpV4Signer) Presign(r *http.Request, payloadHash string, sigtime Time) (*url.URL, http.Header, error) {
if payloadHash == "" {
var err error
payloadHash, err = ContentSHA256Sum(r)
if err != nil {
return nil, nil, err
}
}
query := r.URL.Query()
setRequiredSigningQuery(query, sigtime, s.SessionToken)
// sort each query key's values
for key := range query {
sort.Strings(query[key])
}
credentialScope := strings.Join([]string{
sigtime.ShortTimeFormat(),
s.Region,
s.Service,
"aws4_request",
}, "/")
credentialStr := s.AccessKey + "/" + credentialScope
query.Set(AmzCredentialKey, credentialStr)
var headersLeft http.Header
if s.HeaderHoisting {
headersLeft = make(http.Header, len(r.Header))
for k, v := range r.Header {
if isAllowQueryHoisting(k) {
query[k] = v
} else {
headersLeft[k] = v
}
}
}
// sigBuf is used to act as a sha256 hash buffer
sigBuf := make([]byte, 0, sha256.Size)
hasher := sha256.New()
reqhash, signedHeaderStr := canonicalRequestHash(hasher, r, headersLeft,
query, sanitizeHostForHeader(r), payloadHash, s.EscapeURLPath, true, sigBuf)
keyBytes := s.KeyDeriver.DeriveKey(s.AccessKey, s.Secret, s.Service,
s.Region, sigtime)
sigHasher := hmac.New(sha256.New, keyBytes)
signature := authorizationSignature(sigHasher, sigtime, credentialScope, reqhash, sigBuf)
query.Set(AmzSignatureKey, signature)
u := cloneURL(r.URL)
u.RawQuery = strings.Replace(query.Encode(), "+", "%20", -1)
// For the signed headers we canonicalize the header keys in the returned map.
// This avoids situations where standard library can sometimes add double
// headers. For example, the standard library will set the Host header,
// even if it is present in lower-case form.
signedHeader := strings.Split(signedHeaderStr, ";")
canonHeader := make(http.Header, len(signedHeader))
for _, k := range signedHeader {
canonKey := textproto.CanonicalMIMEHeaderKey(k)
switch k {
case "host":
canonHeader[canonKey] = []string{sanitizeHostForHeader(r)}
case "content-length":
canonHeader[canonKey] = []string{strconv.FormatInt(r.ContentLength, 10)}
default:
canonHeader[canonKey] = append(canonHeader[canonKey], headersLeft[http.CanonicalHeaderKey(k)]...)
}
}
return u, canonHeader, nil
}
// authorizationSignature returns `sig` as documented in step 4 of algorithm
// documentation. key is hSig in step 4. It calculates the result of step 3
// internally.
func authorizationSignature(hasher hash.Hash, sigtime Time, credScope, requestHash string, buf []byte) string {
w := bufio.NewWriterSize(hasher, sha256.BlockSize)
w.WriteString(SigningAlgorithm)
w.WriteByte('\n')
w.WriteString(sigtime.TimeFormat())
w.WriteByte('\n')
w.WriteString(credScope)
w.WriteByte('\n')
w.WriteString(requestHash)
w.Flush() // VERY IMPORTANT! Don't forget to flush remaining buffer
//hasher.Println()
return hex.EncodeToString(hasher.Sum(buf[:0]))
}
// canonicalRequestHash returns the hex-encoded sha256 sum of the canonical
// request string. Refer to step 2 of algorithm documentation. Expect hasher to
// be sha256.New.
func canonicalRequestHash(
hasher hash.Hash, r *http.Request, headers http.Header, query url.Values,
hostname, hashcode string, escapeURL, isPresign bool, buf []byte,
) (string, string) {
w := bufio.NewWriterSize(hasher, sha256.BlockSize)
signedHeaders := make([]string, 0, len(headers)+2)
signedHeaders = append(signedHeaders, "host")
if r.ContentLength > 0 {
signedHeaders = append(signedHeaders, "content-length")
}
for k := range headers {
if strings.EqualFold(k, "content-length") || strings.EqualFold(k, "host") || isIgnoredHeader(k) {
continue
}
signedHeaders = append(signedHeaders, strings.ToLower(k))
}
sort.Strings(signedHeaders)
signedHeaderStr := strings.Join(signedHeaders, ";")
// for presigned requests, we need to add X-Amz-SignedHeaders to calculate the
// correct hash
if isPresign {
query.Set(AmzSignedHeadersKey, signedHeaderStr)
}
// <METHOD>\n<URI>\n<QUERY>\n<HEADERS>\n<SIGNED_HEADERS>\n<PAYLOAD_HASH>
// HTTP_METHOD
w.WriteString(r.Method)
w.WriteByte('\n')
// CANONICAL_URI
writeAWSURIPath(w, r.URL, false, !escapeURL)
w.WriteByte('\n')
// CANONICAL_QUERY_PARAMS
writeCanonicalQueryParams(w, query)
w.WriteByte('\n')
// CANONICAL_HEADERS
for _, head := range signedHeaders {
switch head {
case "host":
w.WriteString(head)
w.WriteByte(':')
writeCanonicalString(w, hostname)
w.WriteByte('\n')
case "content-length":
w.WriteString(head)
w.WriteByte(':')
w.WriteString(strconv.FormatInt(r.ContentLength, 10))
w.WriteByte('\n')
default:
w.WriteString(head)
w.WriteByte(':')
values := headers[http.CanonicalHeaderKey(head)]
for i, v := range values {
if i != 0 {
w.WriteByte(',')
}
writeCanonicalString(w, v)
}
w.WriteByte('\n')
}
}
w.WriteByte('\n')
// SIGNED_HEADERS
w.WriteString(signedHeaderStr)
w.WriteByte('\n')
// PAYLOAD_HASH
w.WriteString(hashcode)
w.Flush() // VERY IMPORTANT! Don't forget to flush remaining buffer
//hasher.Println()
return hex.EncodeToString(hasher.Sum(buf[:0])), signedHeaderStr
}
// writeAuthorizationHeader writes the Authorization header into header:
//
// AWS4-HMAC-SHA256 Credential=<cred>, SignedHeaders=<signed_headers>, Signature=<sig>
func writeAuthorizationHeader(headers http.Header, credentialStr, signedHeaders, signature string) {
const credentialPrefix = "Credential="
const signedHeadersPrefix = "SignedHeaders="
const signaturePrefix = "Signature="
const commaSpace = ", "
var parts strings.Builder
parts.Grow(len(SigningAlgorithm) + 1 +
len(credentialPrefix) + len(credentialStr) + 2 +
len(signedHeadersPrefix) + len(signedHeaders) + 2 +
len(signaturePrefix) + len(signature))
parts.WriteString(SigningAlgorithm)
parts.WriteRune(' ')
parts.WriteString(credentialPrefix)
parts.WriteString(credentialStr)
parts.WriteString(commaSpace)
parts.WriteString(signedHeadersPrefix)
parts.WriteString(signedHeaders)
parts.WriteString(commaSpace)
parts.WriteString(signaturePrefix)
parts.WriteString(signature)
headers[authorizationHeader] = append(headers[authorizationHeader][:0],
parts.String())
}
// helpers
// sanitizeHostForHeader is like hostOrURLHost, but without port if port is the
// default port for the scheme. For example, it removes ":80" suffix if the scheme
// is "http".
func sanitizeHostForHeader(r *http.Request) string {
host := hostOrURLHost(r)
port := parsePort(host)
if port != "" && isDefaultPort(r.URL.Scheme, port) {
return stripPort(host)
}
return host
}
// setRequiredSigningHeaders modifies headers: sets X-Amz-Date to sigtime, and
// if credToken is non-empty, set X-Amz-Security-Token to credToken. This function
// overwrites existing headers values with the same key.
func setRequiredSigningHeaders(headers http.Header, sigtime Time, sessionToken string) {
amzDate := sigtime.TimeFormat()
headers[AmzDateKey] = append(headers[AmzDateKey][:0], amzDate)
if sessionToken != "" {
headers[AmzSecurityTokenKey] = append(headers[AmzSecurityTokenKey][:0],
sessionToken)
}
}
// setRequiredSigningQuery is like setRequiredSigningHeaders, but modifies
// query values. This is used for presign requests.
func setRequiredSigningQuery(query url.Values, sigtime Time, sessionToken string) {
query.Set(AmzAlgorithmKey, SigningAlgorithm)
amzDate := sigtime.TimeFormat()
query.Set(AmzDateKey, amzDate)
if sessionToken != "" {
query.Set(AmzSecurityTokenKey, sessionToken)
}
}

View File

@ -0,0 +1,26 @@
package sigv4
import (
"crypto/sha256"
"encoding/hex"
"io"
"net/http"
)
// ContentSHA256Sum calculates the hex-encoded SHA256 checksum of r.Body. It returns
// EmptyStringSHA256 if r.Body is nil, r.Method is TRACE or r.ContentLength is zero.
// Returns non-nil error if r.Body cannot be read.
func ContentSHA256Sum(r *http.Request) (string, error) {
// We need to check r.Body is non-nil, because io.Copy(dst, nil) panics.
// This is not documented in https://pkg.go.dev/io#Copy.
if r.Method == http.MethodTrace || r.ContentLength == 0 || r.Body == nil {
return EmptyStringSHA256, nil
}
h := sha256.New()
_, err := io.Copy(h, r.Body)
if err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}

View File

@ -0,0 +1,128 @@
package bedrock
import (
"bufio"
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"strings"
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi"
"github.com/aws/smithy-go"
)
type streamReader[T any] struct {
reader *bufio.Reader
response *http.Response
handlerPrefix requester.HandlerPrefix[T]
DataChan chan T
ErrChan chan error
}
func (stream *streamReader[T]) Recv() (<-chan T, <-chan error) {
go stream.processLines()
return stream.DataChan, stream.ErrChan
}
//nolint:gocognit
func (stream *streamReader[T]) processLines() {
decode := eventstream.NewDecoder()
payloadBuf := make([]byte, 0*1024)
for {
payloadBuf = payloadBuf[0:0]
messgae, readErr := decode.Decode(stream.reader, payloadBuf)
if readErr != nil {
stream.ErrChan <- readErr
return
}
line, err := stream.deserializeEventMessage(&messgae)
if err != nil {
stream.ErrChan <- common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
return
}
stream.handlerPrefix(&line, stream.DataChan, stream.ErrChan)
if line == nil {
continue
}
if bytes.Equal(line, requester.StreamClosed) {
return
}
}
}
func (stream *streamReader[T]) Close() {
stream.response.Body.Close()
}
func (stream *streamReader[T]) deserializeEventMessage(msg *eventstream.Message) ([]byte, error) {
messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader)
if messageType == nil {
return nil, fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader)
}
switch messageType.String() {
case eventstreamapi.EventMessageType:
var v BedrockResponseStream
if err := json.Unmarshal(msg.Payload, &v); err != nil {
return nil, err
}
buffer, err := base64.StdEncoding.DecodeString(v.Bytes)
if err != nil {
return nil, err
}
return buffer, nil
case eventstreamapi.ExceptionMessageType:
exceptionType := msg.Headers.Get(eventstreamapi.ExceptionTypeHeader)
return nil, errors.New("Exception message :" + exceptionType.String())
case eventstreamapi.ErrorMessageType:
errorCode := "UnknownError"
errorMessage := errorCode
if header := msg.Headers.Get(eventstreamapi.ErrorCodeHeader); header != nil {
errorCode = header.String()
}
if header := msg.Headers.Get(eventstreamapi.ErrorMessageHeader); header != nil {
errorMessage = header.String()
}
return nil, &smithy.GenericAPIError{
Code: errorCode,
Message: errorMessage,
}
default:
return nil, errors.New("bedrock stream unknown error")
}
}
func RequestStream[T any](resp *http.Response, handlerPrefix requester.HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) {
// 如果返回的头是json格式 说明有错误
if strings.Contains(resp.Header.Get("Content-Type"), "application/json") {
return nil, requester.HandleErrorResp(resp, requestErrorHandle)
}
stream := &streamReader[T]{
reader: bufio.NewReader(resp.Body),
response: resp,
handlerPrefix: handlerPrefix,
DataChan: make(chan T),
ErrChan: make(chan error),
}
return stream, nil
}

11
providers/bedrock/type.go Normal file
View File

@ -0,0 +1,11 @@
package bedrock
const awsService = "bedrock"
type BedrockError struct {
Message string `json:"message"`
}
type BedrockResponseStream struct {
Bytes string `json:"bytes"`
}

View File

@ -8,13 +8,17 @@ import (
"one-api/common"
"one-api/common/image"
"one-api/common/requester"
"one-api/providers/base"
"one-api/types"
"strings"
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
)
type claudeStreamHandler struct {
type ClaudeStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
Prefix string
}
func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
@ -31,7 +35,7 @@ func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionReque
return nil, errWithCode
}
return p.convertToChatOpenai(claudeResponse, request)
return ConvertToChatOpenai(p, claudeResponse, request)
}
func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
@ -47,12 +51,15 @@ func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletio
return nil, errWithCode
}
chatHandler := &claudeStreamHandler{
chatHandler := &ClaudeStreamHandler{
Usage: p.Usage,
Request: request,
Prefix: `data: {"type"`,
}
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
eventstream.NewDecoder()
return requester.RequestStream(p.Requester, resp, chatHandler.HandlerStream)
}
func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -72,7 +79,7 @@ func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*
headers["Accept"] = "text/event-stream"
}
claudeRequest, errWithCode := convertFromChatOpenai(request)
claudeRequest, errWithCode := ConvertFromChatOpenai(request)
if errWithCode != nil {
return nil, errWithCode
}
@ -86,7 +93,7 @@ func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*
return req, nil
}
func convertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest, *types.OpenAIErrorWithStatusCode) {
func ConvertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest, *types.OpenAIErrorWithStatusCode) {
claudeRequest := ClaudeRequest{
Model: request.Model,
Messages: []Message{},
@ -142,7 +149,7 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest
return &claudeRequest, nil
}
func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
func ConvertToChatOpenai(provider base.ProviderInterface, response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
error := errorHandle(&response.Error)
if error != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
@ -182,21 +189,24 @@ func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *
openaiResponse.Usage.CompletionTokens = completionTokens
openaiResponse.Usage.TotalTokens = promptTokens + completionTokens
*p.Usage = *openaiResponse.Usage
usage := provider.GetUsage()
*usage = *openaiResponse.Usage
return openaiResponse, nil
}
// 转换为OpenAI聊天流式请求体
func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
func (h *ClaudeStreamHandler) HandlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), `data: {"type"`) {
if !strings.HasPrefix(string(*rawLine), h.Prefix) {
*rawLine = nil
return
}
// 去除前缀
*rawLine = (*rawLine)[6:]
if strings.HasPrefix(string(*rawLine), "data: ") {
// 去除前缀
*rawLine = (*rawLine)[6:]
}
var claudeResponse ClaudeStreamResponse
err := json.Unmarshal(*rawLine, &claudeResponse)
@ -235,7 +245,7 @@ func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin
}
}
func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStreamResponse, dataChan chan string) {
func (h *ClaudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStreamResponse, dataChan chan string) {
choice := types.ChatCompletionStreamChoice{
Index: claudeResponse.Index,
}

View File

@ -32,7 +32,7 @@ type Message struct {
}
type ClaudeRequest struct {
Model string `json:"model"`
Model string `json:"model,omitempty"`
System string `json:"system,omitempty"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`

View File

@ -13,6 +13,7 @@ import (
"one-api/providers/baichuan"
"one-api/providers/baidu"
"one-api/providers/base"
"one-api/providers/bedrock"
"one-api/providers/claude"
"one-api/providers/closeai"
"one-api/providers/deepseek"
@ -62,6 +63,7 @@ func init() {
providerFactories[common.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{}
providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{}
providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{}
providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
}

View File

@ -101,6 +101,12 @@ export const CHANNEL_OPTIONS = {
value: 31,
color: 'primary'
},
32: {
key: 32,
text: 'Amazon Bedrock',
value: 32,
color: 'orange'
},
24: {
key: 24,
text: 'Azure Speech',

View File

@ -60,8 +60,15 @@ const typeConfig = {
},
14: {
input: {
models: ['claude-instant-1.2', 'claude-2.0', 'claude-2.1', 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'],
test_model: 'claude-3-sonnet-20240229'
models: [
'claude-instant-1.2',
'claude-2.0',
'claude-2.1',
'claude-3-opus-20240229',
'claude-3-sonnet-20240229',
'claude-3-haiku-20240307'
],
test_model: 'claude-3-haiku-20240307'
},
modelGroup: 'Anthropic'
},
@ -202,6 +209,23 @@ const typeConfig = {
test_model: 'llama2-7b-2048'
},
modelGroup: 'Groq'
},
32: {
input: {
models: [
'claude-instant-1.2',
'claude-2.0',
'claude-2.1',
'claude-3-opus-20240229',
'claude-3-sonnet-20240229',
'claude-3-haiku-20240307'
],
test_model: 'claude-3-haiku-20240307'
},
prompt: {
key: '按照如下格式输入Region|AccessKeyID|SecretAccessKey|SessionToken 其中SessionToken可不填空'
},
modelGroup: 'Anthropic'
}
};