mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
move all_input_ids_tensor to hpu to improve perf for large bs in sharded mode
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
249ccfc939
commit
f728cf69f2
@ -425,7 +425,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||||
|
|
||||||
# Create tensors on device
|
# Create tensors on device
|
||||||
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
|
all_input_ids_tensor = torch.tensor(
|
||||||
|
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
|
||||||
|
|
||||||
@ -1856,9 +1858,7 @@ class FlashCausalLM(Model):
|
|||||||
accepted_ids,
|
accepted_ids,
|
||||||
speculative_ids,
|
speculative_ids,
|
||||||
) = batch.next_token_chooser(
|
) = batch.next_token_chooser(
|
||||||
_async_h2d_tensor_copy(
|
batch.all_input_ids_tensor[:, : batch.max_current_length],
|
||||||
batch.all_input_ids_tensor[:, : batch.max_current_length]
|
|
||||||
),
|
|
||||||
batch.next_token_logits,
|
batch.next_token_logits,
|
||||||
speculate,
|
speculate,
|
||||||
batch.speculative_ids,
|
batch.speculative_ids,
|
||||||
@ -1872,7 +1872,6 @@ class FlashCausalLM(Model):
|
|||||||
accepted_ids,
|
accepted_ids,
|
||||||
)
|
)
|
||||||
if batch.valid_indices is not None:
|
if batch.valid_indices is not None:
|
||||||
next_input_ids = next_input_ids.cpu()
|
|
||||||
next_token_logprobs = next_token_logprobs.cpu()
|
next_token_logprobs = next_token_logprobs.cpu()
|
||||||
accepted_ids = accepted_ids.cpu()
|
accepted_ids = accepted_ids.cpu()
|
||||||
batch.all_input_ids_tensor = batch.all_input_ids_tensor[
|
batch.all_input_ids_tensor = batch.all_input_ids_tensor[
|
||||||
@ -1922,7 +1921,6 @@ class FlashCausalLM(Model):
|
|||||||
accepted_ids = accepted_ids.cpu()
|
accepted_ids = accepted_ids.cpu()
|
||||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||||
next_input_ids = next_input_ids.cpu()
|
|
||||||
if batch.speculative_logits is not None:
|
if batch.speculative_logits is not None:
|
||||||
for i in range(len(batch)):
|
for i in range(len(batch)):
|
||||||
batch.all_input_ids_tensor[
|
batch.all_input_ids_tensor[
|
||||||
@ -1934,7 +1932,7 @@ class FlashCausalLM(Model):
|
|||||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||||
else:
|
else:
|
||||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||||
index = index.to(batch.all_input_ids_tensor)
|
index = index.to(batch.all_input_ids_tensor.device)
|
||||||
batch_idx = torch.arange(
|
batch_idx = torch.arange(
|
||||||
0,
|
0,
|
||||||
batch.all_input_ids_tensor.shape[0],
|
batch.all_input_ids_tensor.shape[0],
|
||||||
@ -1944,6 +1942,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.all_input_ids_tensor.index_put_(
|
batch.all_input_ids_tensor.index_put_(
|
||||||
(batch_idx, index.long()), next_input_ids
|
(batch_idx, index.long()), next_input_ids
|
||||||
)
|
)
|
||||||
|
next_input_ids = next_input_ids.cpu()
|
||||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||||
batch.speculative_ids = speculative_ids
|
batch.speculative_ids = speculative_ids
|
||||||
if batch.position_ids.dim() == 2:
|
if batch.position_ids.dim() == 2:
|
||||||
|
@ -13,7 +13,7 @@ def get_hpu_free_memory(device, memory_fraction):
|
|||||||
free_memory = int(
|
free_memory = int(
|
||||||
torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem)
|
torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem)
|
||||||
)
|
)
|
||||||
logger.info(f"Free memory on device {device}: {free_memory} bytes, ")
|
logger.info(f"Free memory on device {device}: {free_memory} bytes.")
|
||||||
return free_memory
|
return free_memory
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user