mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
move all_input_ids_tensor to hpu to improve perf for large bs in sharded mode
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
249ccfc939
commit
f728cf69f2
@ -425,7 +425,9 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||
|
||||
# Create tensors on device
|
||||
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
|
||||
|
||||
@ -1856,9 +1858,7 @@ class FlashCausalLM(Model):
|
||||
accepted_ids,
|
||||
speculative_ids,
|
||||
) = batch.next_token_chooser(
|
||||
_async_h2d_tensor_copy(
|
||||
batch.all_input_ids_tensor[:, : batch.max_current_length]
|
||||
),
|
||||
batch.all_input_ids_tensor[:, : batch.max_current_length],
|
||||
batch.next_token_logits,
|
||||
speculate,
|
||||
batch.speculative_ids,
|
||||
@ -1872,7 +1872,6 @@ class FlashCausalLM(Model):
|
||||
accepted_ids,
|
||||
)
|
||||
if batch.valid_indices is not None:
|
||||
next_input_ids = next_input_ids.cpu()
|
||||
next_token_logprobs = next_token_logprobs.cpu()
|
||||
accepted_ids = accepted_ids.cpu()
|
||||
batch.all_input_ids_tensor = batch.all_input_ids_tensor[
|
||||
@ -1922,7 +1921,6 @@ class FlashCausalLM(Model):
|
||||
accepted_ids = accepted_ids.cpu()
|
||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||
next_input_ids = next_input_ids.cpu()
|
||||
if batch.speculative_logits is not None:
|
||||
for i in range(len(batch)):
|
||||
batch.all_input_ids_tensor[
|
||||
@ -1934,7 +1932,7 @@ class FlashCausalLM(Model):
|
||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||
else:
|
||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||
index = index.to(batch.all_input_ids_tensor)
|
||||
index = index.to(batch.all_input_ids_tensor.device)
|
||||
batch_idx = torch.arange(
|
||||
0,
|
||||
batch.all_input_ids_tensor.shape[0],
|
||||
@ -1944,6 +1942,7 @@ class FlashCausalLM(Model):
|
||||
batch.all_input_ids_tensor.index_put_(
|
||||
(batch_idx, index.long()), next_input_ids
|
||||
)
|
||||
next_input_ids = next_input_ids.cpu()
|
||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
if batch.position_ids.dim() == 2:
|
||||
|
@ -13,7 +13,7 @@ def get_hpu_free_memory(device, memory_fraction):
|
||||
free_memory = int(
|
||||
torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem)
|
||||
)
|
||||
logger.info(f"Free memory on device {device}: {free_memory} bytes, ")
|
||||
logger.info(f"Free memory on device {device}: {free_memory} bytes.")
|
||||
return free_memory
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user