mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
parent
022ce1eaaf
commit
3831f1bed5
@ -341,7 +341,6 @@ class CausalLMBatch(Batch):
|
||||
# move only input_ids, attention_mask and position_ids
|
||||
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, src_tensors[:3], src_dims[:3], src_indices)
|
||||
self.set_tensor_groups(dst_tensors)
|
||||
|
||||
|
||||
@classmethod
|
||||
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
|
||||
@ -783,7 +782,6 @@ class CausalLM(Model):
|
||||
generations: List[Generation] = []
|
||||
prev_batches = []
|
||||
requests_to_generate = []
|
||||
|
||||
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
|
||||
# Stage 1. Collect next token ids of any previously started generations
|
||||
for batch_id, batch in enumerate(batches):
|
||||
@ -1051,23 +1049,28 @@ class CausalLM(Model):
|
||||
return generations, batch if not stopped else None
|
||||
|
||||
def warmup(self, batches: List[CausalLMBatch]) -> None:
|
||||
self.shifting_warmup()
|
||||
|
||||
if len(batches) < 2:
|
||||
return
|
||||
|
||||
# prefill
|
||||
_, prefill_batch = self.generate_token([batches[0]])
|
||||
_, prefill_batch = self.generate_token([batches.pop(0)])
|
||||
# decode
|
||||
_, decode_batch = self.generate_token([prefill_batch])
|
||||
# shifts
|
||||
self.shifting_warmup(decode_batch)
|
||||
# prefill
|
||||
_, prefill_batch = self.generate_token([batches[1]])
|
||||
_, prefill_batch = self.generate_token([batches.pop(0)])
|
||||
# concatenate and decode
|
||||
_, decode_batch = self.generate_token([decode_batch, prefill_batch])
|
||||
# decodes
|
||||
while decode_batch is not None:
|
||||
_, decode_batch = self.generate_token([decode_batch])
|
||||
|
||||
def shifting_warmup(self) -> None:
|
||||
# TODO: add warmup for all possible shift variants
|
||||
pass
|
||||
def shifting_warmup(self, batch: CausalLMBatch) -> None:
|
||||
chunk_sizes = CHUNK_SIZES.copy()
|
||||
chunk_sizes.extend([-chunk for chunk in chunk_sizes])
|
||||
|
||||
for chunk in chunk_sizes:
|
||||
batch.merge_kv_cache_if_needed(batch.batch_size, chunk)
|
||||
batch.realign(batch.batch_size, chunk, 0)
|
||||
batch.split_kv_cache_if_needed()
|
||||
|
Loading…
Reference in New Issue
Block a user