mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
format
This commit is contained in:
parent
8420ee8fa2
commit
c90afea3de
@ -64,14 +64,14 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
|||||||
|
|
||||||
/// Generate method
|
/// Generate method
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(state),
|
skip(state),
|
||||||
fields(
|
fields(
|
||||||
total_time,
|
total_time,
|
||||||
validation_time,
|
validation_time,
|
||||||
queue_time,
|
queue_time,
|
||||||
inference_time,
|
inference_time,
|
||||||
time_per_token
|
time_per_token
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
state: Extension<ServerState>,
|
state: Extension<ServerState>,
|
||||||
@ -123,7 +123,7 @@ async fn generate(
|
|||||||
tokens,
|
tokens,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
false => None
|
false => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Timings
|
// Timings
|
||||||
@ -164,7 +164,6 @@ async fn generate(
|
|||||||
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
||||||
tracing::info!("Output: {}", response.output_text);
|
tracing::info!("Output: {}", response.output_text);
|
||||||
|
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
let response = vec![GeneratedText {
|
let response = vec![GeneratedText {
|
||||||
generated_text: response.output_text,
|
generated_text: response.output_text,
|
||||||
@ -219,7 +218,7 @@ async fn shutdown_signal() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
let terminate = async {
|
let terminate = async {
|
||||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||||
.expect("failed to install signal handler")
|
.expect("failed to install signal handler")
|
||||||
.recv()
|
.recv()
|
||||||
@ -227,7 +226,7 @@ async fn shutdown_signal() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
let terminate = std::future::pending::<()>();
|
let terminate = std::future::pending::<()>();
|
||||||
|
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
_ = ctrl_c => {},
|
_ = ctrl_c => {},
|
||||||
|
@ -128,7 +128,9 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
|
|||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
assert (
|
||||||
|
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
|
)
|
||||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generated_texts[0].generated_tokens
|
||||||
@ -170,7 +172,9 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
assert (
|
||||||
|
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
|
)
|
||||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generated_texts[0].generated_tokens
|
||||||
@ -259,7 +263,9 @@ def test_batch_concatenate(
|
|||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
assert (
|
||||||
|
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
|
)
|
||||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generated_texts[0].generated_tokens
|
||||||
@ -279,7 +285,9 @@ def test_batch_concatenate(
|
|||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
assert (
|
||||||
|
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
|
)
|
||||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generated_texts[0].generated_tokens
|
||||||
|
@ -47,7 +47,7 @@ class CausalLMBatch:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
@ -148,8 +148,8 @@ class CausalLMBatch:
|
|||||||
|
|
||||||
# We need to slice the attention mask to remove padding from previous steps
|
# We need to slice the attention mask to remove padding from previous steps
|
||||||
attention_mask[
|
attention_mask[
|
||||||
start_index:end_index, -batch.max_sequence_length:
|
start_index:end_index, -batch.max_sequence_length :
|
||||||
] = batch.attention_mask[:, -batch.max_sequence_length:]
|
] = batch.attention_mask[:, -batch.max_sequence_length :]
|
||||||
|
|
||||||
for j, past in enumerate(batch.past_key_values):
|
for j, past in enumerate(batch.past_key_values):
|
||||||
past_keys, past_values = past
|
past_keys, past_values = past
|
||||||
@ -197,22 +197,22 @@ class CausalLMBatch:
|
|||||||
# We slice the past keys and values to remove the padding from previous batches
|
# We slice the past keys and values to remove the padding from previous batches
|
||||||
if batch.keys_head_dim_last:
|
if batch.keys_head_dim_last:
|
||||||
past_key_values[j][0][
|
past_key_values[j][0][
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
:,
|
:,
|
||||||
-(batch.max_sequence_length - 1):,
|
-(batch.max_sequence_length - 1) :,
|
||||||
:,
|
:,
|
||||||
] = past_keys[:, :, -(batch.max_sequence_length - 1):, :]
|
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||||
else:
|
else:
|
||||||
past_key_values[j][0][
|
past_key_values[j][0][
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
:,
|
:,
|
||||||
:,
|
:,
|
||||||
-(batch.max_sequence_length - 1):,
|
-(batch.max_sequence_length - 1) :,
|
||||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1):]
|
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||||
|
|
||||||
past_key_values[j][1][
|
past_key_values[j][1][
|
||||||
start_index:end_index, :, -(batch.max_sequence_length - 1):, :
|
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
||||||
] = past_values[:, :, -(batch.max_sequence_length - 1):, :]
|
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||||
|
|
||||||
start_index += batch.size
|
start_index += batch.size
|
||||||
|
|
||||||
@ -268,7 +268,7 @@ class CausalLM(Model):
|
|||||||
return CausalLMBatch
|
return CausalLMBatch
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, past_key_values: Optional = None
|
self, input_ids, attention_mask, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
@ -280,7 +280,7 @@ class CausalLM(Model):
|
|||||||
return outputs.logits, outputs.past_key_values
|
return outputs.logits, outputs.past_key_values
|
||||||
|
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: CausalLMBatch
|
self, batch: CausalLMBatch
|
||||||
) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]:
|
) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]:
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
context_manager = (
|
context_manager = (
|
||||||
@ -320,13 +320,13 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
all_logprobs,
|
all_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
tokens, logprobs = next_token_chooser(all_input_ids, logits)
|
tokens, logprobs = next_token_chooser(all_input_ids, logits)
|
||||||
@ -355,7 +355,9 @@ class CausalLM(Model):
|
|||||||
token_ids = all_input_ids[-new_input_length:]
|
token_ids = all_input_ids[-new_input_length:]
|
||||||
tokens = self.tokenizer.batch_decode(token_ids)
|
tokens = self.tokenizer.batch_decode(token_ids)
|
||||||
# Add NaN for the first prompt token
|
# Add NaN for the first prompt token
|
||||||
logprobs = [float('nan')] + all_logprobs[-new_input_length:].squeeze(1).tolist()
|
logprobs = [float("nan")] + all_logprobs[-new_input_length:].squeeze(
|
||||||
|
1
|
||||||
|
).tolist()
|
||||||
|
|
||||||
# Add to the list of finished generations with the original request
|
# Add to the list of finished generations with the original request
|
||||||
generated_texts.append(
|
generated_texts.append(
|
||||||
@ -366,7 +368,7 @@ class CausalLM(Model):
|
|||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
token_ids=token_ids.squeeze(1).tolist(),
|
token_ids=token_ids.squeeze(1).tolist(),
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
reason=reason
|
reason=reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# add to the next batch
|
# add to the next batch
|
||||||
|
@ -450,7 +450,9 @@ class Seq2SeqLM(Model):
|
|||||||
tokens = self.tokenizer.batch_decode(token_ids)
|
tokens = self.tokenizer.batch_decode(token_ids)
|
||||||
print(tokens)
|
print(tokens)
|
||||||
# Add NaN for the bos token
|
# Add NaN for the bos token
|
||||||
logprobs = [float('nan')] + decoder_logprobs[-new_decoder_input_length:].tolist()
|
logprobs = [float("nan")] + decoder_logprobs[
|
||||||
|
-new_decoder_input_length:
|
||||||
|
].tolist()
|
||||||
# Add to the list of finished generations with the original request
|
# Add to the list of finished generations with the original request
|
||||||
generated_texts.append(
|
generated_texts.append(
|
||||||
GeneratedText(
|
GeneratedText(
|
||||||
@ -460,7 +462,7 @@ class Seq2SeqLM(Model):
|
|||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
token_ids=token_ids.tolist(),
|
token_ids=token_ids.tolist(),
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
reason=reason
|
reason=reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# add to the next batch
|
# add to the next batch
|
||||||
|
@ -99,7 +99,7 @@ class StopSequenceCriteria:
|
|||||||
|
|
||||||
class StoppingCriteria:
|
class StoppingCriteria:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20
|
self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20
|
||||||
):
|
):
|
||||||
self.stop_sequence_criterias = stop_sequence_criterias
|
self.stop_sequence_criterias = stop_sequence_criterias
|
||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
@ -119,7 +119,7 @@ class StoppingCriteria:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
|
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
|
||||||
) -> "StoppingCriteria":
|
) -> "StoppingCriteria":
|
||||||
stop_sequence_criterias = []
|
stop_sequence_criterias = []
|
||||||
for stop_sequence in pb.stop_sequences:
|
for stop_sequence in pb.stop_sequences:
|
||||||
|
Loading…
Reference in New Issue
Block a user