diff --git a/common/constants.go b/common/constants.go index cdcec59b..87221b61 100644 --- a/common/constants.go +++ b/common/constants.go @@ -4,12 +4,3 @@ import "time" var StartTime = time.Now().Unix() // unit: second var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change - -var ( - // CtxKeyChannel is the key to store the channel in the context - CtxKeyChannel string = "channel_docu" - CtxKeyRequestModel string = "request_model" - CtxKeyRawRequest string = "raw_request" - CtxKeyConvertedRequest string = "converted_request" - CtxKeyOriginModel string = "origin_model" -) diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go new file mode 100644 index 00000000..f7b702f6 --- /dev/null +++ b/common/ctxkey/key.go @@ -0,0 +1,8 @@ +package ctxkey + +var ( + Channel = "channel" + RequestModel = "request_model" + ConvertedRequest = "converted_request" + OriginModel = "origin_model" +) diff --git a/controller/relay.go b/controller/relay.go index 035e55ad..a5f77eed 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -7,6 +7,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/middleware" @@ -54,7 +55,7 @@ func Relay(c *gin.Context) { lastFailedChannelId := channelId channelName := c.GetString("channel_name") group := c.GetString("group") - originalModel := c.GetString(common.CtxKeyOriginModel) + originalModel := c.GetString(ctxkey.OriginModel) go processChannelRelayError(ctx, channelId, channelName, bizErr) requestId := c.GetString(logger.RequestIdKey) retryTimes := config.RetryTimes diff --git a/middleware/distributor.go b/middleware/distributor.go index f55e3947..f3a05f5d 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -3,8 +3,8 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channeltype" @@ -59,12 +59,12 @@ func Distribute() func(c *gin.Context) { } func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { - c.Set(common.CtxKeyChannel, channel) + c.Set(ctxkey.Channel, channel) c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("model_mapping", channel.GetModelMapping()) - c.Set(common.CtxKeyOriginModel, modelName) // for retry + c.Set(ctxkey.OriginModel, modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) // this is for backward compatibility diff --git a/relay/adaptor/aws/adapter.go b/relay/adaptor/aws/adapter.go index fec8df1d..7f064efe 100644 --- a/relay/adaptor/aws/adapter.go +++ b/relay/adaptor/aws/adapter.go @@ -1,12 +1,12 @@ package aws import ( + "github.com/songquanpeng/one-api/common/ctxkey" "io" "net/http" "github.com/gin-gonic/gin" "github.com/pkg/errors" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/anthropic" "github.com/songquanpeng/one-api/relay/meta" @@ -36,9 +36,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } claudeReq := anthropic.ConvertRequest(*request) - c.Set(common.CtxKeyRequestModel, request.Model) - c.Set(common.CtxKeyRawRequest, request) - c.Set(common.CtxKeyConvertedRequest, claudeReq) + c.Set(ctxkey.RequestModel, request.Model) + c.Set(ctxkey.ConvertedRequest, claudeReq) return claudeReq, nil } diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go index e4acfa60..8d8345ef 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/main.go @@ -5,6 +5,7 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/ctxkey" "io" "net/http" "strings" @@ -68,10 +69,10 @@ func awsModelID(requestModel string) (string, error) { func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { var channel *model.Channel - if channeli, ok := c.Get(common.CtxKeyChannel); !ok { + if channel_, ok := c.Get(ctxkey.Channel); !ok { return wrapErr(errors.New("channel not found")), nil } else { - channel = channeli.(*model.Channel) + channel = channel_.(*model.Channel) } awsCli, err := newAwsClient(channel) @@ -79,7 +80,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st return wrapErr(errors.Wrap(err, "newAwsClient")), nil } - awsModelId, err := awsModelID(c.GetString(common.CtxKeyRequestModel)) + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return wrapErr(errors.Wrap(err, "awsModelID")), nil } @@ -90,11 +91,11 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st ContentType: aws.String("application/json"), } - claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest) + claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) if !ok { return wrapErr(errors.New("request not found")), nil } - claudeReq := claudeReqi.(*anthropic.Request) + claudeReq := claudeReq_.(*anthropic.Request) awsClaudeReq := &Request{ AnthropicVersion: "bedrock-2023-05-31", } @@ -135,10 +136,10 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt createdTime := helper.GetTimestamp() var channel *model.Channel - if channeli, ok := c.Get(common.CtxKeyChannel); !ok { + if channel_, ok := c.Get(ctxkey.Channel); !ok { return wrapErr(errors.New("channel not found")), nil } else { - channel = channeli.(*model.Channel) + channel = channel_.(*model.Channel) } awsCli, err := newAwsClient(channel) @@ -146,7 +147,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt return wrapErr(errors.Wrap(err, "newAwsClient")), nil } - awsModelId, err := awsModelID(c.GetString(common.CtxKeyRequestModel)) + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return wrapErr(errors.Wrap(err, "awsModelID")), nil } @@ -157,11 +158,11 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt ContentType: aws.String("application/json"), } - claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest) + claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) if !ok { return wrapErr(errors.New("request not found")), nil } - claudeReq := claudeReqi.(*anthropic.Request) + claudeReq := claudeReq_.(*anthropic.Request) awsClaudeReq := &Request{ AnthropicVersion: "bedrock-2023-05-31", @@ -211,7 +212,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt return true } response.Id = id - response.Model = c.GetString(common.CtxKeyOriginModel) + response.Model = c.GetString(ctxkey.OriginModel) response.Created = createdTime jsonStr, err := json.Marshal(response) if err != nil {