mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
better decode
This commit is contained in:
parent
783bc64f47
commit
b5233f9c3c
@ -75,7 +75,7 @@ async fn generate_runs(
|
|||||||
// Warmups on batch size
|
// Warmups on batch size
|
||||||
for _ in 0..warmups {
|
for _ in 0..warmups {
|
||||||
let (_, decode_batch) =
|
let (_, decode_batch) =
|
||||||
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?;
|
||||||
let _ = decode(decode_batch, &mut client).await?;
|
let _ = decode(decode_batch, &mut client).await?;
|
||||||
// Send warmup message
|
// Send warmup message
|
||||||
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
|
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
|
||||||
@ -83,7 +83,7 @@ async fn generate_runs(
|
|||||||
|
|
||||||
for _ in 0..n_runs {
|
for _ in 0..n_runs {
|
||||||
let (prefill, decode_batch) =
|
let (prefill, decode_batch) =
|
||||||
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?;
|
||||||
// Send prefill message
|
// Send prefill message
|
||||||
run_sender
|
run_sender
|
||||||
.send(Ok(Message::Prefill(prefill)))
|
.send(Ok(Message::Prefill(prefill)))
|
||||||
@ -110,6 +110,7 @@ async fn generate_runs(
|
|||||||
// Run a prefill step
|
// Run a prefill step
|
||||||
async fn prefill(
|
async fn prefill(
|
||||||
sequence: String,
|
sequence: String,
|
||||||
|
sequence_length: u32,
|
||||||
batch_size: u32,
|
batch_size: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
@ -119,6 +120,7 @@ async fn prefill(
|
|||||||
.map(|id| Request {
|
.map(|id| Request {
|
||||||
id: id.into(),
|
id: id.into(),
|
||||||
inputs: sequence.clone(),
|
inputs: sequence.clone(),
|
||||||
|
truncate: sequence_length,
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
top_k: 0,
|
top_k: 0,
|
||||||
|
@ -7,9 +7,6 @@ export const options = {
|
|||||||
{duration: '2m', target: 100},
|
{duration: '2m', target: 100},
|
||||||
{duration: '1m', target: 0},
|
{duration: '1m', target: 0},
|
||||||
],
|
],
|
||||||
hosts: {
|
|
||||||
'text-generation-inference.huggingface.co': '127.0.0.1:3000',
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
const SLEEP_DURATION = 1;
|
const SLEEP_DURATION = 1;
|
||||||
|
|
||||||
@ -29,7 +26,7 @@ function greedy_example(inputs, max_new_tokens, name) {
|
|||||||
name: name
|
name: name
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
return http.post('http://text-generation-inference.huggingface.co/generate', body, params);
|
return http.post('https://open-assistant.ngrok.io/generate', body, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sample_example(inputs, max_new_tokens, name) {
|
function sample_example(inputs, max_new_tokens, name) {
|
||||||
@ -50,7 +47,7 @@ function sample_example(inputs, max_new_tokens, name) {
|
|||||||
name: name
|
name: name
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
return http.post('http://text-generation-inference.huggingface.co/generate', body, params);
|
return http.post('https://open-assistant.ngrok.io/generate', body, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function () {
|
export default function () {
|
||||||
|
@ -35,6 +35,7 @@ class CausalLMBatch(Batch):
|
|||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
offsets: List[Optional[int]]
|
offsets: List[Optional[int]]
|
||||||
|
token_offsets: List[Optional[int]]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
@ -66,6 +67,7 @@ class CausalLMBatch(Batch):
|
|||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
offsets = []
|
offsets = []
|
||||||
|
token_offsets = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
@ -73,6 +75,7 @@ class CausalLMBatch(Batch):
|
|||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
offsets.append(None)
|
offsets.append(None)
|
||||||
|
token_offsets.append(None)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
@ -117,6 +120,7 @@ class CausalLMBatch(Batch):
|
|||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
input_lengths=input_lengths.tolist(),
|
input_lengths=input_lengths.tolist(),
|
||||||
offsets=offsets,
|
offsets=offsets,
|
||||||
|
token_offsets=token_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=pb.size,
|
size=pb.size,
|
||||||
@ -140,6 +144,7 @@ class CausalLMBatch(Batch):
|
|||||||
requests = []
|
requests = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
offsets = []
|
offsets = []
|
||||||
|
token_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
@ -157,6 +162,7 @@ class CausalLMBatch(Batch):
|
|||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
offsets.extend(batch.offsets)
|
offsets.extend(batch.offsets)
|
||||||
|
token_offsets.extend(batch.token_offsets)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
@ -271,6 +277,7 @@ class CausalLMBatch(Batch):
|
|||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
offsets=offsets,
|
offsets=offsets,
|
||||||
|
token_offsets=token_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=total_batch_size,
|
size=total_batch_size,
|
||||||
@ -358,6 +365,7 @@ class CausalLM(Model):
|
|||||||
# New values for next forward
|
# New values for next forward
|
||||||
next_batch_input_lengths = []
|
next_batch_input_lengths = []
|
||||||
next_batch_offsets = []
|
next_batch_offsets = []
|
||||||
|
next_batch_token_offsets = []
|
||||||
next_batch_input_ids = []
|
next_batch_input_ids = []
|
||||||
next_batch_all_input_ids = []
|
next_batch_all_input_ids = []
|
||||||
|
|
||||||
@ -373,6 +381,7 @@ class CausalLM(Model):
|
|||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.offsets,
|
batch.offsets,
|
||||||
|
batch.token_offsets,
|
||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
@ -384,6 +393,7 @@ class CausalLM(Model):
|
|||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
offset,
|
offset,
|
||||||
|
token_offset,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
@ -401,7 +411,9 @@ class CausalLM(Model):
|
|||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text, offset = self.decode_token(all_input_ids[:, 0], offset)
|
next_token_text, offset, token_offset = self.decode_token(
|
||||||
|
all_input_ids[:, 0], offset, token_offset
|
||||||
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(
|
||||||
@ -432,6 +444,7 @@ class CausalLM(Model):
|
|||||||
next_batch_size += 1
|
next_batch_size += 1
|
||||||
next_batch_input_lengths.append(new_input_length)
|
next_batch_input_lengths.append(new_input_length)
|
||||||
next_batch_offsets.append(offset)
|
next_batch_offsets.append(offset)
|
||||||
|
next_batch_token_offsets.append(token_offset)
|
||||||
next_batch_max_input_length = max(
|
next_batch_max_input_length = max(
|
||||||
next_batch_max_input_length, new_input_length
|
next_batch_max_input_length, new_input_length
|
||||||
)
|
)
|
||||||
@ -516,6 +529,7 @@ class CausalLM(Model):
|
|||||||
all_input_ids=next_batch_all_input_ids,
|
all_input_ids=next_batch_all_input_ids,
|
||||||
input_lengths=next_batch_input_lengths,
|
input_lengths=next_batch_input_lengths,
|
||||||
offsets=next_batch_offsets,
|
offsets=next_batch_offsets,
|
||||||
|
token_offsets=next_batch_token_offsets,
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
stopping_criterias=next_batch_stopping_criterias,
|
stopping_criterias=next_batch_stopping_criterias,
|
||||||
size=next_batch_size,
|
size=next_batch_size,
|
||||||
|
@ -45,6 +45,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
offsets: List[Optional[int]]
|
offsets: List[Optional[int]]
|
||||||
|
token_offsets: List[Optional[int]]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
@ -69,6 +70,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
offsets = []
|
offsets = []
|
||||||
|
token_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
all_input_ids_tensor = []
|
all_input_ids_tensor = []
|
||||||
|
|
||||||
@ -87,6 +89,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
offsets.append(None)
|
offsets.append(None)
|
||||||
|
token_offsets.append(None)
|
||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
tokenized_input = torch.tensor(tokenized_input, device=device)
|
tokenized_input = torch.tensor(tokenized_input, device=device)
|
||||||
@ -124,6 +127,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
offsets=offsets,
|
offsets=offsets,
|
||||||
|
token_offsets=token_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
@ -137,6 +141,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests = []
|
requests = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
offsets = []
|
offsets = []
|
||||||
|
token_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
all_input_ids_tensor = []
|
all_input_ids_tensor = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
@ -156,6 +161,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
offsets.extend(batch.offsets)
|
offsets.extend(batch.offsets)
|
||||||
|
token_offsets.extend(batch.token_offsets)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
@ -189,6 +195,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
offsets=offsets,
|
offsets=offsets,
|
||||||
|
token_offsets=token_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
@ -287,6 +294,7 @@ class FlashCausalLM(Model):
|
|||||||
next_batch_past_key_values = []
|
next_batch_past_key_values = []
|
||||||
next_batch_input_lengths = []
|
next_batch_input_lengths = []
|
||||||
next_batch_offsets = []
|
next_batch_offsets = []
|
||||||
|
next_batch_token_offsets = []
|
||||||
next_batch_all_input_ids = []
|
next_batch_all_input_ids = []
|
||||||
next_batch_all_input_ids_tensor = []
|
next_batch_all_input_ids_tensor = []
|
||||||
|
|
||||||
@ -301,6 +309,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.offsets,
|
batch.offsets,
|
||||||
|
batch.token_offsets,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
@ -312,6 +321,7 @@ class FlashCausalLM(Model):
|
|||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
offset,
|
offset,
|
||||||
|
token_offset,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
@ -344,8 +354,10 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id_item]
|
next_token_logprob = logprobs[-1, next_token_id_item]
|
||||||
next_token_text, offset = self.decode_token(
|
next_token_text, offset, token_offset = self.decode_token(
|
||||||
all_input_ids[-(stopping_criteria.current_tokens + 1) :], offset
|
all_input_ids[-(stopping_criteria.current_tokens + 1) :],
|
||||||
|
offset,
|
||||||
|
token_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
@ -387,6 +399,7 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
next_batch_input_lengths.append(new_input_length)
|
next_batch_input_lengths.append(new_input_length)
|
||||||
next_batch_offsets.append(offset)
|
next_batch_offsets.append(offset)
|
||||||
|
next_batch_token_offsets.append(token_offset)
|
||||||
next_batch_all_input_ids.append(all_input_ids)
|
next_batch_all_input_ids.append(all_input_ids)
|
||||||
next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
|
next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
|
||||||
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
|
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
|
||||||
@ -464,6 +477,7 @@ class FlashCausalLM(Model):
|
|||||||
past_key_values=next_batch_past_key_values,
|
past_key_values=next_batch_past_key_values,
|
||||||
input_lengths=next_batch_input_lengths,
|
input_lengths=next_batch_input_lengths,
|
||||||
offsets=next_batch_offsets,
|
offsets=next_batch_offsets,
|
||||||
|
token_offsets=next_batch_token_offsets,
|
||||||
all_input_ids=next_batch_all_input_ids,
|
all_input_ids=next_batch_all_input_ids,
|
||||||
all_input_ids_tensor=next_batch_all_input_ids_tensor,
|
all_input_ids_tensor=next_batch_all_input_ids_tensor,
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
|
@ -94,6 +94,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
offsets = []
|
offsets = []
|
||||||
|
token_offsets = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
@ -102,6 +103,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
offsets.append(None)
|
offsets.append(None)
|
||||||
|
token_offsets.append(None)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
@ -147,6 +149,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
offsets=offsets,
|
offsets=offsets,
|
||||||
|
token_offsets=token_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=pb.size,
|
size=pb.size,
|
||||||
|
@ -25,13 +25,18 @@ class Model(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def decode_token(
|
def decode_token(
|
||||||
self, all_input_ids: List[int], offset: Optional[int] = None
|
self,
|
||||||
) -> Tuple[str, Optional[int]]:
|
all_input_ids: List[int],
|
||||||
|
offset: Optional[int] = None,
|
||||||
|
token_offset: Optional[int] = None,
|
||||||
|
) -> Tuple[str, Optional[int], Optional[int]]:
|
||||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||||
|
if token_offset is None:
|
||||||
|
token_offset = len(all_input_ids) - 5
|
||||||
|
|
||||||
# Decode all token minus last one and all tokens
|
# Decode token_offset token minus last one and token_offset tokens
|
||||||
results = self.tokenizer.batch_decode(
|
results = self.tokenizer.batch_decode(
|
||||||
[all_input_ids[:-1], all_input_ids],
|
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -44,6 +49,6 @@ class Model(ABC):
|
|||||||
|
|
||||||
# if text is utf-8
|
# if text is utf-8
|
||||||
if text and text[-1] != "<EFBFBD>":
|
if text and text[-1] != "<EFBFBD>":
|
||||||
return text, None
|
return text, None, None
|
||||||
else:
|
else:
|
||||||
return "", offset
|
return "", offset, token_offset
|
||||||
|
@ -39,6 +39,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
decoder_input_lengths: List[int]
|
decoder_input_lengths: List[int]
|
||||||
offsets: List[Optional[int]]
|
offsets: List[Optional[int]]
|
||||||
|
token_offsets: List[Optional[int]]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
@ -73,6 +74,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
decoder_input_ids = []
|
decoder_input_ids = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
offsets = []
|
offsets = []
|
||||||
|
token_offsets = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
@ -83,6 +85,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
offsets.append(None)
|
offsets.append(None)
|
||||||
|
token_offsets.append(None)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
@ -121,6 +124,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
input_lengths=input_lengths.tolist(),
|
input_lengths=input_lengths.tolist(),
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
offsets=offsets,
|
offsets=offsets,
|
||||||
|
token_offsets=token_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=len(pb.requests),
|
size=len(pb.requests),
|
||||||
@ -152,6 +156,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
input_lengths = []
|
input_lengths = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
offsets = []
|
offsets = []
|
||||||
|
token_offsets = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
@ -172,6 +177,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
||||||
offsets.extend(batch.offsets)
|
offsets.extend(batch.offsets)
|
||||||
|
token_offsets.extend(batch.token_offsets)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
@ -310,6 +316,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
offsets=offsets,
|
offsets=offsets,
|
||||||
|
token_offsets=token_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=total_batch_size,
|
size=total_batch_size,
|
||||||
@ -430,6 +437,7 @@ class Seq2SeqLM(Model):
|
|||||||
# New values for next forward
|
# New values for next forward
|
||||||
next_batch_input_lengths = []
|
next_batch_input_lengths = []
|
||||||
next_batch_offsets = []
|
next_batch_offsets = []
|
||||||
|
next_batch_token_offsets = []
|
||||||
next_batch_decoder_input_ids = []
|
next_batch_decoder_input_ids = []
|
||||||
next_batch_decoder_input_lengths = []
|
next_batch_decoder_input_lengths = []
|
||||||
|
|
||||||
@ -446,6 +454,7 @@ class Seq2SeqLM(Model):
|
|||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.offsets,
|
batch.offsets,
|
||||||
|
batch.token_offsets,
|
||||||
batch.decoder_input_lengths,
|
batch.decoder_input_lengths,
|
||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
@ -458,6 +467,7 @@ class Seq2SeqLM(Model):
|
|||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
offset,
|
offset,
|
||||||
|
token_offset,
|
||||||
decoder_input_length,
|
decoder_input_length,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
@ -476,7 +486,9 @@ class Seq2SeqLM(Model):
|
|||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text, offset = self.decode_token(decoder_input_ids, offset)
|
next_token_text, offset, token_offset = self.decode_token(
|
||||||
|
decoder_input_ids, offset, token_offset
|
||||||
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
||||||
@ -504,6 +516,7 @@ class Seq2SeqLM(Model):
|
|||||||
next_batch_input_lengths.append(input_length)
|
next_batch_input_lengths.append(input_length)
|
||||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
||||||
next_batch_offsets.append(offset)
|
next_batch_offsets.append(offset)
|
||||||
|
next_batch_token_offsets.append(token_offset)
|
||||||
next_batch_max_input_length = max(
|
next_batch_max_input_length = max(
|
||||||
next_batch_max_input_length, input_length
|
next_batch_max_input_length, input_length
|
||||||
)
|
)
|
||||||
@ -590,6 +603,7 @@ class Seq2SeqLM(Model):
|
|||||||
input_lengths=next_batch_input_lengths,
|
input_lengths=next_batch_input_lengths,
|
||||||
decoder_input_lengths=next_batch_decoder_input_lengths,
|
decoder_input_lengths=next_batch_decoder_input_lengths,
|
||||||
offsets=next_batch_offsets,
|
offsets=next_batch_offsets,
|
||||||
|
token_offsets=next_batch_token_offsets,
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
stopping_criterias=next_batch_stopping_criterias,
|
stopping_criterias=next_batch_stopping_criterias,
|
||||||
size=next_batch_size,
|
size=next_batch_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user