mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
fix naming
This commit is contained in:
parent
f08a1a50b7
commit
ab4037c640
@ -35,8 +35,8 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
offsets: List[int]
|
prefix_offsets: List[int]
|
||||||
token_offsets: List[int]
|
read_offsets: List[int]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
@ -70,8 +70,8 @@ class CausalLMBatch(Batch):
|
|||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
@ -102,8 +102,8 @@ class CausalLMBatch(Batch):
|
|||||||
).to(device)
|
).to(device)
|
||||||
for _ in pb.requests:
|
for _ in pb.requests:
|
||||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
offsets.append(0)
|
prefix_offsets.append(0)
|
||||||
token_offsets.append(input_len)
|
read_offsets.append(input_len)
|
||||||
|
|
||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
max_input_length = input_lengths.max()
|
max_input_length = input_lengths.max()
|
||||||
@ -132,8 +132,8 @@ class CausalLMBatch(Batch):
|
|||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=list(all_input_ids),
|
all_input_ids=list(all_input_ids),
|
||||||
input_lengths=input_lengths.tolist(),
|
input_lengths=input_lengths.tolist(),
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
@ -153,8 +153,8 @@ class CausalLMBatch(Batch):
|
|||||||
# New values after filtering
|
# New values after filtering
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
max_input_length = 0
|
max_input_length = 0
|
||||||
|
|
||||||
@ -169,8 +169,8 @@ class CausalLMBatch(Batch):
|
|||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
keep_indices.append(idx)
|
keep_indices.append(idx)
|
||||||
|
|
||||||
offsets.append(self.offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
token_offsets.append(self.token_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
|
|
||||||
request_input_length = self.input_lengths[idx]
|
request_input_length = self.input_lengths[idx]
|
||||||
@ -227,8 +227,8 @@ class CausalLMBatch(Batch):
|
|||||||
self.position_ids = position_ids
|
self.position_ids = position_ids
|
||||||
self.all_input_ids = all_input_ids
|
self.all_input_ids = all_input_ids
|
||||||
self.input_lengths = input_lengths
|
self.input_lengths = input_lengths
|
||||||
self.offsets = offsets
|
self.prefix_offsets = prefix_offsets
|
||||||
self.token_offsets = token_offsets
|
self.read_offsets = read_offsets
|
||||||
self.next_token_choosers = next_token_choosers
|
self.next_token_choosers = next_token_choosers
|
||||||
self.stopping_criterias = stopping_criterias
|
self.stopping_criterias = stopping_criterias
|
||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
@ -253,8 +253,8 @@ class CausalLMBatch(Batch):
|
|||||||
requests = []
|
requests = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
@ -272,8 +272,8 @@ class CausalLMBatch(Batch):
|
|||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
offsets.extend(batch.offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
token_offsets.extend(batch.token_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
@ -430,8 +430,8 @@ class CausalLMBatch(Batch):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
@ -529,8 +529,8 @@ class CausalLM(Model):
|
|||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.offsets,
|
batch.prefix_offsets,
|
||||||
batch.token_offsets,
|
batch.read_offsets,
|
||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
@ -541,8 +541,8 @@ class CausalLM(Model):
|
|||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
offset,
|
prefix_offset,
|
||||||
token_offset,
|
read_offset,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
@ -560,8 +560,8 @@ class CausalLM(Model):
|
|||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text, offset, token_offset = self.decode_token(
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
all_input_ids[:, 0], offset, token_offset
|
all_input_ids[:, 0], prefix_offset, read_offset
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
@ -629,8 +629,8 @@ class CausalLM(Model):
|
|||||||
batch.input_ids[i, 0] = next_token_id
|
batch.input_ids[i, 0] = next_token_id
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
batch.offsets[i] = offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.token_offsets[i] = token_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
|
@ -52,8 +52,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
offsets: List[Optional[int]]
|
prefix_offsets: List[Optional[int]]
|
||||||
token_offsets: List[Optional[int]]
|
read_offsets: List[Optional[int]]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
@ -82,8 +82,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
@ -108,8 +108,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
offsets.append(0)
|
prefix_offsets.append(0)
|
||||||
token_offsets.append(input_length)
|
read_offsets.append(input_length)
|
||||||
|
|
||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
@ -151,8 +151,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=[],
|
all_input_ids_tensor=[],
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
@ -190,8 +190,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor = []
|
all_input_ids_tensor = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
@ -222,8 +222,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
|
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
|
||||||
|
|
||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
offsets.append(self.offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
token_offsets.append(self.token_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
|
||||||
next_token_choosers.append(self.next_token_choosers[idx])
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
|
|
||||||
@ -269,8 +269,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
@ -302,8 +302,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor = []
|
all_input_ids_tensor = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
@ -347,8 +347,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
||||||
|
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
offsets.extend(batch.offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
token_offsets.extend(batch.token_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
|
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
@ -374,8 +374,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
@ -640,8 +640,8 @@ class FlashCausalLM(Model):
|
|||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.offsets,
|
batch.prefix_offsets,
|
||||||
batch.token_offsets,
|
batch.read_offsets,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
@ -654,8 +654,8 @@ class FlashCausalLM(Model):
|
|||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
offset,
|
prefix_offset,
|
||||||
token_offset,
|
read_offset,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
@ -670,10 +670,10 @@ class FlashCausalLM(Model):
|
|||||||
all_input_ids.append(next_token_id)
|
all_input_ids.append(next_token_id)
|
||||||
|
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_text, offset, token_offset = self.decode_token(
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
offset,
|
prefix_offset,
|
||||||
token_offset,
|
read_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
@ -739,8 +739,8 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
batch.offsets[i] = offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.token_offsets[i] = token_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.max_seqlen = batch.max_seqlen + 1
|
batch.max_seqlen = batch.max_seqlen + 1
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
@ -94,8 +94,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
@ -106,8 +106,6 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
offsets.append(None)
|
|
||||||
token_offsets.append(None)
|
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
@ -127,6 +125,10 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_truncation,
|
max_length=max_truncation,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
for _ in pb.requests:
|
||||||
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
|
prefix_offsets.append(0)
|
||||||
|
read_offsets.append(input_len)
|
||||||
|
|
||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
max_input_length = input_lengths.max()
|
max_input_length = input_lengths.max()
|
||||||
@ -155,8 +157,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=list(all_input_ids),
|
all_input_ids=list(all_input_ids),
|
||||||
input_lengths=input_lengths.tolist(),
|
input_lengths=input_lengths.tolist(),
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
|
@ -42,8 +42,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
decoder_input_lengths: List[int]
|
decoder_input_lengths: List[int]
|
||||||
offsets: List[int]
|
prefix_offsets: List[int]
|
||||||
token_offsets: List[int]
|
read_offsets: List[int]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
@ -79,8 +79,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
@ -122,8 +122,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
.view(-1, 1)
|
.view(-1, 1)
|
||||||
)
|
)
|
||||||
for _ in pb.requests:
|
for _ in pb.requests:
|
||||||
offsets.append(0)
|
prefix_offsets.append(0)
|
||||||
token_offsets.append(1)
|
read_offsets.append(1)
|
||||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||||
|
|
||||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||||
@ -141,8 +141,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
input_lengths=input_lengths.tolist(),
|
input_lengths=input_lengths.tolist(),
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
@ -166,8 +166,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
all_decoder_input_ids = []
|
all_decoder_input_ids = []
|
||||||
|
|
||||||
@ -185,8 +185,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
keep_indices.append(idx)
|
keep_indices.append(idx)
|
||||||
|
|
||||||
offsets.append(self.offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
token_offsets.append(self.token_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
|
||||||
all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
|
all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
|
||||||
|
|
||||||
@ -249,8 +249,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
self.all_decoder_input_ids = all_decoder_input_ids
|
self.all_decoder_input_ids = all_decoder_input_ids
|
||||||
self.input_lengths = input_lengths
|
self.input_lengths = input_lengths
|
||||||
self.decoder_input_lengths = decoder_input_lengths
|
self.decoder_input_lengths = decoder_input_lengths
|
||||||
self.offsets = offsets
|
self.prefix_offsets = prefix_offsets
|
||||||
self.token_offsets = token_offsets
|
self.read_offsets = read_offsets
|
||||||
self.next_token_choosers = next_token_choosers
|
self.next_token_choosers = next_token_choosers
|
||||||
self.stopping_criterias = stopping_criterias
|
self.stopping_criterias = stopping_criterias
|
||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
@ -284,8 +284,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
all_decoder_input_ids = []
|
all_decoder_input_ids = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
@ -307,8 +307,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
all_decoder_input_ids.extend(batch.all_decoder_input_ids)
|
all_decoder_input_ids.extend(batch.all_decoder_input_ids)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
||||||
offsets.extend(batch.offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
token_offsets.extend(batch.token_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
@ -483,8 +483,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
@ -608,8 +608,8 @@ class Seq2SeqLM(Model):
|
|||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.offsets,
|
batch.prefix_offsets,
|
||||||
batch.token_offsets,
|
batch.read_offsets,
|
||||||
batch.decoder_input_lengths,
|
batch.decoder_input_lengths,
|
||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
@ -621,8 +621,8 @@ class Seq2SeqLM(Model):
|
|||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
offset,
|
prefix_offset,
|
||||||
token_offset,
|
read_offset,
|
||||||
decoder_input_length,
|
decoder_input_length,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
@ -643,8 +643,8 @@ class Seq2SeqLM(Model):
|
|||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text, offset, token_offset = self.decode_token(
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
all_decoder_input_ids, offset, token_offset
|
all_decoder_input_ids, prefix_offset, read_offset
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
@ -702,8 +702,8 @@ class Seq2SeqLM(Model):
|
|||||||
batch.all_decoder_input_ids[i] = all_decoder_input_ids
|
batch.all_decoder_input_ids[i] = all_decoder_input_ids
|
||||||
batch.input_lengths[i] = input_length
|
batch.input_lengths[i] = input_length
|
||||||
batch.decoder_input_lengths[i] = new_decoder_input_length
|
batch.decoder_input_lengths[i] = new_decoder_input_length
|
||||||
batch.offsets[i] = offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.token_offsets[i] = token_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.max_input_length = max(batch.max_input_length, input_length)
|
batch.max_input_length = max(batch.max_input_length, input_length)
|
||||||
batch.max_decoder_input_length = max(
|
batch.max_decoder_input_length = max(
|
||||||
batch.max_decoder_input_length, new_decoder_input_length
|
batch.max_decoder_input_length, new_decoder_input_length
|
||||||
|
Loading…
Reference in New Issue
Block a user