This commit is contained in:
OlivierDehaene 2024-04-12 18:38:34 +02:00 committed by Karol Damaszke
parent 6ad5aa7180
commit c6a31b9e2b
13 changed files with 32 additions and 18 deletions

12
Cargo.lock generated
View File

@ -3406,7 +3406,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "1.4.5" version = "2.0.0"
dependencies = [ dependencies = [
"average", "average",
"clap", "clap",
@ -3427,7 +3427,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "1.4.5" version = "2.0.0"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",
@ -3444,7 +3444,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "1.4.5" version = "2.0.0"
dependencies = [ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
@ -3462,7 +3462,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "1.4.5" version = "2.0.0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum", "axum",
@ -4657,9 +4657,9 @@ dependencies = [
[[package]] [[package]]
name = "zeroize" name = "zeroize"
version = "1.8.0" version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63381fa6624bf92130a6b87c0d07380116f80b565c42cf0d754136f0238359ef" checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d"
[[package]] [[package]]
name = "zip" name = "zip"

View File

@ -9,7 +9,7 @@ members = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "1.4.5" version = "2.0.0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "1.4.5" "version": "2.0.0"
}, },
"paths": { "paths": {
"/": { "/": {

View File

@ -17,7 +17,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.5-native", "system_fingerprint": "2.0.0-native",
"usage": { "usage": {
"completion_tokens": 100, "completion_tokens": 100,
"prompt_tokens": 60, "prompt_tokens": 60,

View File

@ -31,7 +31,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.5-native", "system_fingerprint": "2.0.0-native",
"usage": { "usage": {
"completion_tokens": 29, "completion_tokens": 29,
"prompt_tokens": 316, "prompt_tokens": 316,

View File

@ -31,7 +31,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.5-native", "system_fingerprint": "2.0.0-native",
"usage": { "usage": {
"completion_tokens": 29, "completion_tokens": 29,
"prompt_tokens": 316, "prompt_tokens": 316,

View File

@ -30,7 +30,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.5-native", "system_fingerprint": "2.0.0-native",
"usage": { "usage": {
"completion_tokens": 21, "completion_tokens": 21,
"prompt_tokens": 187, "prompt_tokens": 187,

View File

@ -23,5 +23,5 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.5-native" "system_fingerprint": "2.0.0-native"
} }

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-integration-tests" name = "text-generation-integration-tests"
version = "1.4.5" version = "2.0.0"
description = "Text Generation Inference integration tests" description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"] authors = ["Nicolas Patry <nicolas@huggingface.co>"]

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-server" name = "text-generation-server"
version = "1.4.5" version = "2.0.0"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -28,6 +28,10 @@ class ExceptionInterceptor(AsyncServerInterceptor):
method_name = method_name.split("/")[-1] method_name = method_name.split("/")[-1]
logger.exception(f"Method {method_name} encountered an error.") logger.exception(f"Method {method_name} encountered an error.")
# Runtime Error cannot be recovered from
if isinstance(err, RuntimeError):
exit(1)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -55,9 +55,10 @@ class CacheManager:
): ):
# Get free blocks indices by finding values in mask that are not set to 0 # Get free blocks indices by finding values in mask that are not set to 0
free_block_indices = self.free_block_mask.nonzero() free_block_indices = self.free_block_mask.nonzero()
assert ( if blocks > len(free_block_indices):
len(free_block_indices) >= blocks raise RuntimeError(
), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks" f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
)
# Slice by the number of required blocks # Slice by the number of required blocks
block_indices = free_block_indices[:blocks] block_indices = free_block_indices[:blocks]

View File

@ -503,6 +503,10 @@ class MedusaHeadV1(nn.Module):
self, input: torch.Tensor self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input) logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits
if input.shape[0] > 128:
return logits, None
speculative_logits = self.medusa(input) speculative_logits = self.medusa(input)
return logits, speculative_logits return logits, speculative_logits
@ -549,6 +553,11 @@ class MedusaHeadV2(nn.Module):
self.lm_head = TensorParallelHead.load(config, prefix, weights) self.lm_head = TensorParallelHead.load(config, prefix, weights)
def forward(self, x): def forward(self, x):
# If we have too many tokens, we skip speculative logits
if x.shape[0] > 128:
logits = self.lm_head(x)
return logits, None
size = x.shape[-1] size = x.shape[-1]
block_size = (size + self.world_size - 1) // self.world_size block_size = (size + self.world_size - 1) // self.world_size
start = self.rank * block_size start = self.rank * block_size