move all_input_ids_tensor to hpu to improve perf for large bs in sharded mode

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-05-12 00:55:04 -07:00
parent 249ccfc939
commit f728cf69f2
2 changed files with 7 additions and 8 deletions

View File

@ -425,7 +425,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor[i, : len(input_ids)] = input_ids all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device # Create tensors on device
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64) all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
@ -1856,9 +1858,7 @@ class FlashCausalLM(Model):
accepted_ids, accepted_ids,
speculative_ids, speculative_ids,
) = batch.next_token_chooser( ) = batch.next_token_chooser(
_async_h2d_tensor_copy( batch.all_input_ids_tensor[:, : batch.max_current_length],
batch.all_input_ids_tensor[:, : batch.max_current_length]
),
batch.next_token_logits, batch.next_token_logits,
speculate, speculate,
batch.speculative_ids, batch.speculative_ids,
@ -1872,7 +1872,6 @@ class FlashCausalLM(Model):
accepted_ids, accepted_ids,
) )
if batch.valid_indices is not None: if batch.valid_indices is not None:
next_input_ids = next_input_ids.cpu()
next_token_logprobs = next_token_logprobs.cpu() next_token_logprobs = next_token_logprobs.cpu()
accepted_ids = accepted_ids.cpu() accepted_ids = accepted_ids.cpu()
batch.all_input_ids_tensor = batch.all_input_ids_tensor[ batch.all_input_ids_tensor = batch.all_input_ids_tensor[
@ -1922,7 +1921,6 @@ class FlashCausalLM(Model):
accepted_ids = accepted_ids.cpu() accepted_ids = accepted_ids.cpu()
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
next_input_ids = next_input_ids.cpu()
if batch.speculative_logits is not None: if batch.speculative_logits is not None:
for i in range(len(batch)): for i in range(len(batch)):
batch.all_input_ids_tensor[ batch.all_input_ids_tensor[
@ -1934,7 +1932,7 @@ class FlashCausalLM(Model):
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
else: else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor index = batch.cache_lengths_tensor + batch.input_lengths_tensor
index = index.to(batch.all_input_ids_tensor) index = index.to(batch.all_input_ids_tensor.device)
batch_idx = torch.arange( batch_idx = torch.arange(
0, 0,
batch.all_input_ids_tensor.shape[0], batch.all_input_ids_tensor.shape[0],
@ -1944,6 +1942,7 @@ class FlashCausalLM(Model):
batch.all_input_ids_tensor.index_put_( batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids (batch_idx, index.long()), next_input_ids
) )
next_input_ids = next_input_ids.cpu()
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
if batch.position_ids.dim() == 2: if batch.position_ids.dim() == 2:

View File

@ -13,7 +13,7 @@ def get_hpu_free_memory(device, memory_fraction):
free_memory = int( free_memory = int(
torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem) torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem)
) )
logger.info(f"Free memory on device {device}: {free_memory} bytes, ") logger.info(f"Free memory on device {device}: {free_memory} bytes.")
return free_memory return free_memory