mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Merge branch 'main' into moe
This commit is contained in:
commit
3d4c50f028
@ -101,6 +101,47 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/chat_tokenize": {
|
||||||
|
"post": {
|
||||||
|
"tags": [
|
||||||
|
"Text Generation Inference"
|
||||||
|
],
|
||||||
|
"summary": "Template and tokenize ChatRequest",
|
||||||
|
"operationId": "get_chat_tokenize",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ChatRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Templated and tokenized ChatRequest",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ChatTokenizeResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"404": {
|
||||||
|
"description": "Failed to tokenize ChatRequest",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/generate": {
|
"/generate": {
|
||||||
"post": {
|
"post": {
|
||||||
"tags": [
|
"tags": [
|
||||||
@ -1092,6 +1133,21 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"ChatTokenizeResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"tokenize_response",
|
||||||
|
"templated_text"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"templated_text": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"tokenize_response": {
|
||||||
|
"$ref": "#/components/schemas/TokenizeResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"Chunk": {
|
"Chunk": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
|
@ -978,16 +978,15 @@
|
|||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1729761651,
|
"lastModified": 1730724647,
|
||||||
"narHash": "sha256-GYykQ9Fxji2EuXCGcPn0dx8Qx8VQBJTkRdcCytp4A/k=",
|
"narHash": "sha256-SVv+50CGaCoU4zZwsg6ZAaOi/D5QJBL1P2SIB+3CEf4=",
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"rev": "f7e3c4fa67d70590ed9ee47feeab645bd9ba81b1",
|
"rev": "1512898a1e5ad9eff025205fa9c4d33a44506cf3",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"ref": "marlin-kernels-0.3.1",
|
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
};
|
};
|
||||||
nix-filter.url = "github:numtide/nix-filter";
|
nix-filter.url = "github:numtide/nix-filter";
|
||||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.1";
|
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
|
@ -1687,13 +1687,6 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
let max_position_embeddings = if let Some(config) = &config {
|
let max_position_embeddings = if let Some(config) = &config {
|
||||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||||
if max_position_embeddings > max_default {
|
if max_position_embeddings > max_default {
|
||||||
let max = max_position_embeddings;
|
|
||||||
if args.max_input_tokens.is_none()
|
|
||||||
&& args.max_total_tokens.is_none()
|
|
||||||
&& args.max_batch_prefill_tokens.is_none()
|
|
||||||
{
|
|
||||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
|
||||||
}
|
|
||||||
max_default
|
max_default
|
||||||
} else {
|
} else {
|
||||||
max_position_embeddings
|
max_position_embeddings
|
||||||
|
@ -181,12 +181,16 @@ async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Template and tokenize ChatRequest
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/chat_tokenize",
|
path = "/chat_tokenize",
|
||||||
request_body = ChatRequest,
|
request_body = ChatRequest,
|
||||||
responses((status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse))
|
responses(
|
||||||
|
(status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse),
|
||||||
|
(status = 404, description = "Failed to tokenize ChatRequest", body = ErrorResponse),
|
||||||
|
)
|
||||||
)]
|
)]
|
||||||
async fn get_chat_tokenize(
|
async fn get_chat_tokenize(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
@ -1501,6 +1505,7 @@ tokenize,
|
|||||||
metrics,
|
metrics,
|
||||||
openai_get_model_info,
|
openai_get_model_info,
|
||||||
sagemaker_compatibility,
|
sagemaker_compatibility,
|
||||||
|
get_chat_tokenize,
|
||||||
),
|
),
|
||||||
components(
|
components(
|
||||||
schemas(
|
schemas(
|
||||||
@ -1558,6 +1563,7 @@ Function,
|
|||||||
FunctionDefinition,
|
FunctionDefinition,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
|
ChatTokenizeResponse,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
tags(
|
tags(
|
||||||
|
@ -44,5 +44,4 @@ class WQLinear(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out_shape = x.shape[:-1] + (self.out_features,)
|
out_shape = x.shape[:-1] + (self.out_features,)
|
||||||
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
|
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
|
||||||
out = out + self.bias if self.bias is not None else out
|
|
||||||
return out.reshape(out_shape)
|
return out.reshape(out_shape)
|
||||||
|
@ -122,5 +122,4 @@ class QuantLinear(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||||
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
|
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
|
||||||
out = out + self.bias if self.bias is not None else out
|
|
||||||
return out.reshape(out_shape)
|
return out.reshape(out_shape)
|
||||||
|
@ -887,11 +887,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
fsm_grammar_states=fsm_grammar_states,
|
fsm_grammar_states=fsm_grammar_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_ids = (
|
# We skip computing the speculative_ids when the batch size is too large, so
|
||||||
torch.cat([b.speculative_ids for b in batches], dim=0)
|
# we must check that all batches have them, otherwise they must be discarded
|
||||||
if batches[0].speculative_ids is not None
|
if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):
|
||||||
else None
|
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
|
||||||
)
|
else:
|
||||||
|
speculative_ids = None
|
||||||
|
|
||||||
if adapter_segment_builder is not None:
|
if adapter_segment_builder is not None:
|
||||||
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
||||||
@ -1532,8 +1533,6 @@ class FlashCausalLM(Model):
|
|||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
max_bt = batch.max_blocks
|
|
||||||
max_s = max_bt * BLOCK_SIZE
|
|
||||||
batch_num_blocks = batch.num_blocks
|
batch_num_blocks = batch.num_blocks
|
||||||
|
|
||||||
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||||
@ -1651,7 +1650,7 @@ class FlashCausalLM(Model):
|
|||||||
# Warmup cuda graphs
|
# Warmup cuda graphs
|
||||||
for bs in CUDA_GRAPHS:
|
for bs in CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate + 1 <= bs:
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
self.cuda_graph_warmup(bs, max_total_tokens, max_total_tokens)
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
logger.exception("Decode cuda graph warmup failed")
|
logger.exception("Decode cuda graph warmup failed")
|
||||||
else:
|
else:
|
||||||
@ -1726,7 +1725,15 @@ class FlashCausalLM(Model):
|
|||||||
new_position_ids = (
|
new_position_ids = (
|
||||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||||
).view(-1)
|
).view(-1)
|
||||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
|
||||||
|
# Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
|
||||||
|
# then update the slots with the additional indices to ensure we're grabbing the ones that have been
|
||||||
|
# allocated
|
||||||
|
slot_indices = (
|
||||||
|
batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
|
).view(-1)
|
||||||
|
slots = batch.slots[slot_indices]
|
||||||
|
|
||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
|
@ -55,7 +55,7 @@ def block_tables_to_ragged(
|
|||||||
cache_lengths: List[int],
|
cache_lengths: List[int],
|
||||||
input_lengths_tensor: torch.Tensor,
|
input_lengths_tensor: torch.Tensor,
|
||||||
cache_lengths_tensor: torch.Tensor,
|
cache_lengths_tensor: torch.Tensor,
|
||||||
max_current_length: int
|
max_current_length: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Convert block table to ragged format compatible with FlashInfer."""
|
"""Convert block table to ragged format compatible with FlashInfer."""
|
||||||
assert len(input_lengths) == len(cache_lengths)
|
assert len(input_lengths) == len(cache_lengths)
|
||||||
|
Loading…
Reference in New Issue
Block a user