diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 9a0f789a..4217c17b 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -425,7 +425,9 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor[i, : len(input_ids)] = input_ids # Create tensors on device - all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64) + all_input_ids_tensor = torch.tensor( + all_input_ids_tensor, dtype=torch.int64, device=device + ) top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64) @@ -1856,9 +1858,7 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - _async_h2d_tensor_copy( - batch.all_input_ids_tensor[:, : batch.max_current_length] - ), + batch.all_input_ids_tensor[:, : batch.max_current_length], batch.next_token_logits, speculate, batch.speculative_ids, @@ -1872,7 +1872,6 @@ class FlashCausalLM(Model): accepted_ids, ) if batch.valid_indices is not None: - next_input_ids = next_input_ids.cpu() next_token_logprobs = next_token_logprobs.cpu() accepted_ids = accepted_ids.cpu() batch.all_input_ids_tensor = batch.all_input_ids_tensor[ @@ -1922,7 +1921,6 @@ class FlashCausalLM(Model): accepted_ids = accepted_ids.cpu() cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) - next_input_ids = next_input_ids.cpu() if batch.speculative_logits is not None: for i in range(len(batch)): batch.all_input_ids_tensor[ @@ -1934,7 +1932,7 @@ class FlashCausalLM(Model): ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] else: index = batch.cache_lengths_tensor + batch.input_lengths_tensor - index = index.to(batch.all_input_ids_tensor) + index = index.to(batch.all_input_ids_tensor.device) batch_idx = torch.arange( 0, batch.all_input_ids_tensor.shape[0], @@ -1944,6 +1942,7 @@ class FlashCausalLM(Model): batch.all_input_ids_tensor.index_put_( (batch_idx, index.long()), next_input_ids ) + next_input_ids = next_input_ids.cpu() batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.speculative_ids = speculative_ids if batch.position_ids.dim() == 2: diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py index 39156140..d25484d6 100644 --- a/backends/gaudi/server/text_generation_server/utils/import_utils.py +++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py @@ -13,7 +13,7 @@ def get_hpu_free_memory(device, memory_fraction): free_memory = int( torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem) ) - logger.info(f"Free memory on device {device}: {free_memory} bytes, ") + logger.info(f"Free memory on device {device}: {free_memory} bytes.") return free_memory