diff --git a/server/tests/utils/test_qkv_batch_speedup.py b/server/tests/utils/test_qkv_batch_speedup.py new file mode 100644 index 00000000..d20212d4 --- /dev/null +++ b/server/tests/utils/test_qkv_batch_speedup.py @@ -0,0 +1,203 @@ +import torch +from types import SimpleNamespace +from typing import List, Union, Dict, Optional +import time +from pathlib import Path +from text_generation_server.utils.weights import Weights +from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaAttention, +) + + +dummy_file_system = { + "file1": { + "prefix.self_attn.q_proj.weight": torch.rand( + 4096, 4096 + ), # (hidden_size * num_heads, hidden_size) + "prefix.self_attn.k_proj.weight": torch.rand(4096, 4096), + "prefix.self_attn.v_proj.weight": torch.rand(4096, 4096), + "prefix.self_attn.o_proj.weight": torch.rand( + 4096, 4096 + ), # (hidden_size, hidden_size * num_heads) + "prefix.self_attn.q_proj.bias": torch.rand(4096), + "prefix.self_attn.k_proj.bias": torch.rand(4096), + "prefix.self_attn.v_proj.bias": torch.rand(4096), + "prefix.self_attn.o_proj.bias": torch.rand(4096), + }, +} + + +class MockSlice: + def __init__(self, tensor): + self.tensor = tensor + + def get_shape(self): + return self.tensor.shape + + def __getitem__(self, idx): + return self.tensor[idx] + + +def mock_get_slice(tensor_name, filename): + tensor = dummy_file_system[filename][tensor_name] + return MockSlice(tensor) + + +def mock_handle(filename, device, dtype): + return SimpleNamespace( + get_slice=lambda tensor_name: mock_get_slice(tensor_name, filename) + ) + + +class MockSafeOpen: + def __init__(self, filename, framework, dummy_fs): + self.filename = filename + self.framework = framework + self.dummy_fs = dummy_fs + + def keys(self): + return list(self.dummy_fs[self.filename].keys()) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class MockWeights(Weights): + def __init__( + self, + filenames: List[Union[Path, str]], + device, + dtype, + process_group, + dummy_fs, + aliases: Optional[Dict[str, List[str]]] = None, + prefix: Optional[str] = None, + ): + routing = {} + self.dummy_fs = dummy_fs + for filename in filenames: + with MockSafeOpen(filename, framework="pytorch", dummy_fs=dummy_fs) as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + if aliases is None: + aliases = {} + self.aliases = aliases + self.routing = routing + self.device = device + self.dtype = dtype + self.process_group = process_group + self.prefix = prefix + self._handles = {} + + def _get_handle(self, filename: Union[Path, str]): + if filename in self._handles: + return self._handles[filename] + else: + handle = mock_handle(filename, self.device, self.dtype) + self._handles[filename] = handle + return handle + + def get_shape(self, tensor_name: str): + filename, _ = self.get_filename(tensor_name) + handle = self._get_handle(filename) + return handle.get_slice(tensor_name).get_shape() + + def get_tensor(self, tensor_name: str): + filename, _ = self.get_filename(tensor_name) + handle = self._get_handle(filename) + return handle.get_slice(tensor_name).tensor + + +dummy_process_group = SimpleNamespace(rank=lambda: 0, size=lambda: 1) + + +def run_test(attention, input_tensor, num_warmup=0, num_iterations=1): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # warm-up + for _ in range(num_warmup): + attention.query_key_value(input_tensor, None) + torch.cuda.synchronize() + + # timed + times = [] + for _ in range(num_iterations): + torch.cuda.synchronize() + start_time = time.perf_counter() + attention.query_key_value(input_tensor, None) + torch.cuda.synchronize() + end_time = time.perf_counter() + times.append((end_time - start_time) * 1000) # milliseconds + + return sum(times) / len(times) + + +def test_qkv_batch_speedup(capsys): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + hidden_size = 4096 + num_heads = 32 + + config = SimpleNamespace( + num_attention_heads=num_heads, + hidden_size=hidden_size, + rope_theta=10000.0, + num_key_value_heads=num_heads, + model_type="llama", + quantize=None, + ) + + weights = MockWeights( + filenames=["file1"], + device=device, + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + attention = FlashLlamaAttention( + index=0, + prefix="prefix.self_attn", + config=config, + weights=weights, + ) + + # sequence with various odd lengths + sequence_lengths = [3, 35, 365, 3_501, 11_111] + batch_sizes = [*range(16)] + + with capsys.disabled(): + # allow printing + for sequence_length in sequence_lengths: + print(f"Testing with sequence length: {sequence_length}") + print( + f"{'Batch Size':<10} {'Batch Time (ms)':<20} {'Sequential Time (ms)':<25} {'Speedup':<10}" + ) + print("-" * 65) + + for batch_size in batch_sizes: + # batch + batch_input = torch.rand( + sequence_length * batch_size, hidden_size, device=device + ) + batch_time = run_test(attention, batch_input) + + # sequential + sequential_time = 0 + for index in range(batch_size): + single_input = batch_input[ + index * sequence_length : (index + 1) * sequence_length + ] + sequential_time += run_test(attention, single_input) + + speedup = sequential_time / batch_time + print( + f"{batch_size:<10} {batch_time:<20.2f} {sequential_time:<25.2f} {speedup:<10.2f}" + )