mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Minor fixup
This commit is contained in:
parent
d57b7091aa
commit
29d3601457
@ -540,7 +540,6 @@ async fn generate_stream_internal(
|
|||||||
// Inference
|
// Inference
|
||||||
let mut end_reached = false;
|
let mut end_reached = false;
|
||||||
let mut error = false;
|
let mut error = false;
|
||||||
let mut index = 0;
|
|
||||||
|
|
||||||
let mut add_prompt = None;
|
let mut add_prompt = None;
|
||||||
if req.parameters.return_full_text.unwrap_or(false) {
|
if req.parameters.return_full_text.unwrap_or(false) {
|
||||||
@ -563,6 +562,7 @@ async fn generate_stream_internal(
|
|||||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
// Keep permit as long as generate_stream lives
|
// Keep permit as long as generate_stream lives
|
||||||
Ok((_permit, input_length, response_stream)) => {
|
Ok((_permit, input_length, response_stream)) => {
|
||||||
|
let mut index = 0;
|
||||||
let mut response_stream = Box::pin(response_stream);
|
let mut response_stream = Box::pin(response_stream);
|
||||||
// Server-Sent Event stream
|
// Server-Sent Event stream
|
||||||
while let Some(response) = response_stream.next().await {
|
while let Some(response) = response_stream.next().await {
|
||||||
@ -679,7 +679,6 @@ async fn generate_stream_internal(
|
|||||||
if !end_reached && !error {
|
if !end_reached && !error {
|
||||||
let err = InferError::IncompleteGenerationStream;
|
let err = InferError::IncompleteGenerationStream;
|
||||||
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
||||||
tracing::info!("n iterations {index}");
|
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
}
|
}
|
||||||
|
@ -1087,12 +1087,12 @@ class FlashCausalLM(Model):
|
|||||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
(
|
(
|
||||||
torch.zeros(
|
torch.empty(
|
||||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
torch.zeros(
|
torch.empty(
|
||||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
@ -1103,12 +1103,12 @@ class FlashCausalLM(Model):
|
|||||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
(
|
(
|
||||||
torch.zeros(
|
torch.empty(
|
||||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
torch.zeros(
|
torch.empty(
|
||||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
@ -1520,24 +1520,6 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
# assert block_tables.shape[0] >= slots.shape[0]
|
# assert block_tables.shape[0] >= slots.shape[0]
|
||||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
page_size = BLOCK_SIZE
|
|
||||||
indptr = torch.zeros(
|
|
||||||
input_lengths.shape[0] + 1,
|
|
||||||
device=input_lengths.device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
# Round up to page size and then calculate the cumulative sum to get
|
|
||||||
# the indices into the block table.
|
|
||||||
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
|
||||||
indptr[1:].div_(page_size, rounding_mode="floor")
|
|
||||||
indptr[1:].cumsum_(-1)
|
|
||||||
# Get the lengths of the last page in a block.
|
|
||||||
last_page_len = torch.empty(
|
|
||||||
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
|
||||||
)
|
|
||||||
torch.sub(input_lengths, 1, out=last_page_len)
|
|
||||||
last_page_len.remainder_(page_size)
|
|
||||||
last_page_len += 1
|
|
||||||
cuda_graph["slots"].fill_(0)
|
cuda_graph["slots"].fill_(0)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
@ -1898,8 +1880,6 @@ class FlashCausalLM(Model):
|
|||||||
top_tokens,
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# assert all(n is not None for n in next_token_texts)
|
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# accept each new token for this specific request since we may
|
# accept each new token for this specific request since we may
|
||||||
|
Loading…
Reference in New Issue
Block a user