diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index bbccbf1d..d83e3fe3 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -118,6 +118,9 @@ class Client: Return the decoder input token logprobs and ids top_n_tokens (`int`): Return the `n` most likely tokens at each step + grammar (`Grammar`): + Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation + of the text to match a regular expression or JSON schema. Returns: Response: generated response @@ -209,6 +212,9 @@ class Client: Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) top_n_tokens (`int`): Return the `n` most likely tokens at each step + grammar (`Grammar`): + Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation + of the text to match a regular expression or JSON schema. Returns: Iterator[StreamResponse]: stream of generated tokens @@ -372,6 +378,9 @@ class AsyncClient: Return the decoder input token logprobs and ids top_n_tokens (`int`): Return the `n` most likely tokens at each step + grammar (`Grammar`): + Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation + of the text to match a regular expression or JSON schema. Returns: Response: generated response @@ -462,6 +471,9 @@ class AsyncClient: Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) top_n_tokens (`int`): Return the `n` most likely tokens at each step + grammar (`Grammar`): + Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation + of the text to match a regular expression or JSON schema. Returns: AsyncIterator[StreamResponse]: stream of generated tokens diff --git a/router/src/infer.rs b/router/src/infer.rs index a208cb14..eaa72a75 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -198,13 +198,12 @@ impl Infer { pub(crate) fn apply_completion_template( &self, prompt: String, - prefix: Option, suffix: Option, ) -> Result { self.completion_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(prompt, prefix, suffix) + .apply(prompt, suffix) .map_err(|e| { metrics::increment_counter!("tgi_request_failure", "err" => "template"); tracing::error!("{e}"); @@ -386,15 +385,9 @@ impl CompletionTemplate { Self { template } } - fn apply( - &self, - prompt: String, - prefix: Option, - suffix: Option, - ) -> Result { + fn apply(&self, prompt: String, suffix: Option) -> Result { self.template .render(CompletionTemplateInputs { - prefix: prefix.as_deref(), prompt: prompt.as_str(), suffix: suffix.as_deref(), }) diff --git a/router/src/lib.rs b/router/src/lib.rs index 24168e8d..f03cb53e 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -595,7 +595,6 @@ pub(crate) struct ChatTemplateInputs<'a> { #[derive(Clone, Serialize, Deserialize)] pub(crate) struct CompletionTemplateInputs<'a> { - prefix: Option<&'a str>, prompt: &'a str, suffix: Option<&'a str>, } diff --git a/router/src/server.rs b/router/src/server.rs index 23190c7b..4dd00745 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -578,7 +578,7 @@ async fn completions( let stream = req.stream.unwrap_or_default(); let seed = req.seed; - let inputs = match infer.apply_completion_template(req.prompt, None, req.suffix) { + let inputs = match infer.apply_completion_template(req.prompt, req.suffix) { Ok(inputs) => inputs, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); @@ -1263,8 +1263,8 @@ pub async fn run( .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) - .route("/vertex", post(vertex_compatibility)) .route("/v1/completions", post(completions)) + .route("/vertex", post(vertex_compatibility)) .route("/tokenize", post(tokenize)) .route("/health", get(health)) .route("/ping", get(health)) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index abab3486..9deef2e1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -8,7 +8,6 @@ from typing import Optional from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM -from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM @@ -34,7 +33,6 @@ __all__ = [ "Model", "BLOOMSharded", "CausalLM", - "FlashCausalLM", "GalacticaSharded", "Seq2SeqLM", "SantaCoder", @@ -45,7 +43,20 @@ __all__ = [ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." -FLASH_ATTENTION = True +# FlashCausalLM reqiures CUDA Graphs to be enabled on the system. This will throw a RuntimeError +# if CUDA Graphs are not available when calling `torch.cuda.graph_pool_handle()` in the FlashCausalLM +HAS_CUDA_GRAPH = False +try: + from text_generation_server.models.flash_causal_lm import FlashCausalLM + + HAS_CUDA_GRAPH = True +except RuntimeError as e: + logger.warning(f"Could not import FlashCausalLM: {e}") + +if HAS_CUDA_GRAPH: + __all__.append(FlashCausalLM) + + try: from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_neox import FlashNeoXSharded