Merge branch 'main' into moe

This commit is contained in:
Wang, Yi A 2024-11-04 17:53:13 -08:00
commit 3d4c50f028
9 changed files with 84 additions and 25 deletions

View File

@ -101,6 +101,47 @@
} }
} }
}, },
"/chat_tokenize": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Template and tokenize ChatRequest",
"operationId": "get_chat_tokenize",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ChatRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Templated and tokenized ChatRequest",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ChatTokenizeResponse"
}
}
}
},
"404": {
"description": "Failed to tokenize ChatRequest",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
}
}
}
}
}
}
},
"/generate": { "/generate": {
"post": { "post": {
"tags": [ "tags": [
@ -1092,6 +1133,21 @@
} }
} }
}, },
"ChatTokenizeResponse": {
"type": "object",
"required": [
"tokenize_response",
"templated_text"
],
"properties": {
"templated_text": {
"type": "string"
},
"tokenize_response": {
"$ref": "#/components/schemas/TokenizeResponse"
}
}
},
"Chunk": { "Chunk": {
"type": "object", "type": "object",
"required": [ "required": [

View File

@ -978,16 +978,15 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1729761651, "lastModified": 1730724647,
"narHash": "sha256-GYykQ9Fxji2EuXCGcPn0dx8Qx8VQBJTkRdcCytp4A/k=", "narHash": "sha256-SVv+50CGaCoU4zZwsg6ZAaOi/D5QJBL1P2SIB+3CEf4=",
"owner": "huggingface", "owner": "huggingface",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"rev": "f7e3c4fa67d70590ed9ee47feeab645bd9ba81b1", "rev": "1512898a1e5ad9eff025205fa9c4d33a44506cf3",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "huggingface", "owner": "huggingface",
"ref": "marlin-kernels-0.3.1",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"type": "github" "type": "github"
} }

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
}; };
nix-filter.url = "github:numtide/nix-filter"; nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.1"; tgi-nix.url = "github:huggingface/text-generation-inference-nix";
nixpkgs.follows = "tgi-nix/nixpkgs"; nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
rust-overlay = { rust-overlay = {

View File

@ -1687,13 +1687,6 @@ fn main() -> Result<(), LauncherError> {
let max_position_embeddings = if let Some(config) = &config { let max_position_embeddings = if let Some(config) = &config {
if let Some(max_position_embeddings) = config.max_position_embeddings { if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default { if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
max_default max_default
} else { } else {
max_position_embeddings max_position_embeddings

View File

@ -181,12 +181,16 @@ async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
}) })
} }
/// Template and tokenize ChatRequest
#[utoipa::path( #[utoipa::path(
post, post,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/chat_tokenize", path = "/chat_tokenize",
request_body = ChatRequest, request_body = ChatRequest,
responses((status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse)) responses(
(status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse),
(status = 404, description = "Failed to tokenize ChatRequest", body = ErrorResponse),
)
)] )]
async fn get_chat_tokenize( async fn get_chat_tokenize(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
@ -1501,6 +1505,7 @@ tokenize,
metrics, metrics,
openai_get_model_info, openai_get_model_info,
sagemaker_compatibility, sagemaker_compatibility,
get_chat_tokenize,
), ),
components( components(
schemas( schemas(
@ -1558,6 +1563,7 @@ Function,
FunctionDefinition, FunctionDefinition,
ToolChoice, ToolChoice,
ModelInfo, ModelInfo,
ChatTokenizeResponse,
) )
), ),
tags( tags(

View File

@ -44,5 +44,4 @@ class WQLinear(nn.Module):
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,) out_shape = x.shape[:-1] + (self.out_features,)
out = self.woq_linear(x.reshape(-1, x.shape[-1])) out = self.woq_linear(x.reshape(-1, x.shape[-1]))
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)

View File

@ -122,5 +122,4 @@ class QuantLinear(nn.Module):
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,) out_shape = x.shape[:-1] + (self.outfeatures,)
out = self.woq_linear(x.reshape(-1, x.shape[-1])) out = self.woq_linear(x.reshape(-1, x.shape[-1]))
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)

View File

@ -887,11 +887,12 @@ class FlashCausalLMBatch(Batch):
fsm_grammar_states=fsm_grammar_states, fsm_grammar_states=fsm_grammar_states,
) )
speculative_ids = ( # We skip computing the speculative_ids when the batch size is too large, so
torch.cat([b.speculative_ids for b in batches], dim=0) # we must check that all batches have them, otherwise they must be discarded
if batches[0].speculative_ids is not None if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):
else None speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
) else:
speculative_ids = None
if adapter_segment_builder is not None: if adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build() adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
@ -1532,8 +1533,6 @@ class FlashCausalLM(Model):
self.kv_cache_dtype, self.kv_cache_dtype,
self.device, self.device,
) )
max_bt = batch.max_blocks
max_s = max_bt * BLOCK_SIZE
batch_num_blocks = batch.num_blocks batch_num_blocks = batch.num_blocks
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
@ -1651,7 +1650,7 @@ class FlashCausalLM(Model):
# Warmup cuda graphs # Warmup cuda graphs
for bs in CUDA_GRAPHS: for bs in CUDA_GRAPHS:
if self.speculate is None or self.speculate + 1 <= bs: if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt) self.cuda_graph_warmup(bs, max_total_tokens, max_total_tokens)
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
logger.exception("Decode cuda graph warmup failed") logger.exception("Decode cuda graph warmup failed")
else: else:
@ -1726,7 +1725,15 @@ class FlashCausalLM(Model):
new_position_ids = ( new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1) ).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
# then update the slots with the additional indices to ensure we're grabbing the ones that have been
# allocated
slot_indices = (
batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
slots = batch.slots[slot_indices]
input_lengths = ( input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)

View File

@ -55,7 +55,7 @@ def block_tables_to_ragged(
cache_lengths: List[int], cache_lengths: List[int],
input_lengths_tensor: torch.Tensor, input_lengths_tensor: torch.Tensor,
cache_lengths_tensor: torch.Tensor, cache_lengths_tensor: torch.Tensor,
max_current_length: int max_current_length: int,
) -> torch.Tensor: ) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer.""" """Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(cache_lengths) assert len(input_lengths) == len(cache_lengths)