Add warmup for logits processors (#107)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-03-18 15:17:47 +01:00 committed by GitHub
parent 8504f9c41c
commit b45f648483
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 9 deletions

View File

@ -165,14 +165,14 @@ impl Client {
inputs: self.get_random_input(input_length, seq_bucket_size), inputs: self.get_random_input(input_length, seq_bucket_size),
truncate: max_input_length, truncate: max_input_length,
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 1.0, temperature: 0.9,
top_k: 0, top_k: 10,
top_p: 1.0, top_p: 0.9,
typical_p: 1.0, typical_p: 0.9,
do_sample: false, do_sample: true,
seed: 0, seed: 0,
repetition_penalty: 1.0, repetition_penalty: 1.2,
watermark: false, watermark: true,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: cmp::min(10, max_total_tokens - max_input_length), max_new_tokens: cmp::min(10, max_total_tokens - max_input_length),

View File

@ -802,11 +802,11 @@ class CausalLM(Model):
input_length = batch.input_length input_length = batch.input_length
if self.is_optimized_for_gaudi and logits.shape[-2] > 1: if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids[:, :token_idx_scalar], logits[:, input_length - 1: input_length, :].squeeze(-2) batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2)
) )
else: else:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids[:, :token_idx_scalar], logits.squeeze(-2) batch.input_ids, logits.squeeze(-2)
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens,