text-generation-inference/server/tests/utils/test_qkv_batch_speedup.py

204 lines
6.0 KiB
Python

import torch
from types import SimpleNamespace
from typing import List, Union, Dict, Optional
import time
from pathlib import Path
from text_generation_server.utils.weights import Weights
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaAttention,
)
dummy_file_system = {
"file1": {
"prefix.self_attn.q_proj.weight": torch.rand(
4096, 4096
), # (hidden_size * num_heads, hidden_size)
"prefix.self_attn.k_proj.weight": torch.rand(4096, 4096),
"prefix.self_attn.v_proj.weight": torch.rand(4096, 4096),
"prefix.self_attn.o_proj.weight": torch.rand(
4096, 4096
), # (hidden_size, hidden_size * num_heads)
"prefix.self_attn.q_proj.bias": torch.rand(4096),
"prefix.self_attn.k_proj.bias": torch.rand(4096),
"prefix.self_attn.v_proj.bias": torch.rand(4096),
"prefix.self_attn.o_proj.bias": torch.rand(4096),
},
}
class MockSlice:
def __init__(self, tensor):
self.tensor = tensor
def get_shape(self):
return self.tensor.shape
def __getitem__(self, idx):
return self.tensor[idx]
def mock_get_slice(tensor_name, filename):
tensor = dummy_file_system[filename][tensor_name]
return MockSlice(tensor)
def mock_handle(filename, device, dtype):
return SimpleNamespace(
get_slice=lambda tensor_name: mock_get_slice(tensor_name, filename)
)
class MockSafeOpen:
def __init__(self, filename, framework, dummy_fs):
self.filename = filename
self.framework = framework
self.dummy_fs = dummy_fs
def keys(self):
return list(self.dummy_fs[self.filename].keys())
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class MockWeights(Weights):
def __init__(
self,
filenames: List[Union[Path, str]],
device,
dtype,
process_group,
dummy_fs,
aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None,
):
routing = {}
self.dummy_fs = dummy_fs
for filename in filenames:
with MockSafeOpen(filename, framework="pytorch", dummy_fs=dummy_fs) as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
if aliases is None:
aliases = {}
self.aliases = aliases
self.routing = routing
self.device = device
self.dtype = dtype
self.process_group = process_group
self.prefix = prefix
self._handles = {}
def _get_handle(self, filename: Union[Path, str]):
if filename in self._handles:
return self._handles[filename]
else:
handle = mock_handle(filename, self.device, self.dtype)
self._handles[filename] = handle
return handle
def get_shape(self, tensor_name: str):
filename, _ = self.get_filename(tensor_name)
handle = self._get_handle(filename)
return handle.get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str):
filename, _ = self.get_filename(tensor_name)
handle = self._get_handle(filename)
return handle.get_slice(tensor_name).tensor
dummy_process_group = SimpleNamespace(rank=lambda: 0, size=lambda: 1)
def run_test(attention, input_tensor, num_warmup=0, num_iterations=1):
torch.cuda.empty_cache()
torch.cuda.synchronize()
# warm-up
for _ in range(num_warmup):
attention.query_key_value(input_tensor, None)
torch.cuda.synchronize()
# timed
times = []
for _ in range(num_iterations):
torch.cuda.synchronize()
start_time = time.perf_counter()
attention.query_key_value(input_tensor, None)
torch.cuda.synchronize()
end_time = time.perf_counter()
times.append((end_time - start_time) * 1000) # milliseconds
return sum(times) / len(times)
def test_qkv_batch_speedup(capsys):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
hidden_size = 4096
num_heads = 32
config = SimpleNamespace(
num_attention_heads=num_heads,
hidden_size=hidden_size,
rope_theta=10000.0,
num_key_value_heads=num_heads,
model_type="llama",
quantize=None,
)
weights = MockWeights(
filenames=["file1"],
device=device,
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
attention = FlashLlamaAttention(
index=0,
prefix="prefix.self_attn",
config=config,
weights=weights,
)
# sequence with various odd lengths
sequence_lengths = [3, 35, 365, 3_501, 11_111]
batch_sizes = [*range(16)]
with capsys.disabled():
# allow printing
for sequence_length in sequence_lengths:
print(f"Testing with sequence length: {sequence_length}")
print(
f"{'Batch Size':<10} {'Batch Time (ms)':<20} {'Sequential Time (ms)':<25} {'Speedup':<10}"
)
print("-" * 65)
for batch_size in batch_sizes:
# batch
batch_input = torch.rand(
sequence_length * batch_size, hidden_size, device=device
)
batch_time = run_test(attention, batch_input)
# sequential
sequential_time = 0
for index in range(batch_size):
single_input = batch_input[
index * sequence_length : (index + 1) * sequence_length
]
sequential_time += run_test(attention, single_input)
speedup = sequential_time / batch_time
print(
f"{batch_size:<10} {batch_time:<20.2f} {sequential_time:<25.2f} {speedup:<10.2f}"
)