mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-16 22:32:07 +00:00
Add warmup for logits processors (#107)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
8504f9c41c
commit
b45f648483
@ -165,14 +165,14 @@ impl Client {
|
|||||||
inputs: self.get_random_input(input_length, seq_bucket_size),
|
inputs: self.get_random_input(input_length, seq_bucket_size),
|
||||||
truncate: max_input_length,
|
truncate: max_input_length,
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 1.0,
|
temperature: 0.9,
|
||||||
top_k: 0,
|
top_k: 10,
|
||||||
top_p: 1.0,
|
top_p: 0.9,
|
||||||
typical_p: 1.0,
|
typical_p: 0.9,
|
||||||
do_sample: false,
|
do_sample: true,
|
||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 1.0,
|
repetition_penalty: 1.2,
|
||||||
watermark: false,
|
watermark: true,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: cmp::min(10, max_total_tokens - max_input_length),
|
max_new_tokens: cmp::min(10, max_total_tokens - max_input_length),
|
||||||
|
@ -802,11 +802,11 @@ class CausalLM(Model):
|
|||||||
input_length = batch.input_length
|
input_length = batch.input_length
|
||||||
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
|
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
|
||||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||||
batch.input_ids[:, :token_idx_scalar], logits[:, input_length - 1: input_length, :].squeeze(-2)
|
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||||
batch.input_ids[:, :token_idx_scalar], logits.squeeze(-2)
|
batch.input_ids, logits.squeeze(-2)
|
||||||
)
|
)
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
|
Loading…
Reference in New Issue
Block a user