diff --git a/common/constants.go b/common/constants.go index e4466a57..cdcec59b 100644 --- a/common/constants.go +++ b/common/constants.go @@ -11,4 +11,5 @@ var ( CtxKeyRequestModel string = "request_model" CtxKeyRawRequest string = "raw_request" CtxKeyConvertedRequest string = "converted_request" + CtxKeyOriginModel string = "origin_model" ) diff --git a/controller/relay.go b/controller/relay.go index 3c2d4340..035e55ad 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -54,7 +54,7 @@ func Relay(c *gin.Context) { lastFailedChannelId := channelId channelName := c.GetString("channel_name") group := c.GetString("group") - originalModel := c.GetString("original_model") + originalModel := c.GetString(common.CtxKeyOriginModel) go processChannelRelayError(ctx, channelId, channelName, bizErr) requestId := c.GetString(logger.RequestIdKey) retryTimes := config.RetryTimes diff --git a/go.sum b/go.sum index 27a205b7..038ea249 100644 --- a/go.sum +++ b/go.sum @@ -129,7 +129,6 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -229,14 +228,9 @@ gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k= gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= -gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= -gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E= gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE= gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= -gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= -gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= -gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= gorm.io/gorm v1.25.9 h1:wct0gxZIELDk8+ZqF/MVnHLkA1rvYlBWUMv2EdsK1g8= gorm.io/gorm v1.25.9/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= diff --git a/middleware/distributor.go b/middleware/distributor.go index f6c106c4..f55e3947 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -64,7 +64,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("model_mapping", channel.GetModelMapping()) - c.Set("original_model", modelName) // for retry + c.Set(common.CtxKeyOriginModel, modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) // this is for backward compatibility diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go index 0c0643ed..a110e7c4 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/main.go @@ -81,7 +81,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st return wrapErr(errors.Wrap(err, "newAwsClient")), nil } - awsModelId, err := awsModelID(channel.Models) + awsModelId, err := awsModelID(c.GetString(common.CtxKeyRequestModel)) if err != nil { return wrapErr(errors.Wrap(err, "awsModelID")), nil } @@ -148,7 +148,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt return wrapErr(errors.Wrap(err, "newAwsClient")), nil } - awsModelId, err := awsModelID(channel.Models) + awsModelId, err := awsModelID(c.GetString(common.CtxKeyRequestModel)) if err != nil { return wrapErr(errors.Wrap(err, "awsModelID")), nil } @@ -211,7 +211,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt return true } response.Id = id - response.Model = c.GetString("original_model") + response.Model = c.GetString(common.CtxKeyOriginModel) response.Created = createdTime jsonStr, err := json.Marshal(response) if err != nil {