mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
204 lines
6.0 KiB
Python
204 lines
6.0 KiB
Python
import torch
|
|
from types import SimpleNamespace
|
|
from typing import List, Union, Dict, Optional
|
|
import time
|
|
from pathlib import Path
|
|
from text_generation_server.utils.weights import Weights
|
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
|
FlashLlamaAttention,
|
|
)
|
|
|
|
|
|
dummy_file_system = {
|
|
"file1": {
|
|
"prefix.self_attn.q_proj.weight": torch.rand(
|
|
4096, 4096
|
|
), # (hidden_size * num_heads, hidden_size)
|
|
"prefix.self_attn.k_proj.weight": torch.rand(4096, 4096),
|
|
"prefix.self_attn.v_proj.weight": torch.rand(4096, 4096),
|
|
"prefix.self_attn.o_proj.weight": torch.rand(
|
|
4096, 4096
|
|
), # (hidden_size, hidden_size * num_heads)
|
|
"prefix.self_attn.q_proj.bias": torch.rand(4096),
|
|
"prefix.self_attn.k_proj.bias": torch.rand(4096),
|
|
"prefix.self_attn.v_proj.bias": torch.rand(4096),
|
|
"prefix.self_attn.o_proj.bias": torch.rand(4096),
|
|
},
|
|
}
|
|
|
|
|
|
class MockSlice:
|
|
def __init__(self, tensor):
|
|
self.tensor = tensor
|
|
|
|
def get_shape(self):
|
|
return self.tensor.shape
|
|
|
|
def __getitem__(self, idx):
|
|
return self.tensor[idx]
|
|
|
|
|
|
def mock_get_slice(tensor_name, filename):
|
|
tensor = dummy_file_system[filename][tensor_name]
|
|
return MockSlice(tensor)
|
|
|
|
|
|
def mock_handle(filename, device, dtype):
|
|
return SimpleNamespace(
|
|
get_slice=lambda tensor_name: mock_get_slice(tensor_name, filename)
|
|
)
|
|
|
|
|
|
class MockSafeOpen:
|
|
def __init__(self, filename, framework, dummy_fs):
|
|
self.filename = filename
|
|
self.framework = framework
|
|
self.dummy_fs = dummy_fs
|
|
|
|
def keys(self):
|
|
return list(self.dummy_fs[self.filename].keys())
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
pass
|
|
|
|
|
|
class MockWeights(Weights):
|
|
def __init__(
|
|
self,
|
|
filenames: List[Union[Path, str]],
|
|
device,
|
|
dtype,
|
|
process_group,
|
|
dummy_fs,
|
|
aliases: Optional[Dict[str, List[str]]] = None,
|
|
prefix: Optional[str] = None,
|
|
):
|
|
routing = {}
|
|
self.dummy_fs = dummy_fs
|
|
for filename in filenames:
|
|
with MockSafeOpen(filename, framework="pytorch", dummy_fs=dummy_fs) as f:
|
|
for k in f.keys():
|
|
if k in routing:
|
|
raise RuntimeError(
|
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
|
)
|
|
routing[k] = filename
|
|
if aliases is None:
|
|
aliases = {}
|
|
self.aliases = aliases
|
|
self.routing = routing
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.process_group = process_group
|
|
self.prefix = prefix
|
|
self._handles = {}
|
|
|
|
def _get_handle(self, filename: Union[Path, str]):
|
|
if filename in self._handles:
|
|
return self._handles[filename]
|
|
else:
|
|
handle = mock_handle(filename, self.device, self.dtype)
|
|
self._handles[filename] = handle
|
|
return handle
|
|
|
|
def get_shape(self, tensor_name: str):
|
|
filename, _ = self.get_filename(tensor_name)
|
|
handle = self._get_handle(filename)
|
|
return handle.get_slice(tensor_name).get_shape()
|
|
|
|
def get_tensor(self, tensor_name: str):
|
|
filename, _ = self.get_filename(tensor_name)
|
|
handle = self._get_handle(filename)
|
|
return handle.get_slice(tensor_name).tensor
|
|
|
|
|
|
dummy_process_group = SimpleNamespace(rank=lambda: 0, size=lambda: 1)
|
|
|
|
|
|
def run_test(attention, input_tensor, num_warmup=0, num_iterations=1):
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
# warm-up
|
|
for _ in range(num_warmup):
|
|
attention.query_key_value(input_tensor, None)
|
|
torch.cuda.synchronize()
|
|
|
|
# timed
|
|
times = []
|
|
for _ in range(num_iterations):
|
|
torch.cuda.synchronize()
|
|
start_time = time.perf_counter()
|
|
attention.query_key_value(input_tensor, None)
|
|
torch.cuda.synchronize()
|
|
end_time = time.perf_counter()
|
|
times.append((end_time - start_time) * 1000) # milliseconds
|
|
|
|
return sum(times) / len(times)
|
|
|
|
|
|
def test_qkv_batch_speedup(capsys):
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
hidden_size = 4096
|
|
num_heads = 32
|
|
|
|
config = SimpleNamespace(
|
|
num_attention_heads=num_heads,
|
|
hidden_size=hidden_size,
|
|
rope_theta=10000.0,
|
|
num_key_value_heads=num_heads,
|
|
model_type="llama",
|
|
quantize=None,
|
|
)
|
|
|
|
weights = MockWeights(
|
|
filenames=["file1"],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
process_group=dummy_process_group,
|
|
dummy_fs=dummy_file_system,
|
|
)
|
|
|
|
attention = FlashLlamaAttention(
|
|
index=0,
|
|
prefix="prefix.self_attn",
|
|
config=config,
|
|
weights=weights,
|
|
)
|
|
|
|
# sequence with various odd lengths
|
|
sequence_lengths = [3, 35, 365, 3_501, 11_111]
|
|
batch_sizes = [*range(16)]
|
|
|
|
with capsys.disabled():
|
|
# allow printing
|
|
for sequence_length in sequence_lengths:
|
|
print(f"Testing with sequence length: {sequence_length}")
|
|
print(
|
|
f"{'Batch Size':<10} {'Batch Time (ms)':<20} {'Sequential Time (ms)':<25} {'Speedup':<10}"
|
|
)
|
|
print("-" * 65)
|
|
|
|
for batch_size in batch_sizes:
|
|
# batch
|
|
batch_input = torch.rand(
|
|
sequence_length * batch_size, hidden_size, device=device
|
|
)
|
|
batch_time = run_test(attention, batch_input)
|
|
|
|
# sequential
|
|
sequential_time = 0
|
|
for index in range(batch_size):
|
|
single_input = batch_input[
|
|
index * sequence_length : (index + 1) * sequence_length
|
|
]
|
|
sequential_time += run_test(attention, single_input)
|
|
|
|
speedup = sequential_time / batch_time
|
|
print(
|
|
f"{batch_size:<10} {batch_time:<20.2f} {sequential_time:<25.2f} {speedup:<10.2f}"
|
|
)
|