mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: update docs refactor and avoid cuda graphs when unavailable
This commit is contained in:
parent
90d6330819
commit
19c0248985
@ -118,6 +118,9 @@ class Client:
|
|||||||
Return the decoder input token logprobs and ids
|
Return the decoder input token logprobs and ids
|
||||||
top_n_tokens (`int`):
|
top_n_tokens (`int`):
|
||||||
Return the `n` most likely tokens at each step
|
Return the `n` most likely tokens at each step
|
||||||
|
grammar (`Grammar`):
|
||||||
|
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||||
|
of the text to match a regular expression or JSON schema.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response: generated response
|
Response: generated response
|
||||||
@ -209,6 +212,9 @@ class Client:
|
|||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
top_n_tokens (`int`):
|
top_n_tokens (`int`):
|
||||||
Return the `n` most likely tokens at each step
|
Return the `n` most likely tokens at each step
|
||||||
|
grammar (`Grammar`):
|
||||||
|
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||||
|
of the text to match a regular expression or JSON schema.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Iterator[StreamResponse]: stream of generated tokens
|
Iterator[StreamResponse]: stream of generated tokens
|
||||||
@ -372,6 +378,9 @@ class AsyncClient:
|
|||||||
Return the decoder input token logprobs and ids
|
Return the decoder input token logprobs and ids
|
||||||
top_n_tokens (`int`):
|
top_n_tokens (`int`):
|
||||||
Return the `n` most likely tokens at each step
|
Return the `n` most likely tokens at each step
|
||||||
|
grammar (`Grammar`):
|
||||||
|
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||||
|
of the text to match a regular expression or JSON schema.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response: generated response
|
Response: generated response
|
||||||
@ -462,6 +471,9 @@ class AsyncClient:
|
|||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
top_n_tokens (`int`):
|
top_n_tokens (`int`):
|
||||||
Return the `n` most likely tokens at each step
|
Return the `n` most likely tokens at each step
|
||||||
|
grammar (`Grammar`):
|
||||||
|
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||||
|
of the text to match a regular expression or JSON schema.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncIterator[StreamResponse]: stream of generated tokens
|
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||||
|
@ -198,13 +198,12 @@ impl Infer {
|
|||||||
pub(crate) fn apply_completion_template(
|
pub(crate) fn apply_completion_template(
|
||||||
&self,
|
&self,
|
||||||
prompt: String,
|
prompt: String,
|
||||||
prefix: Option<String>,
|
|
||||||
suffix: Option<String>,
|
suffix: Option<String>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
self.completion_template
|
self.completion_template
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
.apply(prompt, prefix, suffix)
|
.apply(prompt, suffix)
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||||
tracing::error!("{e}");
|
tracing::error!("{e}");
|
||||||
@ -386,15 +385,9 @@ impl CompletionTemplate {
|
|||||||
Self { template }
|
Self { template }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply(
|
fn apply(&self, prompt: String, suffix: Option<String>) -> Result<String, InferError> {
|
||||||
&self,
|
|
||||||
prompt: String,
|
|
||||||
prefix: Option<String>,
|
|
||||||
suffix: Option<String>,
|
|
||||||
) -> Result<String, InferError> {
|
|
||||||
self.template
|
self.template
|
||||||
.render(CompletionTemplateInputs {
|
.render(CompletionTemplateInputs {
|
||||||
prefix: prefix.as_deref(),
|
|
||||||
prompt: prompt.as_str(),
|
prompt: prompt.as_str(),
|
||||||
suffix: suffix.as_deref(),
|
suffix: suffix.as_deref(),
|
||||||
})
|
})
|
||||||
|
@ -595,7 +595,6 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
pub(crate) struct CompletionTemplateInputs<'a> {
|
pub(crate) struct CompletionTemplateInputs<'a> {
|
||||||
prefix: Option<&'a str>,
|
|
||||||
prompt: &'a str,
|
prompt: &'a str,
|
||||||
suffix: Option<&'a str>,
|
suffix: Option<&'a str>,
|
||||||
}
|
}
|
||||||
|
@ -578,7 +578,7 @@ async fn completions(
|
|||||||
let stream = req.stream.unwrap_or_default();
|
let stream = req.stream.unwrap_or_default();
|
||||||
let seed = req.seed;
|
let seed = req.seed;
|
||||||
|
|
||||||
let inputs = match infer.apply_completion_template(req.prompt, None, req.suffix) {
|
let inputs = match infer.apply_completion_template(req.prompt, req.suffix) {
|
||||||
Ok(inputs) => inputs,
|
Ok(inputs) => inputs,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
@ -1263,8 +1263,8 @@ pub async fn run(
|
|||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
.route("/vertex", post(vertex_compatibility))
|
|
||||||
.route("/v1/completions", post(completions))
|
.route("/v1/completions", post(completions))
|
||||||
|
.route("/vertex", post(vertex_compatibility))
|
||||||
.route("/tokenize", post(tokenize))
|
.route("/tokenize", post(tokenize))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health))
|
||||||
|
@ -8,7 +8,6 @@ from typing import Optional
|
|||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
|
||||||
from text_generation_server.models.bloom import BLOOMSharded
|
from text_generation_server.models.bloom import BLOOMSharded
|
||||||
from text_generation_server.models.mpt import MPTSharded
|
from text_generation_server.models.mpt import MPTSharded
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
@ -34,7 +33,6 @@ __all__ = [
|
|||||||
"Model",
|
"Model",
|
||||||
"BLOOMSharded",
|
"BLOOMSharded",
|
||||||
"CausalLM",
|
"CausalLM",
|
||||||
"FlashCausalLM",
|
|
||||||
"GalacticaSharded",
|
"GalacticaSharded",
|
||||||
"Seq2SeqLM",
|
"Seq2SeqLM",
|
||||||
"SantaCoder",
|
"SantaCoder",
|
||||||
@ -45,7 +43,20 @@ __all__ = [
|
|||||||
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
|
|
||||||
FLASH_ATTENTION = True
|
# FlashCausalLM reqiures CUDA Graphs to be enabled on the system. This will throw a RuntimeError
|
||||||
|
# if CUDA Graphs are not available when calling `torch.cuda.graph_pool_handle()` in the FlashCausalLM
|
||||||
|
HAS_CUDA_GRAPH = False
|
||||||
|
try:
|
||||||
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
|
|
||||||
|
HAS_CUDA_GRAPH = True
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning(f"Could not import FlashCausalLM: {e}")
|
||||||
|
|
||||||
|
if HAS_CUDA_GRAPH:
|
||||||
|
__all__.append(FlashCausalLM)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||||
|
Loading…
Reference in New Issue
Block a user