diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 55162186..68086b8c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -449,7 +449,7 @@ class FlashCausalLM(Model): pre_allocate_past_size=pre_allocate_past_size, ) - def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, use_cache: Optional[torch.dtype]): + def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, cache_dtype: Optional[torch.dtype]): diff = max_input_length - max(batch.input_lengths) for i in range(len(batch)): batch.input_lengths[i] += diff @@ -459,7 +459,19 @@ class FlashCausalLM(Model): # TODO: Bug!?! batch.stopping_criterias[i].current_tokens += diff - if use_cache: + if cache_dtype is None: + batch.max_seqlen = max(batch.input_lengths) + batch.all_input_ids_tensor=[] + + batch.input_ids = torch.tensor( + np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int64, device=batch.input_ids.device + ) + batch.position_ids = torch.tensor( + np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int32, device=batch.input_ids.device + ) + batch.cu_seqlens = torch.tensor(np.pad(np.cumsum(batch.input_lengths),(1,0)), device=batch.input_ids.device, dtype=torch.int32) + batch.past_key_values=None + else: assert len(batch.all_input_ids_tensor)>0, "Must run prefill first" batch.input_ids.fill_(self.tokenizer.pad_token_id) batch.position_ids += diff @@ -475,20 +487,8 @@ class FlashCausalLM(Model): batch.past_key_values.shape[0], batch.past_key_values.shape[1] + len(batch.requests), *batch.past_key_values.shape[2:], - ), device=batch.past_key_values.device, dtype= batch.past_key_values.dtype + ), device=batch.past_key_values.device, dtype= cache_dtype ) - else: - batch.max_seqlen = max(batch.input_lengths) - batch.all_input_ids_tensor=[] - - batch.input_ids = torch.tensor( - np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int64, device=batch.input_ids.device - ) - batch.position_ids = torch.tensor( - np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int32, device=batch.input_ids.device - ) - batch.cu_seqlens = torch.tensor(np.pad(np.cumsum(batch.input_lengths),(1,0)), device=batch.input_ids.device, dtype=torch.int32) - batch.past_key_values=None