From 4ba562c63d84072a48bb08410bd3fa22b2e38e0a Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Fri, 19 Apr 2024 02:08:27 +0000 Subject: [PATCH] fix: Refactor AWS Adapter for Improved Model Mapping and Error Handling * Refactor AWS adapter to improve model management - Replace hardcoded model list in `adapter.go` with a function to get models from `awsModelIDMap` - Update `GetModelList` function to return model list directly - Add `GetChannelName` function to get channel name from `Adaptor` object * Improve error handling and code organization in main.go - Replace switch statement with a map to map AWS model IDs to OpenAI model IDs - Return an error if the model is not found in the map - Use a single return statement instead of wrapping multiple return statements in the `awsModelID` function - Add a new error message for when the model is not found in the map in the `Handler` function --- relay/adaptor/aws/adapter.go | 13 +++++-------- relay/adaptor/aws/main.go | 28 +++++++++++++--------------- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/relay/adaptor/aws/adapter.go b/relay/adaptor/aws/adapter.go index e0dd545d..fec8df1d 100644 --- a/relay/adaptor/aws/adapter.go +++ b/relay/adaptor/aws/adapter.go @@ -62,15 +62,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met return } -// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html -var ModelList = []string{ - "claude-3-haiku-20240307", - "claude-3-sonnet-20240229", - "claude-3-opus-20240229", -} +func (a *Adaptor) GetModelList() (models []string) { + for n := range awsModelIDMap { + models = append(models, n) + } -func (a *Adaptor) GetModelList() []string { - return ModelList + return } func (a *Adaptor) GetChannelName() string { diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go index 6cd21f7f..3e3b7804 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/main.go @@ -49,23 +49,21 @@ func wrapErr(err error) *relaymodel.ErrorWithStatusCode { } // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html +var awsModelIDMap = map[string]string{ + "claude-instant-1.2": "anthropic.claude-instant-v1", + "claude-2.0": "anthropic.claude-v2", + "claude-2.1": "anthropic.claude-v2:1", + "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", + "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", +} + func awsModelID(requestModel string) (string, error) { - switch requestModel { - case "claude-instant-1.2": - return "anthropic.claude-instant-v1", nil - case "claude-2.0": - return "anthropic.claude-v2", nil - case "claude-2.1": - return "anthropic.claude-v2:1", nil - case "claude-3-sonnet-20240229": - return "anthropic.claude-3-sonnet-20240229-v1:0", nil - case "claude-3-opus-20240229": - return "anthropic.claude-3-opus-20240229-v1:0", nil - case "claude-3-haiku-20240307": - return "anthropic.claude-3-haiku-20240307-v1:0", nil - default: - return "", errors.Errorf("unknown model: %s", requestModel) + if awsModelID, ok := awsModelIDMap[requestModel]; ok { + return awsModelID, nil } + + return "", errors.Errorf("model %s not found", requestModel) } func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {