This commit is contained in:
OlivierDehaene 2022-12-15 16:50:11 +01:00
parent 8420ee8fa2
commit c90afea3de
5 changed files with 57 additions and 46 deletions

View File

@ -123,7 +123,7 @@ async fn generate(
tokens, tokens,
}) })
} }
false => None false => None,
}; };
// Timings // Timings
@ -164,7 +164,6 @@ async fn generate(
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
tracing::info!("Output: {}", response.output_text); tracing::info!("Output: {}", response.output_text);
// Send response // Send response
let response = vec![GeneratedText { let response = vec![GeneratedText {
generated_text: response.output_text, generated_text: response.output_text,

View File

@ -128,7 +128,9 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generated_texts) == 1
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_bloom_batch.requests[0] assert generated_texts[0].request == default_bloom_batch.requests[0]
assert ( assert (
generated_texts[0].generated_tokens generated_texts[0].generated_tokens
@ -170,7 +172,9 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generated_texts) == 1
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert ( assert (
generated_texts[0].generated_tokens generated_texts[0].generated_tokens
@ -259,7 +263,9 @@ def test_batch_concatenate(
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generated_texts) == 1
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_bloom_batch.requests[0] assert generated_texts[0].request == default_bloom_batch.requests[0]
assert ( assert (
generated_texts[0].generated_tokens generated_texts[0].generated_tokens
@ -279,7 +285,9 @@ def test_batch_concatenate(
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generated_texts) == 1
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert ( assert (
generated_texts[0].generated_tokens generated_texts[0].generated_tokens

View File

@ -355,7 +355,9 @@ class CausalLM(Model):
token_ids = all_input_ids[-new_input_length:] token_ids = all_input_ids[-new_input_length:]
tokens = self.tokenizer.batch_decode(token_ids) tokens = self.tokenizer.batch_decode(token_ids)
# Add NaN for the first prompt token # Add NaN for the first prompt token
logprobs = [float('nan')] + all_logprobs[-new_input_length:].squeeze(1).tolist() logprobs = [float("nan")] + all_logprobs[-new_input_length:].squeeze(
1
).tolist()
# Add to the list of finished generations with the original request # Add to the list of finished generations with the original request
generated_texts.append( generated_texts.append(
@ -366,7 +368,7 @@ class CausalLM(Model):
tokens=tokens, tokens=tokens,
token_ids=token_ids.squeeze(1).tolist(), token_ids=token_ids.squeeze(1).tolist(),
logprobs=logprobs, logprobs=logprobs,
reason=reason reason=reason,
) )
) )
# add to the next batch # add to the next batch

View File

@ -450,7 +450,9 @@ class Seq2SeqLM(Model):
tokens = self.tokenizer.batch_decode(token_ids) tokens = self.tokenizer.batch_decode(token_ids)
print(tokens) print(tokens)
# Add NaN for the bos token # Add NaN for the bos token
logprobs = [float('nan')] + decoder_logprobs[-new_decoder_input_length:].tolist() logprobs = [float("nan")] + decoder_logprobs[
-new_decoder_input_length:
].tolist()
# Add to the list of finished generations with the original request # Add to the list of finished generations with the original request
generated_texts.append( generated_texts.append(
GeneratedText( GeneratedText(
@ -460,7 +462,7 @@ class Seq2SeqLM(Model):
tokens=tokens, tokens=tokens,
token_ids=token_ids.tolist(), token_ids=token_ids.tolist(),
logprobs=logprobs, logprobs=logprobs,
reason=reason reason=reason,
) )
) )
# add to the next batch # add to the next batch