updated ipynb for simulating interacting with the server

This commit is contained in:
rsnm2 2023-08-22 04:06:52 +00:00
parent a972a7870f
commit 010739bba1
3 changed files with 16 additions and 11 deletions

View File

@ -90,7 +90,7 @@ class DeepSparseCausalLMBatch:
old_idx = self.requests_idx_mapping[request_id] old_idx = self.requests_idx_mapping[request_id]
requests.append(self.requests[old_idx]) requests.append(self.requests[old_idx])
input_ids_list.append(self.input_ids_list[old_idx]) input_ids_list.append(self.input_ids_list[old_idx])
past_key_values_list.append(self.past_key_values[old_idx]) past_key_values_list.append(self.past_key_values_list[old_idx])
# update batch state # update batch state
self.requests = requests self.requests = requests
@ -112,7 +112,7 @@ class DeepSparseCausalLMBatch:
start_index = 0 start_index = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
assert batch.past_key_values_list is None, "only concatenate prefilled batches" assert batch.past_key_values_list is not None, "only concatenate prefilled batches"
# concatenate request, input_ids, and past_key_values lists # concatenate request, input_ids, and past_key_values lists
requests.extend(batch.requests) requests.extend(batch.requests)
@ -129,7 +129,7 @@ class DeepSparseCausalLMBatch:
start_index += len(batch) start_index += len(batch)
return cls( return cls(
batch_id= batches[0].id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list, input_ids_list=input_ids_list,
@ -210,8 +210,11 @@ class DeepSparseCausalLM:
# check stopping criteria # check stopping criteria
# simple for now --- should use StoppingCriteria # simple for now --- should use StoppingCriteria
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
stop = self.should_stop( stop = self.should_stop(
num_tokens_processed=len(input_ids) + 1, num_tokens_processed=input_ids.shape[1] + 1,
generated_token_id = generated_token_id generated_token_id = generated_token_id
) )

View File

@ -31,4 +31,5 @@ class DecodeRequest:
@dataclass @dataclass
class FilterBatchRequest: class FilterBatchRequest:
batch_id: int batch_id: int
request_ids: List[int]

View File

@ -56,16 +56,17 @@ class DeepSparseService:
def Prefill( def Prefill(
self, self,
request: PrefillRequest request: PrefillRequest
) -> [List[Generation], CachedBatch]: ) -> [Generation, CachedBatch]:
ds_batch = DeepSparseCausalLMBatch.from_batch( ds_batch = DeepSparseCausalLMBatch.from_batch(
batch=request.batch, batch=request.batch,
tokenizer=self.model.tokenizer tokenizer=self.model.tokenizer
) )
generations, next_ds_batch = self.model.generate_token(ds_batch) generations, next_ds_batch = self.model.generate_token(ds_batch)
assert len(generations) == 1
self.cache.set(next_ds_batch) self.cache.set(next_ds_batch)
return generations, next_ds_batch.to_batch() return generations[0], next_ds_batch.to_batch()
def Decode( def Decode(
self, self,
@ -75,16 +76,16 @@ class DeepSparseService:
ds_batches = [] ds_batches = []
for batch in request.batches: for batch in request.batches:
ds_batch = self.cache.pop(batch.id) ds_batch = self.cache.pop(batch.batch_id)
assert batch is not None, "Batch ID {batch.id} not found in cache." assert batch is not None, "Batch ID {batch.id} not found in cache."
ds_batches.append(ds_batch) ds_batches.append(ds_batch)
if len(ds_batches) > 1: if len(ds_batches) > 1:
ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches) ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches)
else: else:
batch = ds_batches[0] ds_batch = ds_batches[0]
generations, next_ds_batch = self.model.generate_token(ds_batches) generations, next_ds_batch = self.model.generate_token(ds_batch)
self.cache.set(next_ds_batch) self.cache.set(next_ds_batch)
return generations, next_ds_batch.to_batch() return generations, next_ds_batch.to_batch() if next_ds_batch else None