diff --git a/router/client/src/client.rs b/router/client/src/client.rs index a42e23cb..8c3f6da4 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -225,7 +225,7 @@ impl Client { do_sample: true, seed: 0, repetition_penalty: 1.2, - watermark: true, + watermark: false, }) }; requests.push(Request { diff --git a/router/src/validation.rs b/router/src/validation.rs index e3df5300..aeaf463a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -207,6 +207,14 @@ impl Validation { return Err(ValidationError::RepetitionPenalty); } + // TODO: enable watermark with fp8 quantization + let quantization_enabled = env::var("QUANT_CONFIG") + .ok() + .map_or(false, |value| !value.is_empty()); + if watermark && quantization_enabled { + return Err(ValidationError::WatermarkWithQuantization) + } + // Different because the proto default value is not a valid value // for the user let top_p = top_p @@ -450,6 +458,8 @@ pub enum ValidationError { StopSequence(usize, usize), #[error("tokenizer error {0}")] Tokenizer(String), + #[error("`watermark` = true is not allowed with FP8 quantization.")] + WatermarkWithQuantization, } #[cfg(test)] diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index f8ef8050..7c5b93ee 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -390,7 +390,8 @@ class CausalLMBatch(Batch): next_token_chooser = HeterogeneousNextTokenChooser.from_pb( parameters, batches[dst_batch_idx].next_token_chooser.dtype, - batches[dst_batch_idx].next_token_chooser.device + batches[dst_batch_idx].next_token_chooser.device, + hq_env.is_quantization_enabled ) input_ids = batches[dst_batch_idx].input_ids @@ -445,7 +446,9 @@ class CausalLMBatch(Batch): #append the dummy parameters for dummy request parameters.append(parameters[0]) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb(parameters, dtype, device) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + parameters, dtype, device, hq_env.is_quantization_enabled + ) tokenized_inputs = tokenizer( [r.data.inputs for r in requests] + dummy_inputs, return_tensors="pt", diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index c4775a09..1ea4d9bc 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -153,9 +153,11 @@ class HeterogeneousNextTokenChooser: typical_p: List[float], do_sample: List[bool], seeds: List[int], + quantization_enabled: bool, ): warpers = [] + # TODO: enable watermark with FP8 quantization self.watermark_processor = ( HeterogeneousProcessorWrapper( { @@ -164,7 +166,7 @@ class HeterogeneousNextTokenChooser: if do_watermark } ) - if any(watermark) + if any(watermark) and not quantization_enabled else None ) @@ -252,6 +254,7 @@ class HeterogeneousNextTokenChooser: pb: List[generate_pb2.NextTokenChooserParameters], dtype: torch.dtype, device: torch.device, + quantization_enabled: bool, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], @@ -264,6 +267,7 @@ class HeterogeneousNextTokenChooser: seeds=[pb_.seed for pb_ in pb], device=device, dtype=dtype, + quantization_enabled=quantization_enabled, )