mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix
This commit is contained in:
parent
a515fbde4c
commit
72eefa3612
@ -449,7 +449,7 @@ class FlashCausalLM(Model):
|
||||
pre_allocate_past_size=pre_allocate_past_size,
|
||||
)
|
||||
|
||||
def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, use_cache: Optional[torch.dtype]):
|
||||
def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, cache_dtype: Optional[torch.dtype]):
|
||||
diff = max_input_length - max(batch.input_lengths)
|
||||
for i in range(len(batch)):
|
||||
batch.input_lengths[i] += diff
|
||||
@ -459,7 +459,19 @@ class FlashCausalLM(Model):
|
||||
# TODO: Bug!?!
|
||||
batch.stopping_criterias[i].current_tokens += diff
|
||||
|
||||
if use_cache:
|
||||
if cache_dtype is None:
|
||||
batch.max_seqlen = max(batch.input_lengths)
|
||||
batch.all_input_ids_tensor=[]
|
||||
|
||||
batch.input_ids = torch.tensor(
|
||||
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int64, device=batch.input_ids.device
|
||||
)
|
||||
batch.position_ids = torch.tensor(
|
||||
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int32, device=batch.input_ids.device
|
||||
)
|
||||
batch.cu_seqlens = torch.tensor(np.pad(np.cumsum(batch.input_lengths),(1,0)), device=batch.input_ids.device, dtype=torch.int32)
|
||||
batch.past_key_values=None
|
||||
else:
|
||||
assert len(batch.all_input_ids_tensor)>0, "Must run prefill first"
|
||||
batch.input_ids.fill_(self.tokenizer.pad_token_id)
|
||||
batch.position_ids += diff
|
||||
@ -475,20 +487,8 @@ class FlashCausalLM(Model):
|
||||
batch.past_key_values.shape[0],
|
||||
batch.past_key_values.shape[1] + len(batch.requests),
|
||||
*batch.past_key_values.shape[2:],
|
||||
), device=batch.past_key_values.device, dtype= batch.past_key_values.dtype
|
||||
), device=batch.past_key_values.device, dtype= cache_dtype
|
||||
)
|
||||
else:
|
||||
batch.max_seqlen = max(batch.input_lengths)
|
||||
batch.all_input_ids_tensor=[]
|
||||
|
||||
batch.input_ids = torch.tensor(
|
||||
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int64, device=batch.input_ids.device
|
||||
)
|
||||
batch.position_ids = torch.tensor(
|
||||
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int32, device=batch.input_ids.device
|
||||
)
|
||||
batch.cu_seqlens = torch.tensor(np.pad(np.cumsum(batch.input_lengths),(1,0)), device=batch.input_ids.device, dtype=torch.int32)
|
||||
batch.past_key_values=None
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user