updated ipynb for simulating interacting with the server

This commit is contained in:
rsnm2 2023-08-22 04:06:52 +00:00
parent a972a7870f
commit 010739bba1
3 changed files with 16 additions and 11 deletions

View File

@ -90,7 +90,7 @@ class DeepSparseCausalLMBatch:
old_idx = self.requests_idx_mapping[request_id]
requests.append(self.requests[old_idx])
input_ids_list.append(self.input_ids_list[old_idx])
past_key_values_list.append(self.past_key_values[old_idx])
past_key_values_list.append(self.past_key_values_list[old_idx])
# update batch state
self.requests = requests
@ -112,7 +112,7 @@ class DeepSparseCausalLMBatch:
start_index = 0
for i, batch in enumerate(batches):
assert batch.past_key_values_list is None, "only concatenate prefilled batches"
assert batch.past_key_values_list is not None, "only concatenate prefilled batches"
# concatenate request, input_ids, and past_key_values lists
requests.extend(batch.requests)
@ -129,7 +129,7 @@ class DeepSparseCausalLMBatch:
start_index += len(batch)
return cls(
batch_id= batches[0].id,
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list,
@ -210,8 +210,11 @@ class DeepSparseCausalLM:
# check stopping criteria
# simple for now --- should use StoppingCriteria
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
stop = self.should_stop(
num_tokens_processed=len(input_ids) + 1,
num_tokens_processed=input_ids.shape[1] + 1,
generated_token_id = generated_token_id
)

View File

@ -31,4 +31,5 @@ class DecodeRequest:
@dataclass
class FilterBatchRequest:
batch_id: int
batch_id: int
request_ids: List[int]

View File

@ -56,16 +56,17 @@ class DeepSparseService:
def Prefill(
self,
request: PrefillRequest
) -> [List[Generation], CachedBatch]:
) -> [Generation, CachedBatch]:
ds_batch = DeepSparseCausalLMBatch.from_batch(
batch=request.batch,
tokenizer=self.model.tokenizer
)
generations, next_ds_batch = self.model.generate_token(ds_batch)
assert len(generations) == 1
self.cache.set(next_ds_batch)
return generations, next_ds_batch.to_batch()
return generations[0], next_ds_batch.to_batch()
def Decode(
self,
@ -75,16 +76,16 @@ class DeepSparseService:
ds_batches = []
for batch in request.batches:
ds_batch = self.cache.pop(batch.id)
ds_batch = self.cache.pop(batch.batch_id)
assert batch is not None, "Batch ID {batch.id} not found in cache."
ds_batches.append(ds_batch)
if len(ds_batches) > 1:
ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches)
else:
batch = ds_batches[0]
ds_batch = ds_batches[0]
generations, next_ds_batch = self.model.generate_token(ds_batches)
generations, next_ds_batch = self.model.generate_token(ds_batch)
self.cache.set(next_ds_batch)
return generations, next_ds_batch.to_batch()
return generations, next_ds_batch.to_batch() if next_ds_batch else None