Are we done yet ?

This commit is contained in:
Nicolas Patry 2024-09-10 10:24:56 +02:00
parent e128bc540b
commit d57b7091aa
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
11 changed files with 571 additions and 534 deletions

View File

@ -376,10 +376,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// Send generation responses back to the infer task // Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the // If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true) // request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| { let stopped = send_responses(generation, entry).inspect_err(|_err| {
tracing::error!("Entry response channel error."); tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err
}).unwrap_or(true); }).unwrap_or(true);
if stopped { if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug."); entries.remove(&id).expect("ID not found in entries. This is a bug.");

View File

@ -366,7 +366,7 @@ impl State {
break; break;
} }
Some(block_allocation) => { Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}"); // tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation) Some(block_allocation)
} }

View File

@ -123,7 +123,7 @@ impl Allocator for RadixAllocator {
prefill_tokens: prefill_tokens.clone(), prefill_tokens: prefill_tokens.clone(),
}; };
tracing::debug!("Blocks {blocks:?}"); // tracing::debug!("Blocks {blocks:?}");
self.allocation_id += 1; self.allocation_id += 1;
self.allocations.insert(self.allocation_id, allocation); self.allocations.insert(self.allocation_id, allocation);

View File

@ -1,38 +1,38 @@
{ {
"choices": [ "choices": [
{ {
"finish_reason": "stop", "finish_reason": "length",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": " PR for more information?" "text": " This is a question that has puzzled many people for"
}, },
{ {
"finish_reason": "length", "finish_reason": "length",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": "hd20220811-" "text": "usculas_minusculas(s):\n \"\"\"\n"
}, },
{ {
"finish_reason": "length", "finish_reason": "length",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": "le Business Incubator is providing a workspace" "text": " A Beginners Guide\nDeep learning is a subset"
}, },
{ {
"finish_reason": "length", "finish_reason": "length",
"index": 2, "index": 2,
"logprobs": null, "logprobs": null,
"text": " severely flawed and often has a substandard" "text": " Paris\nWhat is the capital of France?\nThe"
} }
], ],
"created": 1722014725, "created": 1725877154,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "2.2.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 36, "completion_tokens": 40,
"prompt_tokens": 8, "prompt_tokens": 22,
"total_tokens": 44 "total_tokens": 62
} }
} }

View File

@ -4,17 +4,17 @@
"finish_reason": "length", "finish_reason": "length",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": "\n2.2 How" "text": " A Beginners Guide\nDeep learning is a subset"
} }
], ],
"created": 1725874238, "created": 1725876621,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "2.2.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 5, "completion_tokens": 10,
"prompt_tokens": 6, "prompt_tokens": 6,
"total_tokens": 11 "total_tokens": 16
} }
} }

View File

@ -11,7 +11,7 @@ from text_generation.types import (
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_completion_handle(launcher): def flash_llama_completion_handle(launcher):
with launcher( with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "meta-llama/Meta-Llama-3.1-8B-Instruct",
) as handle: ) as handle:
yield handle yield handle
@ -35,15 +35,18 @@ def test_flash_llama_completion_single_prompt(
json={ json={
"model": "tgi", "model": "tgi",
"prompt": "What is Deep Learning?", "prompt": "What is Deep Learning?",
"max_tokens": 5, "max_tokens": 10,
"seed": 0, "temperature": 0.0,
}, },
headers=flash_llama_completion.headers, headers=flash_llama_completion.headers,
stream=False, stream=False,
) )
response = response.json() response = response.json()
assert len(response["choices"]) == 1 assert len(response["choices"]) == 1
assert response["choices"][0]["text"] == "\n2.2 How" assert (
response["choices"][0]["text"]
== " A Beginners Guide\nDeep learning is a subset"
)
assert response == response_snapshot assert response == response_snapshot
@ -53,9 +56,15 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
f"{flash_llama_completion.base_url}/v1/completions", f"{flash_llama_completion.base_url}/v1/completions",
json={ json={
"model": "tgi", "model": "tgi",
"prompt": ["Say", "this", "is", "a"], "prompt": [
"What is Deep Learning?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10, "max_tokens": 10,
"seed": 0, "seed": 0,
"temperature": 0.0,
}, },
headers=flash_llama_completion.headers, headers=flash_llama_completion.headers,
stream=False, stream=False,
@ -63,9 +72,16 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
response = response.json() response = response.json()
assert len(response["choices"]) == 4 assert len(response["choices"]) == 4
all_indexes = [choice["index"] for choice in response["choices"]] all_indexes = [(choice["index"], choice["text"]) for choice in response["choices"]]
all_indexes.sort() all_indexes.sort()
assert all_indexes == [0, 1, 2, 3] all_indices, all_strings = zip(*all_indexes)
assert list(all_indices) == [0, 1, 2, 3]
assert list(all_strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert response == response_snapshot assert response == response_snapshot
@ -84,6 +100,7 @@ async def test_flash_llama_completion_many_prompts_stream(
], ],
"max_tokens": 10, "max_tokens": 10,
"seed": 0, "seed": 0,
"temperature": 0.0,
"stream": True, "stream": True,
} }
@ -114,5 +131,10 @@ async def test_flash_llama_completion_many_prompts_stream(
strings[index] += c["choices"][0]["text"] strings[index] += c["choices"][0]["text"]
assert response.status == 200 assert response.status == 200
# assert strings == ["What Business: And Stock Mohs`('\\", '\nrig Business Process And Stock ,s, And', '\n\n202 Stock Mohs a Service', 'hd\n20207\nR1'] assert list(strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert chunks == response_snapshot assert chunks == response_snapshot

View File

@ -1843,9 +1843,8 @@ fn main() -> Result<(), LauncherError> {
shutdown.clone(), shutdown.clone(),
&shutdown_receiver, &shutdown_receiver,
) )
.map_err(|err| { .inspect_err(|_| {
shutdown_shards(shutdown.clone(), &shutdown_receiver); shutdown_shards(shutdown.clone(), &shutdown_receiver);
err
})?; })?;
// Default exit code // Default exit code

View File

@ -336,6 +336,8 @@ pub enum InferError {
ValidationError(#[from] ValidationError), ValidationError(#[from] ValidationError),
#[error("Incomplete generation")] #[error("Incomplete generation")]
IncompleteGeneration, IncompleteGeneration,
#[error("Incomplete generation stream")]
IncompleteGenerationStream,
#[error("Template error: {0}")] #[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error), TemplateError(#[from] minijinja::Error),
#[error("Missing template vatiable: {0}")] #[error("Missing template vatiable: {0}")]
@ -351,6 +353,7 @@ impl InferError {
InferError::Overloaded(_) => "overloaded", InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation", InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation", InferError::IncompleteGeneration => "incomplete_generation",
InferError::IncompleteGenerationStream => "incomplete_generation_stream",
InferError::TemplateError(_) => "template_error", InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable", InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error", InferError::ToolError(_) => "tool_error",

View File

@ -540,6 +540,7 @@ async fn generate_stream_internal(
// Inference // Inference
let mut end_reached = false; let mut end_reached = false;
let mut error = false; let mut error = false;
let mut index = 0;
let mut add_prompt = None; let mut add_prompt = None;
if req.parameters.return_full_text.unwrap_or(false) { if req.parameters.return_full_text.unwrap_or(false) {
@ -562,7 +563,6 @@ async fn generate_stream_internal(
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, input_length, response_stream)) => { Ok((_permit, input_length, response_stream)) => {
let mut index = 0;
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
// Server-Sent Event stream // Server-Sent Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {
@ -677,8 +677,9 @@ async fn generate_stream_internal(
// Check if generation reached the end // Check if generation reached the end
// Skip if we already sent an error // Skip if we already sent an error
if !end_reached && !error { if !end_reached && !error {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGenerationStream;
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::info!("n iterations {index}");
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} }
@ -2558,6 +2559,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::IncompleteGenerationStream => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,

View File

@ -515,6 +515,7 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
assert len(pb.requests) > 0
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@ -640,6 +641,7 @@ class FlashCausalLMBatch(Batch):
adapter_segments = torch.tensor( adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device adapter_segments, dtype=torch.int32, device=device
) )
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
return type(self)( return type(self)(
batch_id=self.batch_id, batch_id=self.batch_id,
@ -834,6 +836,8 @@ class FlashCausalLMBatch(Batch):
start_slots = torch.concat(start_slots) start_slots = torch.concat(start_slots)
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype, dtype=batches[0].next_token_chooser.dtype,
@ -1083,12 +1087,12 @@ class FlashCausalLM(Model):
if ATTENTION in {"flashdecoding", "flashinfer"}: if ATTENTION in {"flashdecoding", "flashinfer"}:
self.kv_cache = [ self.kv_cache = [
( (
torch.empty( torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size), (num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype, dtype=dtype,
device=device, device=device,
), ),
torch.empty( torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size), (num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype, dtype=dtype,
device=device, device=device,
@ -1099,12 +1103,12 @@ class FlashCausalLM(Model):
elif SYSTEM == "ipex" and device == torch.device("cpu"): elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [ self.kv_cache = [
( (
torch.empty( torch.zeros(
(num_blocks, num_heads, BLOCK_SIZE, head_size), (num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype, dtype=dtype,
device=device, device=device,
), ),
torch.empty( torch.zeros(
(num_blocks, num_heads, BLOCK_SIZE, head_size), (num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype, dtype=dtype,
device=device, device=device,
@ -1150,20 +1154,6 @@ class FlashCausalLM(Model):
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lens=prefix_lengths, prefix_lens=prefix_lengths,
) )
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": self.kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
}
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import ( from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs, create_decode_state_cuda_graphs,
) )
@ -1180,19 +1170,29 @@ class FlashCausalLM(Model):
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
) )
self.cuda_graphs[bs]["state"] = state
else: else:
state = None state = None
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": self.kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
"state": state,
"graph": graph,
}
torch.cuda.synchronize() torch.cuda.synchronize()
# Run once outside to warmup # Run once outside to warmup
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor, input_lengths_tensor=input_lengths_tensor,
state=state, state=state,
prefix_lens=prefix_lengths,
prefix_lens_tensor=prefix_lengths_tensor, prefix_lens_tensor=prefix_lengths_tensor,
): ):
seqlen = Seqlen( seqlen = Seqlen(
@ -1214,6 +1214,7 @@ class FlashCausalLM(Model):
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
) )
del seqlen
torch.cuda.synchronize() torch.cuda.synchronize()
@ -1479,9 +1480,7 @@ class FlashCausalLM(Model):
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths, input_lengths_tensor=input_lengths,
input_lengths_tensor=input_lengths + prefix_lens_tensor,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor, prefix_lens_tensor=prefix_lens_tensor,
): ):
max_k = (input_lengths + prefix_lens_tensor).max().item() max_k = (input_lengths + prefix_lens_tensor).max().item()
@ -1519,12 +1518,27 @@ class FlashCausalLM(Model):
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
) )
# assert block_tables.shape[0] >= slots.shape[0]
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else: page_size = BLOCK_SIZE
cuda_graph["block_tables"][ indptr = torch.zeros(
: block_tables.shape[0], : block_tables.shape[1] input_lengths.shape[0] + 1,
] = block_tables device=input_lengths.device,
cuda_graph["slots"].fill_(-1) dtype=torch.int32,
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
@ -1534,11 +1548,9 @@ class FlashCausalLM(Model):
with self._forward_context( with self._forward_context(
block_tables=cuda_graph["block_tables"], block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=batch.input_lengths,
input_lengths_tensor=cuda_graph["input_lengths"], input_lengths_tensor=cuda_graph["input_lengths"],
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=cuda_graph["prefix_lengths"], prefix_lens_tensor=cuda_graph["prefix_lengths"],
state=cuda_graph.get("state"), state=cuda_graph["state"],
): ):
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
@ -1767,7 +1779,7 @@ class FlashCausalLM(Model):
left = 0 left = 0
if n_accepted_ids > 1: if n_accepted_ids > 1:
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}") log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
current_stopped = False current_stopped = False
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):
@ -1886,6 +1898,8 @@ class FlashCausalLM(Model):
top_tokens, top_tokens,
) )
# assert all(n is not None for n in next_token_texts)
generations.append(generation) generations.append(generation)
# accept each new token for this specific request since we may # accept each new token for this specific request since we may
@ -1922,9 +1936,7 @@ class FlashCausalLM(Model):
*, *,
block_tables: torch.Tensor, block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
input_lengths: List[int],
input_lengths_tensor: torch.Tensor, input_lengths_tensor: torch.Tensor,
prefix_lens: List[int],
prefix_lens_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> ContextManager: ) -> ContextManager:
@ -1950,7 +1962,7 @@ class FlashCausalLM(Model):
# ), # ),
block_tables=block_tables, block_tables=block_tables,
cu_seqlens=cu_seqlen_prefill, cu_seqlens=cu_seqlen_prefill,
input_lengths=input_lengths_tensor, input_lengths=input_lengths_tensor + prefix_lens_tensor,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
@ -1960,7 +1972,7 @@ class FlashCausalLM(Model):
assert input_lengths_tensor is not None assert input_lengths_tensor is not None
return use_decode_state( return use_decode_state(
state=state if state is not None else self.decode_state, state=state if state is not None else self.decode_state,
input_lengths=input_lengths_tensor, input_lengths=input_lengths_tensor + prefix_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,