Trying to fix non chunking targets.

This commit is contained in:
Nicolas Patry 2024-10-23 15:02:52 +08:00
parent a31db04709
commit 0a01dde986
No known key found for this signature in database
GPG Key ID: 788A1EA699458B2F

View File

@ -1398,22 +1398,32 @@ class FlashCausalLM(Model):
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
if max_total_tokens is None: if max_total_tokens is None:
model_max_length = self.tokenizer.model_max_length if get_support_chunking():
free_memory = get_free_memory(self.device, MEMORY_FRACTION) model_max_length = self.tokenizer.model_max_length
spare_blocks = ( free_memory = get_free_memory(self.device, MEMORY_FRACTION)
# Leave 5% for some wiggle room spare_blocks = (
int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size) # Leave 5% for some wiggle room
+ batch.num_blocks int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
) + batch.num_blocks
spare_blocks = small_power_of_2(spare_blocks) )
spare_blocks = small_power_of_2(spare_blocks)
available_blocks = min(model_max_length, spare_blocks) available_blocks = min(model_max_length, spare_blocks)
batch.num_blocks = available_blocks batch.num_blocks = available_blocks
batch.max_blocks = available_blocks batch.max_blocks = available_blocks
max_input_tokens = ( max_input_tokens = (
available_blocks - 1 if max_input_tokens is None else max_input_tokens available_blocks - 1
) if max_input_tokens is None
max_total_tokens = available_blocks else max_input_tokens
)
max_total_tokens = available_blocks
else:
max_total_tokens = batch.num_blocks
max_input_tokens = (
batch.num_blocks - 1
if max_input_tokens is None
else max_input_tokens
)
try: try:
self.init_kv_cache( self.init_kv_cache(