Disable watermark with FP8 quantization (#114)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-03-27 13:32:20 +01:00 committed by GitHub
parent 56f00a552b
commit bf5263b88b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 4 deletions

View File

@ -225,7 +225,7 @@ impl Client {
do_sample: true, do_sample: true,
seed: 0, seed: 0,
repetition_penalty: 1.2, repetition_penalty: 1.2,
watermark: true, watermark: false,
}) })
}; };
requests.push(Request { requests.push(Request {

View File

@ -207,6 +207,14 @@ impl Validation {
return Err(ValidationError::RepetitionPenalty); return Err(ValidationError::RepetitionPenalty);
} }
// TODO: enable watermark with fp8 quantization
let quantization_enabled = env::var("QUANT_CONFIG")
.ok()
.map_or(false, |value| !value.is_empty());
if watermark && quantization_enabled {
return Err(ValidationError::WatermarkWithQuantization)
}
// Different because the proto default value is not a valid value // Different because the proto default value is not a valid value
// for the user // for the user
let top_p = top_p let top_p = top_p
@ -450,6 +458,8 @@ pub enum ValidationError {
StopSequence(usize, usize), StopSequence(usize, usize),
#[error("tokenizer error {0}")] #[error("tokenizer error {0}")]
Tokenizer(String), Tokenizer(String),
#[error("`watermark` = true is not allowed with FP8 quantization.")]
WatermarkWithQuantization,
} }
#[cfg(test)] #[cfg(test)]

View File

@ -390,7 +390,8 @@ class CausalLMBatch(Batch):
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
parameters, parameters,
batches[dst_batch_idx].next_token_chooser.dtype, batches[dst_batch_idx].next_token_chooser.dtype,
batches[dst_batch_idx].next_token_chooser.device batches[dst_batch_idx].next_token_chooser.device,
hq_env.is_quantization_enabled
) )
input_ids = batches[dst_batch_idx].input_ids input_ids = batches[dst_batch_idx].input_ids
@ -445,7 +446,9 @@ class CausalLMBatch(Batch):
#append the dummy parameters for dummy request #append the dummy parameters for dummy request
parameters.append(parameters[0]) parameters.append(parameters[0])
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(parameters, dtype, device) next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
parameters, dtype, device, hq_env.is_quantization_enabled
)
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
[r.data.inputs for r in requests] + dummy_inputs, [r.data.inputs for r in requests] + dummy_inputs,
return_tensors="pt", return_tensors="pt",

View File

@ -153,9 +153,11 @@ class HeterogeneousNextTokenChooser:
typical_p: List[float], typical_p: List[float],
do_sample: List[bool], do_sample: List[bool],
seeds: List[int], seeds: List[int],
quantization_enabled: bool,
): ):
warpers = [] warpers = []
# TODO: enable watermark with FP8 quantization
self.watermark_processor = ( self.watermark_processor = (
HeterogeneousProcessorWrapper( HeterogeneousProcessorWrapper(
{ {
@ -164,7 +166,7 @@ class HeterogeneousNextTokenChooser:
if do_watermark if do_watermark
} }
) )
if any(watermark) if any(watermark) and not quantization_enabled
else None else None
) )
@ -252,6 +254,7 @@ class HeterogeneousNextTokenChooser:
pb: List[generate_pb2.NextTokenChooserParameters], pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
quantization_enabled: bool,
) -> "HeterogeneousNextTokenChooser": ) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser( return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb], watermark=[pb_.watermark for pb_ in pb],
@ -264,6 +267,7 @@ class HeterogeneousNextTokenChooser:
seeds=[pb_.seed for pb_ in pb], seeds=[pb_.seed for pb_ in pb],
device=device, device=device,
dtype=dtype, dtype=dtype,
quantization_enabled=quantization_enabled,
) )