Disable watermark with FP8 quantization (#114)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-03-27 13:32:20 +01:00 committed by GitHub
parent 56f00a552b
commit bf5263b88b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 4 deletions

View File

@ -225,7 +225,7 @@ impl Client {
do_sample: true,
seed: 0,
repetition_penalty: 1.2,
watermark: true,
watermark: false,
})
};
requests.push(Request {

View File

@ -207,6 +207,14 @@ impl Validation {
return Err(ValidationError::RepetitionPenalty);
}
// TODO: enable watermark with fp8 quantization
let quantization_enabled = env::var("QUANT_CONFIG")
.ok()
.map_or(false, |value| !value.is_empty());
if watermark && quantization_enabled {
return Err(ValidationError::WatermarkWithQuantization)
}
// Different because the proto default value is not a valid value
// for the user
let top_p = top_p
@ -450,6 +458,8 @@ pub enum ValidationError {
StopSequence(usize, usize),
#[error("tokenizer error {0}")]
Tokenizer(String),
#[error("`watermark` = true is not allowed with FP8 quantization.")]
WatermarkWithQuantization,
}
#[cfg(test)]

View File

@ -390,7 +390,8 @@ class CausalLMBatch(Batch):
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
parameters,
batches[dst_batch_idx].next_token_chooser.dtype,
batches[dst_batch_idx].next_token_chooser.device
batches[dst_batch_idx].next_token_chooser.device,
hq_env.is_quantization_enabled
)
input_ids = batches[dst_batch_idx].input_ids
@ -445,7 +446,9 @@ class CausalLMBatch(Batch):
#append the dummy parameters for dummy request
parameters.append(parameters[0])
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(parameters, dtype, device)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
parameters, dtype, device, hq_env.is_quantization_enabled
)
tokenized_inputs = tokenizer(
[r.data.inputs for r in requests] + dummy_inputs,
return_tensors="pt",

View File

@ -153,9 +153,11 @@ class HeterogeneousNextTokenChooser:
typical_p: List[float],
do_sample: List[bool],
seeds: List[int],
quantization_enabled: bool,
):
warpers = []
# TODO: enable watermark with FP8 quantization
self.watermark_processor = (
HeterogeneousProcessorWrapper(
{
@ -164,7 +166,7 @@ class HeterogeneousNextTokenChooser:
if do_watermark
}
)
if any(watermark)
if any(watermark) and not quantization_enabled
else None
)
@ -252,6 +254,7 @@ class HeterogeneousNextTokenChooser:
pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype,
device: torch.device,
quantization_enabled: bool,
) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb],
@ -264,6 +267,7 @@ class HeterogeneousNextTokenChooser:
seeds=[pb_.seed for pb_ in pb],
device=device,
dtype=dtype,
quantization_enabled=quantization_enabled,
)