Fix for continuous batching (#1)

This commit is contained in:
Karol Damaszke 2023-12-11 09:24:09 +01:00 committed by GitHub
parent e5f124b077
commit 6436ae86a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 5 deletions

View File

@ -287,7 +287,7 @@ class CausalLMBatch(Batch):
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
if kv_tuple: if kv_tuple:
self.past_key_values = [tuple(layer) for layer in self.past_key_values] self.past_key_values = tuple([tuple(layer) for layer in self.past_key_values])
self.requests = requests self.requests = requests
self.requests_idx_mapping = requests_idx_mapping self.requests_idx_mapping = requests_idx_mapping
@ -374,7 +374,7 @@ class CausalLMBatch(Batch):
# input_ids is always of shape [batch_size, 1] # input_ids is always of shape [batch_size, 1]
# We do not need to pad it # We do not need to pad it
if input_ids is None: if input_ids is None:
input_ids = batch.input_ids.new_empty((total_batch_size, 1)) input_ids = batch.input_ids.new_empty((total_batch_size, max_total_tokens))
# Copy to correct indices # Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids input_ids[start_index:end_index] = batch.input_ids
@ -522,6 +522,9 @@ class CausalLMBatch(Batch):
else: else:
past_key_values.append([padded_past_keys, padded_past_values]) past_key_values.append([padded_past_keys, padded_past_values])
if kv_tuple:
past_key_values = tuple(past_key_values)
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
@ -893,11 +896,10 @@ class CausalLM(Model):
batch.read_offsets[i] = read_offset batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length) batch.max_input_length = max(batch.max_input_length, new_input_length)
next_tokens = torch.tensor(next_token_ids, dtype=torch.int64).to(self.device)
if token_idx is None: if token_idx is None:
batch.input_ids[:, 0] = next_tokens[:, 0] batch.input_ids[:, 0] = next_token_ids[:, 0]
else: else:
batch.input_ids[:, token_idx] = next_tokens batch.input_ids[:, token_idx] = next_token_ids
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if stopped: if stopped:
if self.hb_profer_started == True: if self.hb_profer_started == True:

View File

@ -153,6 +153,8 @@ def serve(
data_type = torch.bfloat16 data_type = torch.bfloat16
else: else:
data_type = torch.float data_type = torch.float
if revision == "None":
revision = None
try: try:
model = get_model(model_id, revision=revision, dtype=data_type) model = get_model(model_id, revision=revision, dtype=data_type)
except Exception: except Exception: