[Torch.compile] Enable llama-2-7b (#157)

Co-authored-by: Jacek Czaja <jczaja@habana.ai>
This commit is contained in:
Jacek Czaja 2024-06-14 15:56:23 +02:00 committed by GitHub
parent 3bf8e8e466
commit ef86232c94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 6 deletions

View File

@ -95,7 +95,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.past_key_values is None
assert all(
[
torch.equal(input_ids, request.all_input_ids[:batch.input_length + 1, 0])
torch.equal(input_ids.to('cpu'), request.all_input_ids[:batch.input_length + 1, 0])
for input_ids, request in zip(batch.input_ids, batch.requests)
]
)

View File

@ -11,12 +11,12 @@ import time
from typing import Dict, List, Optional, Tuple, Type
import torch
import torch._dynamo
from loguru import logger
from opentelemetry import trace
import text_generation_server.habana_quantization_env as hq_env
import habana_frameworks.torch as htorch
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from optimum.habana.utils import HabanaProfile
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
from optimum.habana.checkpoint_utils import (
@ -58,6 +58,13 @@ BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
def torch_compile_for_eager(func):
if LAZY_MODE == 1:
return func
return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True})
def round_up(number, k):
@ -65,7 +72,7 @@ def round_up(number, k):
def to_tensor_indices(indices, device):
return torch.tensor(indices, dtype=torch.int32, device=device)
return torch.tensor(indices, dtype=torch.long, device=device)
def calculate_chunks(offset):
@ -86,6 +93,7 @@ def biggest_single_chunk(offset):
return 0
@torch_compile_for_eager
def grouped_pad(tensor_groups, dims, values):
grouped_result = []
for tensors, dim, value in zip(tensor_groups, dims, values):
@ -101,6 +109,7 @@ def grouped_pad(tensor_groups, dims, values):
return grouped_result
@torch_compile_for_eager
def roll(tensor, chunk, dim, merge_graphs):
if dim is None:
return tensor
@ -110,6 +119,7 @@ def roll(tensor, chunk, dim, merge_graphs):
return tensor
@torch_compile_for_eager
def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)]
if merge_graphs:
@ -117,6 +127,7 @@ def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
return tensor_groups
@torch_compile_for_eager
def grouped_shift(tensor_groups, dims, offset, merge_graphs):
chunks = calculate_chunks(offset)
for c in chunks:
@ -124,6 +135,7 @@ def grouped_shift(tensor_groups, dims, offset, merge_graphs):
return tensor_groups
@torch_compile_for_eager
def move(dst_tensors, dst_indices, src_tensors):
bs_dim = 0
num_indices = dst_indices.size(0)
@ -139,12 +151,14 @@ def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups):
move(dst_tensors, dst_indices, src_tensors)
@torch_compile_for_eager
def extend_tensor(tensor, padding, dim):
result = torch.cat([tensor, padding], dim=dim)
htorch.core.mark_step()
return result
@torch_compile_for_eager
def extend_batch(tensors, target_bs, dim):
diff = target_bs - tensors[0].size(dim)
# TODO: add support for shrinking bs
@ -162,12 +176,14 @@ def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
return tensor_groups
@torch_compile_for_eager
def merge(tensor_group):
tensor_group = [torch.stack(tensor_group)]
htorch.core.mark_step()
return tensor_group
@torch_compile_for_eager
def split(tensor_group, clone_data):
tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)]
if clone_data:
@ -638,7 +654,14 @@ class CausalLM(Model):
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
else:
if LAZY_MODE == 0:
# It is said that "keep_input_mutations" is safe for inference to be done
dbg_trace(
"TORCH COMPILE", f'Torch compiling of model')
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
model = self.setup_quantization(model)
@ -826,7 +849,7 @@ class CausalLM(Model):
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"token_idx": token_idx
"token_idx": token_idx,
}
if self.has_position_ids:
@ -952,7 +975,7 @@ class CausalLM(Model):
batch.position_ids,
token_idx,
batch.past_key_values,
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
)
else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
@ -963,7 +986,7 @@ class CausalLM(Model):
batch.position_ids,
token_idx,
batch.past_key_values,
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
)
htorch.core.mark_step()