mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 10:52:07 +00:00
Fix for continuous batching (#1)
This commit is contained in:
parent
e5f124b077
commit
6436ae86a1
@ -287,7 +287,7 @@ class CausalLMBatch(Batch):
|
|||||||
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||||
|
|
||||||
if kv_tuple:
|
if kv_tuple:
|
||||||
self.past_key_values = [tuple(layer) for layer in self.past_key_values]
|
self.past_key_values = tuple([tuple(layer) for layer in self.past_key_values])
|
||||||
|
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
self.requests_idx_mapping = requests_idx_mapping
|
self.requests_idx_mapping = requests_idx_mapping
|
||||||
@ -374,7 +374,7 @@ class CausalLMBatch(Batch):
|
|||||||
# input_ids is always of shape [batch_size, 1]
|
# input_ids is always of shape [batch_size, 1]
|
||||||
# We do not need to pad it
|
# We do not need to pad it
|
||||||
if input_ids is None:
|
if input_ids is None:
|
||||||
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
|
input_ids = batch.input_ids.new_empty((total_batch_size, max_total_tokens))
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
input_ids[start_index:end_index] = batch.input_ids
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
|
|
||||||
@ -522,6 +522,9 @@ class CausalLMBatch(Batch):
|
|||||||
else:
|
else:
|
||||||
past_key_values.append([padded_past_keys, padded_past_values])
|
past_key_values.append([padded_past_keys, padded_past_values])
|
||||||
|
|
||||||
|
if kv_tuple:
|
||||||
|
past_key_values = tuple(past_key_values)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
@ -893,11 +896,10 @@ class CausalLM(Model):
|
|||||||
batch.read_offsets[i] = read_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||||
|
|
||||||
next_tokens = torch.tensor(next_token_ids, dtype=torch.int64).to(self.device)
|
|
||||||
if token_idx is None:
|
if token_idx is None:
|
||||||
batch.input_ids[:, 0] = next_tokens[:, 0]
|
batch.input_ids[:, 0] = next_token_ids[:, 0]
|
||||||
else:
|
else:
|
||||||
batch.input_ids[:, token_idx] = next_tokens
|
batch.input_ids[:, token_idx] = next_token_ids
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if stopped:
|
if stopped:
|
||||||
if self.hb_profer_started == True:
|
if self.hb_profer_started == True:
|
||||||
|
@ -153,6 +153,8 @@ def serve(
|
|||||||
data_type = torch.bfloat16
|
data_type = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
data_type = torch.float
|
data_type = torch.float
|
||||||
|
if revision == "None":
|
||||||
|
revision = None
|
||||||
try:
|
try:
|
||||||
model = get_model(model_id, revision=revision, dtype=data_type)
|
model = get_model(model_id, revision=revision, dtype=data_type)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
Loading…
Reference in New Issue
Block a user