text-generation-inference/server/tests/utils/test_qkv_batch_speedup.py

204 lines
6.0 KiB
Python
Raw Permalink Normal View History

import torch
from types import SimpleNamespace
from typing import List, Union, Dict, Optional
import time
from pathlib import Path
from text_generation_server.utils.weights import Weights
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaAttention,
)
dummy_file_system = {
"file1": {
"prefix.self_attn.q_proj.weight": torch.rand(
4096, 4096
), # (hidden_size * num_heads, hidden_size)
"prefix.self_attn.k_proj.weight": torch.rand(4096, 4096),
"prefix.self_attn.v_proj.weight": torch.rand(4096, 4096),
"prefix.self_attn.o_proj.weight": torch.rand(
4096, 4096
), # (hidden_size, hidden_size * num_heads)
"prefix.self_attn.q_proj.bias": torch.rand(4096),
"prefix.self_attn.k_proj.bias": torch.rand(4096),
"prefix.self_attn.v_proj.bias": torch.rand(4096),
"prefix.self_attn.o_proj.bias": torch.rand(4096),
},
}
class MockSlice:
def __init__(self, tensor):
self.tensor = tensor
def get_shape(self):
return self.tensor.shape
def __getitem__(self, idx):
return self.tensor[idx]
def mock_get_slice(tensor_name, filename):
tensor = dummy_file_system[filename][tensor_name]
return MockSlice(tensor)
def mock_handle(filename, device, dtype):
return SimpleNamespace(
get_slice=lambda tensor_name: mock_get_slice(tensor_name, filename)
)
class MockSafeOpen:
def __init__(self, filename, framework, dummy_fs):
self.filename = filename
self.framework = framework
self.dummy_fs = dummy_fs
def keys(self):
return list(self.dummy_fs[self.filename].keys())
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class MockWeights(Weights):
def __init__(
self,
filenames: List[Union[Path, str]],
device,
dtype,
process_group,
dummy_fs,
aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None,
):
routing = {}
self.dummy_fs = dummy_fs
for filename in filenames:
with MockSafeOpen(filename, framework="pytorch", dummy_fs=dummy_fs) as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
if aliases is None:
aliases = {}
self.aliases = aliases
self.routing = routing
self.device = device
self.dtype = dtype
self.process_group = process_group
self.prefix = prefix
self._handles = {}
def _get_handle(self, filename: Union[Path, str]):
if filename in self._handles:
return self._handles[filename]
else:
handle = mock_handle(filename, self.device, self.dtype)
self._handles[filename] = handle
return handle
def get_shape(self, tensor_name: str):
filename, _ = self.get_filename(tensor_name)
handle = self._get_handle(filename)
return handle.get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str):
filename, _ = self.get_filename(tensor_name)
handle = self._get_handle(filename)
return handle.get_slice(tensor_name).tensor
dummy_process_group = SimpleNamespace(rank=lambda: 0, size=lambda: 1)
def run_test(attention, input_tensor, num_warmup=0, num_iterations=1):
torch.cuda.empty_cache()
torch.cuda.synchronize()
# warm-up
for _ in range(num_warmup):
attention.query_key_value(input_tensor, None)
torch.cuda.synchronize()
# timed
times = []
for _ in range(num_iterations):
torch.cuda.synchronize()
start_time = time.perf_counter()
attention.query_key_value(input_tensor, None)
torch.cuda.synchronize()
end_time = time.perf_counter()
times.append((end_time - start_time) * 1000) # milliseconds
return sum(times) / len(times)
def test_qkv_batch_speedup(capsys):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
hidden_size = 4096
num_heads = 32
config = SimpleNamespace(
num_attention_heads=num_heads,
hidden_size=hidden_size,
rope_theta=10000.0,
num_key_value_heads=num_heads,
model_type="llama",
quantize=None,
)
weights = MockWeights(
filenames=["file1"],
device=device,
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
attention = FlashLlamaAttention(
index=0,
prefix="prefix.self_attn",
config=config,
weights=weights,
)
# sequence with various odd lengths
sequence_lengths = [3, 35, 365, 3_501, 11_111]
batch_sizes = [*range(16)]
with capsys.disabled():
# allow printing
for sequence_length in sequence_lengths:
print(f"Testing with sequence length: {sequence_length}")
print(
f"{'Batch Size':<10} {'Batch Time (ms)':<20} {'Sequential Time (ms)':<25} {'Speedup':<10}"
)
print("-" * 65)
for batch_size in batch_sizes:
# batch
batch_input = torch.rand(
sequence_length * batch_size, hidden_size, device=device
)
batch_time = run_test(attention, batch_input)
# sequential
sequential_time = 0
for index in range(batch_size):
single_input = batch_input[
index * sequence_length : (index + 1) * sequence_length
]
sequential_time += run_test(attention, single_input)
speedup = sequential_time / batch_time
print(
f"{batch_size:<10} {batch_time:<20.2f} {sequential_time:<25.2f} {speedup:<10.2f}"
)