mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
revert change on all gather
This commit is contained in:
parent
bf5990ee9e
commit
1a3aa08fa0
@ -665,12 +665,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
|
||||
if self.model.tp_embeddings:
|
||||
# Logits are sharded, so we need to gather them
|
||||
world_logits = logits.new_empty(
|
||||
(logits.shape[0], logits.shape[1] * self.world_size)
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_logits, logits, group=self.process_group
|
||||
)
|
||||
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
|
||||
world_logits = torch.cat(world_logits, dim=1)
|
||||
|
||||
return world_logits, present
|
||||
return logits, present
|
||||
|
@ -741,12 +741,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
|
||||
if self.gpt_neox.tp_embeddings:
|
||||
# Logits are sharded, so we need to gather them
|
||||
world_logits = logits.new_empty(
|
||||
(logits.shape[0], logits.shape[1] * self.world_size)
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_logits, logits, group=self.process_group
|
||||
)
|
||||
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
|
||||
world_logits = torch.cat(world_logits, dim=1)
|
||||
|
||||
return world_logits, present
|
||||
return logits, present
|
||||
|
@ -581,12 +581,13 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||
|
||||
if self.transformer.tp_embeddings:
|
||||
# Logits are sharded, so we need to gather them
|
||||
world_logits = logits.new_empty(
|
||||
(logits.shape[0], logits.shape[1] * self.transformer.tp_world_size)
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_logits = [
|
||||
torch.empty_like(logits) for _ in range(self.transformer.tp_world_size)
|
||||
]
|
||||
torch.distributed.all_gather(
|
||||
world_logits, logits, group=self.transformer.process_group
|
||||
)
|
||||
world_logits = torch.cat(world_logits, dim=1)
|
||||
|
||||
return world_logits, present
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user