mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
more explicit
This commit is contained in:
parent
c969c8c091
commit
bf5990ee9e
@ -602,7 +602,7 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
# Decode mode
|
# Decode mode
|
||||||
# out is of shape [batch_size, vocab_size]
|
# out is of shape [batch_size, vocab_size]
|
||||||
logits = out[i].unsqueeze(0)
|
logits = out[i].view(1, -1)
|
||||||
|
|
||||||
all_input_ids_tensor = batch.all_input_ids_tensor[i]
|
all_input_ids_tensor = batch.all_input_ids_tensor[i]
|
||||||
|
|
||||||
@ -612,7 +612,7 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Add to all_input_ids_tensor
|
# Add to all_input_ids_tensor
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.view(1)
|
||||||
all_input_ids_tensor[input_length] = next_token_id_squeezed
|
all_input_ids_tensor[input_length] = next_token_id_squeezed
|
||||||
|
|
||||||
# Set values
|
# Set values
|
||||||
@ -630,10 +630,10 @@ class FlashCausalLM(Model):
|
|||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
||||||
prefill_logprobs = torch.gather(
|
prefill_logprobs = torch.gather(
|
||||||
prefill_logprobs_tensor, 1, prefill_tokens_indices.unsqueeze(1)
|
prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
|
||||||
)
|
)
|
||||||
# GPU <-> CPU sync
|
# GPU <-> CPU sync
|
||||||
prefill_logprobs = prefill_logprobs.squeeze(1).tolist()
|
prefill_logprobs = prefill_logprobs.view(-1).tolist()
|
||||||
|
|
||||||
# GPU <-> CPU sync
|
# GPU <-> CPU sync
|
||||||
next_token_logprobs = next_token_logprobs.tolist()
|
next_token_logprobs = next_token_logprobs.tolist()
|
||||||
|
Loading…
Reference in New Issue
Block a user