mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Minor fixup
This commit is contained in:
parent
d57b7091aa
commit
29d3601457
@ -540,7 +540,6 @@ async fn generate_stream_internal(
|
||||
// Inference
|
||||
let mut end_reached = false;
|
||||
let mut error = false;
|
||||
let mut index = 0;
|
||||
|
||||
let mut add_prompt = None;
|
||||
if req.parameters.return_full_text.unwrap_or(false) {
|
||||
@ -563,6 +562,7 @@ async fn generate_stream_internal(
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, input_length, response_stream)) => {
|
||||
let mut index = 0;
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
// Server-Sent Event stream
|
||||
while let Some(response) = response_stream.next().await {
|
||||
@ -679,7 +679,6 @@ async fn generate_stream_internal(
|
||||
if !end_reached && !error {
|
||||
let err = InferError::IncompleteGenerationStream;
|
||||
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
||||
tracing::info!("n iterations {index}");
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
}
|
||||
|
@ -1087,12 +1087,12 @@ class FlashCausalLM(Model):
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.zeros(
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.zeros(
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
@ -1103,12 +1103,12 @@ class FlashCausalLM(Model):
|
||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.zeros(
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.zeros(
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
@ -1520,24 +1520,6 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
# assert block_tables.shape[0] >= slots.shape[0]
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
page_size = BLOCK_SIZE
|
||||
indptr = torch.zeros(
|
||||
input_lengths.shape[0] + 1,
|
||||
device=input_lengths.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
# Round up to page size and then calculate the cumulative sum to get
|
||||
# the indices into the block table.
|
||||
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
||||
indptr[1:].div_(page_size, rounding_mode="floor")
|
||||
indptr[1:].cumsum_(-1)
|
||||
# Get the lengths of the last page in a block.
|
||||
last_page_len = torch.empty(
|
||||
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||
)
|
||||
torch.sub(input_lengths, 1, out=last_page_len)
|
||||
last_page_len.remainder_(page_size)
|
||||
last_page_len += 1
|
||||
cuda_graph["slots"].fill_(0)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
@ -1898,8 +1880,6 @@ class FlashCausalLM(Model):
|
||||
top_tokens,
|
||||
)
|
||||
|
||||
# assert all(n is not None for n in next_token_texts)
|
||||
|
||||
generations.append(generation)
|
||||
|
||||
# accept each new token for this specific request since we may
|
||||
|
Loading…
Reference in New Issue
Block a user