mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
better decode
This commit is contained in:
parent
783bc64f47
commit
b5233f9c3c
@ -75,7 +75,7 @@ async fn generate_runs(
|
||||
// Warmups on batch size
|
||||
for _ in 0..warmups {
|
||||
let (_, decode_batch) =
|
||||
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
||||
prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?;
|
||||
let _ = decode(decode_batch, &mut client).await?;
|
||||
// Send warmup message
|
||||
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
|
||||
@ -83,7 +83,7 @@ async fn generate_runs(
|
||||
|
||||
for _ in 0..n_runs {
|
||||
let (prefill, decode_batch) =
|
||||
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
||||
prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?;
|
||||
// Send prefill message
|
||||
run_sender
|
||||
.send(Ok(Message::Prefill(prefill)))
|
||||
@ -110,6 +110,7 @@ async fn generate_runs(
|
||||
// Run a prefill step
|
||||
async fn prefill(
|
||||
sequence: String,
|
||||
sequence_length: u32,
|
||||
batch_size: u32,
|
||||
decode_length: u32,
|
||||
client: &mut ShardedClient,
|
||||
@ -119,6 +120,7 @@ async fn prefill(
|
||||
.map(|id| Request {
|
||||
id: id.into(),
|
||||
inputs: sequence.clone(),
|
||||
truncate: sequence_length,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
|
@ -7,9 +7,6 @@ export const options = {
|
||||
{duration: '2m', target: 100},
|
||||
{duration: '1m', target: 0},
|
||||
],
|
||||
hosts: {
|
||||
'text-generation-inference.huggingface.co': '127.0.0.1:3000',
|
||||
},
|
||||
};
|
||||
const SLEEP_DURATION = 1;
|
||||
|
||||
@ -29,7 +26,7 @@ function greedy_example(inputs, max_new_tokens, name) {
|
||||
name: name
|
||||
}
|
||||
};
|
||||
return http.post('http://text-generation-inference.huggingface.co/generate', body, params);
|
||||
return http.post('https://open-assistant.ngrok.io/generate', body, params);
|
||||
}
|
||||
|
||||
function sample_example(inputs, max_new_tokens, name) {
|
||||
@ -50,7 +47,7 @@ function sample_example(inputs, max_new_tokens, name) {
|
||||
name: name
|
||||
}
|
||||
};
|
||||
return http.post('http://text-generation-inference.huggingface.co/generate', body, params);
|
||||
return http.post('https://open-assistant.ngrok.io/generate', body, params);
|
||||
}
|
||||
|
||||
export default function () {
|
||||
@ -95,4 +92,4 @@ export default function () {
|
||||
'is status 200': (r) => r.status === 200,
|
||||
});
|
||||
sleep(SLEEP_DURATION);
|
||||
}
|
||||
}
|
||||
|
@ -35,6 +35,7 @@ class CausalLMBatch(Batch):
|
||||
# Lengths of all generations present in the batch
|
||||
input_lengths: List[int]
|
||||
offsets: List[Optional[int]]
|
||||
token_offsets: List[Optional[int]]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
@ -66,6 +67,7 @@ class CausalLMBatch(Batch):
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
offsets = []
|
||||
token_offsets = []
|
||||
|
||||
# Parse batch
|
||||
max_truncation = 0
|
||||
@ -73,6 +75,7 @@ class CausalLMBatch(Batch):
|
||||
for r in pb.requests:
|
||||
inputs.append(r.inputs)
|
||||
offsets.append(None)
|
||||
token_offsets.append(None)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
@ -117,6 +120,7 @@ class CausalLMBatch(Batch):
|
||||
all_input_ids=all_input_ids,
|
||||
input_lengths=input_lengths.tolist(),
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=pb.size,
|
||||
@ -140,6 +144,7 @@ class CausalLMBatch(Batch):
|
||||
requests = []
|
||||
input_lengths = []
|
||||
offsets = []
|
||||
token_offsets = []
|
||||
all_input_ids = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
@ -157,6 +162,7 @@ class CausalLMBatch(Batch):
|
||||
requests.extend(batch.requests)
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
offsets.extend(batch.offsets)
|
||||
token_offsets.extend(batch.token_offsets)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
@ -271,6 +277,7 @@ class CausalLMBatch(Batch):
|
||||
all_input_ids=all_input_ids,
|
||||
input_lengths=input_lengths,
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=total_batch_size,
|
||||
@ -358,6 +365,7 @@ class CausalLM(Model):
|
||||
# New values for next forward
|
||||
next_batch_input_lengths = []
|
||||
next_batch_offsets = []
|
||||
next_batch_token_offsets = []
|
||||
next_batch_input_ids = []
|
||||
next_batch_all_input_ids = []
|
||||
|
||||
@ -373,6 +381,7 @@ class CausalLM(Model):
|
||||
batch.requests,
|
||||
batch.input_lengths,
|
||||
batch.offsets,
|
||||
batch.token_offsets,
|
||||
logits,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
@ -384,6 +393,7 @@ class CausalLM(Model):
|
||||
request,
|
||||
input_length,
|
||||
offset,
|
||||
token_offset,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
@ -401,7 +411,9 @@ class CausalLM(Model):
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text, offset = self.decode_token(all_input_ids[:, 0], offset)
|
||||
next_token_text, offset, token_offset = self.decode_token(
|
||||
all_input_ids[:, 0], offset, token_offset
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
@ -432,6 +444,7 @@ class CausalLM(Model):
|
||||
next_batch_size += 1
|
||||
next_batch_input_lengths.append(new_input_length)
|
||||
next_batch_offsets.append(offset)
|
||||
next_batch_token_offsets.append(token_offset)
|
||||
next_batch_max_input_length = max(
|
||||
next_batch_max_input_length, new_input_length
|
||||
)
|
||||
@ -516,6 +529,7 @@ class CausalLM(Model):
|
||||
all_input_ids=next_batch_all_input_ids,
|
||||
input_lengths=next_batch_input_lengths,
|
||||
offsets=next_batch_offsets,
|
||||
token_offsets=next_batch_token_offsets,
|
||||
next_token_choosers=next_batch_next_token_choosers,
|
||||
stopping_criterias=next_batch_stopping_criterias,
|
||||
size=next_batch_size,
|
||||
|
@ -45,6 +45,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Lengths of all generations present in the batch
|
||||
input_lengths: List[int]
|
||||
offsets: List[Optional[int]]
|
||||
token_offsets: List[Optional[int]]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
@ -69,6 +70,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
input_lengths = []
|
||||
offsets = []
|
||||
token_offsets = []
|
||||
all_input_ids = []
|
||||
all_input_ids_tensor = []
|
||||
|
||||
@ -87,6 +89,7 @@ class FlashCausalLMBatch(Batch):
|
||||
max_seqlen = max(max_seqlen, input_length)
|
||||
input_lengths.append(input_length)
|
||||
offsets.append(None)
|
||||
token_offsets.append(None)
|
||||
all_input_ids.append(tokenized_input)
|
||||
|
||||
tokenized_input = torch.tensor(tokenized_input, device=device)
|
||||
@ -124,6 +127,7 @@ class FlashCausalLMBatch(Batch):
|
||||
past_key_values=None,
|
||||
input_lengths=input_lengths,
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_choosers=next_token_choosers,
|
||||
@ -137,6 +141,7 @@ class FlashCausalLMBatch(Batch):
|
||||
requests = []
|
||||
input_lengths = []
|
||||
offsets = []
|
||||
token_offsets = []
|
||||
all_input_ids = []
|
||||
all_input_ids_tensor = []
|
||||
next_token_choosers = []
|
||||
@ -156,6 +161,7 @@ class FlashCausalLMBatch(Batch):
|
||||
requests.extend(batch.requests)
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
offsets.extend(batch.offsets)
|
||||
token_offsets.extend(batch.token_offsets)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
@ -189,6 +195,7 @@ class FlashCausalLMBatch(Batch):
|
||||
past_key_values=past_key_values,
|
||||
input_lengths=input_lengths,
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_choosers=next_token_choosers,
|
||||
@ -287,6 +294,7 @@ class FlashCausalLM(Model):
|
||||
next_batch_past_key_values = []
|
||||
next_batch_input_lengths = []
|
||||
next_batch_offsets = []
|
||||
next_batch_token_offsets = []
|
||||
next_batch_all_input_ids = []
|
||||
next_batch_all_input_ids_tensor = []
|
||||
|
||||
@ -301,6 +309,7 @@ class FlashCausalLM(Model):
|
||||
batch.requests,
|
||||
batch.input_lengths,
|
||||
batch.offsets,
|
||||
batch.token_offsets,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
@ -312,6 +321,7 @@ class FlashCausalLM(Model):
|
||||
request,
|
||||
input_length,
|
||||
offset,
|
||||
token_offset,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
@ -344,8 +354,10 @@ class FlashCausalLM(Model):
|
||||
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id_item]
|
||||
next_token_text, offset = self.decode_token(
|
||||
all_input_ids[-(stopping_criteria.current_tokens + 1) :], offset
|
||||
next_token_text, offset, token_offset = self.decode_token(
|
||||
all_input_ids[-(stopping_criteria.current_tokens + 1) :],
|
||||
offset,
|
||||
token_offset,
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
@ -387,6 +399,7 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
next_batch_input_lengths.append(new_input_length)
|
||||
next_batch_offsets.append(offset)
|
||||
next_batch_token_offsets.append(token_offset)
|
||||
next_batch_all_input_ids.append(all_input_ids)
|
||||
next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
|
||||
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
|
||||
@ -464,6 +477,7 @@ class FlashCausalLM(Model):
|
||||
past_key_values=next_batch_past_key_values,
|
||||
input_lengths=next_batch_input_lengths,
|
||||
offsets=next_batch_offsets,
|
||||
token_offsets=next_batch_token_offsets,
|
||||
all_input_ids=next_batch_all_input_ids,
|
||||
all_input_ids_tensor=next_batch_all_input_ids_tensor,
|
||||
next_token_choosers=next_batch_next_token_choosers,
|
||||
|
@ -94,6 +94,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
offsets = []
|
||||
token_offsets = []
|
||||
|
||||
# Parse batch
|
||||
max_truncation = 0
|
||||
@ -102,6 +103,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
offsets.append(None)
|
||||
token_offsets.append(None)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
@ -147,6 +149,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
all_input_ids=all_input_ids,
|
||||
input_lengths=input_lengths,
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=pb.size,
|
||||
|
@ -25,13 +25,18 @@ class Model(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode_token(
|
||||
self, all_input_ids: List[int], offset: Optional[int] = None
|
||||
) -> Tuple[str, Optional[int]]:
|
||||
self,
|
||||
all_input_ids: List[int],
|
||||
offset: Optional[int] = None,
|
||||
token_offset: Optional[int] = None,
|
||||
) -> Tuple[str, Optional[int], Optional[int]]:
|
||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||
if token_offset is None:
|
||||
token_offset = len(all_input_ids) - 5
|
||||
|
||||
# Decode all token minus last one and all tokens
|
||||
# Decode token_offset token minus last one and token_offset tokens
|
||||
results = self.tokenizer.batch_decode(
|
||||
[all_input_ids[:-1], all_input_ids],
|
||||
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
@ -44,6 +49,6 @@ class Model(ABC):
|
||||
|
||||
# if text is utf-8
|
||||
if text and text[-1] != "<EFBFBD>":
|
||||
return text, None
|
||||
return text, None, None
|
||||
else:
|
||||
return "", offset
|
||||
return "", offset, token_offset
|
||||
|
@ -39,6 +39,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
input_lengths: List[int]
|
||||
decoder_input_lengths: List[int]
|
||||
offsets: List[Optional[int]]
|
||||
token_offsets: List[Optional[int]]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
@ -73,6 +74,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
decoder_input_ids = []
|
||||
decoder_input_lengths = []
|
||||
offsets = []
|
||||
token_offsets = []
|
||||
|
||||
# Parse batch
|
||||
max_truncation = 0
|
||||
@ -83,6 +85,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||
decoder_input_lengths.append(1)
|
||||
offsets.append(None)
|
||||
token_offsets.append(None)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
@ -121,6 +124,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
input_lengths=input_lengths.tolist(),
|
||||
decoder_input_lengths=decoder_input_lengths,
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=len(pb.requests),
|
||||
@ -152,6 +156,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
input_lengths = []
|
||||
decoder_input_lengths = []
|
||||
offsets = []
|
||||
token_offsets = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
@ -172,6 +177,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
||||
offsets.extend(batch.offsets)
|
||||
token_offsets.extend(batch.token_offsets)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
@ -310,6 +316,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
input_lengths=input_lengths,
|
||||
decoder_input_lengths=decoder_input_lengths,
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=total_batch_size,
|
||||
@ -430,6 +437,7 @@ class Seq2SeqLM(Model):
|
||||
# New values for next forward
|
||||
next_batch_input_lengths = []
|
||||
next_batch_offsets = []
|
||||
next_batch_token_offsets = []
|
||||
next_batch_decoder_input_ids = []
|
||||
next_batch_decoder_input_lengths = []
|
||||
|
||||
@ -446,6 +454,7 @@ class Seq2SeqLM(Model):
|
||||
batch.requests,
|
||||
batch.input_lengths,
|
||||
batch.offsets,
|
||||
batch.token_offsets,
|
||||
batch.decoder_input_lengths,
|
||||
logits,
|
||||
batch.next_token_choosers,
|
||||
@ -458,6 +467,7 @@ class Seq2SeqLM(Model):
|
||||
request,
|
||||
input_length,
|
||||
offset,
|
||||
token_offset,
|
||||
decoder_input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
@ -476,7 +486,9 @@ class Seq2SeqLM(Model):
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text, offset = self.decode_token(decoder_input_ids, offset)
|
||||
next_token_text, offset, token_offset = self.decode_token(
|
||||
decoder_input_ids, offset, token_offset
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
||||
@ -504,6 +516,7 @@ class Seq2SeqLM(Model):
|
||||
next_batch_input_lengths.append(input_length)
|
||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
||||
next_batch_offsets.append(offset)
|
||||
next_batch_token_offsets.append(token_offset)
|
||||
next_batch_max_input_length = max(
|
||||
next_batch_max_input_length, input_length
|
||||
)
|
||||
@ -590,6 +603,7 @@ class Seq2SeqLM(Model):
|
||||
input_lengths=next_batch_input_lengths,
|
||||
decoder_input_lengths=next_batch_decoder_input_lengths,
|
||||
offsets=next_batch_offsets,
|
||||
token_offsets=next_batch_token_offsets,
|
||||
next_token_choosers=next_batch_next_token_choosers,
|
||||
stopping_criterias=next_batch_stopping_criterias,
|
||||
size=next_batch_size,
|
||||
|
Loading…
Reference in New Issue
Block a user