feat: update docs refactor and avoid cuda graphs when unavailable

This commit is contained in:
drbh 2024-02-21 11:17:41 -05:00
parent 90d6330819
commit 19c0248985
5 changed files with 30 additions and 15 deletions

View File

@ -118,6 +118,9 @@ class Client:
Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
grammar (`Grammar`):
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
of the text to match a regular expression or JSON schema.
Returns:
Response: generated response
@ -209,6 +212,9 @@ class Client:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
grammar (`Grammar`):
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
of the text to match a regular expression or JSON schema.
Returns:
Iterator[StreamResponse]: stream of generated tokens
@ -372,6 +378,9 @@ class AsyncClient:
Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
grammar (`Grammar`):
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
of the text to match a regular expression or JSON schema.
Returns:
Response: generated response
@ -462,6 +471,9 @@ class AsyncClient:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
grammar (`Grammar`):
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
of the text to match a regular expression or JSON schema.
Returns:
AsyncIterator[StreamResponse]: stream of generated tokens

View File

@ -198,13 +198,12 @@ impl Infer {
pub(crate) fn apply_completion_template(
&self,
prompt: String,
prefix: Option<String>,
suffix: Option<String>,
) -> Result<String, InferError> {
self.completion_template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(prompt, prefix, suffix)
.apply(prompt, suffix)
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}");
@ -386,15 +385,9 @@ impl CompletionTemplate {
Self { template }
}
fn apply(
&self,
prompt: String,
prefix: Option<String>,
suffix: Option<String>,
) -> Result<String, InferError> {
fn apply(&self, prompt: String, suffix: Option<String>) -> Result<String, InferError> {
self.template
.render(CompletionTemplateInputs {
prefix: prefix.as_deref(),
prompt: prompt.as_str(),
suffix: suffix.as_deref(),
})

View File

@ -595,7 +595,6 @@ pub(crate) struct ChatTemplateInputs<'a> {
#[derive(Clone, Serialize, Deserialize)]
pub(crate) struct CompletionTemplateInputs<'a> {
prefix: Option<&'a str>,
prompt: &'a str,
suffix: Option<&'a str>,
}

View File

@ -578,7 +578,7 @@ async fn completions(
let stream = req.stream.unwrap_or_default();
let seed = req.seed;
let inputs = match infer.apply_completion_template(req.prompt, None, req.suffix) {
let inputs = match infer.apply_completion_template(req.prompt, req.suffix) {
Ok(inputs) => inputs,
Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
@ -1263,8 +1263,8 @@ pub async fn run(
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions))
.route("/vertex", post(vertex_compatibility))
.route("/v1/completions", post(completions))
.route("/vertex", post(vertex_compatibility))
.route("/tokenize", post(tokenize))
.route("/health", get(health))
.route("/ping", get(health))

View File

@ -8,7 +8,6 @@ from typing import Optional
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
@ -34,7 +33,6 @@ __all__ = [
"Model",
"BLOOMSharded",
"CausalLM",
"FlashCausalLM",
"GalacticaSharded",
"Seq2SeqLM",
"SantaCoder",
@ -45,7 +43,20 @@ __all__ = [
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = True
# FlashCausalLM reqiures CUDA Graphs to be enabled on the system. This will throw a RuntimeError
# if CUDA Graphs are not available when calling `torch.cuda.graph_pool_handle()` in the FlashCausalLM
HAS_CUDA_GRAPH = False
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
HAS_CUDA_GRAPH = True
except RuntimeError as e:
logger.warning(f"Could not import FlashCausalLM: {e}")
if HAS_CUDA_GRAPH:
__all__.append(FlashCausalLM)
try:
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded