small refactor

This commit is contained in:
OlivierDehaene 2023-06-30 16:32:23 +02:00
parent c5da6579dc
commit c52e84fe10
7 changed files with 51 additions and 67 deletions

View File

@ -173,8 +173,9 @@ struct Args {
/// for end users. /// for end users.
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
/// The port to listen on. /// The port to listen on.
#[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
/// The name of the socket for gRPC communication between the webserver /// The name of the socket for gRPC communication between the webserver

View File

@ -113,7 +113,7 @@ impl Client {
requests.push(Request { requests.push(Request {
id: 0, id: 0,
// We truncate the input on the server side to be sure that it has the correct size // We truncate the input on the server side to be sure that it has the correct size
inputs: "test".to_string().repeat(max_input_length as usize), inputs: "_test ".to_string().repeat(max_input_length as usize),
truncate: min(max_input_length, max_prefill_tokens - n_tokens), truncate: min(max_input_length, max_prefill_tokens - n_tokens),
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {

View File

@ -5,7 +5,7 @@ vllm:
git clone https://github.com/OlivierDehaene/vllm.git git clone https://github.com/OlivierDehaene/vllm.git
build-vllm: vllm build-vllm: vllm
cd vllm && git fetch && git checkout $(flash_att_commit) cd vllm && git fetch && git checkout $(vllm_commit)
cd vllm && python setup.py build cd vllm && python setup.py build
install-vllm: build-vllm install-vllm: build-vllm

View File

@ -19,7 +19,6 @@ class Cache:
def delete(self, batch_id: int): def delete(self, batch_id: int):
batch = self.pop(batch_id) batch = self.pop(batch_id)
if batch is not None: if batch is not None:
batch.free()
del batch del batch
def clear(self): def clear(self):

View File

@ -63,24 +63,48 @@ class CacheManager:
0, num_blocks * self.block_size, dtype=torch.int32 0, num_blocks * self.block_size, dtype=torch.int32
).view(num_blocks, self.block_size) ).view(num_blocks, self.block_size)
def allocate(self, num_blocks: int) -> Tuple[torch.Tensor, torch.Tensor]: def allocate(self, batch: "FlashCausalLMBatch"):
# Get free blocks indices by finding values in mask that are not set to 0 # Get free blocks indices by finding values in mask that are not set to 0
free_block_indices = self.free_block_mask.nonzero() free_block_indices = self.free_block_mask.nonzero()
assert ( assert (
len(free_block_indices) >= num_blocks len(free_block_indices) >= batch.blocks
), f"Out of available cache blocks: asked {num_blocks}, only {len(free_block_indices)} free blocks" ), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks"
# Slice by the number of required blocks
block_indices = free_block_indices[: batch.blocks]
block_indices = block_indices.flatten()
# Padded block tables
block_tables_tensor = torch.zeros(
(len(batch), batch.max_blocks), dtype=torch.int32
)
# Allocate paged attention blocks
cumulative_blocks = 0
slots = []
block_tables = []
for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots):
# Get allocated blocks for this sequence
allocated_blocks = block_indices[
cumulative_blocks : cumulative_blocks + needed_blocks
]
# Get slots for the allocated blocks
allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]
slots.append(allocated_slots)
block_tables.append(allocated_blocks.tolist())
block_tables_tensor[i, :needed_blocks] = allocated_blocks
batch.needed_blocks_slots = None
batch.block_tables = block_tables
batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device)
batch.slots = torch.concat(slots).to(batch.input_ids.device)
# Allocate the required number of blocks by setting the mask to 0 # Allocate the required number of blocks by setting the mask to 0
block_indices = free_block_indices[:num_blocks]
self.free_block_mask[block_indices] = 0 self.free_block_mask[block_indices] = 0
# Get slots for the allocated blocks
slots = self.slots[block_indices].flatten()
return block_indices.flatten(), slots
def free(self, block_indices: Optional[List[int]]): def free(self, block_indices: Optional[List[int]]):
if block_indices is not None: if block_indices is not None and block_indices:
# Reset mask # Reset mask
self.free_block_mask[block_indices] = 1 self.free_block_mask[block_indices] = 1
@ -448,12 +472,14 @@ class FlashCausalLMBatch(Batch):
max_blocks = max(max_blocks, len(request_block_table)) max_blocks = max(max_blocks, len(request_block_table))
global CACHE_MANAGER global CACHE_MANAGER
block_indices_to_free = []
# Iterate on all requests # Iterate on all requests
for i, r in enumerate(self.requests): for i, r in enumerate(self.requests):
# Filter requests that are not part of the new batch # Filter requests that are not part of the new batch
if r.id not in requests_idx_mapping.keys(): if r.id not in requests_idx_mapping.keys():
# Free blocks block_indices_to_free.extend(self.block_tables[i])
CACHE_MANAGER.free(self.block_tables[i]) # Free blocks
CACHE_MANAGER.free(block_indices_to_free)
# Index into tensors # Index into tensors
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
@ -643,7 +669,7 @@ class FlashCausalLMBatch(Batch):
max_blocks=max_blocks, max_blocks=max_blocks,
) )
def free(self): def __del__(self):
if self.block_tables is not None: if self.block_tables is not None:
global CACHE_MANAGER global CACHE_MANAGER
# Free blocks # Free blocks
@ -703,10 +729,10 @@ class FlashCausalLM(Model):
logger.exception( logger.exception(
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
f"prefill tokens. " f"prefill tokens. "
f"You need to decrease `--max-batch-total-tokens` and `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
) )
raise e raise e
batch.free() del batch
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode( return self.tokenizer.decode(
@ -749,33 +775,8 @@ class FlashCausalLM(Model):
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
if batch.needed_blocks_slots: if batch.needed_blocks_slots:
# Padded block tables # Allocate blocks to this batch
block_tables_tensor = torch.zeros( CACHE_MANAGER.allocate(batch)
(len(batch), batch.max_blocks), dtype=torch.int32
)
# Allocate paged attention blocks
slots = []
block_tables = []
try:
for i, (needed_blocks, needed_slots) in enumerate(
batch.needed_blocks_slots
):
allocated_blocks, allocated_slots = CACHE_MANAGER.allocate(
needed_blocks
)
slots.append(allocated_slots[:needed_slots])
block_tables.append(allocated_blocks.tolist())
block_tables_tensor[i, :needed_blocks] = allocated_blocks
except Exception as e:
for blocks in block_tables:
CACHE_MANAGER.free(blocks)
raise e
batch.needed_blocks_slots = None
batch.block_tables = block_tables
batch.block_tables_tensor = block_tables_tensor.to(self.device)
batch.slots = torch.concat(slots).to(self.device)
out = self.forward( out = self.forward(
batch.input_ids, batch.input_ids,
@ -990,7 +991,7 @@ class FlashCausalLM(Model):
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
if stopped: if stopped:
batch.free() del batch
# No need to return a batch if we know that all requests stopped # No need to return a batch if we know that all requests stopped
return generations, None return generations, None

View File

@ -35,9 +35,6 @@ class Batch(ABC):
def concatenate(cls, batches: List["Batch"]) -> "Batch": def concatenate(cls, batches: List["Batch"]) -> "Batch":
raise NotImplementedError raise NotImplementedError
def free(self):
pass
@abstractmethod @abstractmethod
def __len__(self): def __len__(self):
raise NotImplementedError raise NotImplementedError

View File

@ -65,12 +65,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
try: generations, next_batch = self.model.generate_token(batch)
generations, next_batch = self.model.generate_token(batch)
except Exception as e:
batch.free()
raise e
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.PrefillResponse( return generate_pb2.PrefillResponse(
@ -93,20 +88,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
raise ValueError("All batches are empty") raise ValueError("All batches are empty")
if len(batches) > 1: if len(batches) > 1:
try: batch = self.model.batch_type.concatenate(batches)
batch = self.model.batch_type.concatenate(batches)
except Exception as e:
[batch.free() for batch in batches]
raise e
else: else:
batch = batches[0] batch = batches[0]
try: generations, next_batch = self.model.generate_token(batch)
generations, next_batch = self.model.generate_token(batch)
except Exception as e:
batch.free()
raise e
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(