Add warmup for logits processors (#107)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-03-18 15:17:47 +01:00 committed by GitHub
parent 8504f9c41c
commit b45f648483
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 9 deletions

View File

@ -165,14 +165,14 @@ impl Client {
inputs: self.get_random_input(input_length, seq_bucket_size),
truncate: max_input_length,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: true,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
repetition_penalty: 1.2,
watermark: true,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: cmp::min(10, max_total_tokens - max_input_length),

View File

@ -802,11 +802,11 @@ class CausalLM(Model):
input_length = batch.input_length
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids[:, :token_idx_scalar], logits[:, input_length - 1: input_length, :].squeeze(-2)
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2)
)
else:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids[:, :token_idx_scalar], logits.squeeze(-2)
batch.input_ids, logits.squeeze(-2)
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,