mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Are we done yet ?
This commit is contained in:
parent
e128bc540b
commit
d57b7091aa
@ -376,10 +376,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
// Send generation responses back to the infer task
|
// Send generation responses back to the infer task
|
||||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||||
// request and we need to stop generating hence why we unwrap_or(true)
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
let stopped = send_responses(generation, entry).inspect_err(|_err| {
|
||||||
tracing::error!("Entry response channel error.");
|
tracing::error!("Entry response channel error.");
|
||||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
err
|
|
||||||
}).unwrap_or(true);
|
}).unwrap_or(true);
|
||||||
if stopped {
|
if stopped {
|
||||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
|
@ -366,7 +366,7 @@ impl State {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Some(block_allocation) => {
|
Some(block_allocation) => {
|
||||||
tracing::debug!("Allocation: {block_allocation:?}");
|
// tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||||
Some(block_allocation)
|
Some(block_allocation)
|
||||||
}
|
}
|
||||||
|
@ -123,7 +123,7 @@ impl Allocator for RadixAllocator {
|
|||||||
prefill_tokens: prefill_tokens.clone(),
|
prefill_tokens: prefill_tokens.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
tracing::debug!("Blocks {blocks:?}");
|
// tracing::debug!("Blocks {blocks:?}");
|
||||||
|
|
||||||
self.allocation_id += 1;
|
self.allocation_id += 1;
|
||||||
self.allocations.insert(self.allocation_id, allocation);
|
self.allocations.insert(self.allocation_id, allocation);
|
||||||
|
@ -1,38 +1,38 @@
|
|||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"finish_reason": "stop",
|
"finish_reason": "length",
|
||||||
"index": 1,
|
"index": 1,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"text": " PR for more information?"
|
"text": " This is a question that has puzzled many people for"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
"index": 3,
|
"index": 3,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"text": "hd20220811-"
|
"text": "usculas_minusculas(s):\n \"\"\"\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"text": "le Business Incubator is providing a workspace"
|
"text": " A Beginner’s Guide\nDeep learning is a subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
"index": 2,
|
"index": 2,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"text": " severely flawed and often has a substandard"
|
"text": " Paris\nWhat is the capital of France?\nThe"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1722014725,
|
"created": 1725877154,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.2.1-dev0-native",
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 36,
|
"completion_tokens": 40,
|
||||||
"prompt_tokens": 8,
|
"prompt_tokens": 22,
|
||||||
"total_tokens": 44
|
"total_tokens": 62
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -4,17 +4,17 @@
|
|||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"text": "\n2.2 How"
|
"text": " A Beginner’s Guide\nDeep learning is a subset"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1725874238,
|
"created": 1725876621,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.2.1-dev0-native",
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 5,
|
"completion_tokens": 10,
|
||||||
"prompt_tokens": 6,
|
"prompt_tokens": 6,
|
||||||
"total_tokens": 11
|
"total_tokens": 16
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,7 @@ from text_generation.types import (
|
|||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_llama_completion_handle(launcher):
|
def flash_llama_completion_handle(launcher):
|
||||||
with launcher(
|
with launcher(
|
||||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
) as handle:
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
@ -35,15 +35,18 @@ def test_flash_llama_completion_single_prompt(
|
|||||||
json={
|
json={
|
||||||
"model": "tgi",
|
"model": "tgi",
|
||||||
"prompt": "What is Deep Learning?",
|
"prompt": "What is Deep Learning?",
|
||||||
"max_tokens": 5,
|
"max_tokens": 10,
|
||||||
"seed": 0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
headers=flash_llama_completion.headers,
|
headers=flash_llama_completion.headers,
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
response = response.json()
|
response = response.json()
|
||||||
assert len(response["choices"]) == 1
|
assert len(response["choices"]) == 1
|
||||||
assert response["choices"][0]["text"] == "\n2.2 How"
|
assert (
|
||||||
|
response["choices"][0]["text"]
|
||||||
|
== " A Beginner’s Guide\nDeep learning is a subset"
|
||||||
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -53,9 +56,15 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
|
|||||||
f"{flash_llama_completion.base_url}/v1/completions",
|
f"{flash_llama_completion.base_url}/v1/completions",
|
||||||
json={
|
json={
|
||||||
"model": "tgi",
|
"model": "tgi",
|
||||||
"prompt": ["Say", "this", "is", "a"],
|
"prompt": [
|
||||||
|
"What is Deep Learning?",
|
||||||
|
"Is water wet?",
|
||||||
|
"What is the capital of France?",
|
||||||
|
"def mai",
|
||||||
|
],
|
||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
headers=flash_llama_completion.headers,
|
headers=flash_llama_completion.headers,
|
||||||
stream=False,
|
stream=False,
|
||||||
@ -63,9 +72,16 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
|
|||||||
response = response.json()
|
response = response.json()
|
||||||
assert len(response["choices"]) == 4
|
assert len(response["choices"]) == 4
|
||||||
|
|
||||||
all_indexes = [choice["index"] for choice in response["choices"]]
|
all_indexes = [(choice["index"], choice["text"]) for choice in response["choices"]]
|
||||||
all_indexes.sort()
|
all_indexes.sort()
|
||||||
assert all_indexes == [0, 1, 2, 3]
|
all_indices, all_strings = zip(*all_indexes)
|
||||||
|
assert list(all_indices) == [0, 1, 2, 3]
|
||||||
|
assert list(all_strings) == [
|
||||||
|
" A Beginner’s Guide\nDeep learning is a subset",
|
||||||
|
" This is a question that has puzzled many people for",
|
||||||
|
" Paris\nWhat is the capital of France?\nThe",
|
||||||
|
'usculas_minusculas(s):\n """\n',
|
||||||
|
]
|
||||||
|
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
@ -84,6 +100,7 @@ async def test_flash_llama_completion_many_prompts_stream(
|
|||||||
],
|
],
|
||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
|
"temperature": 0.0,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,5 +131,10 @@ async def test_flash_llama_completion_many_prompts_stream(
|
|||||||
strings[index] += c["choices"][0]["text"]
|
strings[index] += c["choices"][0]["text"]
|
||||||
|
|
||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
# assert strings == ["What Business: And Stock Mohs`('\\", '\nrig Business Process And Stock ,s, And', '\n\n202 Stock Mohs a Service', 'hd\n20207\nR1']
|
assert list(strings) == [
|
||||||
|
" A Beginner’s Guide\nDeep learning is a subset",
|
||||||
|
" This is a question that has puzzled many people for",
|
||||||
|
" Paris\nWhat is the capital of France?\nThe",
|
||||||
|
'usculas_minusculas(s):\n """\n',
|
||||||
|
]
|
||||||
assert chunks == response_snapshot
|
assert chunks == response_snapshot
|
||||||
|
@ -1843,9 +1843,8 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
shutdown.clone(),
|
shutdown.clone(),
|
||||||
&shutdown_receiver,
|
&shutdown_receiver,
|
||||||
)
|
)
|
||||||
.map_err(|err| {
|
.inspect_err(|_| {
|
||||||
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||||
err
|
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Default exit code
|
// Default exit code
|
||||||
|
@ -336,6 +336,8 @@ pub enum InferError {
|
|||||||
ValidationError(#[from] ValidationError),
|
ValidationError(#[from] ValidationError),
|
||||||
#[error("Incomplete generation")]
|
#[error("Incomplete generation")]
|
||||||
IncompleteGeneration,
|
IncompleteGeneration,
|
||||||
|
#[error("Incomplete generation stream")]
|
||||||
|
IncompleteGenerationStream,
|
||||||
#[error("Template error: {0}")]
|
#[error("Template error: {0}")]
|
||||||
TemplateError(#[from] minijinja::Error),
|
TemplateError(#[from] minijinja::Error),
|
||||||
#[error("Missing template vatiable: {0}")]
|
#[error("Missing template vatiable: {0}")]
|
||||||
@ -351,6 +353,7 @@ impl InferError {
|
|||||||
InferError::Overloaded(_) => "overloaded",
|
InferError::Overloaded(_) => "overloaded",
|
||||||
InferError::ValidationError(_) => "validation",
|
InferError::ValidationError(_) => "validation",
|
||||||
InferError::IncompleteGeneration => "incomplete_generation",
|
InferError::IncompleteGeneration => "incomplete_generation",
|
||||||
|
InferError::IncompleteGenerationStream => "incomplete_generation_stream",
|
||||||
InferError::TemplateError(_) => "template_error",
|
InferError::TemplateError(_) => "template_error",
|
||||||
InferError::MissingTemplateVariable(_) => "missing_template_variable",
|
InferError::MissingTemplateVariable(_) => "missing_template_variable",
|
||||||
InferError::ToolError(_) => "tool_error",
|
InferError::ToolError(_) => "tool_error",
|
||||||
|
@ -540,6 +540,7 @@ async fn generate_stream_internal(
|
|||||||
// Inference
|
// Inference
|
||||||
let mut end_reached = false;
|
let mut end_reached = false;
|
||||||
let mut error = false;
|
let mut error = false;
|
||||||
|
let mut index = 0;
|
||||||
|
|
||||||
let mut add_prompt = None;
|
let mut add_prompt = None;
|
||||||
if req.parameters.return_full_text.unwrap_or(false) {
|
if req.parameters.return_full_text.unwrap_or(false) {
|
||||||
@ -562,7 +563,6 @@ async fn generate_stream_internal(
|
|||||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
// Keep permit as long as generate_stream lives
|
// Keep permit as long as generate_stream lives
|
||||||
Ok((_permit, input_length, response_stream)) => {
|
Ok((_permit, input_length, response_stream)) => {
|
||||||
let mut index = 0;
|
|
||||||
let mut response_stream = Box::pin(response_stream);
|
let mut response_stream = Box::pin(response_stream);
|
||||||
// Server-Sent Event stream
|
// Server-Sent Event stream
|
||||||
while let Some(response) = response_stream.next().await {
|
while let Some(response) = response_stream.next().await {
|
||||||
@ -677,8 +677,9 @@ async fn generate_stream_internal(
|
|||||||
// Check if generation reached the end
|
// Check if generation reached the end
|
||||||
// Skip if we already sent an error
|
// Skip if we already sent an error
|
||||||
if !end_reached && !error {
|
if !end_reached && !error {
|
||||||
let err = InferError::IncompleteGeneration;
|
let err = InferError::IncompleteGenerationStream;
|
||||||
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
||||||
|
tracing::info!("n iterations {index}");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
}
|
}
|
||||||
@ -2558,6 +2559,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||||||
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
|
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
|
||||||
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
InferError::IncompleteGenerationStream => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
@ -515,6 +515,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
|
assert len(pb.requests) > 0
|
||||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||||
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
|
|
||||||
@ -640,6 +641,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
adapter_segments = torch.tensor(
|
adapter_segments = torch.tensor(
|
||||||
adapter_segments, dtype=torch.int32, device=device
|
adapter_segments, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
|
||||||
|
|
||||||
return type(self)(
|
return type(self)(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
@ -834,6 +836,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
start_slots = torch.concat(start_slots)
|
start_slots = torch.concat(start_slots)
|
||||||
|
|
||||||
|
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters,
|
next_token_chooser_parameters,
|
||||||
dtype=batches[0].next_token_chooser.dtype,
|
dtype=batches[0].next_token_chooser.dtype,
|
||||||
@ -1083,12 +1087,12 @@ class FlashCausalLM(Model):
|
|||||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
(
|
(
|
||||||
torch.empty(
|
torch.zeros(
|
||||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
torch.empty(
|
torch.zeros(
|
||||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
@ -1099,12 +1103,12 @@ class FlashCausalLM(Model):
|
|||||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
(
|
(
|
||||||
torch.empty(
|
torch.zeros(
|
||||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
torch.empty(
|
torch.zeros(
|
||||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
@ -1150,20 +1154,6 @@ class FlashCausalLM(Model):
|
|||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
prefix_lens=prefix_lengths,
|
prefix_lens=prefix_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cuda_graphs[bs] = {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"position_ids": position_ids,
|
|
||||||
"kv_cache": self.kv_cache,
|
|
||||||
"block_tables": block_tables,
|
|
||||||
"slots": slots,
|
|
||||||
"input_lengths": input_lengths_tensor,
|
|
||||||
"prefix_lengths": prefix_lengths_tensor,
|
|
||||||
}
|
|
||||||
graph = torch.cuda.CUDAGraph()
|
|
||||||
self.cuda_graphs[bs]["graph"] = graph
|
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
|
||||||
from text_generation_server.layers.attention.flashinfer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
create_decode_state_cuda_graphs,
|
create_decode_state_cuda_graphs,
|
||||||
)
|
)
|
||||||
@ -1180,19 +1170,29 @@ class FlashCausalLM(Model):
|
|||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
)
|
)
|
||||||
self.cuda_graphs[bs]["state"] = state
|
|
||||||
else:
|
else:
|
||||||
state = None
|
state = None
|
||||||
|
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
self.cuda_graphs[bs] = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"kv_cache": self.kv_cache,
|
||||||
|
"block_tables": block_tables,
|
||||||
|
"slots": slots,
|
||||||
|
"input_lengths": input_lengths_tensor,
|
||||||
|
"prefix_lengths": prefix_lengths_tensor,
|
||||||
|
"state": state,
|
||||||
|
"graph": graph,
|
||||||
|
}
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
# Run once outside to warmup
|
# Run once outside to warmup
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths=input_lengths,
|
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
state=state,
|
state=state,
|
||||||
prefix_lens=prefix_lengths,
|
|
||||||
prefix_lens_tensor=prefix_lengths_tensor,
|
prefix_lens_tensor=prefix_lengths_tensor,
|
||||||
):
|
):
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
@ -1214,6 +1214,7 @@ class FlashCausalLM(Model):
|
|||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
)
|
)
|
||||||
|
del seqlen
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@ -1479,9 +1480,7 @@ class FlashCausalLM(Model):
|
|||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths_tensor=input_lengths,
|
||||||
input_lengths_tensor=input_lengths + prefix_lens_tensor,
|
|
||||||
prefix_lens=batch.prefix_lens,
|
|
||||||
prefix_lens_tensor=prefix_lens_tensor,
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
):
|
):
|
||||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||||
@ -1519,12 +1518,27 @@ class FlashCausalLM(Model):
|
|||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
prefix_lens=batch.prefix_lens,
|
prefix_lens=batch.prefix_lens,
|
||||||
)
|
)
|
||||||
|
# assert block_tables.shape[0] >= slots.shape[0]
|
||||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
else:
|
page_size = BLOCK_SIZE
|
||||||
cuda_graph["block_tables"][
|
indptr = torch.zeros(
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
input_lengths.shape[0] + 1,
|
||||||
] = block_tables
|
device=input_lengths.device,
|
||||||
cuda_graph["slots"].fill_(-1)
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
# Round up to page size and then calculate the cumulative sum to get
|
||||||
|
# the indices into the block table.
|
||||||
|
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
||||||
|
indptr[1:].div_(page_size, rounding_mode="floor")
|
||||||
|
indptr[1:].cumsum_(-1)
|
||||||
|
# Get the lengths of the last page in a block.
|
||||||
|
last_page_len = torch.empty(
|
||||||
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||||
|
)
|
||||||
|
torch.sub(input_lengths, 1, out=last_page_len)
|
||||||
|
last_page_len.remainder_(page_size)
|
||||||
|
last_page_len += 1
|
||||||
|
cuda_graph["slots"].fill_(0)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
@ -1534,11 +1548,9 @@ class FlashCausalLM(Model):
|
|||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=cuda_graph["block_tables"],
|
block_tables=cuda_graph["block_tables"],
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths=batch.input_lengths,
|
|
||||||
input_lengths_tensor=cuda_graph["input_lengths"],
|
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||||
prefix_lens=batch.prefix_lens,
|
|
||||||
prefix_lens_tensor=cuda_graph["prefix_lengths"],
|
prefix_lens_tensor=cuda_graph["prefix_lengths"],
|
||||||
state=cuda_graph.get("state"),
|
state=cuda_graph["state"],
|
||||||
):
|
):
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
@ -1767,7 +1779,7 @@ class FlashCausalLM(Model):
|
|||||||
left = 0
|
left = 0
|
||||||
|
|
||||||
if n_accepted_ids > 1:
|
if n_accepted_ids > 1:
|
||||||
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
|
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
|
||||||
|
|
||||||
current_stopped = False
|
current_stopped = False
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
@ -1886,6 +1898,8 @@ class FlashCausalLM(Model):
|
|||||||
top_tokens,
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# assert all(n is not None for n in next_token_texts)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# accept each new token for this specific request since we may
|
# accept each new token for this specific request since we may
|
||||||
@ -1922,9 +1936,7 @@ class FlashCausalLM(Model):
|
|||||||
*,
|
*,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
input_lengths: List[int],
|
|
||||||
input_lengths_tensor: torch.Tensor,
|
input_lengths_tensor: torch.Tensor,
|
||||||
prefix_lens: List[int],
|
|
||||||
prefix_lens_tensor: torch.Tensor,
|
prefix_lens_tensor: torch.Tensor,
|
||||||
state: Optional[Any] = None,
|
state: Optional[Any] = None,
|
||||||
) -> ContextManager:
|
) -> ContextManager:
|
||||||
@ -1950,7 +1962,7 @@ class FlashCausalLM(Model):
|
|||||||
# ),
|
# ),
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlens=cu_seqlen_prefill,
|
cu_seqlens=cu_seqlen_prefill,
|
||||||
input_lengths=input_lengths_tensor,
|
input_lengths=input_lengths_tensor + prefix_lens_tensor,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
@ -1960,7 +1972,7 @@ class FlashCausalLM(Model):
|
|||||||
assert input_lengths_tensor is not None
|
assert input_lengths_tensor is not None
|
||||||
return use_decode_state(
|
return use_decode_state(
|
||||||
state=state if state is not None else self.decode_state,
|
state=state if state is not None else self.decode_state,
|
||||||
input_lengths=input_lengths_tensor,
|
input_lengths=input_lengths_tensor + prefix_lens_tensor,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
Loading…
Reference in New Issue
Block a user