more explicit

This commit is contained in:
OlivierDehaene 2023-05-05 17:37:17 +02:00
parent c969c8c091
commit bf5990ee9e

View File

@ -602,7 +602,7 @@ class FlashCausalLM(Model):
else:
# Decode mode
# out is of shape [batch_size, vocab_size]
logits = out[i].unsqueeze(0)
logits = out[i].view(1, -1)
all_input_ids_tensor = batch.all_input_ids_tensor[i]
@ -612,7 +612,7 @@ class FlashCausalLM(Model):
)
# Add to all_input_ids_tensor
next_token_id_squeezed = next_token_id.squeeze()
next_token_id_squeezed = next_token_id.view(1)
all_input_ids_tensor[input_length] = next_token_id_squeezed
# Set values
@ -630,10 +630,10 @@ class FlashCausalLM(Model):
# Get prefill logprobs
prefill_logprobs_tensor = torch.log_softmax(out, -1)
prefill_logprobs = torch.gather(
prefill_logprobs_tensor, 1, prefill_tokens_indices.unsqueeze(1)
prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
)
# GPU <-> CPU sync
prefill_logprobs = prefill_logprobs.squeeze(1).tolist()
prefill_logprobs = prefill_logprobs.view(-1).tolist()
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()