mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-28 05:22:07 +00:00
parent
022ce1eaaf
commit
3831f1bed5
@ -341,7 +341,6 @@ class CausalLMBatch(Batch):
|
|||||||
# move only input_ids, attention_mask and position_ids
|
# move only input_ids, attention_mask and position_ids
|
||||||
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, src_tensors[:3], src_dims[:3], src_indices)
|
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, src_tensors[:3], src_dims[:3], src_indices)
|
||||||
self.set_tensor_groups(dst_tensors)
|
self.set_tensor_groups(dst_tensors)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
|
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
|
||||||
@ -783,7 +782,6 @@ class CausalLM(Model):
|
|||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
prev_batches = []
|
prev_batches = []
|
||||||
requests_to_generate = []
|
requests_to_generate = []
|
||||||
|
|
||||||
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
|
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
|
||||||
# Stage 1. Collect next token ids of any previously started generations
|
# Stage 1. Collect next token ids of any previously started generations
|
||||||
for batch_id, batch in enumerate(batches):
|
for batch_id, batch in enumerate(batches):
|
||||||
@ -1051,23 +1049,28 @@ class CausalLM(Model):
|
|||||||
return generations, batch if not stopped else None
|
return generations, batch if not stopped else None
|
||||||
|
|
||||||
def warmup(self, batches: List[CausalLMBatch]) -> None:
|
def warmup(self, batches: List[CausalLMBatch]) -> None:
|
||||||
self.shifting_warmup()
|
|
||||||
|
|
||||||
if len(batches) < 2:
|
if len(batches) < 2:
|
||||||
return
|
return
|
||||||
|
|
||||||
# prefill
|
# prefill
|
||||||
_, prefill_batch = self.generate_token([batches[0]])
|
_, prefill_batch = self.generate_token([batches.pop(0)])
|
||||||
# decode
|
# decode
|
||||||
_, decode_batch = self.generate_token([prefill_batch])
|
_, decode_batch = self.generate_token([prefill_batch])
|
||||||
|
# shifts
|
||||||
|
self.shifting_warmup(decode_batch)
|
||||||
# prefill
|
# prefill
|
||||||
_, prefill_batch = self.generate_token([batches[1]])
|
_, prefill_batch = self.generate_token([batches.pop(0)])
|
||||||
# concatenate and decode
|
# concatenate and decode
|
||||||
_, decode_batch = self.generate_token([decode_batch, prefill_batch])
|
_, decode_batch = self.generate_token([decode_batch, prefill_batch])
|
||||||
# decodes
|
# decodes
|
||||||
while decode_batch is not None:
|
while decode_batch is not None:
|
||||||
_, decode_batch = self.generate_token([decode_batch])
|
_, decode_batch = self.generate_token([decode_batch])
|
||||||
|
|
||||||
def shifting_warmup(self) -> None:
|
def shifting_warmup(self, batch: CausalLMBatch) -> None:
|
||||||
# TODO: add warmup for all possible shift variants
|
chunk_sizes = CHUNK_SIZES.copy()
|
||||||
pass
|
chunk_sizes.extend([-chunk for chunk in chunk_sizes])
|
||||||
|
|
||||||
|
for chunk in chunk_sizes:
|
||||||
|
batch.merge_kv_cache_if_needed(batch.batch_size, chunk)
|
||||||
|
batch.realign(batch.batch_size, chunk, 0)
|
||||||
|
batch.split_kv_cache_if_needed()
|
||||||
|
Loading…
Reference in New Issue
Block a user