Minor fixup

This commit is contained in:
Nicolas Patry 2024-09-10 10:52:02 +02:00
parent d57b7091aa
commit 29d3601457
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
2 changed files with 5 additions and 26 deletions

View File

@ -540,7 +540,6 @@ async fn generate_stream_internal(
// Inference // Inference
let mut end_reached = false; let mut end_reached = false;
let mut error = false; let mut error = false;
let mut index = 0;
let mut add_prompt = None; let mut add_prompt = None;
if req.parameters.return_full_text.unwrap_or(false) { if req.parameters.return_full_text.unwrap_or(false) {
@ -563,6 +562,7 @@ async fn generate_stream_internal(
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, input_length, response_stream)) => { Ok((_permit, input_length, response_stream)) => {
let mut index = 0;
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
// Server-Sent Event stream // Server-Sent Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {
@ -679,7 +679,6 @@ async fn generate_stream_internal(
if !end_reached && !error { if !end_reached && !error {
let err = InferError::IncompleteGenerationStream; let err = InferError::IncompleteGenerationStream;
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::info!("n iterations {index}");
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} }

View File

@ -1087,12 +1087,12 @@ class FlashCausalLM(Model):
if ATTENTION in {"flashdecoding", "flashinfer"}: if ATTENTION in {"flashdecoding", "flashinfer"}:
self.kv_cache = [ self.kv_cache = [
( (
torch.zeros( torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size), (num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype, dtype=dtype,
device=device, device=device,
), ),
torch.zeros( torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size), (num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype, dtype=dtype,
device=device, device=device,
@ -1103,12 +1103,12 @@ class FlashCausalLM(Model):
elif SYSTEM == "ipex" and device == torch.device("cpu"): elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [ self.kv_cache = [
( (
torch.zeros( torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size), (num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype, dtype=dtype,
device=device, device=device,
), ),
torch.zeros( torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size), (num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype, dtype=dtype,
device=device, device=device,
@ -1520,24 +1520,6 @@ class FlashCausalLM(Model):
) )
# assert block_tables.shape[0] >= slots.shape[0] # assert block_tables.shape[0] >= slots.shape[0]
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
page_size = BLOCK_SIZE
indptr = torch.zeros(
input_lengths.shape[0] + 1,
device=input_lengths.device,
dtype=torch.int32,
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
cuda_graph["slots"].fill_(0) cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
@ -1898,8 +1880,6 @@ class FlashCausalLM(Model):
top_tokens, top_tokens,
) )
# assert all(n is not None for n in next_token_texts)
generations.append(generation) generations.append(generation)
# accept each new token for this specific request since we may # accept each new token for this specific request since we may