feat: update docs refactor and avoid cuda graphs when unavailable

This commit is contained in:
drbh 2024-02-21 11:17:41 -05:00
parent 90d6330819
commit 19c0248985
5 changed files with 30 additions and 15 deletions

View File

@ -118,6 +118,9 @@ class Client:
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
top_n_tokens (`int`): top_n_tokens (`int`):
Return the `n` most likely tokens at each step Return the `n` most likely tokens at each step
grammar (`Grammar`):
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
of the text to match a regular expression or JSON schema.
Returns: Returns:
Response: generated response Response: generated response
@ -209,6 +212,9 @@ class Client:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`): top_n_tokens (`int`):
Return the `n` most likely tokens at each step Return the `n` most likely tokens at each step
grammar (`Grammar`):
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
of the text to match a regular expression or JSON schema.
Returns: Returns:
Iterator[StreamResponse]: stream of generated tokens Iterator[StreamResponse]: stream of generated tokens
@ -372,6 +378,9 @@ class AsyncClient:
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
top_n_tokens (`int`): top_n_tokens (`int`):
Return the `n` most likely tokens at each step Return the `n` most likely tokens at each step
grammar (`Grammar`):
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
of the text to match a regular expression or JSON schema.
Returns: Returns:
Response: generated response Response: generated response
@ -462,6 +471,9 @@ class AsyncClient:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`): top_n_tokens (`int`):
Return the `n` most likely tokens at each step Return the `n` most likely tokens at each step
grammar (`Grammar`):
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
of the text to match a regular expression or JSON schema.
Returns: Returns:
AsyncIterator[StreamResponse]: stream of generated tokens AsyncIterator[StreamResponse]: stream of generated tokens

View File

@ -198,13 +198,12 @@ impl Infer {
pub(crate) fn apply_completion_template( pub(crate) fn apply_completion_template(
&self, &self,
prompt: String, prompt: String,
prefix: Option<String>,
suffix: Option<String>, suffix: Option<String>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
self.completion_template self.completion_template
.as_ref() .as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(prompt, prefix, suffix) .apply(prompt, suffix)
.map_err(|e| { .map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template"); metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}"); tracing::error!("{e}");
@ -386,15 +385,9 @@ impl CompletionTemplate {
Self { template } Self { template }
} }
fn apply( fn apply(&self, prompt: String, suffix: Option<String>) -> Result<String, InferError> {
&self,
prompt: String,
prefix: Option<String>,
suffix: Option<String>,
) -> Result<String, InferError> {
self.template self.template
.render(CompletionTemplateInputs { .render(CompletionTemplateInputs {
prefix: prefix.as_deref(),
prompt: prompt.as_str(), prompt: prompt.as_str(),
suffix: suffix.as_deref(), suffix: suffix.as_deref(),
}) })

View File

@ -595,7 +595,6 @@ pub(crate) struct ChatTemplateInputs<'a> {
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub(crate) struct CompletionTemplateInputs<'a> { pub(crate) struct CompletionTemplateInputs<'a> {
prefix: Option<&'a str>,
prompt: &'a str, prompt: &'a str,
suffix: Option<&'a str>, suffix: Option<&'a str>,
} }

View File

@ -578,7 +578,7 @@ async fn completions(
let stream = req.stream.unwrap_or_default(); let stream = req.stream.unwrap_or_default();
let seed = req.seed; let seed = req.seed;
let inputs = match infer.apply_completion_template(req.prompt, None, req.suffix) { let inputs = match infer.apply_completion_template(req.prompt, req.suffix) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
@ -1263,8 +1263,8 @@ pub async fn run(
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions)) .route("/v1/chat/completions", post(chat_completions))
.route("/vertex", post(vertex_compatibility))
.route("/v1/completions", post(completions)) .route("/v1/completions", post(completions))
.route("/vertex", post(vertex_compatibility))
.route("/tokenize", post(tokenize)) .route("/tokenize", post(tokenize))
.route("/health", get(health)) .route("/health", get(health))
.route("/ping", get(health)) .route("/ping", get(health))

View File

@ -8,7 +8,6 @@ from typing import Optional
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
@ -34,7 +33,6 @@ __all__ = [
"Model", "Model",
"BLOOMSharded", "BLOOMSharded",
"CausalLM", "CausalLM",
"FlashCausalLM",
"GalacticaSharded", "GalacticaSharded",
"Seq2SeqLM", "Seq2SeqLM",
"SantaCoder", "SantaCoder",
@ -45,7 +43,20 @@ __all__ = [
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = True # FlashCausalLM reqiures CUDA Graphs to be enabled on the system. This will throw a RuntimeError
# if CUDA Graphs are not available when calling `torch.cuda.graph_pool_handle()` in the FlashCausalLM
HAS_CUDA_GRAPH = False
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
HAS_CUDA_GRAPH = True
except RuntimeError as e:
logger.warning(f"Could not import FlashCausalLM: {e}")
if HAS_CUDA_GRAPH:
__all__.append(FlashCausalLM)
try: try:
from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoXSharded