mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
more explicit
This commit is contained in:
parent
c969c8c091
commit
bf5990ee9e
@ -602,7 +602,7 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
# Decode mode
|
||||
# out is of shape [batch_size, vocab_size]
|
||||
logits = out[i].unsqueeze(0)
|
||||
logits = out[i].view(1, -1)
|
||||
|
||||
all_input_ids_tensor = batch.all_input_ids_tensor[i]
|
||||
|
||||
@ -612,7 +612,7 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
# Add to all_input_ids_tensor
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_id_squeezed = next_token_id.view(1)
|
||||
all_input_ids_tensor[input_length] = next_token_id_squeezed
|
||||
|
||||
# Set values
|
||||
@ -630,10 +630,10 @@ class FlashCausalLM(Model):
|
||||
# Get prefill logprobs
|
||||
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
||||
prefill_logprobs = torch.gather(
|
||||
prefill_logprobs_tensor, 1, prefill_tokens_indices.unsqueeze(1)
|
||||
prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
|
||||
)
|
||||
# GPU <-> CPU sync
|
||||
prefill_logprobs = prefill_logprobs.squeeze(1).tolist()
|
||||
prefill_logprobs = prefill_logprobs.view(-1).tolist()
|
||||
|
||||
# GPU <-> CPU sync
|
||||
next_token_logprobs = next_token_logprobs.tolist()
|
||||
|
Loading…
Reference in New Issue
Block a user