diff --git a/router/src/validation.rs b/router/src/validation.rs index 054276c8..92491d88 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -167,7 +167,8 @@ impl Validation { )); } - let input_ids = encoding.get_ids()[..input_length].to_owned(); + let ids = encoding.get_ids(); + let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned(); metrics::histogram!("tgi_request_input_length").record(input_length as f64); Ok((inputs, Some(input_ids), input_length, max_new_tokens)) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 409fe2e3..4ed3f56d 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -272,6 +272,9 @@ class FlashCausalLMBatch(Batch): prefix_len = r.prefix_len assert prefix_len <= orig_input_length + if prefix_len == orig_input_length: + assert prefix_len > 0 + prefix_len -= 1 prefix_ids.append(tokenized_input[:prefix_len]) tokenized_input = tokenized_input[prefix_len:]