diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 7727f54e..898e2b11 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -8,8 +8,6 @@ use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; - - /// Text Generation Inference gRPC client #[derive(Debug, Clone)] pub struct Client { @@ -163,7 +161,11 @@ impl Client { ) -> Result<(Vec, Option, PrefillTimings)> { let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); let response = self.stub.prefill(request).await?.into_inner(); - Ok((response.generations, response.batch, PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns))) + Ok(( + response.generations, + response.batch, + PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + )) } /// Generate one token for each request in the given cached batches @@ -177,7 +179,16 @@ impl Client { ) -> Result<(Vec, Option, DecodeTimings)> { let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); let response = self.stub.decode(request).await?.into_inner(); - Ok((response.generations, response.batch, DecodeTimings::new(response.concat_ns, response.forward_ns, response.decode_ns, response.total_ns))) + Ok(( + response.generations, + response.batch, + DecodeTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) } } @@ -213,4 +224,4 @@ impl DecodeTimings { total: Duration::from_nanos(total_ns), } } -} \ No newline at end of file +} diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 827696ad..6c5da3c7 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,10 +1,10 @@ +use crate::client::{DecodeTimings, PrefillTimings}; /// Multi shard Client use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; use crate::{ClientError, Result}; use futures::future::join_all; use tonic::transport::Uri; use tracing::instrument; -use crate::client::{DecodeTimings, PrefillTimings}; #[derive(Debug, Clone)] /// Text Generation Inference gRPC multi client @@ -131,7 +131,8 @@ impl ShardedClient { join_all(futures).await.into_iter().collect(); let mut results = results?; - let (mut generations, next_batch, mut timings) = results.pop().ok_or(ClientError::EmptyResults)?; + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; // Merge generations from different model shards for (mut shard_generations, _, shard_timings) in results.into_iter() { @@ -162,7 +163,8 @@ impl ShardedClient { join_all(futures).await.into_iter().collect(); let mut results = results?; - let (mut generations, next_batch, mut timings) = results.pop().ok_or(ClientError::EmptyResults)?; + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; // Merge generations from different model shards for (mut shard_generations, _, shard_timings) in results.into_iter() { diff --git a/router/src/validation.rs b/router/src/validation.rs index 1b47fc97..8a732bf8 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -229,6 +229,15 @@ impl Validation { stop_sequences.len(), )); } + for s in stop_sequences { + if s.chars().len() > 50 { + return Err(ValidationError::StopSequence( + self.max_stop_sequences, + stop_sequences.len(), + )); + + } + } // If seed is None, assign a random one let seed = match seed { diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index cc0c8a32..930082cd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -943,6 +943,7 @@ class FlashCausalLM(Model): # GPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() next_token_ids = next_input_ids.tolist() + accepted_ids = accepted_ids.tolist() start_decode = time.time_ns() # Zipped iterator @@ -980,7 +981,6 @@ class FlashCausalLM(Model): # Append next token to all tokens next_token_texts = [] left = 0 - before = stopping_criteria.current_tokens current_stopped = False for j in range(index, index + n_accepted_ids): @@ -1095,7 +1095,7 @@ class FlashCausalLM(Model): generations.append(generation) # Update values - batch.input_lengths[i] = input_length + n_accepted_ids.item() + batch.input_lengths[i] = input_length + n_accepted_ids if batch.input_lengths[i] > batch.max_seqlen: batch.max_seqlen = batch.input_lengths[i] batch.prefix_offsets[i] = prefix_offset diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 0d208104..ff0556df 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -92,7 +92,7 @@ class NextTokenChooser: class StopSequenceCriteria: def __init__(self, stop_sequence: str): stop_sequence = re.escape(stop_sequence) - self.regex = re.compile(f".*{stop_sequence}$") + self.regex = re.compile(f"{stop_sequence}$") def __call__(self, output: str) -> bool: if self.regex.findall(output):