diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index 1f94fd3a..5e7a2098 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -20,7 +20,9 @@ use tracing::{instrument, Level, span}; use text_generation_router::{FinishReason, Token}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; -use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest}; +use text_generation_router::validation::{ + Chunk, ValidationError, ValidGenerateRequest, ValidParameters, +}; use text_generation_router::validation::ValidationError::UnsupportedModality; use crate::errors::TensorRtLlmBackendError; @@ -121,6 +123,11 @@ impl TensorRtLlmBackend { )); } + // TODO: Is it really needed? How can it be validated before? + if request.parameters.grammar.is_some() { + return Err(InferError::ValidationError(ValidationError::Grammar)); + } + match request.inputs.len() { 0 => Err(InferError::ValidationError(ValidationError::EmptyInput)), 2.. => Err(InferError::GenerationError( @@ -308,6 +315,8 @@ impl Backend for TensorRtLlmBackend { params.top_k, params.top_p, params.temperature, + params.repetition_penalty, + params.frequency_penalty, params.seed, );