mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
format
This commit is contained in:
parent
8420ee8fa2
commit
c90afea3de
@ -123,7 +123,7 @@ async fn generate(
|
|||||||
tokens,
|
tokens,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
false => None
|
false => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Timings
|
// Timings
|
||||||
@ -164,7 +164,6 @@ async fn generate(
|
|||||||
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
||||||
tracing::info!("Output: {}", response.output_text);
|
tracing::info!("Output: {}", response.output_text);
|
||||||
|
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
let response = vec![GeneratedText {
|
let response = vec![GeneratedText {
|
||||||
generated_text: response.output_text,
|
generated_text: response.output_text,
|
||||||
|
@ -128,7 +128,9 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
|
|||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
assert (
|
||||||
|
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
|
)
|
||||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generated_texts[0].generated_tokens
|
||||||
@ -170,7 +172,9 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
assert (
|
||||||
|
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
|
)
|
||||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generated_texts[0].generated_tokens
|
||||||
@ -259,7 +263,9 @@ def test_batch_concatenate(
|
|||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
assert (
|
||||||
|
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
|
)
|
||||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generated_texts[0].generated_tokens
|
||||||
@ -279,7 +285,9 @@ def test_batch_concatenate(
|
|||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
assert (
|
||||||
|
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
|
)
|
||||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generated_texts[0].generated_tokens
|
||||||
|
@ -355,7 +355,9 @@ class CausalLM(Model):
|
|||||||
token_ids = all_input_ids[-new_input_length:]
|
token_ids = all_input_ids[-new_input_length:]
|
||||||
tokens = self.tokenizer.batch_decode(token_ids)
|
tokens = self.tokenizer.batch_decode(token_ids)
|
||||||
# Add NaN for the first prompt token
|
# Add NaN for the first prompt token
|
||||||
logprobs = [float('nan')] + all_logprobs[-new_input_length:].squeeze(1).tolist()
|
logprobs = [float("nan")] + all_logprobs[-new_input_length:].squeeze(
|
||||||
|
1
|
||||||
|
).tolist()
|
||||||
|
|
||||||
# Add to the list of finished generations with the original request
|
# Add to the list of finished generations with the original request
|
||||||
generated_texts.append(
|
generated_texts.append(
|
||||||
@ -366,7 +368,7 @@ class CausalLM(Model):
|
|||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
token_ids=token_ids.squeeze(1).tolist(),
|
token_ids=token_ids.squeeze(1).tolist(),
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
reason=reason
|
reason=reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# add to the next batch
|
# add to the next batch
|
||||||
|
@ -450,7 +450,9 @@ class Seq2SeqLM(Model):
|
|||||||
tokens = self.tokenizer.batch_decode(token_ids)
|
tokens = self.tokenizer.batch_decode(token_ids)
|
||||||
print(tokens)
|
print(tokens)
|
||||||
# Add NaN for the bos token
|
# Add NaN for the bos token
|
||||||
logprobs = [float('nan')] + decoder_logprobs[-new_decoder_input_length:].tolist()
|
logprobs = [float("nan")] + decoder_logprobs[
|
||||||
|
-new_decoder_input_length:
|
||||||
|
].tolist()
|
||||||
# Add to the list of finished generations with the original request
|
# Add to the list of finished generations with the original request
|
||||||
generated_texts.append(
|
generated_texts.append(
|
||||||
GeneratedText(
|
GeneratedText(
|
||||||
@ -460,7 +462,7 @@ class Seq2SeqLM(Model):
|
|||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
token_ids=token_ids.tolist(),
|
token_ids=token_ids.tolist(),
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
reason=reason
|
reason=reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# add to the next batch
|
# add to the next batch
|
||||||
|
Loading…
Reference in New Issue
Block a user