diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 1293124a..637c95df 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -665,9 +665,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): if self.model.tp_embeddings: # Logits are sharded, so we need to gather them - world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] - torch.distributed.all_gather(world_logits, logits, group=self.process_group) - world_logits = torch.cat(world_logits, dim=1) + world_logits = logits.new_empty( + (logits.shape[0], logits.shape[1] * self.world_size) + ) + torch.distributed.all_gather_into_tensor( + world_logits, logits, group=self.process_group + ) return world_logits, present return logits, present diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index ae1465ab..8d93301e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -741,9 +741,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): if self.gpt_neox.tp_embeddings: # Logits are sharded, so we need to gather them - world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] - torch.distributed.all_gather(world_logits, logits, group=self.process_group) - world_logits = torch.cat(world_logits, dim=1) + world_logits = logits.new_empty( + (logits.shape[0], logits.shape[1] * self.world_size) + ) + torch.distributed.all_gather_into_tensor( + world_logits, logits, group=self.process_group + ) return world_logits, present return logits, present diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 20ad8385..597eaef1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -581,13 +581,12 @@ class FlashSantacoderForCausalLM(nn.Module): if self.transformer.tp_embeddings: # Logits are sharded, so we need to gather them - world_logits = [ - torch.empty_like(logits) for _ in range(self.transformer.tp_world_size) - ] - torch.distributed.all_gather( + world_logits = logits.new_empty( + (logits.shape[0], logits.shape[1] * self.transformer.tp_world_size) + ) + torch.distributed.all_gather_into_tensor( world_logits, logits, group=self.transformer.process_group ) - world_logits = torch.cat(world_logits, dim=1) return world_logits, present diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index ba318d14..76265217 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -521,11 +521,12 @@ class FlashCausalLM(Model): end_index = cumulative_length + input_length # Indices to copy present at the correct place in past_key_values - past_indices[start_index:end_index] = torch.arange( + torch.arange( start_index + i, end_index + i, dtype=torch.int64, device=self.device, + out=past_indices[start_index:end_index], ) cumulative_length += input_length @@ -632,11 +633,11 @@ class FlashCausalLM(Model): prefill_logprobs_tensor, 1, prefill_tokens_indices.unsqueeze(1) ) # GPU <-> CPU sync - prefill_logprobs = prefill_logprobs.squeeze(1).to("cpu").numpy() + prefill_logprobs = prefill_logprobs.squeeze(1).tolist() # GPU <-> CPU sync - next_token_logprobs = next_token_logprobs.to("cpu").numpy() - next_token_ids = batch.input_ids.to("cpu").numpy() + next_token_logprobs = next_token_logprobs.tolist() + next_token_ids = batch.input_ids.tolist() cumulative_length = 0 @@ -709,7 +710,7 @@ class FlashCausalLM(Model): # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] + prefill_logprobs[ start_index : end_index - 1 - ].tolist() + ] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids,