Fix the generate_stream crash in concurrent query (#105)

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu2017 2024-03-15 17:54:56 +08:00 committed by GitHub
parent 3d81a80577
commit a4d5c3f40f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -380,8 +380,15 @@ class CausalLMBatch(Batch):
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
parameters = [r.data.parameters for r in flat_requests]
if len(flat_requests) < new_bs:
for i in range(new_bs-len(flat_requests)) :
#append the dummy parameters for dummy request
parameters.append(parameters[0])
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
[r.data.parameters for r in flat_requests],
parameters,
batches[dst_batch_idx].next_token_chooser.dtype,
batches[dst_batch_idx].next_token_chooser.device
)
@ -426,13 +433,19 @@ class CausalLMBatch(Batch):
# TODO: Add support for sparse batches
top_n_tokens = [r.top_n_tokens for r in pb.requests]
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], dtype, device)
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
# this means that we cannot shift inputs to the left after a long input sequence
# was filtered out
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
dummy_inputs = ["?"] * (new_bs - len(requests))
parameters = [r.parameters for r in pb.requests]
if len(pb.requests) < new_bs:
for i in range(new_bs-len(pb.requests)) :
#append the dummy parameters for dummy request
parameters.append(parameters[0])
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(parameters, dtype, device)
tokenized_inputs = tokenizer(
[r.data.inputs for r in requests] + dummy_inputs,
return_tensors="pt",