diff --git a/router/src/sagemaker.rs b/router/src/sagemaker.rs index 750ef222b..520831c88 100644 --- a/router/src/sagemaker.rs +++ b/router/src/sagemaker.rs @@ -67,16 +67,26 @@ pub(crate) async fn sagemaker_compatibility( default_return_full_text: Extension, infer: Extension, compute_type: Extension, + context: Extension>, info: Extension, Json(req): Json, ) -> Result)> { match req { SagemakerRequest::Generate(req) => { - compat_generate(default_return_full_text, infer, compute_type, Json(req)).await + compat_generate( + default_return_full_text, + infer, + compute_type, + context, + Json(req), + ) + .await + } + SagemakerRequest::Chat(req) => { + chat_completions(infer, compute_type, info, context, Json(req)).await } - SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await, SagemakerRequest::Completion(req) => { - completions(infer, compute_type, info, Json(req)).await + completions(infer, compute_type, info, context, Json(req)).await } } }