diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 41129da0..b19170cb 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -341,7 +341,6 @@ class CausalLMBatch(Batch): # move only input_ids, attention_mask and position_ids dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, src_tensors[:3], src_dims[:3], src_indices) self.set_tensor_groups(dst_tensors) - @classmethod def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch": @@ -783,7 +782,6 @@ class CausalLM(Model): generations: List[Generation] = [] prev_batches = [] requests_to_generate = [] - # In order to pipeline any actions on CPU we perform the operation in 3 main stages: # Stage 1. Collect next token ids of any previously started generations for batch_id, batch in enumerate(batches): @@ -1051,23 +1049,28 @@ class CausalLM(Model): return generations, batch if not stopped else None def warmup(self, batches: List[CausalLMBatch]) -> None: - self.shifting_warmup() - if len(batches) < 2: return # prefill - _, prefill_batch = self.generate_token([batches[0]]) + _, prefill_batch = self.generate_token([batches.pop(0)]) # decode _, decode_batch = self.generate_token([prefill_batch]) + # shifts + self.shifting_warmup(decode_batch) # prefill - _, prefill_batch = self.generate_token([batches[1]]) + _, prefill_batch = self.generate_token([batches.pop(0)]) # concatenate and decode _, decode_batch = self.generate_token([decode_batch, prefill_batch]) # decodes while decode_batch is not None: _, decode_batch = self.generate_token([decode_batch]) - def shifting_warmup(self) -> None: - # TODO: add warmup for all possible shift variants - pass + def shifting_warmup(self, batch: CausalLMBatch) -> None: + chunk_sizes = CHUNK_SIZES.copy() + chunk_sizes.extend([-chunk for chunk in chunk_sizes]) + + for chunk in chunk_sizes: + batch.merge_kv_cache_if_needed(batch.batch_size, chunk) + batch.realign(batch.batch_size, chunk, 0) + batch.split_kv_cache_if_needed()