diff --git a/server/deepsparse/deepsparse_causal_lm.py b/server/deepsparse/deepsparse_causal_lm.py index 6a32a10d..2afeb028 100644 --- a/server/deepsparse/deepsparse_causal_lm.py +++ b/server/deepsparse/deepsparse_causal_lm.py @@ -90,7 +90,7 @@ class DeepSparseCausalLMBatch: old_idx = self.requests_idx_mapping[request_id] requests.append(self.requests[old_idx]) input_ids_list.append(self.input_ids_list[old_idx]) - past_key_values_list.append(self.past_key_values[old_idx]) + past_key_values_list.append(self.past_key_values_list[old_idx]) # update batch state self.requests = requests @@ -112,7 +112,7 @@ class DeepSparseCausalLMBatch: start_index = 0 for i, batch in enumerate(batches): - assert batch.past_key_values_list is None, "only concatenate prefilled batches" + assert batch.past_key_values_list is not None, "only concatenate prefilled batches" # concatenate request, input_ids, and past_key_values lists requests.extend(batch.requests) @@ -129,7 +129,7 @@ class DeepSparseCausalLMBatch: start_index += len(batch) return cls( - batch_id= batches[0].id, + batch_id=batches[0].batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids_list=input_ids_list, @@ -210,8 +210,11 @@ class DeepSparseCausalLM: # check stopping criteria # simple for now --- should use StoppingCriteria + assert len(input_ids.shape) == 2 + assert input_ids.shape[0] == 1 + stop = self.should_stop( - num_tokens_processed=len(input_ids) + 1, + num_tokens_processed=input_ids.shape[1] + 1, generated_token_id = generated_token_id ) diff --git a/server/deepsparse/deepsparse_requests.py b/server/deepsparse/deepsparse_requests.py index 106b87b6..524793b0 100644 --- a/server/deepsparse/deepsparse_requests.py +++ b/server/deepsparse/deepsparse_requests.py @@ -31,4 +31,5 @@ class DecodeRequest: @dataclass class FilterBatchRequest: - batch_id: int \ No newline at end of file + batch_id: int + request_ids: List[int] \ No newline at end of file diff --git a/server/deepsparse/deepsparse_service.py b/server/deepsparse/deepsparse_service.py index f16c4d6c..9458a8a6 100644 --- a/server/deepsparse/deepsparse_service.py +++ b/server/deepsparse/deepsparse_service.py @@ -56,16 +56,17 @@ class DeepSparseService: def Prefill( self, request: PrefillRequest - ) -> [List[Generation], CachedBatch]: + ) -> [Generation, CachedBatch]: ds_batch = DeepSparseCausalLMBatch.from_batch( batch=request.batch, tokenizer=self.model.tokenizer ) generations, next_ds_batch = self.model.generate_token(ds_batch) + assert len(generations) == 1 self.cache.set(next_ds_batch) - return generations, next_ds_batch.to_batch() + return generations[0], next_ds_batch.to_batch() def Decode( self, @@ -75,16 +76,16 @@ class DeepSparseService: ds_batches = [] for batch in request.batches: - ds_batch = self.cache.pop(batch.id) + ds_batch = self.cache.pop(batch.batch_id) assert batch is not None, "Batch ID {batch.id} not found in cache." ds_batches.append(ds_batch) if len(ds_batches) > 1: ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches) else: - batch = ds_batches[0] + ds_batch = ds_batches[0] - generations, next_ds_batch = self.model.generate_token(ds_batches) + generations, next_ds_batch = self.model.generate_token(ds_batch) self.cache.set(next_ds_batch) - return generations, next_ds_batch.to_batch() \ No newline at end of file + return generations, next_ds_batch.to_batch() if next_ds_batch else None \ No newline at end of file