mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
small refactor
This commit is contained in:
parent
c5da6579dc
commit
c52e84fe10
@ -173,8 +173,9 @@ struct Args {
|
|||||||
/// for end users.
|
/// for end users.
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
|
||||||
/// The port to listen on.
|
/// The port to listen on.
|
||||||
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
/// The name of the socket for gRPC communication between the webserver
|
/// The name of the socket for gRPC communication between the webserver
|
||||||
|
@ -113,7 +113,7 @@ impl Client {
|
|||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
// We truncate the input on the server side to be sure that it has the correct size
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
inputs: "test".to_string().repeat(max_input_length as usize),
|
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
||||||
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
|
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
@ -5,7 +5,7 @@ vllm:
|
|||||||
git clone https://github.com/OlivierDehaene/vllm.git
|
git clone https://github.com/OlivierDehaene/vllm.git
|
||||||
|
|
||||||
build-vllm: vllm
|
build-vllm: vllm
|
||||||
cd vllm && git fetch && git checkout $(flash_att_commit)
|
cd vllm && git fetch && git checkout $(vllm_commit)
|
||||||
cd vllm && python setup.py build
|
cd vllm && python setup.py build
|
||||||
|
|
||||||
install-vllm: build-vllm
|
install-vllm: build-vllm
|
||||||
|
@ -19,7 +19,6 @@ class Cache:
|
|||||||
def delete(self, batch_id: int):
|
def delete(self, batch_id: int):
|
||||||
batch = self.pop(batch_id)
|
batch = self.pop(batch_id)
|
||||||
if batch is not None:
|
if batch is not None:
|
||||||
batch.free()
|
|
||||||
del batch
|
del batch
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
|
@ -63,24 +63,48 @@ class CacheManager:
|
|||||||
0, num_blocks * self.block_size, dtype=torch.int32
|
0, num_blocks * self.block_size, dtype=torch.int32
|
||||||
).view(num_blocks, self.block_size)
|
).view(num_blocks, self.block_size)
|
||||||
|
|
||||||
def allocate(self, num_blocks: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
def allocate(self, batch: "FlashCausalLMBatch"):
|
||||||
# Get free blocks indices by finding values in mask that are not set to 0
|
# Get free blocks indices by finding values in mask that are not set to 0
|
||||||
free_block_indices = self.free_block_mask.nonzero()
|
free_block_indices = self.free_block_mask.nonzero()
|
||||||
assert (
|
assert (
|
||||||
len(free_block_indices) >= num_blocks
|
len(free_block_indices) >= batch.blocks
|
||||||
), f"Out of available cache blocks: asked {num_blocks}, only {len(free_block_indices)} free blocks"
|
), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks"
|
||||||
|
|
||||||
|
# Slice by the number of required blocks
|
||||||
|
block_indices = free_block_indices[: batch.blocks]
|
||||||
|
block_indices = block_indices.flatten()
|
||||||
|
|
||||||
|
# Padded block tables
|
||||||
|
block_tables_tensor = torch.zeros(
|
||||||
|
(len(batch), batch.max_blocks), dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allocate paged attention blocks
|
||||||
|
cumulative_blocks = 0
|
||||||
|
slots = []
|
||||||
|
block_tables = []
|
||||||
|
for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots):
|
||||||
|
# Get allocated blocks for this sequence
|
||||||
|
allocated_blocks = block_indices[
|
||||||
|
cumulative_blocks : cumulative_blocks + needed_blocks
|
||||||
|
]
|
||||||
|
# Get slots for the allocated blocks
|
||||||
|
allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]
|
||||||
|
|
||||||
|
slots.append(allocated_slots)
|
||||||
|
block_tables.append(allocated_blocks.tolist())
|
||||||
|
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
||||||
|
|
||||||
|
batch.needed_blocks_slots = None
|
||||||
|
batch.block_tables = block_tables
|
||||||
|
batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device)
|
||||||
|
batch.slots = torch.concat(slots).to(batch.input_ids.device)
|
||||||
|
|
||||||
# Allocate the required number of blocks by setting the mask to 0
|
# Allocate the required number of blocks by setting the mask to 0
|
||||||
block_indices = free_block_indices[:num_blocks]
|
|
||||||
self.free_block_mask[block_indices] = 0
|
self.free_block_mask[block_indices] = 0
|
||||||
|
|
||||||
# Get slots for the allocated blocks
|
|
||||||
slots = self.slots[block_indices].flatten()
|
|
||||||
|
|
||||||
return block_indices.flatten(), slots
|
|
||||||
|
|
||||||
def free(self, block_indices: Optional[List[int]]):
|
def free(self, block_indices: Optional[List[int]]):
|
||||||
if block_indices is not None:
|
if block_indices is not None and block_indices:
|
||||||
# Reset mask
|
# Reset mask
|
||||||
self.free_block_mask[block_indices] = 1
|
self.free_block_mask[block_indices] = 1
|
||||||
|
|
||||||
@ -448,12 +472,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_blocks = max(max_blocks, len(request_block_table))
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
|
||||||
global CACHE_MANAGER
|
global CACHE_MANAGER
|
||||||
|
block_indices_to_free = []
|
||||||
# Iterate on all requests
|
# Iterate on all requests
|
||||||
for i, r in enumerate(self.requests):
|
for i, r in enumerate(self.requests):
|
||||||
# Filter requests that are not part of the new batch
|
# Filter requests that are not part of the new batch
|
||||||
if r.id not in requests_idx_mapping.keys():
|
if r.id not in requests_idx_mapping.keys():
|
||||||
# Free blocks
|
block_indices_to_free.extend(self.block_tables[i])
|
||||||
CACHE_MANAGER.free(self.block_tables[i])
|
# Free blocks
|
||||||
|
CACHE_MANAGER.free(block_indices_to_free)
|
||||||
|
|
||||||
# Index into tensors
|
# Index into tensors
|
||||||
input_ids = self.input_ids[indices]
|
input_ids = self.input_ids[indices]
|
||||||
@ -643,7 +669,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
def free(self):
|
def __del__(self):
|
||||||
if self.block_tables is not None:
|
if self.block_tables is not None:
|
||||||
global CACHE_MANAGER
|
global CACHE_MANAGER
|
||||||
# Free blocks
|
# Free blocks
|
||||||
@ -703,10 +729,10 @@ class FlashCausalLM(Model):
|
|||||||
logger.exception(
|
logger.exception(
|
||||||
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
||||||
f"prefill tokens. "
|
f"prefill tokens. "
|
||||||
f"You need to decrease `--max-batch-total-tokens` and `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
batch.free()
|
del batch
|
||||||
|
|
||||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
||||||
return self.tokenizer.decode(
|
return self.tokenizer.decode(
|
||||||
@ -749,33 +775,8 @@ class FlashCausalLM(Model):
|
|||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
if batch.needed_blocks_slots:
|
if batch.needed_blocks_slots:
|
||||||
# Padded block tables
|
# Allocate blocks to this batch
|
||||||
block_tables_tensor = torch.zeros(
|
CACHE_MANAGER.allocate(batch)
|
||||||
(len(batch), batch.max_blocks), dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
# Allocate paged attention blocks
|
|
||||||
slots = []
|
|
||||||
block_tables = []
|
|
||||||
try:
|
|
||||||
for i, (needed_blocks, needed_slots) in enumerate(
|
|
||||||
batch.needed_blocks_slots
|
|
||||||
):
|
|
||||||
allocated_blocks, allocated_slots = CACHE_MANAGER.allocate(
|
|
||||||
needed_blocks
|
|
||||||
)
|
|
||||||
slots.append(allocated_slots[:needed_slots])
|
|
||||||
block_tables.append(allocated_blocks.tolist())
|
|
||||||
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
|
||||||
except Exception as e:
|
|
||||||
for blocks in block_tables:
|
|
||||||
CACHE_MANAGER.free(blocks)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
batch.needed_blocks_slots = None
|
|
||||||
batch.block_tables = block_tables
|
|
||||||
batch.block_tables_tensor = block_tables_tensor.to(self.device)
|
|
||||||
batch.slots = torch.concat(slots).to(self.device)
|
|
||||||
|
|
||||||
out = self.forward(
|
out = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
@ -990,7 +991,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
|
|
||||||
if stopped:
|
if stopped:
|
||||||
batch.free()
|
del batch
|
||||||
# No need to return a batch if we know that all requests stopped
|
# No need to return a batch if we know that all requests stopped
|
||||||
return generations, None
|
return generations, None
|
||||||
|
|
||||||
|
@ -35,9 +35,6 @@ class Batch(ABC):
|
|||||||
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def free(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -65,12 +65,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
|
||||||
except Exception as e:
|
|
||||||
batch.free()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.PrefillResponse(
|
return generate_pb2.PrefillResponse(
|
||||||
@ -93,20 +88,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
raise ValueError("All batches are empty")
|
raise ValueError("All batches are empty")
|
||||||
|
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
try:
|
batch = self.model.batch_type.concatenate(batches)
|
||||||
batch = self.model.batch_type.concatenate(batches)
|
|
||||||
except Exception as e:
|
|
||||||
[batch.free() for batch in batches]
|
|
||||||
raise e
|
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
try:
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
|
||||||
except Exception as e:
|
|
||||||
batch.free()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.DecodeResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
|
Loading…
Reference in New Issue
Block a user