mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-08-02 05:10:23 +00:00
[Torch.compile] Enable llama-2-7b (#157)
Co-authored-by: Jacek Czaja <jczaja@habana.ai>
This commit is contained in:
parent
3bf8e8e466
commit
ef86232c94
@ -95,7 +95,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
||||
assert batch.past_key_values is None
|
||||
assert all(
|
||||
[
|
||||
torch.equal(input_ids, request.all_input_ids[:batch.input_length + 1, 0])
|
||||
torch.equal(input_ids.to('cpu'), request.all_input_ids[:batch.input_length + 1, 0])
|
||||
for input_ids, request in zip(batch.input_ids, batch.requests)
|
||||
]
|
||||
)
|
||||
|
@ -11,12 +11,12 @@ import time
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from loguru import logger
|
||||
from opentelemetry import trace
|
||||
|
||||
import text_generation_server.habana_quantization_env as hq_env
|
||||
import habana_frameworks.torch as htorch
|
||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||
from optimum.habana.utils import HabanaProfile
|
||||
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||
from optimum.habana.checkpoint_utils import (
|
||||
@ -58,6 +58,13 @@ BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
|
||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
||||
|
||||
|
||||
def torch_compile_for_eager(func):
|
||||
if LAZY_MODE == 1:
|
||||
return func
|
||||
return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True})
|
||||
|
||||
|
||||
def round_up(number, k):
|
||||
@ -65,7 +72,7 @@ def round_up(number, k):
|
||||
|
||||
|
||||
def to_tensor_indices(indices, device):
|
||||
return torch.tensor(indices, dtype=torch.int32, device=device)
|
||||
return torch.tensor(indices, dtype=torch.long, device=device)
|
||||
|
||||
|
||||
def calculate_chunks(offset):
|
||||
@ -86,6 +93,7 @@ def biggest_single_chunk(offset):
|
||||
return 0
|
||||
|
||||
|
||||
@torch_compile_for_eager
|
||||
def grouped_pad(tensor_groups, dims, values):
|
||||
grouped_result = []
|
||||
for tensors, dim, value in zip(tensor_groups, dims, values):
|
||||
@ -101,6 +109,7 @@ def grouped_pad(tensor_groups, dims, values):
|
||||
return grouped_result
|
||||
|
||||
|
||||
@torch_compile_for_eager
|
||||
def roll(tensor, chunk, dim, merge_graphs):
|
||||
if dim is None:
|
||||
return tensor
|
||||
@ -110,6 +119,7 @@ def roll(tensor, chunk, dim, merge_graphs):
|
||||
return tensor
|
||||
|
||||
|
||||
@torch_compile_for_eager
|
||||
def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
|
||||
tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)]
|
||||
if merge_graphs:
|
||||
@ -117,6 +127,7 @@ def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
|
||||
return tensor_groups
|
||||
|
||||
|
||||
@torch_compile_for_eager
|
||||
def grouped_shift(tensor_groups, dims, offset, merge_graphs):
|
||||
chunks = calculate_chunks(offset)
|
||||
for c in chunks:
|
||||
@ -124,6 +135,7 @@ def grouped_shift(tensor_groups, dims, offset, merge_graphs):
|
||||
return tensor_groups
|
||||
|
||||
|
||||
@torch_compile_for_eager
|
||||
def move(dst_tensors, dst_indices, src_tensors):
|
||||
bs_dim = 0
|
||||
num_indices = dst_indices.size(0)
|
||||
@ -139,12 +151,14 @@ def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups):
|
||||
move(dst_tensors, dst_indices, src_tensors)
|
||||
|
||||
|
||||
@torch_compile_for_eager
|
||||
def extend_tensor(tensor, padding, dim):
|
||||
result = torch.cat([tensor, padding], dim=dim)
|
||||
htorch.core.mark_step()
|
||||
return result
|
||||
|
||||
|
||||
@torch_compile_for_eager
|
||||
def extend_batch(tensors, target_bs, dim):
|
||||
diff = target_bs - tensors[0].size(dim)
|
||||
# TODO: add support for shrinking bs
|
||||
@ -162,12 +176,14 @@ def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
|
||||
return tensor_groups
|
||||
|
||||
|
||||
@torch_compile_for_eager
|
||||
def merge(tensor_group):
|
||||
tensor_group = [torch.stack(tensor_group)]
|
||||
htorch.core.mark_step()
|
||||
return tensor_group
|
||||
|
||||
|
||||
@torch_compile_for_eager
|
||||
def split(tensor_group, clone_data):
|
||||
tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)]
|
||||
if clone_data:
|
||||
@ -638,7 +654,14 @@ class CausalLM(Model):
|
||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||
model = remove_kv_cache_from_output(model)
|
||||
if self.enable_hpu_graph:
|
||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||
else:
|
||||
if LAZY_MODE == 0:
|
||||
# It is said that "keep_input_mutations" is safe for inference to be done
|
||||
dbg_trace(
|
||||
"TORCH COMPILE", f'Torch compiling of model')
|
||||
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
||||
|
||||
model = self.setup_quantization(model)
|
||||
|
||||
@ -826,7 +849,7 @@ class CausalLM(Model):
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"token_idx": token_idx
|
||||
"token_idx": token_idx,
|
||||
}
|
||||
|
||||
if self.has_position_ids:
|
||||
@ -952,7 +975,7 @@ class CausalLM(Model):
|
||||
batch.position_ids,
|
||||
token_idx,
|
||||
batch.past_key_values,
|
||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
||||
)
|
||||
else:
|
||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||
@ -963,7 +986,7 @@ class CausalLM(Model):
|
||||
batch.position_ids,
|
||||
token_idx,
|
||||
batch.past_key_values,
|
||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
||||
)
|
||||
|
||||
htorch.core.mark_step()
|
||||
|
Loading…
Reference in New Issue
Block a user