mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-18 07:44:53 +00:00
[Torch.compile] Enable llama-2-7b (#157)
Co-authored-by: Jacek Czaja <jczaja@habana.ai>
This commit is contained in:
parent
3bf8e8e466
commit
ef86232c94
@ -95,7 +95,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
|||||||
assert batch.past_key_values is None
|
assert batch.past_key_values is None
|
||||||
assert all(
|
assert all(
|
||||||
[
|
[
|
||||||
torch.equal(input_ids, request.all_input_ids[:batch.input_length + 1, 0])
|
torch.equal(input_ids.to('cpu'), request.all_input_ids[:batch.input_length + 1, 0])
|
||||||
for input_ids, request in zip(batch.input_ids, batch.requests)
|
for input_ids, request in zip(batch.input_ids, batch.requests)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -11,12 +11,12 @@ import time
|
|||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch._dynamo
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
|
|
||||||
import text_generation_server.habana_quantization_env as hq_env
|
import text_generation_server.habana_quantization_env as hq_env
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
|
||||||
from optimum.habana.utils import HabanaProfile
|
from optimum.habana.utils import HabanaProfile
|
||||||
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||||
from optimum.habana.checkpoint_utils import (
|
from optimum.habana.checkpoint_utils import (
|
||||||
@ -58,6 +58,13 @@ BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
|||||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
|
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
|
||||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
||||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||||
|
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
||||||
|
|
||||||
|
|
||||||
|
def torch_compile_for_eager(func):
|
||||||
|
if LAZY_MODE == 1:
|
||||||
|
return func
|
||||||
|
return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True})
|
||||||
|
|
||||||
|
|
||||||
def round_up(number, k):
|
def round_up(number, k):
|
||||||
@ -65,7 +72,7 @@ def round_up(number, k):
|
|||||||
|
|
||||||
|
|
||||||
def to_tensor_indices(indices, device):
|
def to_tensor_indices(indices, device):
|
||||||
return torch.tensor(indices, dtype=torch.int32, device=device)
|
return torch.tensor(indices, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
|
||||||
def calculate_chunks(offset):
|
def calculate_chunks(offset):
|
||||||
@ -86,6 +93,7 @@ def biggest_single_chunk(offset):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@torch_compile_for_eager
|
||||||
def grouped_pad(tensor_groups, dims, values):
|
def grouped_pad(tensor_groups, dims, values):
|
||||||
grouped_result = []
|
grouped_result = []
|
||||||
for tensors, dim, value in zip(tensor_groups, dims, values):
|
for tensors, dim, value in zip(tensor_groups, dims, values):
|
||||||
@ -101,6 +109,7 @@ def grouped_pad(tensor_groups, dims, values):
|
|||||||
return grouped_result
|
return grouped_result
|
||||||
|
|
||||||
|
|
||||||
|
@torch_compile_for_eager
|
||||||
def roll(tensor, chunk, dim, merge_graphs):
|
def roll(tensor, chunk, dim, merge_graphs):
|
||||||
if dim is None:
|
if dim is None:
|
||||||
return tensor
|
return tensor
|
||||||
@ -110,6 +119,7 @@ def roll(tensor, chunk, dim, merge_graphs):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
@torch_compile_for_eager
|
||||||
def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
|
def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
|
||||||
tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)]
|
tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)]
|
||||||
if merge_graphs:
|
if merge_graphs:
|
||||||
@ -117,6 +127,7 @@ def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
|
|||||||
return tensor_groups
|
return tensor_groups
|
||||||
|
|
||||||
|
|
||||||
|
@torch_compile_for_eager
|
||||||
def grouped_shift(tensor_groups, dims, offset, merge_graphs):
|
def grouped_shift(tensor_groups, dims, offset, merge_graphs):
|
||||||
chunks = calculate_chunks(offset)
|
chunks = calculate_chunks(offset)
|
||||||
for c in chunks:
|
for c in chunks:
|
||||||
@ -124,6 +135,7 @@ def grouped_shift(tensor_groups, dims, offset, merge_graphs):
|
|||||||
return tensor_groups
|
return tensor_groups
|
||||||
|
|
||||||
|
|
||||||
|
@torch_compile_for_eager
|
||||||
def move(dst_tensors, dst_indices, src_tensors):
|
def move(dst_tensors, dst_indices, src_tensors):
|
||||||
bs_dim = 0
|
bs_dim = 0
|
||||||
num_indices = dst_indices.size(0)
|
num_indices = dst_indices.size(0)
|
||||||
@ -139,12 +151,14 @@ def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups):
|
|||||||
move(dst_tensors, dst_indices, src_tensors)
|
move(dst_tensors, dst_indices, src_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
@torch_compile_for_eager
|
||||||
def extend_tensor(tensor, padding, dim):
|
def extend_tensor(tensor, padding, dim):
|
||||||
result = torch.cat([tensor, padding], dim=dim)
|
result = torch.cat([tensor, padding], dim=dim)
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@torch_compile_for_eager
|
||||||
def extend_batch(tensors, target_bs, dim):
|
def extend_batch(tensors, target_bs, dim):
|
||||||
diff = target_bs - tensors[0].size(dim)
|
diff = target_bs - tensors[0].size(dim)
|
||||||
# TODO: add support for shrinking bs
|
# TODO: add support for shrinking bs
|
||||||
@ -162,12 +176,14 @@ def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
|
|||||||
return tensor_groups
|
return tensor_groups
|
||||||
|
|
||||||
|
|
||||||
|
@torch_compile_for_eager
|
||||||
def merge(tensor_group):
|
def merge(tensor_group):
|
||||||
tensor_group = [torch.stack(tensor_group)]
|
tensor_group = [torch.stack(tensor_group)]
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
return tensor_group
|
return tensor_group
|
||||||
|
|
||||||
|
|
||||||
|
@torch_compile_for_eager
|
||||||
def split(tensor_group, clone_data):
|
def split(tensor_group, clone_data):
|
||||||
tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)]
|
tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)]
|
||||||
if clone_data:
|
if clone_data:
|
||||||
@ -638,7 +654,14 @@ class CausalLM(Model):
|
|||||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||||
model = remove_kv_cache_from_output(model)
|
model = remove_kv_cache_from_output(model)
|
||||||
if self.enable_hpu_graph:
|
if self.enable_hpu_graph:
|
||||||
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||||
|
else:
|
||||||
|
if LAZY_MODE == 0:
|
||||||
|
# It is said that "keep_input_mutations" is safe for inference to be done
|
||||||
|
dbg_trace(
|
||||||
|
"TORCH COMPILE", f'Torch compiling of model')
|
||||||
|
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
||||||
|
|
||||||
model = self.setup_quantization(model)
|
model = self.setup_quantization(model)
|
||||||
|
|
||||||
@ -826,7 +849,7 @@ class CausalLM(Model):
|
|||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"token_idx": token_idx
|
"token_idx": token_idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.has_position_ids:
|
if self.has_position_ids:
|
||||||
@ -952,7 +975,7 @@ class CausalLM(Model):
|
|||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
token_idx,
|
token_idx,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||||
@ -963,7 +986,7 @@ class CausalLM(Model):
|
|||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
token_idx,
|
token_idx,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
Loading…
Reference in New Issue
Block a user