fix(server): fix flash causal

This commit is contained in:
OlivierDehaene 2023-04-21 19:48:41 +02:00
parent 86bca365df
commit c0df99e704

View File

@ -225,6 +225,7 @@ class FlashCausalLMBatch(Batch):
return FlashCausalLMBatch( return FlashCausalLMBatch(
batch_id=self.batch_id, batch_id=self.batch_id,
past_pad=self.past_pad,
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
@ -311,6 +312,7 @@ class FlashCausalLMBatch(Batch):
return FlashCausalLMBatch( return FlashCausalLMBatch(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
past_pad=batches[0].past_pad,
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,