mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: update docs refactor and avoid cuda graphs when unavailable
This commit is contained in:
parent
90d6330819
commit
19c0248985
@ -118,6 +118,9 @@ class Client:
|
||||
Return the decoder input token logprobs and ids
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
grammar (`Grammar`):
|
||||
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||
of the text to match a regular expression or JSON schema.
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
@ -209,6 +212,9 @@ class Client:
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
grammar (`Grammar`):
|
||||
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||
of the text to match a regular expression or JSON schema.
|
||||
|
||||
Returns:
|
||||
Iterator[StreamResponse]: stream of generated tokens
|
||||
@ -372,6 +378,9 @@ class AsyncClient:
|
||||
Return the decoder input token logprobs and ids
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
grammar (`Grammar`):
|
||||
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||
of the text to match a regular expression or JSON schema.
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
@ -462,6 +471,9 @@ class AsyncClient:
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
grammar (`Grammar`):
|
||||
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||
of the text to match a regular expression or JSON schema.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||
|
@ -198,13 +198,12 @@ impl Infer {
|
||||
pub(crate) fn apply_completion_template(
|
||||
&self,
|
||||
prompt: String,
|
||||
prefix: Option<String>,
|
||||
suffix: Option<String>,
|
||||
) -> Result<String, InferError> {
|
||||
self.completion_template
|
||||
.as_ref()
|
||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||
.apply(prompt, prefix, suffix)
|
||||
.apply(prompt, suffix)
|
||||
.map_err(|e| {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||
tracing::error!("{e}");
|
||||
@ -386,15 +385,9 @@ impl CompletionTemplate {
|
||||
Self { template }
|
||||
}
|
||||
|
||||
fn apply(
|
||||
&self,
|
||||
prompt: String,
|
||||
prefix: Option<String>,
|
||||
suffix: Option<String>,
|
||||
) -> Result<String, InferError> {
|
||||
fn apply(&self, prompt: String, suffix: Option<String>) -> Result<String, InferError> {
|
||||
self.template
|
||||
.render(CompletionTemplateInputs {
|
||||
prefix: prefix.as_deref(),
|
||||
prompt: prompt.as_str(),
|
||||
suffix: suffix.as_deref(),
|
||||
})
|
||||
|
@ -595,7 +595,6 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct CompletionTemplateInputs<'a> {
|
||||
prefix: Option<&'a str>,
|
||||
prompt: &'a str,
|
||||
suffix: Option<&'a str>,
|
||||
}
|
||||
|
@ -578,7 +578,7 @@ async fn completions(
|
||||
let stream = req.stream.unwrap_or_default();
|
||||
let seed = req.seed;
|
||||
|
||||
let inputs = match infer.apply_completion_template(req.prompt, None, req.suffix) {
|
||||
let inputs = match infer.apply_completion_template(req.prompt, req.suffix) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
@ -1263,8 +1263,8 @@ pub async fn run(
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/vertex", post(vertex_compatibility))
|
||||
.route("/v1/completions", post(completions))
|
||||
.route("/vertex", post(vertex_compatibility))
|
||||
.route("/tokenize", post(tokenize))
|
||||
.route("/health", get(health))
|
||||
.route("/ping", get(health))
|
||||
|
@ -8,7 +8,6 @@ from typing import Optional
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.bloom import BLOOMSharded
|
||||
from text_generation_server.models.mpt import MPTSharded
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
@ -34,7 +33,6 @@ __all__ = [
|
||||
"Model",
|
||||
"BLOOMSharded",
|
||||
"CausalLM",
|
||||
"FlashCausalLM",
|
||||
"GalacticaSharded",
|
||||
"Seq2SeqLM",
|
||||
"SantaCoder",
|
||||
@ -45,7 +43,20 @@ __all__ = [
|
||||
|
||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||
|
||||
FLASH_ATTENTION = True
|
||||
# FlashCausalLM reqiures CUDA Graphs to be enabled on the system. This will throw a RuntimeError
|
||||
# if CUDA Graphs are not available when calling `torch.cuda.graph_pool_handle()` in the FlashCausalLM
|
||||
HAS_CUDA_GRAPH = False
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
|
||||
HAS_CUDA_GRAPH = True
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Could not import FlashCausalLM: {e}")
|
||||
|
||||
if HAS_CUDA_GRAPH:
|
||||
__all__.append(FlashCausalLM)
|
||||
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||
|
Loading…
Reference in New Issue
Block a user