more explicit

This commit is contained in:
OlivierDehaene 2023-05-05 17:37:17 +02:00
parent c969c8c091
commit bf5990ee9e

View File

@ -602,7 +602,7 @@ class FlashCausalLM(Model):
else: else:
# Decode mode # Decode mode
# out is of shape [batch_size, vocab_size] # out is of shape [batch_size, vocab_size]
logits = out[i].unsqueeze(0) logits = out[i].view(1, -1)
all_input_ids_tensor = batch.all_input_ids_tensor[i] all_input_ids_tensor = batch.all_input_ids_tensor[i]
@ -612,7 +612,7 @@ class FlashCausalLM(Model):
) )
# Add to all_input_ids_tensor # Add to all_input_ids_tensor
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.view(1)
all_input_ids_tensor[input_length] = next_token_id_squeezed all_input_ids_tensor[input_length] = next_token_id_squeezed
# Set values # Set values
@ -630,10 +630,10 @@ class FlashCausalLM(Model):
# Get prefill logprobs # Get prefill logprobs
prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs_tensor = torch.log_softmax(out, -1)
prefill_logprobs = torch.gather( prefill_logprobs = torch.gather(
prefill_logprobs_tensor, 1, prefill_tokens_indices.unsqueeze(1) prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
) )
# GPU <-> CPU sync # GPU <-> CPU sync
prefill_logprobs = prefill_logprobs.squeeze(1).tolist() prefill_logprobs = prefill_logprobs.view(-1).tolist()
# GPU <-> CPU sync # GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist() next_token_logprobs = next_token_logprobs.tolist()