This commit is contained in:
OlivierDehaene 2023-05-05 17:26:52 +02:00
parent 1cbc5c633e
commit c969c8c091
4 changed files with 22 additions and 16 deletions

View File

@ -665,9 +665,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
if self.model.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
world_logits = logits.new_empty(
(logits.shape[0], logits.shape[1] * self.world_size)
)
torch.distributed.all_gather_into_tensor(
world_logits, logits, group=self.process_group
)
return world_logits, present
return logits, present

View File

@ -741,9 +741,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
if self.gpt_neox.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
world_logits = logits.new_empty(
(logits.shape[0], logits.shape[1] * self.world_size)
)
torch.distributed.all_gather_into_tensor(
world_logits, logits, group=self.process_group
)
return world_logits, present
return logits, present

View File

@ -581,13 +581,12 @@ class FlashSantacoderForCausalLM(nn.Module):
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [
torch.empty_like(logits) for _ in range(self.transformer.tp_world_size)
]
torch.distributed.all_gather(
world_logits = logits.new_empty(
(logits.shape[0], logits.shape[1] * self.transformer.tp_world_size)
)
torch.distributed.all_gather_into_tensor(
world_logits, logits, group=self.transformer.process_group
)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present

View File

@ -521,11 +521,12 @@ class FlashCausalLM(Model):
end_index = cumulative_length + input_length
# Indices to copy present at the correct place in past_key_values
past_indices[start_index:end_index] = torch.arange(
torch.arange(
start_index + i,
end_index + i,
dtype=torch.int64,
device=self.device,
out=past_indices[start_index:end_index],
)
cumulative_length += input_length
@ -632,11 +633,11 @@ class FlashCausalLM(Model):
prefill_logprobs_tensor, 1, prefill_tokens_indices.unsqueeze(1)
)
# GPU <-> CPU sync
prefill_logprobs = prefill_logprobs.squeeze(1).to("cpu").numpy()
prefill_logprobs = prefill_logprobs.squeeze(1).tolist()
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.to("cpu").numpy()
next_token_ids = batch.input_ids.to("cpu").numpy()
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = batch.input_ids.tolist()
cumulative_length = 0
@ -709,7 +710,7 @@ class FlashCausalLM(Model):
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
start_index : end_index - 1
].tolist()
]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,