mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
small refactor
This commit is contained in:
parent
c5da6579dc
commit
c52e84fe10
@ -173,8 +173,9 @@ struct Args {
|
||||
/// for end users.
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
|
||||
/// The port to listen on.
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
|
||||
/// The name of the socket for gRPC communication between the webserver
|
||||
|
@ -113,7 +113,7 @@ impl Client {
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
inputs: "test".to_string().repeat(max_input_length as usize),
|
||||
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
||||
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
|
@ -5,7 +5,7 @@ vllm:
|
||||
git clone https://github.com/OlivierDehaene/vllm.git
|
||||
|
||||
build-vllm: vllm
|
||||
cd vllm && git fetch && git checkout $(flash_att_commit)
|
||||
cd vllm && git fetch && git checkout $(vllm_commit)
|
||||
cd vllm && python setup.py build
|
||||
|
||||
install-vllm: build-vllm
|
||||
|
@ -19,7 +19,6 @@ class Cache:
|
||||
def delete(self, batch_id: int):
|
||||
batch = self.pop(batch_id)
|
||||
if batch is not None:
|
||||
batch.free()
|
||||
del batch
|
||||
|
||||
def clear(self):
|
||||
|
@ -63,24 +63,48 @@ class CacheManager:
|
||||
0, num_blocks * self.block_size, dtype=torch.int32
|
||||
).view(num_blocks, self.block_size)
|
||||
|
||||
def allocate(self, num_blocks: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def allocate(self, batch: "FlashCausalLMBatch"):
|
||||
# Get free blocks indices by finding values in mask that are not set to 0
|
||||
free_block_indices = self.free_block_mask.nonzero()
|
||||
assert (
|
||||
len(free_block_indices) >= num_blocks
|
||||
), f"Out of available cache blocks: asked {num_blocks}, only {len(free_block_indices)} free blocks"
|
||||
len(free_block_indices) >= batch.blocks
|
||||
), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks"
|
||||
|
||||
# Slice by the number of required blocks
|
||||
block_indices = free_block_indices[: batch.blocks]
|
||||
block_indices = block_indices.flatten()
|
||||
|
||||
# Padded block tables
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(batch), batch.max_blocks), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Allocate paged attention blocks
|
||||
cumulative_blocks = 0
|
||||
slots = []
|
||||
block_tables = []
|
||||
for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots):
|
||||
# Get allocated blocks for this sequence
|
||||
allocated_blocks = block_indices[
|
||||
cumulative_blocks : cumulative_blocks + needed_blocks
|
||||
]
|
||||
# Get slots for the allocated blocks
|
||||
allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]
|
||||
|
||||
slots.append(allocated_slots)
|
||||
block_tables.append(allocated_blocks.tolist())
|
||||
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
||||
|
||||
batch.needed_blocks_slots = None
|
||||
batch.block_tables = block_tables
|
||||
batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device)
|
||||
batch.slots = torch.concat(slots).to(batch.input_ids.device)
|
||||
|
||||
# Allocate the required number of blocks by setting the mask to 0
|
||||
block_indices = free_block_indices[:num_blocks]
|
||||
self.free_block_mask[block_indices] = 0
|
||||
|
||||
# Get slots for the allocated blocks
|
||||
slots = self.slots[block_indices].flatten()
|
||||
|
||||
return block_indices.flatten(), slots
|
||||
|
||||
def free(self, block_indices: Optional[List[int]]):
|
||||
if block_indices is not None:
|
||||
if block_indices is not None and block_indices:
|
||||
# Reset mask
|
||||
self.free_block_mask[block_indices] = 1
|
||||
|
||||
@ -448,12 +472,14 @@ class FlashCausalLMBatch(Batch):
|
||||
max_blocks = max(max_blocks, len(request_block_table))
|
||||
|
||||
global CACHE_MANAGER
|
||||
block_indices_to_free = []
|
||||
# Iterate on all requests
|
||||
for i, r in enumerate(self.requests):
|
||||
# Filter requests that are not part of the new batch
|
||||
if r.id not in requests_idx_mapping.keys():
|
||||
block_indices_to_free.extend(self.block_tables[i])
|
||||
# Free blocks
|
||||
CACHE_MANAGER.free(self.block_tables[i])
|
||||
CACHE_MANAGER.free(block_indices_to_free)
|
||||
|
||||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
@ -643,7 +669,7 @@ class FlashCausalLMBatch(Batch):
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
|
||||
def free(self):
|
||||
def __del__(self):
|
||||
if self.block_tables is not None:
|
||||
global CACHE_MANAGER
|
||||
# Free blocks
|
||||
@ -703,10 +729,10 @@ class FlashCausalLM(Model):
|
||||
logger.exception(
|
||||
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
||||
f"prefill tokens. "
|
||||
f"You need to decrease `--max-batch-total-tokens` and `--max-batch-prefill-tokens`"
|
||||
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
|
||||
)
|
||||
raise e
|
||||
batch.free()
|
||||
del batch
|
||||
|
||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
||||
return self.tokenizer.decode(
|
||||
@ -749,33 +775,8 @@ class FlashCausalLM(Model):
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
if batch.needed_blocks_slots:
|
||||
# Padded block tables
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(batch), batch.max_blocks), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Allocate paged attention blocks
|
||||
slots = []
|
||||
block_tables = []
|
||||
try:
|
||||
for i, (needed_blocks, needed_slots) in enumerate(
|
||||
batch.needed_blocks_slots
|
||||
):
|
||||
allocated_blocks, allocated_slots = CACHE_MANAGER.allocate(
|
||||
needed_blocks
|
||||
)
|
||||
slots.append(allocated_slots[:needed_slots])
|
||||
block_tables.append(allocated_blocks.tolist())
|
||||
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
||||
except Exception as e:
|
||||
for blocks in block_tables:
|
||||
CACHE_MANAGER.free(blocks)
|
||||
raise e
|
||||
|
||||
batch.needed_blocks_slots = None
|
||||
batch.block_tables = block_tables
|
||||
batch.block_tables_tensor = block_tables_tensor.to(self.device)
|
||||
batch.slots = torch.concat(slots).to(self.device)
|
||||
# Allocate blocks to this batch
|
||||
CACHE_MANAGER.allocate(batch)
|
||||
|
||||
out = self.forward(
|
||||
batch.input_ids,
|
||||
@ -990,7 +991,7 @@ class FlashCausalLM(Model):
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
|
||||
if stopped:
|
||||
batch.free()
|
||||
del batch
|
||||
# No need to return a batch if we know that all requests stopped
|
||||
return generations, None
|
||||
|
||||
|
@ -35,9 +35,6 @@ class Batch(ABC):
|
||||
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
||||
raise NotImplementedError
|
||||
|
||||
def free(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
@ -65,12 +65,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
try:
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
except Exception as e:
|
||||
batch.free()
|
||||
raise e
|
||||
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.PrefillResponse(
|
||||
@ -93,20 +88,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
raise ValueError("All batches are empty")
|
||||
|
||||
if len(batches) > 1:
|
||||
try:
|
||||
batch = self.model.batch_type.concatenate(batches)
|
||||
except Exception as e:
|
||||
[batch.free() for batch in batches]
|
||||
raise e
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
try:
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
except Exception as e:
|
||||
batch.free()
|
||||
raise e
|
||||
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
|
Loading…
Reference in New Issue
Block a user