From e6ee67f301a25114d771f2ab9095c0470030aa9f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 28 Aug 2024 10:53:22 +0200 Subject: [PATCH] Truncating left for radix purposes. --- router/src/validation.rs | 3 ++- server/text_generation_server/models/flash_causal_lm.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/router/src/validation.rs b/router/src/validation.rs index 054276c8..92491d88 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -167,7 +167,8 @@ impl Validation { )); } - let input_ids = encoding.get_ids()[..input_length].to_owned(); + let ids = encoding.get_ids(); + let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned(); metrics::histogram!("tgi_request_input_length").record(input_length as f64); Ok((inputs, Some(input_ids), input_length, max_new_tokens)) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 409fe2e3..4ed3f56d 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -272,6 +272,9 @@ class FlashCausalLMBatch(Batch): prefix_len = r.prefix_len assert prefix_len <= orig_input_length + if prefix_len == orig_input_length: + assert prefix_len > 0 + prefix_len -= 1 prefix_ids.append(tokenized_input[:prefix_len]) tokenized_input = tokenized_input[prefix_len:]