better decode

This commit is contained in:
OlivierDehaene 2023-04-05 13:47:25 +02:00
parent 783bc64f47
commit b5233f9c3c
7 changed files with 67 additions and 18 deletions

View File

@ -75,7 +75,7 @@ async fn generate_runs(
// Warmups on batch size // Warmups on batch size
for _ in 0..warmups { for _ in 0..warmups {
let (_, decode_batch) = let (_, decode_batch) =
prefill(sequence.clone(), b, decode_length, &mut client).await?; prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?;
let _ = decode(decode_batch, &mut client).await?; let _ = decode(decode_batch, &mut client).await?;
// Send warmup message // Send warmup message
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(()); run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
@ -83,7 +83,7 @@ async fn generate_runs(
for _ in 0..n_runs { for _ in 0..n_runs {
let (prefill, decode_batch) = let (prefill, decode_batch) =
prefill(sequence.clone(), b, decode_length, &mut client).await?; prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?;
// Send prefill message // Send prefill message
run_sender run_sender
.send(Ok(Message::Prefill(prefill))) .send(Ok(Message::Prefill(prefill)))
@ -110,6 +110,7 @@ async fn generate_runs(
// Run a prefill step // Run a prefill step
async fn prefill( async fn prefill(
sequence: String, sequence: String,
sequence_length: u32,
batch_size: u32, batch_size: u32,
decode_length: u32, decode_length: u32,
client: &mut ShardedClient, client: &mut ShardedClient,
@ -119,6 +120,7 @@ async fn prefill(
.map(|id| Request { .map(|id| Request {
id: id.into(), id: id.into(),
inputs: sequence.clone(), inputs: sequence.clone(),
truncate: sequence_length,
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 1.0, temperature: 1.0,
top_k: 0, top_k: 0,

View File

@ -7,9 +7,6 @@ export const options = {
{duration: '2m', target: 100}, {duration: '2m', target: 100},
{duration: '1m', target: 0}, {duration: '1m', target: 0},
], ],
hosts: {
'text-generation-inference.huggingface.co': '127.0.0.1:3000',
},
}; };
const SLEEP_DURATION = 1; const SLEEP_DURATION = 1;
@ -29,7 +26,7 @@ function greedy_example(inputs, max_new_tokens, name) {
name: name name: name
} }
}; };
return http.post('http://text-generation-inference.huggingface.co/generate', body, params); return http.post('https://open-assistant.ngrok.io/generate', body, params);
} }
function sample_example(inputs, max_new_tokens, name) { function sample_example(inputs, max_new_tokens, name) {
@ -50,7 +47,7 @@ function sample_example(inputs, max_new_tokens, name) {
name: name name: name
} }
}; };
return http.post('http://text-generation-inference.huggingface.co/generate', body, params); return http.post('https://open-assistant.ngrok.io/generate', body, params);
} }
export default function () { export default function () {
@ -95,4 +92,4 @@ export default function () {
'is status 200': (r) => r.status === 200, 'is status 200': (r) => r.status === 200,
}); });
sleep(SLEEP_DURATION); sleep(SLEEP_DURATION);
} }

View File

@ -35,6 +35,7 @@ class CausalLMBatch(Batch):
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
offsets: List[Optional[int]] offsets: List[Optional[int]]
token_offsets: List[Optional[int]]
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
@ -66,6 +67,7 @@ class CausalLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
offsets = [] offsets = []
token_offsets = []
# Parse batch # Parse batch
max_truncation = 0 max_truncation = 0
@ -73,6 +75,7 @@ class CausalLMBatch(Batch):
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
offsets.append(None) offsets.append(None)
token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
@ -117,6 +120,7 @@ class CausalLMBatch(Batch):
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
input_lengths=input_lengths.tolist(), input_lengths=input_lengths.tolist(),
offsets=offsets, offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=pb.size, size=pb.size,
@ -140,6 +144,7 @@ class CausalLMBatch(Batch):
requests = [] requests = []
input_lengths = [] input_lengths = []
offsets = [] offsets = []
token_offsets = []
all_input_ids = [] all_input_ids = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -157,6 +162,7 @@ class CausalLMBatch(Batch):
requests.extend(batch.requests) requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets) offsets.extend(batch.offsets)
token_offsets.extend(batch.token_offsets)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
@ -271,6 +277,7 @@ class CausalLMBatch(Batch):
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
input_lengths=input_lengths, input_lengths=input_lengths,
offsets=offsets, offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size, size=total_batch_size,
@ -358,6 +365,7 @@ class CausalLM(Model):
# New values for next forward # New values for next forward
next_batch_input_lengths = [] next_batch_input_lengths = []
next_batch_offsets = [] next_batch_offsets = []
next_batch_token_offsets = []
next_batch_input_ids = [] next_batch_input_ids = []
next_batch_all_input_ids = [] next_batch_all_input_ids = []
@ -373,6 +381,7 @@ class CausalLM(Model):
batch.requests, batch.requests,
batch.input_lengths, batch.input_lengths,
batch.offsets, batch.offsets,
batch.token_offsets,
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
@ -384,6 +393,7 @@ class CausalLM(Model):
request, request,
input_length, input_length,
offset, offset,
token_offset,
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
@ -401,7 +411,9 @@ class CausalLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text, offset = self.decode_token(all_input_ids[:, 0], offset) next_token_text, offset, token_offset = self.decode_token(
all_input_ids[:, 0], offset, token_offset
)
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria( stop, reason = stopping_criteria(
@ -432,6 +444,7 @@ class CausalLM(Model):
next_batch_size += 1 next_batch_size += 1
next_batch_input_lengths.append(new_input_length) next_batch_input_lengths.append(new_input_length)
next_batch_offsets.append(offset) next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_max_input_length = max( next_batch_max_input_length = max(
next_batch_max_input_length, new_input_length next_batch_max_input_length, new_input_length
) )
@ -516,6 +529,7 @@ class CausalLM(Model):
all_input_ids=next_batch_all_input_ids, all_input_ids=next_batch_all_input_ids,
input_lengths=next_batch_input_lengths, input_lengths=next_batch_input_lengths,
offsets=next_batch_offsets, offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
next_token_choosers=next_batch_next_token_choosers, next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias, stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size, size=next_batch_size,

View File

@ -45,6 +45,7 @@ class FlashCausalLMBatch(Batch):
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
offsets: List[Optional[int]] offsets: List[Optional[int]]
token_offsets: List[Optional[int]]
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
@ -69,6 +70,7 @@ class FlashCausalLMBatch(Batch):
input_lengths = [] input_lengths = []
offsets = [] offsets = []
token_offsets = []
all_input_ids = [] all_input_ids = []
all_input_ids_tensor = [] all_input_ids_tensor = []
@ -87,6 +89,7 @@ class FlashCausalLMBatch(Batch):
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length) input_lengths.append(input_length)
offsets.append(None) offsets.append(None)
token_offsets.append(None)
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
tokenized_input = torch.tensor(tokenized_input, device=device) tokenized_input = torch.tensor(tokenized_input, device=device)
@ -124,6 +127,7 @@ class FlashCausalLMBatch(Batch):
past_key_values=None, past_key_values=None,
input_lengths=input_lengths, input_lengths=input_lengths,
offsets=offsets, offsets=offsets,
token_offsets=token_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
@ -137,6 +141,7 @@ class FlashCausalLMBatch(Batch):
requests = [] requests = []
input_lengths = [] input_lengths = []
offsets = [] offsets = []
token_offsets = []
all_input_ids = [] all_input_ids = []
all_input_ids_tensor = [] all_input_ids_tensor = []
next_token_choosers = [] next_token_choosers = []
@ -156,6 +161,7 @@ class FlashCausalLMBatch(Batch):
requests.extend(batch.requests) requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets) offsets.extend(batch.offsets)
token_offsets.extend(batch.token_offsets)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor) all_input_ids_tensor.extend(batch.all_input_ids_tensor)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
@ -189,6 +195,7 @@ class FlashCausalLMBatch(Batch):
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
offsets=offsets, offsets=offsets,
token_offsets=token_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
@ -287,6 +294,7 @@ class FlashCausalLM(Model):
next_batch_past_key_values = [] next_batch_past_key_values = []
next_batch_input_lengths = [] next_batch_input_lengths = []
next_batch_offsets = [] next_batch_offsets = []
next_batch_token_offsets = []
next_batch_all_input_ids = [] next_batch_all_input_ids = []
next_batch_all_input_ids_tensor = [] next_batch_all_input_ids_tensor = []
@ -301,6 +309,7 @@ class FlashCausalLM(Model):
batch.requests, batch.requests,
batch.input_lengths, batch.input_lengths,
batch.offsets, batch.offsets,
batch.token_offsets,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
@ -312,6 +321,7 @@ class FlashCausalLM(Model):
request, request,
input_length, input_length,
offset, offset,
token_offset,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
@ -344,8 +354,10 @@ class FlashCausalLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id_item] next_token_logprob = logprobs[-1, next_token_id_item]
next_token_text, offset = self.decode_token( next_token_text, offset, token_offset = self.decode_token(
all_input_ids[-(stopping_criteria.current_tokens + 1) :], offset all_input_ids[-(stopping_criteria.current_tokens + 1) :],
offset,
token_offset,
) )
# Evaluate stopping criteria # Evaluate stopping criteria
@ -387,6 +399,7 @@ class FlashCausalLM(Model):
) )
next_batch_input_lengths.append(new_input_length) next_batch_input_lengths.append(new_input_length)
next_batch_offsets.append(offset) next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_all_input_ids.append(all_input_ids) next_batch_all_input_ids.append(all_input_ids)
next_batch_all_input_ids_tensor.append(all_input_ids_tensor) next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
@ -464,6 +477,7 @@ class FlashCausalLM(Model):
past_key_values=next_batch_past_key_values, past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths, input_lengths=next_batch_input_lengths,
offsets=next_batch_offsets, offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
all_input_ids=next_batch_all_input_ids, all_input_ids=next_batch_all_input_ids,
all_input_ids_tensor=next_batch_all_input_ids_tensor, all_input_ids_tensor=next_batch_all_input_ids_tensor,
next_token_choosers=next_batch_next_token_choosers, next_token_choosers=next_batch_next_token_choosers,

View File

@ -94,6 +94,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
offsets = [] offsets = []
token_offsets = []
# Parse batch # Parse batch
max_truncation = 0 max_truncation = 0
@ -102,6 +103,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
offsets.append(None) offsets.append(None)
token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
@ -147,6 +149,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
input_lengths=input_lengths, input_lengths=input_lengths,
offsets=offsets, offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=pb.size, size=pb.size,

View File

@ -25,13 +25,18 @@ class Model(ABC):
raise NotImplementedError raise NotImplementedError
def decode_token( def decode_token(
self, all_input_ids: List[int], offset: Optional[int] = None self,
) -> Tuple[str, Optional[int]]: all_input_ids: List[int],
offset: Optional[int] = None,
token_offset: Optional[int] = None,
) -> Tuple[str, Optional[int], Optional[int]]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers""" """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
if token_offset is None:
token_offset = len(all_input_ids) - 5
# Decode all token minus last one and all tokens # Decode token_offset token minus last one and token_offset tokens
results = self.tokenizer.batch_decode( results = self.tokenizer.batch_decode(
[all_input_ids[:-1], all_input_ids], [all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
skip_special_tokens=False, skip_special_tokens=False,
) )
@ -44,6 +49,6 @@ class Model(ABC):
# if text is utf-8 # if text is utf-8
if text and text[-1] != "<EFBFBD>": if text and text[-1] != "<EFBFBD>":
return text, None return text, None, None
else: else:
return "", offset return "", offset, token_offset

View File

@ -39,6 +39,7 @@ class Seq2SeqLMBatch(Batch):
input_lengths: List[int] input_lengths: List[int]
decoder_input_lengths: List[int] decoder_input_lengths: List[int]
offsets: List[Optional[int]] offsets: List[Optional[int]]
token_offsets: List[Optional[int]]
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
@ -73,6 +74,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids = [] decoder_input_ids = []
decoder_input_lengths = [] decoder_input_lengths = []
offsets = [] offsets = []
token_offsets = []
# Parse batch # Parse batch
max_truncation = 0 max_truncation = 0
@ -83,6 +85,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
offsets.append(None) offsets.append(None)
token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
@ -121,6 +124,7 @@ class Seq2SeqLMBatch(Batch):
input_lengths=input_lengths.tolist(), input_lengths=input_lengths.tolist(),
decoder_input_lengths=decoder_input_lengths, decoder_input_lengths=decoder_input_lengths,
offsets=offsets, offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=len(pb.requests), size=len(pb.requests),
@ -152,6 +156,7 @@ class Seq2SeqLMBatch(Batch):
input_lengths = [] input_lengths = []
decoder_input_lengths = [] decoder_input_lengths = []
offsets = [] offsets = []
token_offsets = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -172,6 +177,7 @@ class Seq2SeqLMBatch(Batch):
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths)
offsets.extend(batch.offsets) offsets.extend(batch.offsets)
token_offsets.extend(batch.token_offsets)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
@ -310,6 +316,7 @@ class Seq2SeqLMBatch(Batch):
input_lengths=input_lengths, input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths, decoder_input_lengths=decoder_input_lengths,
offsets=offsets, offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size, size=total_batch_size,
@ -430,6 +437,7 @@ class Seq2SeqLM(Model):
# New values for next forward # New values for next forward
next_batch_input_lengths = [] next_batch_input_lengths = []
next_batch_offsets = [] next_batch_offsets = []
next_batch_token_offsets = []
next_batch_decoder_input_ids = [] next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = [] next_batch_decoder_input_lengths = []
@ -446,6 +454,7 @@ class Seq2SeqLM(Model):
batch.requests, batch.requests,
batch.input_lengths, batch.input_lengths,
batch.offsets, batch.offsets,
batch.token_offsets,
batch.decoder_input_lengths, batch.decoder_input_lengths,
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
@ -458,6 +467,7 @@ class Seq2SeqLM(Model):
request, request,
input_length, input_length,
offset, offset,
token_offset,
decoder_input_length, decoder_input_length,
logits, logits,
next_token_chooser, next_token_chooser,
@ -476,7 +486,9 @@ class Seq2SeqLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text, offset = self.decode_token(decoder_input_ids, offset) next_token_text, offset, token_offset = self.decode_token(
decoder_input_ids, offset, token_offset
)
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria(next_token_id, next_token_text) stop, reason = stopping_criteria(next_token_id, next_token_text)
@ -504,6 +516,7 @@ class Seq2SeqLM(Model):
next_batch_input_lengths.append(input_length) next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length) next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_offsets.append(offset) next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_max_input_length = max( next_batch_max_input_length = max(
next_batch_max_input_length, input_length next_batch_max_input_length, input_length
) )
@ -590,6 +603,7 @@ class Seq2SeqLM(Model):
input_lengths=next_batch_input_lengths, input_lengths=next_batch_input_lengths,
decoder_input_lengths=next_batch_decoder_input_lengths, decoder_input_lengths=next_batch_decoder_input_lengths,
offsets=next_batch_offsets, offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
next_token_choosers=next_batch_next_token_choosers, next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias, stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size, size=next_batch_size,