diff --git a/relay/adaptor/aws/adapter.go b/relay/adaptor/aws/adapter.go index e0dd545d..fec8df1d 100644 --- a/relay/adaptor/aws/adapter.go +++ b/relay/adaptor/aws/adapter.go @@ -62,15 +62,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met return } -// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html -var ModelList = []string{ - "claude-3-haiku-20240307", - "claude-3-sonnet-20240229", - "claude-3-opus-20240229", -} +func (a *Adaptor) GetModelList() (models []string) { + for n := range awsModelIDMap { + models = append(models, n) + } -func (a *Adaptor) GetModelList() []string { - return ModelList + return } func (a *Adaptor) GetChannelName() string { diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go index 6cd21f7f..3e3b7804 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/main.go @@ -49,23 +49,21 @@ func wrapErr(err error) *relaymodel.ErrorWithStatusCode { } // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html +var awsModelIDMap = map[string]string{ + "claude-instant-1.2": "anthropic.claude-instant-v1", + "claude-2.0": "anthropic.claude-v2", + "claude-2.1": "anthropic.claude-v2:1", + "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", + "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", +} + func awsModelID(requestModel string) (string, error) { - switch requestModel { - case "claude-instant-1.2": - return "anthropic.claude-instant-v1", nil - case "claude-2.0": - return "anthropic.claude-v2", nil - case "claude-2.1": - return "anthropic.claude-v2:1", nil - case "claude-3-sonnet-20240229": - return "anthropic.claude-3-sonnet-20240229-v1:0", nil - case "claude-3-opus-20240229": - return "anthropic.claude-3-opus-20240229-v1:0", nil - case "claude-3-haiku-20240307": - return "anthropic.claude-3-haiku-20240307-v1:0", nil - default: - return "", errors.Errorf("unknown model: %s", requestModel) + if awsModelID, ok := awsModelIDMap[requestModel]; ok { + return awsModelID, nil } + + return "", errors.Errorf("model %s not found", requestModel) } func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {