mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
Disable watermark with FP8 quantization (#114)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
56f00a552b
commit
bf5263b88b
@ -225,7 +225,7 @@ impl Client {
|
||||
do_sample: true,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
watermark: true,
|
||||
watermark: false,
|
||||
})
|
||||
};
|
||||
requests.push(Request {
|
||||
|
@ -207,6 +207,14 @@ impl Validation {
|
||||
return Err(ValidationError::RepetitionPenalty);
|
||||
}
|
||||
|
||||
// TODO: enable watermark with fp8 quantization
|
||||
let quantization_enabled = env::var("QUANT_CONFIG")
|
||||
.ok()
|
||||
.map_or(false, |value| !value.is_empty());
|
||||
if watermark && quantization_enabled {
|
||||
return Err(ValidationError::WatermarkWithQuantization)
|
||||
}
|
||||
|
||||
// Different because the proto default value is not a valid value
|
||||
// for the user
|
||||
let top_p = top_p
|
||||
@ -450,6 +458,8 @@ pub enum ValidationError {
|
||||
StopSequence(usize, usize),
|
||||
#[error("tokenizer error {0}")]
|
||||
Tokenizer(String),
|
||||
#[error("`watermark` = true is not allowed with FP8 quantization.")]
|
||||
WatermarkWithQuantization,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -390,7 +390,8 @@ class CausalLMBatch(Batch):
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
parameters,
|
||||
batches[dst_batch_idx].next_token_chooser.dtype,
|
||||
batches[dst_batch_idx].next_token_chooser.device
|
||||
batches[dst_batch_idx].next_token_chooser.device,
|
||||
hq_env.is_quantization_enabled
|
||||
)
|
||||
|
||||
input_ids = batches[dst_batch_idx].input_ids
|
||||
@ -445,7 +446,9 @@ class CausalLMBatch(Batch):
|
||||
#append the dummy parameters for dummy request
|
||||
parameters.append(parameters[0])
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(parameters, dtype, device)
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
parameters, dtype, device, hq_env.is_quantization_enabled
|
||||
)
|
||||
tokenized_inputs = tokenizer(
|
||||
[r.data.inputs for r in requests] + dummy_inputs,
|
||||
return_tensors="pt",
|
||||
|
@ -153,9 +153,11 @@ class HeterogeneousNextTokenChooser:
|
||||
typical_p: List[float],
|
||||
do_sample: List[bool],
|
||||
seeds: List[int],
|
||||
quantization_enabled: bool,
|
||||
):
|
||||
warpers = []
|
||||
|
||||
# TODO: enable watermark with FP8 quantization
|
||||
self.watermark_processor = (
|
||||
HeterogeneousProcessorWrapper(
|
||||
{
|
||||
@ -164,7 +166,7 @@ class HeterogeneousNextTokenChooser:
|
||||
if do_watermark
|
||||
}
|
||||
)
|
||||
if any(watermark)
|
||||
if any(watermark) and not quantization_enabled
|
||||
else None
|
||||
)
|
||||
|
||||
@ -252,6 +254,7 @@ class HeterogeneousNextTokenChooser:
|
||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
quantization_enabled: bool,
|
||||
) -> "HeterogeneousNextTokenChooser":
|
||||
return HeterogeneousNextTokenChooser(
|
||||
watermark=[pb_.watermark for pb_ in pb],
|
||||
@ -264,6 +267,7 @@ class HeterogeneousNextTokenChooser:
|
||||
seeds=[pb_.seed for pb_ in pb],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
quantization_enabled=quantization_enabled,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user