mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
faster
This commit is contained in:
parent
1cbc5c633e
commit
c969c8c091
@ -665,9 +665,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
|
||||
if self.model.tp_embeddings:
|
||||
# Logits are sharded, so we need to gather them
|
||||
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
|
||||
world_logits = torch.cat(world_logits, dim=1)
|
||||
world_logits = logits.new_empty(
|
||||
(logits.shape[0], logits.shape[1] * self.world_size)
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_logits, logits, group=self.process_group
|
||||
)
|
||||
|
||||
return world_logits, present
|
||||
return logits, present
|
||||
|
@ -741,9 +741,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
|
||||
if self.gpt_neox.tp_embeddings:
|
||||
# Logits are sharded, so we need to gather them
|
||||
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
|
||||
world_logits = torch.cat(world_logits, dim=1)
|
||||
world_logits = logits.new_empty(
|
||||
(logits.shape[0], logits.shape[1] * self.world_size)
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_logits, logits, group=self.process_group
|
||||
)
|
||||
|
||||
return world_logits, present
|
||||
return logits, present
|
||||
|
@ -581,13 +581,12 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||
|
||||
if self.transformer.tp_embeddings:
|
||||
# Logits are sharded, so we need to gather them
|
||||
world_logits = [
|
||||
torch.empty_like(logits) for _ in range(self.transformer.tp_world_size)
|
||||
]
|
||||
torch.distributed.all_gather(
|
||||
world_logits = logits.new_empty(
|
||||
(logits.shape[0], logits.shape[1] * self.transformer.tp_world_size)
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_logits, logits, group=self.transformer.process_group
|
||||
)
|
||||
world_logits = torch.cat(world_logits, dim=1)
|
||||
|
||||
return world_logits, present
|
||||
|
||||
|
@ -521,11 +521,12 @@ class FlashCausalLM(Model):
|
||||
end_index = cumulative_length + input_length
|
||||
|
||||
# Indices to copy present at the correct place in past_key_values
|
||||
past_indices[start_index:end_index] = torch.arange(
|
||||
torch.arange(
|
||||
start_index + i,
|
||||
end_index + i,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
out=past_indices[start_index:end_index],
|
||||
)
|
||||
cumulative_length += input_length
|
||||
|
||||
@ -632,11 +633,11 @@ class FlashCausalLM(Model):
|
||||
prefill_logprobs_tensor, 1, prefill_tokens_indices.unsqueeze(1)
|
||||
)
|
||||
# GPU <-> CPU sync
|
||||
prefill_logprobs = prefill_logprobs.squeeze(1).to("cpu").numpy()
|
||||
prefill_logprobs = prefill_logprobs.squeeze(1).tolist()
|
||||
|
||||
# GPU <-> CPU sync
|
||||
next_token_logprobs = next_token_logprobs.to("cpu").numpy()
|
||||
next_token_ids = batch.input_ids.to("cpu").numpy()
|
||||
next_token_logprobs = next_token_logprobs.tolist()
|
||||
next_token_ids = batch.input_ids.tolist()
|
||||
|
||||
cumulative_length = 0
|
||||
|
||||
@ -709,7 +710,7 @@ class FlashCausalLM(Model):
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
||||
start_index : end_index - 1
|
||||
].tolist()
|
||||
]
|
||||
prefill_token_ids = all_input_ids[:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user