diff --git a/router/src/server.rs b/router/src/server.rs index 913f2011..6a04ab00 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -540,7 +540,6 @@ async fn generate_stream_internal( // Inference let mut end_reached = false; let mut error = false; - let mut index = 0; let mut add_prompt = None; if req.parameters.return_full_text.unwrap_or(false) { @@ -563,6 +562,7 @@ async fn generate_stream_internal( match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives Ok((_permit, input_length, response_stream)) => { + let mut index = 0; let mut response_stream = Box::pin(response_stream); // Server-Sent Event stream while let Some(response) = response_stream.next().await { @@ -679,7 +679,6 @@ async fn generate_stream_internal( if !end_reached && !error { let err = InferError::IncompleteGenerationStream; metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); - tracing::info!("n iterations {index}"); tracing::error!("{err}"); yield Ok(Event::from(err)); } diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d7308041..0bff6ce8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1087,12 +1087,12 @@ class FlashCausalLM(Model): if ATTENTION in {"flashdecoding", "flashinfer"}: self.kv_cache = [ ( - torch.zeros( + torch.empty( (num_blocks, BLOCK_SIZE, num_heads, head_size), dtype=dtype, device=device, ), - torch.zeros( + torch.empty( (num_blocks, BLOCK_SIZE, num_heads, head_size), dtype=dtype, device=device, @@ -1103,12 +1103,12 @@ class FlashCausalLM(Model): elif SYSTEM == "ipex" and device == torch.device("cpu"): self.kv_cache = [ ( - torch.zeros( + torch.empty( (num_blocks, num_heads, BLOCK_SIZE, head_size), dtype=dtype, device=device, ), - torch.zeros( + torch.empty( (num_blocks, num_heads, BLOCK_SIZE, head_size), dtype=dtype, device=device, @@ -1520,24 +1520,6 @@ class FlashCausalLM(Model): ) # assert block_tables.shape[0] >= slots.shape[0] cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables - page_size = BLOCK_SIZE - indptr = torch.zeros( - input_lengths.shape[0] + 1, - device=input_lengths.device, - dtype=torch.int32, - ) - # Round up to page size and then calculate the cumulative sum to get - # the indices into the block table. - torch.add(input_lengths, page_size - 1, out=indptr[1:]) - indptr[1:].div_(page_size, rounding_mode="floor") - indptr[1:].cumsum_(-1) - # Get the lengths of the last page in a block. - last_page_len = torch.empty( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) - torch.sub(input_lengths, 1, out=last_page_len) - last_page_len.remainder_(page_size) - last_page_len += 1 cuda_graph["slots"].fill_(0) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() @@ -1898,8 +1880,6 @@ class FlashCausalLM(Model): top_tokens, ) - # assert all(n is not None for n in next_token_texts) - generations.append(generation) # accept each new token for this specific request since we may