Add warmup for shift operation (#59) (#86)

This commit is contained in:
Karol Damaszke 2024-02-29 09:19:28 +01:00 committed by GitHub
parent 022ce1eaaf
commit 3831f1bed5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -341,7 +341,6 @@ class CausalLMBatch(Batch):
# move only input_ids, attention_mask and position_ids
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, src_tensors[:3], src_dims[:3], src_indices)
self.set_tensor_groups(dst_tensors)
@classmethod
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
@ -783,7 +782,6 @@ class CausalLM(Model):
generations: List[Generation] = []
prev_batches = []
requests_to_generate = []
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
# Stage 1. Collect next token ids of any previously started generations
for batch_id, batch in enumerate(batches):
@ -1051,23 +1049,28 @@ class CausalLM(Model):
return generations, batch if not stopped else None
def warmup(self, batches: List[CausalLMBatch]) -> None:
self.shifting_warmup()
if len(batches) < 2:
return
# prefill
_, prefill_batch = self.generate_token([batches[0]])
_, prefill_batch = self.generate_token([batches.pop(0)])
# decode
_, decode_batch = self.generate_token([prefill_batch])
# shifts
self.shifting_warmup(decode_batch)
# prefill
_, prefill_batch = self.generate_token([batches[1]])
_, prefill_batch = self.generate_token([batches.pop(0)])
# concatenate and decode
_, decode_batch = self.generate_token([decode_batch, prefill_batch])
# decodes
while decode_batch is not None:
_, decode_batch = self.generate_token([decode_batch])
def shifting_warmup(self) -> None:
# TODO: add warmup for all possible shift variants
pass
def shifting_warmup(self, batch: CausalLMBatch) -> None:
chunk_sizes = CHUNK_SIZES.copy()
chunk_sizes.extend([-chunk for chunk in chunk_sizes])
for chunk in chunk_sizes:
batch.merge_kv_cache_if_needed(batch.batch_size, chunk)
batch.realign(batch.batch_size, chunk, 0)
batch.split_kv_cache_if_needed()