From bf700e7eef4771f280c19dbc7270c8c7c20efbbc Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 26 Feb 2024 19:49:28 +0100 Subject: [PATCH 1/5] Revamp medusa implementation so that every model can benefit. (#1588) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- integration-tests/conftest.py | 8 ++ integration-tests/models/test_flash_medusa.py | 4 +- server/text_generation_server/cli.py | 14 +-- .../text_generation_server/models/__init__.py | 86 +++++++++-------- server/text_generation_server/models/bloom.py | 6 +- .../models/causal_lm.py | 13 ++- .../models/custom_modeling/bloom_modeling.py | 21 +++-- .../custom_modeling/flash_gemma_modeling.py | 10 +- .../custom_modeling/flash_llama_modeling.py | 10 +- .../custom_modeling/flash_mistral_modeling.py | 4 +- .../custom_modeling/flash_mixtral_modeling.py | 4 +- .../custom_modeling/flash_neox_modeling.py | 4 +- .../custom_modeling/flash_phi_modeling.py | 4 +- .../custom_modeling/flash_rw_modeling.py | 6 +- .../flash_santacoder_modeling.py | 4 +- .../custom_modeling/idefics_modeling.py | 29 +++--- .../models/custom_modeling/mamba_modeling.py | 11 +-- .../models/custom_modeling/mpt_modeling.py | 21 +++-- .../models/custom_modeling/neox_modeling.py | 4 +- .../models/custom_modeling/opt_modeling.py | 4 +- .../models/custom_modeling/phi_modeling.py | 4 +- .../models/custom_modeling/t5_modeling.py | 31 ++++--- .../models/flash_causal_lm.py | 24 +++-- .../models/flash_gemma.py | 33 +------ .../models/flash_llama.py | 34 +------ .../models/flash_mistral.py | 24 ++++- .../models/flash_mixtral.py | 2 + .../models/flash_neox.py | 2 + .../models/flash_phi.py | 3 +- .../text_generation_server/models/flash_rw.py | 2 + .../models/flash_santacoder.py | 2 + .../text_generation_server/models/idefics.py | 2 + .../models/idefics_causal_lm.py | 11 ++- server/text_generation_server/models/mamba.py | 18 +++- server/text_generation_server/models/mpt.py | 2 + server/text_generation_server/models/opt.py | 2 + server/text_generation_server/models/phi.py | 2 + .../models/santacoder.py | 1 + .../models/seq2seq_lm.py | 11 ++- server/text_generation_server/models/t5.py | 5 +- server/text_generation_server/utils/hub.py | 2 + server/text_generation_server/utils/layers.py | 92 ++++++++++++++++++- server/text_generation_server/utils/medusa.py | 59 ------------ 43 files changed, 352 insertions(+), 283 deletions(-) delete mode 100644 server/text_generation_server/utils/medusa.py diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 80457bc2..e11c7cf9 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -236,6 +236,7 @@ def launcher(event_loop): use_flash_attention: bool = True, disable_grammar_support: bool = False, dtype: Optional[str] = None, + revision: Optional[str] = None, ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -268,6 +269,9 @@ def launcher(event_loop): if dtype is not None: args.append("--dtype") args.append(dtype) + if revision is not None: + args.append("--revision") + args.append(revision) if trust_remote_code: args.append("--trust-remote-code") @@ -302,6 +306,7 @@ def launcher(event_loop): use_flash_attention: bool = True, disable_grammar_support: bool = False, dtype: Optional[str] = None, + revision: Optional[str] = None, ): port = random.randint(8000, 10_000) @@ -317,6 +322,9 @@ def launcher(event_loop): if dtype is not None: args.append("--dtype") args.append(dtype) + if revision is not None: + args.append("--revision") + args.append(revision) if trust_remote_code: args.append("--trust-remote-code") diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py index e0cc1039..27db5665 100644 --- a/integration-tests/models/test_flash_medusa.py +++ b/integration-tests/models/test_flash_medusa.py @@ -3,7 +3,9 @@ import pytest @pytest.fixture(scope="module") def flash_medusa_handle(launcher): - with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle: + with launcher( + "FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2, revision="refs/pr/1" + ) as handle: yield handle diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index b74fbe36..a513f5e6 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -154,12 +154,8 @@ def download_weights( import json medusa_head = hf_hub_download( - model_id, revision=revision, filename="medusa_lm_head.pt" + model_id, revision=revision, filename="medusa_lm_head.safetensors" ) - if auto_convert: - medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors") - if not medusa_sf.exists(): - utils.convert_files([Path(medusa_head)], [medusa_sf], []) medusa_config = hf_hub_download( model_id, revision=revision, filename="config.json" ) @@ -198,16 +194,12 @@ def download_weights( if not extension == ".safetensors" or not auto_convert: raise e - elif (Path(model_id) / "medusa_lm_head.pt").exists(): + elif (Path(model_id) / "medusa_lm_head.safetensors").exists(): # Try to load as a local Medusa model try: import json - medusa_head = Path(model_id) / "medusa_lm_head.pt" - if auto_convert: - medusa_sf = Path(model_id) / "medusa_lm_head.safetensors" - if not medusa_sf.exists(): - utils.convert_files([Path(medusa_head)], [medusa_sf], []) + medusa_head = Path(model_id) / "medusa_lm_head.safetensors" medusa_config = Path(model_id) / "config.json" with open(medusa_config, "r") as f: config = json.load(f) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index abab3486..3208275c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -3,7 +3,9 @@ import torch from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto +from huggingface_hub import hf_hub_download from typing import Optional +from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model @@ -115,44 +117,14 @@ def get_model( else: set_speculate(0) - if "facebook/galactica" in model_id: - return GalacticaSharded( - model_id, - revision, - quantize=quantize, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_id.startswith("bigcode/"): - if FLASH_ATTENTION: - return FlashSantacoderSharded( - model_id, - revision, - quantize=quantize, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - elif sharded: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") - ) - else: - return SantaCoder( - model_id, - revision, - quantize=quantize, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) use_medusa = None if "medusa_num_heads" in config_dict: - use_medusa = model_id + medusa_model_id = model_id + medusa_revision = revision model_id = config_dict["base_model_name_or_path"] revision = "main" speculate_medusa = config_dict["medusa_num_heads"] @@ -169,6 +141,20 @@ def get_model( config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) + is_local = Path(medusa_model_id).exists() + if not is_local: + medusa_config = hf_hub_download( + medusa_model_id, revision=medusa_revision, filename="config.json" + ) + hf_hub_download( + medusa_model_id, + revision=medusa_revision, + filename="medusa_lm_head.safetensors", + ) + use_medusa = Path(medusa_config).parent + else: + use_medusa = Path(medusa_model_id) + method = "medusa" else: method = "n-gram" @@ -193,16 +179,22 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) - if model_type == "gpt_bigcode": + if ( + model_type == "gpt_bigcode" + or model_type == "gpt2" + and model_id.startswith("bigcode/") + ): if FLASH_ATTENTION: return FlashSantacoderSharded( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -215,6 +207,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -224,6 +217,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -232,6 +226,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -242,6 +237,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -250,6 +246,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -258,6 +255,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -268,15 +266,16 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, - use_medusa=use_medusa, ) else: return CausalLM( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -291,6 +290,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -301,9 +301,9 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, - use_medusa=use_medusa, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) @@ -312,6 +312,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -321,9 +322,9 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, - use_medusa=use_medusa, ) elif sharded: raise NotImplementedError( @@ -334,6 +335,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -347,6 +349,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -357,6 +360,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -365,6 +369,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -378,6 +383,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -391,6 +397,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -400,6 +407,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -409,6 +417,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -418,6 +427,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -441,6 +451,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -449,6 +460,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -460,6 +472,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -468,6 +481,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index fed5e6f3..67129ec3 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -42,6 +42,7 @@ class BLOOMSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -70,6 +71,7 @@ class BLOOMSharded(CausalLM): ) config.pad_token_id = 3 config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") @@ -103,7 +105,7 @@ class BLOOMSharded(CausalLM): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -112,4 +114,4 @@ class BLOOMSharded(CausalLM): ) logits = outputs.logits - return logits, outputs.past_key_values + return logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index a0f0c9e8..bbcef210 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -482,6 +482,7 @@ class CausalLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -550,7 +551,9 @@ class CausalLM(Model): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + ) -> Tuple[ + torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]] + ]: # Model Forward kwargs = { "input_ids": input_ids, @@ -563,7 +566,11 @@ class CausalLM(Model): kwargs["position_ids"] = position_ids outputs = self.model.forward(**kwargs) - return outputs.logits, outputs.past_key_values + if isinstance(outputs, tuple): + outputs, speculative_logits = outputs + else: + speculative_logits = None + return outputs.logits, speculative_logits, outputs.past_key_values @tracer.start_as_current_span("generate_token") def generate_token( @@ -573,7 +580,7 @@ class CausalLM(Model): # slice the attention mask to the correct shape attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - logits, past = self.forward( + logits, speculative_logits, past = self.forward( batch.input_ids, attention_mask, batch.position_ids, diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 5423d75a..10b40483 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -36,7 +36,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, ) CUSTOM_KERNELS_ENABLED = False @@ -820,7 +820,7 @@ class BloomForCausalLM(BloomPreTrainedModel): super().__init__(config) self.transformer = BloomModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="word_embeddings", weights=weights, @@ -904,17 +904,20 @@ class BloomForCausalLM(BloomPreTrainedModel): ) hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) + logits, speculative_logits = self.lm_head(hidden_states) loss = None if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, + return ( + CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ), + speculative_logits, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 4a08bc2a..e91927df 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -37,7 +37,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastRMSNorm, ) @@ -575,7 +575,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): super().__init__() self.model = FlashGemmaModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", weights=weights, @@ -592,7 +592,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, position_ids, @@ -605,5 +605,5 @@ class FlashGemmaForCausalLM(torch.nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits = self.lm_head(hidden_states) - return logits + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 1626eb4d..3a269fc0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -32,7 +32,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastRMSNorm, ) @@ -410,7 +410,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): super().__init__() self.model = FlashLlamaModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, @@ -427,7 +427,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, position_ids, @@ -440,5 +440,5 @@ class FlashLlamaForCausalLM(torch.nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits = self.lm_head(hidden_states) - return logits + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index fda34e5a..ed9306e0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -32,7 +32,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastRMSNorm, ) @@ -419,7 +419,7 @@ class FlashMistralForCausalLM(torch.nn.Module): super().__init__() self.model = MistralModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 3d3caba3..17d4f708 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -37,7 +37,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, ) @@ -810,7 +810,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): super().__init__() self.model = MixtralModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 780861c2..ee062d3d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -33,7 +33,7 @@ from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - TensorParallelHead, + SpeculativeHead, FastLayerNorm, PositionRotaryEmbedding, get_linear, @@ -369,7 +369,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): super().__init__(config) self.gpt_neox = FlashGPTNeoXModel(config, weights) - self.embed_out = TensorParallelHead.load( + self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index a9a929e9..cfe447a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -12,7 +12,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastLayerNorm, ) @@ -376,7 +376,7 @@ class FlashPhiForCausalLM(torch.nn.Module): super().__init__() self.model = FlashPhiModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 6a530f3c..a9127d1f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -12,7 +12,7 @@ from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - TensorParallelHead, + SpeculativeHead, FastLayerNorm, PositionRotaryEmbedding, get_linear, @@ -613,9 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): self.transformer = FlashRWModel(config, weights) - self.lm_head = TensorParallelHead.load( - config, prefix="lm_head", weights=weights - ) + self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights) def forward( self, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index d3fe95d0..bbb603a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -9,7 +9,7 @@ from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, - TensorParallelHead, + SpeculativeHead, TensorParallelEmbedding, FastLayerNorm, get_linear, @@ -453,7 +453,7 @@ class FlashSantacoderForCausalLM(nn.Module): def __init__(self, config, weights): super().__init__() self.transformer = FlashSantacoderModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 4f7dfb95..ee4cdb08 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -51,7 +51,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, PositionRotaryEmbedding, FastLinear, ) @@ -272,9 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module): weights, ) -> None: super().__init__() - self.fc = TensorParallelHead.load( - config=config, prefix="lm_head", weights=weights - ) + self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights) self.additional_fc = FastLinear.load( config=config, prefix="lm_head.additional_fc", @@ -283,11 +281,11 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module): ) def forward(self, input: torch.Tensor) -> torch.Tensor: - output = self.fc(input) + output, speculative_logits = self.fc(input) additional_features = self.additional_fc(input) output = torch.cat((output, additional_features), -1) - return output + return output, speculative_logits def extra_repr(self) -> str: """Overwriting `nn.Linear.extra_repr` to include new parameters.""" @@ -1503,17 +1501,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + logits, speculative_logits = self.lm_head(hidden_states) loss = None - return CausalLMOutputWithPastImage( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, + return ( + CausalLMOutputWithPastImage( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ), + speculative_logits, ) def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index baf1fb85..c58a617f 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -9,6 +9,7 @@ from transformers.configuration_utils import PretrainedConfig import torch.nn.functional as F from text_generation_server.utils.layers import ( + SpeculativeHead, TensorParallelEmbedding, FastRMSNorm, FastLinear, @@ -205,14 +206,12 @@ class MambaModel(nn.Module): self.norm_f = FastRMSNorm.load( f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon ) - self.lm_head = FastLinear.load( - config, f"{prefix}.embedding", weights, bias=False - ) + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) self.config = config def forward( self, input_ids: torch.Tensor, inference_params=None, residual=None - ) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.embed_tokens(input_ids) for i, block in enumerate(self.blocks): hidden_states, residual, conv_state, ssm_state = block( @@ -226,8 +225,8 @@ class MambaModel(nn.Module): ) hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) hidden_states = hidden_states.view(residual.shape) - logits = self.lm_head(hidden_states) + logits, speculative_logits = self.lm_head(hidden_states) # update the offset for the next inference using these params inference_params.seqlen_offset += input_ids.size(1) - return logits + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index 2e2e423e..9b0f8b92 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -21,7 +21,7 @@ from text_generation_server.utils.layers import ( TensorParallelEmbedding, TensorParallelColumnLinear, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, get_linear, ) @@ -1090,7 +1090,7 @@ class MPTForCausalLM(MPTPreTrainedModel): if not config.tie_word_embeddings: raise ValueError("MPTForCausalLM only supports tied word embeddings") self.transformer = MPTModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights ) self.logit_scale = None @@ -1133,7 +1133,7 @@ class MPTForCausalLM(MPTPreTrainedModel): output_hidden_states=output_hidden_states, use_cache=use_cache, ) - logits = self.lm_head(outputs.last_hidden_state) + logits, speculative_logits = self.lm_head(outputs.last_hidden_state) if self.logit_scale is not None: if self.logit_scale == 0: warnings.warn( @@ -1147,12 +1147,15 @@ class MPTForCausalLM(MPTPreTrainedModel): loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) ) - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + return ( + CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ), + speculative_logits, ) def prepare_inputs_for_generation( diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index dbcefbae..2550d2d1 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -44,7 +44,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, ) @@ -646,7 +646,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): def __init__(self, config, weights): super().__init__(config) self.gpt_neox = GPTNeoXModel(config, weights) - self.embed_out = TensorParallelHead.load( + self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index ce3f5e21..de5e95af 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -32,7 +32,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, ) EPS = 1e-5 @@ -748,7 +748,7 @@ class OPTForCausalLM(OPTPreTrainedModel): self.model = OPTModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="model.decoder.embed_tokens", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/phi_modeling.py b/server/text_generation_server/models/custom_modeling/phi_modeling.py index e5c09728..1571f9fd 100644 --- a/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/phi_modeling.py @@ -13,7 +13,7 @@ from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - TensorParallelHead, + SpeculativeHead, FastLinear, ) @@ -120,7 +120,7 @@ class PhiCausalLMHead(nn.Module): weights=weights, eps=config.layer_norm_epsilon, ) - self.linear = TensorParallelHead.load( + self.linear = SpeculativeHead.load( config=config, prefix="lm_head.linear", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index d3e4f53a..2773fb15 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -42,7 +42,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, ) @@ -1033,14 +1033,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ) try: - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights ) except RuntimeError: # Some models like t5-small were saved with shared weights unlike flan # Since they are declared as the same arch we have no choice but hope # that this is OK instead of using a proper flag. - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="shared", weights=weights ) @@ -1126,7 +1126,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) - lm_logits = self.lm_head(sequence_output) + logits, speculative_logits = self.lm_head(sequence_output) loss = None if labels is not None: @@ -1140,16 +1140,19 @@ class T5ForConditionalGeneration(T5PreTrainedModel): output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs return ((loss,) + output) if loss is not None else output - return Seq2SeqLMOutput( - loss=loss, - logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, + return ( + Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ), + speculative_logits, ) def prepare_inputs_for_generation( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b8d0be22..988637d4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -723,7 +723,7 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - self.cuda_graphs[bs]["logits"] = self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, @@ -734,6 +734,8 @@ class FlashCausalLM(Model): max_s=max_s, lm_head_indices=None, ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() def warmup(self, batch: FlashCausalLMBatch): @@ -805,7 +807,9 @@ class FlashCausalLM(Model): return int(num_blocks * BLOCK_SIZE) - def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor: + def forward( + self, batch: FlashCausalLMBatch + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: input_ids = batch.input_ids @@ -900,9 +904,14 @@ class FlashCausalLM(Model): # Replay the graph cuda_graph["graph"].replay() - # Slice output to the correct shape - return cuda_graph["logits"][:bs] + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits @tracer.start_as_current_span("generate_token") def generate_token( @@ -926,16 +935,11 @@ class FlashCausalLM(Model): batch.slots = slots try: - out = self.forward(batch) + out, speculative_logits = self.forward(batch) except Exception as e: del batch raise e - if isinstance(out, tuple): - out, speculative_logits = out - else: - speculative_logits = None - if prefill: next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 220b3992..8cfb6631 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -25,9 +25,9 @@ class FlashGemma(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, - use_medusa: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -50,6 +50,7 @@ class FlashGemma(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) @@ -59,36 +60,6 @@ class FlashGemma(FlashCausalLM): weights._set_gptq_params(model_id, revision) model = FlashGemmaForCausalLM(config, weights) - if use_medusa: - from text_generation_server.utils.medusa import MedusaModel - from huggingface_hub import hf_hub_download - import json - import os - from pathlib import Path - - is_local_model = ( - Path(use_medusa).exists() and Path(use_medusa).is_dir() - ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None - - if not is_local_model: - medusa_config = hf_hub_download( - use_medusa, revision=revision, filename="config.json" - ) - medusa_head = hf_hub_download( - use_medusa, revision=revision, filename="medusa_lm_head.pt" - ) - else: - medusa_config = str(Path(use_medusa) / "config.json") - medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") - - with open(medusa_config, "r") as f: - config = json.load(f) - medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" - weights = Weights( - [medusa_sf], device, dtype, process_group=self.process_group - ) - lm_head = model.lm_head - model.lm_head = MedusaModel(config, weights, lm_head) torch.distributed.barrier(group=self.process_group) super(FlashGemma, self).__init__( diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 94bd58f4..a2ac759a 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -26,9 +26,9 @@ class FlashLlama(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, - use_medusa: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -58,6 +58,7 @@ class FlashLlama(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) @@ -67,37 +68,6 @@ class FlashLlama(FlashCausalLM): weights._set_gptq_params(model_id, revision) model = FlashLlamaForCausalLM(config, weights) - if use_medusa: - from text_generation_server.utils.medusa import MedusaModel - from huggingface_hub import hf_hub_download - import json - import os - from pathlib import Path - - is_local_model = ( - Path(use_medusa).exists() and Path(use_medusa).is_dir() - ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None - - if not is_local_model: - medusa_config = hf_hub_download( - use_medusa, revision=revision, filename="config.json" - ) - medusa_head = hf_hub_download( - use_medusa, revision=revision, filename="medusa_lm_head.pt" - ) - else: - medusa_config = str(Path(use_medusa) / "config.json") - medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") - - with open(medusa_config, "r") as f: - config = json.load(f) - medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" - weights = Weights( - [medusa_sf], device, dtype, process_group=self.process_group - ) - lm_head = model.lm_head - model.lm_head = MedusaModel(config, weights, lm_head) - torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( model=model, diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 5df4e214..d3c0da9c 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -294,6 +294,7 @@ class BaseFlashMistral(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -319,6 +320,7 @@ class BaseFlashMistral(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa # Set context windows if config.sliding_window is not None: @@ -394,7 +396,7 @@ class BaseFlashMistral(FlashCausalLM): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - self.cuda_graphs[bs]["logits"] = self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, @@ -406,9 +408,13 @@ class BaseFlashMistral(FlashCausalLM): prefill_cache_indices=None, lm_head_indices=None, ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() - def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, batch: FlashMistralBatch + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: input_ids = batch.input_ids @@ -479,7 +485,7 @@ class BaseFlashMistral(FlashCausalLM): cuda_graph = self.cuda_graphs.get(padded_bs, None) if cu_seqlen_prefill is not None or cuda_graph is None: - logits = self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -493,7 +499,7 @@ class BaseFlashMistral(FlashCausalLM): ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None - return logits + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded @@ -511,7 +517,13 @@ class BaseFlashMistral(FlashCausalLM): cuda_graph["graph"].replay() # Slice output to the correct shape - return cuda_graph["logits"][:bs] + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits class FlashMistral(BaseFlashMistral): @@ -520,6 +532,7 @@ class FlashMistral(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -529,6 +542,7 @@ class FlashMistral(BaseFlashMistral): model_id=model_id, revision=revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py index 6f77a658..2ee35e82 100644 --- a/server/text_generation_server/models/flash_mixtral.py +++ b/server/text_generation_server/models/flash_mixtral.py @@ -15,6 +15,7 @@ class FlashMixtral(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -24,6 +25,7 @@ class FlashMixtral(BaseFlashMistral): model_id=model_id, revision=revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 80f8804d..5a351bd7 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -24,6 +24,7 @@ class FlashNeoXSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -46,6 +47,7 @@ class FlashNeoXSharded(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 061b9740..cb55f9e6 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -25,9 +25,9 @@ class FlashPhi(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, - use_medusa: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -48,6 +48,7 @@ class FlashPhi(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index dfab8888..fc1e26bd 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -25,6 +25,7 @@ class FlashRWSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -61,6 +62,7 @@ class FlashRWSharded(FlashCausalLM): ) config.quantize = quantize + config.use_medusa = use_medusa if config.quantize == "gptq": weights._set_gptq_params(model_id, revision) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 22171ec0..034949f9 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -27,6 +27,7 @@ class FlashSantacoderSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -51,6 +52,7 @@ class FlashSantacoderSharded(FlashCausalLM): trust_remote_code=True, ) config.quantize = quantize + config.use_medusa = use_medusa config.transpose = config.architectures[0].startswith("GPT2") torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index fa23d1f9..baa1945b 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -31,6 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -51,6 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize + config.use_medusa = use_medusa config.vision_config.quantize = quantize tokenizer = LlamaTokenizerFast.from_pretrained( diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index a6df2ebe..c96e8152 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -662,8 +662,13 @@ class IdeficsCausalLM(Model): if self.has_position_ids: kwargs["position_ids"] = position_ids - outputs = self.model.forward(**kwargs) - return outputs.logits, outputs.past_key_values, outputs.image_hidden_states + outputs, speculative_logits = self.model.forward(**kwargs) + return ( + outputs.logits, + speculative_logits, + outputs.past_key_values, + outputs.image_hidden_states, + ) @tracer.start_as_current_span("generate_token") def generate_token( @@ -686,7 +691,7 @@ class IdeficsCausalLM(Model): :, : -batch.padding_right_offset ] - logits, past, image_hidden_states = self.forward( + logits, speculative_logits, past, image_hidden_states = self.forward( input_ids=batch.input_ids, attention_mask=attention_mask, position_ids=batch.position_ids, diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 9d59f424..2500d454 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -408,6 +408,7 @@ class Mamba(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -444,6 +445,7 @@ class Mamba(Model): tokenizer.pad_token = tokenizer.eos_token config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) @@ -505,7 +507,7 @@ class Mamba(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - logits = self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, inference_params=inference_params ) torch.cuda.synchronize() @@ -514,6 +516,7 @@ class Mamba(Model): "inference_params": inference_params, "graph": graph, "logits": logits, + "speculative_logits": speculative_logits, } self.cuda_graphs[batch_size] = graph_dict @@ -556,9 +559,14 @@ class Mamba(Model): inference_params.ssm_states.copy_( cuda_graph["inference_params"].ssm_states[:, :bs] ) - # Slice output to the correct shape - return cuda_graph["logits"][:bs] + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: start = time.time_ns() @@ -589,7 +597,9 @@ class Mamba(Model): batch.inference_params = inference_params # Forward pass - logits = self.forward(input_ids, inference_params=batch.inference_params) + logits, speculative_logits = self.forward( + input_ids, inference_params=batch.inference_params + ) # batch.inference_params = new_inference_params # Results diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index e419467f..6b3f29a6 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -43,6 +43,7 @@ class MPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -75,6 +76,7 @@ class MPTSharded(CausalLM): config = json.load(f) config = PretrainedConfig(**config) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 58fb212f..703e5b58 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -22,6 +22,7 @@ class OPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -47,6 +48,7 @@ class OPTSharded(CausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize + config.use_medusa = use_medusa tokenizer.pad_token_id = config.pad_token_id torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py index 79aa3fb9..cc4e2505 100644 --- a/server/text_generation_server/models/phi.py +++ b/server/text_generation_server/models/phi.py @@ -22,6 +22,7 @@ class Phi(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -52,6 +53,7 @@ class Phi(CausalLM): tokenizer.pad_token = tokenizer.eos_token config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 7b269d8e..73c21cce 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -19,6 +19,7 @@ class SantaCoder(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 777a55ba..fae9a2df 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -532,6 +532,7 @@ class Seq2SeqLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -596,6 +597,7 @@ class Seq2SeqLM(Model): past_key_values: Optional = None, ) -> Tuple[ torch.Tensor, + Optional[torch.Tensor], torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], ]: @@ -609,8 +611,15 @@ class Seq2SeqLM(Model): past_key_values=past_key_values, use_cache=True, ) + if isinstance(outputs, tuple): + # Our custom models + outputs, speculative_logits = outputs + else: + # Generic transformers models + speculative_logits = None return ( outputs.logits, + speculative_logits, outputs.encoder_last_hidden_state, outputs.past_key_values, ) @@ -635,7 +644,7 @@ class Seq2SeqLM(Model): else: encoder_last_hidden_state = None - logits, encoder_last_hidden_state, past = self.forward( + logits, speculative_logits, encoder_last_hidden_state, past = self.forward( batch.input_ids, batch.attention_mask, batch.decoder_input_ids, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 161e69ba..3f3cb965 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -25,6 +25,7 @@ class T5Sharded(Seq2SeqLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -42,6 +43,7 @@ class T5Sharded(Seq2SeqLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize + config.use_medusa = use_medusa tokenizer = AutoTokenizer.from_pretrained( model_id, @@ -94,7 +96,7 @@ class T5Sharded(Seq2SeqLM): List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], ]: # Model Forward - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -106,6 +108,7 @@ class T5Sharded(Seq2SeqLM): return ( outputs.logits, + speculative_logits, outputs.encoder_last_hidden_state, outputs.past_key_values, ) diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index b56484f6..a81e659d 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -40,6 +40,7 @@ def _weight_hub_files_from_model_info( and "arguments" not in s.rfilename and "args" not in s.rfilename and "training" not in s.rfilename + and "medusa_lm_head" not in s.rfilename ] @@ -56,6 +57,7 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: and "args" not in f and "adapter" not in f and "training" not in f + and "medusa_lm_head" not in f ] return filenames diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index bef2a146..209f1c8a 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -4,7 +4,7 @@ import torch.distributed from torch import nn from torch.nn import functional as F -from typing import List +from typing import List, Tuple, Optional from loguru import logger from functools import lru_cache @@ -380,6 +380,96 @@ class SuperLayer(nn.Module): return self.linear.forward(x) +class ResBlock(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.linear = FastLinear.load( + config, prefix=f"{prefix}.linear", weights=weights, bias=True + ) + self.act = torch.nn.SiLU() + + def forward(self, x): + return x + self.act(self.linear(x)) + + +class MedusaModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + self.heads = torch.nn.ModuleList( + [ + MedusaHead(config, prefix=f"{i}", weights=weights) + for i in range(config["medusa_num_heads"]) + ] + ) + + def forward(self, x): + speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) + return speculative_logits + + +class MedusaHead(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.blocks = torch.nn.ModuleList( + [ + ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) + for i in range(config["medusa_num_layers"]) + ] + ) + n = len(self.blocks) + self.out = FastLinear.load( + config, prefix=f"{prefix}.{n}", weights=weights, bias=False + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = self.out(x) + return x + + +class SpeculativeHead(nn.Module): + def __init__(self, lm_head, medusa): + super().__init__() + self.lm_head = lm_head + self.medusa = medusa + + @staticmethod + def load(config, prefix: str, weights): + lm_head = TensorParallelHead.load(config, prefix, weights) + use_medusa = config.use_medusa + if use_medusa: + from pathlib import Path + from safetensors import safe_open + import json + + medusa_config = str(Path(use_medusa) / "config.json") + filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") + + with open(medusa_config, "r") as f: + config = json.load(f) + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + weights.routing[k] = filename + + medusa = MedusaModel(config, weights) + else: + medusa = None + return SpeculativeHead(lm_head, medusa) + + def forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + logits = self.lm_head(input) + speculative_logits = self.medusa(input) if self.medusa is not None else None + return logits, speculative_logits + + class TensorParallelHead(SuperLayer): def __init__(self, linear, process_group, should_gather: bool): super().__init__(linear) diff --git a/server/text_generation_server/utils/medusa.py b/server/text_generation_server/utils/medusa.py deleted file mode 100644 index 634119cb..00000000 --- a/server/text_generation_server/utils/medusa.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -from dataclasses import dataclass -from text_generation_server.utils.layers import TensorParallelHead, FastLinear - - -@dataclass -class Output: - logits: torch.FloatTensor = None - speculative_logits: torch.FloatTensor = None - - -class ResBlock(torch.nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.linear = FastLinear.load( - config, prefix=f"{prefix}.linear", weights=weights, bias=True - ) - self.act = torch.nn.SiLU() - - def forward(self, x): - return x + self.act(self.linear(x)) - - -class MedusaModel(torch.nn.Module): - def __init__(self, config, weights, lm_head): - super().__init__() - self.heads = torch.nn.ModuleList( - [ - MedusaHead(config, prefix=f"{i}", weights=weights) - for i in range(config["medusa_num_heads"]) - ] - ) - self.lm_head = lm_head - - def forward(self, x): - logits = self.lm_head(x) - speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) - return logits, speculative_logits - - -class MedusaHead(torch.nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.blocks = torch.nn.ModuleList( - [ - ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) - for i in range(config["medusa_num_layers"]) - ] - ) - n = len(self.blocks) - self.out = FastLinear.load( - config, prefix=f"{prefix}.{n}", weights=weights, bias=False - ) - - def forward(self, x): - for block in self.blocks: - x = block(x) - x = self.out(x) - return x From 9b6db5f79312466ac698c128c8abd4fb3b7b47d3 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 28 Feb 2024 05:10:27 -0500 Subject: [PATCH 2/5] Support tools (#1587) This work in progress PR begins to add support for tools. Tools relies on grammar support and still has some unsolved challenges. Opening the PR for visibility and feedback --- clients/python/text_generation/client.py | 222 +++++++++- clients/python/text_generation/types.py | 120 ++++- docs/source/_toctree.yml | 2 + docs/source/guidance.md | 419 ++++++++++++++++++ integration-tests/conftest.py | 34 ++ .../test_flash_llama_grammar_no_tools.json | 26 ++ .../test_flash_llama_grammar_tools.json | 38 ++ .../test_flash_llama_grammar_tools_auto.json | 38 ++ ...test_flash_llama_grammar_tools_choice.json | 37 ++ ...test_flash_llama_grammar_tools_stream.json | 27 ++ integration-tests/models/test_tools_llama.py | 240 ++++++++++ router/src/infer.rs | 51 ++- router/src/lib.rs | 168 ++++++- router/src/server.rs | 115 ++++- 14 files changed, 1510 insertions(+), 27 deletions(-) create mode 100644 docs/source/guidance.md create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json create mode 100644 integration-tests/models/test_tools_llama.py diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index bbccbf1d..09660de3 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -3,7 +3,7 @@ import requests from aiohttp import ClientSession, ClientTimeout from pydantic import ValidationError -from typing import Dict, Optional, List, AsyncIterator, Iterator +from typing import Dict, Optional, List, AsyncIterator, Iterator, Union from text_generation.types import ( StreamResponse, @@ -11,6 +11,11 @@ from text_generation.types import ( Request, Parameters, Grammar, + ChatRequest, + ChatCompletionChunk, + ChatComplete, + Message, + Tool, ) from text_generation.errors import parse_error @@ -59,6 +64,114 @@ class Client: self.cookies = cookies self.timeout = timeout + def chat( + self, + messages: List[Message], + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + stream: bool = False, + seed: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + tools: Optional[List[Tool]] = None, + tool_choice: Optional[str] = None, + ): + """ + Given a list of messages, generate a response asynchronously + + Args: + messages (`List[Message]`): + List of messages + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + logit_bias (`List[float]`): + Adjust the likelihood of specified tokens + logprobs (`bool`): + Include log probabilities in the response + top_logprobs (`int`): + Include the `n` most likely tokens at each step + max_tokens (`int`): + Maximum number of generated tokens + n (`int`): + Generate `n` completions + presence_penalty (`float`): + The parameter for presence penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + stream (`bool`): + Stream the response + seed (`int`): + Random sampling seed + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + tools (`List[Tool]`): + List of tools to use + tool_choice (`str`): + The tool to use + + """ + request = ChatRequest( + model="tgi", + messages=messages, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + top_logprobs=top_logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + stream=stream, + seed=seed, + temperature=temperature, + top_p=top_p, + tools=tools, + tool_choice=tool_choice, + ) + if not stream: + resp = requests.post( + f"{self.base_url}/v1/chat/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + return ChatComplete(**payload) + else: + return self._chat_stream_response(request) + + def _chat_stream_response(self, request): + resp = requests.post( + f"{self.base_url}/v1/chat/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + stream=True, + ) + # iterate and print stream + for byte_payload in resp.iter_lines(): + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = ChatCompletionChunk(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + def generate( self, prompt: str, @@ -313,6 +426,113 @@ class AsyncClient: self.cookies = cookies self.timeout = ClientTimeout(timeout * 60) + async def chat( + self, + messages: List[Message], + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + stream: bool = False, + seed: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + tools: Optional[List[Tool]] = None, + tool_choice: Optional[str] = None, + ) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]: + """ + Given a list of messages, generate a response asynchronously + + Args: + messages (`List[Message]`): + List of messages + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + logit_bias (`List[float]`): + Adjust the likelihood of specified tokens + logprobs (`bool`): + Include log probabilities in the response + top_logprobs (`int`): + Include the `n` most likely tokens at each step + max_tokens (`int`): + Maximum number of generated tokens + n (`int`): + Generate `n` completions + presence_penalty (`float`): + The parameter for presence penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + stream (`bool`): + Stream the response + seed (`int`): + Random sampling seed + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + tools (`List[Tool]`): + List of tools to use + tool_choice (`str`): + The tool to use + + """ + request = ChatRequest( + model="tgi", + messages=messages, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + top_logprobs=top_logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + stream=stream, + seed=seed, + temperature=temperature, + top_p=top_p, + tools=tools, + tool_choice=tool_choice, + ) + if not stream: + return await self._chat_single_response(request) + else: + return self._chat_stream_response(request) + + async def _chat_single_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/chat/completions", json=request.dict() + ) as resp: + payload = await resp.json() + if resp.status != 200: + raise parse_error(resp.status, payload) + return ChatComplete(**payload) + + async def _chat_stream_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/chat/completions", json=request.dict() + ) as resp: + async for byte_payload in resp.content: + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = ChatCompletionChunk(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + async def generate( self, prompt: str, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 911114ee..4a308cef 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,6 +1,6 @@ from enum import Enum from pydantic import BaseModel, validator -from typing import Optional, List, Union +from typing import Optional, List, Union, Any from text_generation.errors import ValidationError @@ -19,6 +19,124 @@ class Grammar(BaseModel): value: Union[str, dict] +class ToolCall(BaseModel): + # Id of the tool call + id: int + # Type of the tool call + type: str + # Function details of the tool call + function: dict + + +class Message(BaseModel): + # Role of the message sender + role: str + # Content of the message + content: Optional[str] + # Optional name of the message sender + name: Optional[str] = None + # Tool calls associated with the chat completion + tool_calls: Optional[Any] = None + + +class Tool(BaseModel): + # Type of the tool + type: str + # Function details of the tool + function: dict + + +class ChatCompletionComplete(BaseModel): + # Index of the chat completion + index: int + # Message associated with the chat completion + message: Message + # Log probabilities for the chat completion + logprobs: Optional[Any] + # Reason for completion + finish_reason: str + # Usage details of the chat completion + usage: Any + + +class Function(BaseModel): + name: Optional[str] + arguments: str + + +class ChoiceDeltaToolCall(BaseModel): + index: int + id: str + type: str + function: Function + + +class ChoiceDelta(BaseModel): + role: str + content: Optional[str] + tool_calls: Optional[ChoiceDeltaToolCall] + + +class Choice(BaseModel): + index: int + delta: ChoiceDelta + logprobs: Optional[dict] = None + finish_reason: Optional[str] = None + + +class ChatCompletionChunk(BaseModel): + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[Choice] + + +class ChatComplete(BaseModel): + # Chat completion details + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[ChatCompletionComplete] + usage: Any + + +class ChatRequest(BaseModel): + # Model identifier + model: str + # List of messages in the conversation + messages: List[Message] + # Penalty for frequency of new tokens + frequency_penalty: Optional[float] = None + # Bias values for token selection + logit_bias: Optional[List[float]] = None + # Whether to return log probabilities + logprobs: Optional[bool] = None + # Number of most likely tokens to return at each position + top_logprobs: Optional[int] = None + # Maximum number of tokens to generate + max_tokens: Optional[int] = None + # Number of chat completion choices to generate + n: Optional[int] = None + # Penalty for presence of new tokens + presence_penalty: Optional[float] = None + # Flag to indicate streaming response + stream: bool = False + # Random sampling seed + seed: Optional[int] = None + # Sampling temperature + temperature: Optional[float] = None + # Top-p value for nucleus sampling + top_p: Optional[float] = None + # List of tools to be used + tools: Optional[List[Tool]] = None + # Choice of tool to be used + tool_choice: Optional[str] = None + + class Parameters(BaseModel): # Activate logits sampling do_sample: bool = False diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index d57a594d..964a743a 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -9,6 +9,8 @@ title: Supported Models and Hardware - local: messages_api title: Messages API + - local: guidance + title: Guidance title: Getting started - sections: - local: basic_tutorials/consuming_tgi diff --git a/docs/source/guidance.md b/docs/source/guidance.md new file mode 100644 index 00000000..8b9ba094 --- /dev/null +++ b/docs/source/guidance.md @@ -0,0 +1,419 @@ +# Guidance + +Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs. + +These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them! + +## Quick Start + +Before we jump into the deep end, ensure your system is using TGI version `1.4.3` or later to access all the features we're about to explore in this guide. + +If you're not up to date, grab the latest version and let's get started! + +## Table of Contents 📚 + +### Grammar and Constraints + +- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision. +- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models. +- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema. +- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses. + +### Tools and Functions + +- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions. +- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions. +- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. + +## Grammar and Constraints 🛣️ + +### The Grammar Parameter + +In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the AI. This is a game-changer for those who need precise control over the AI's output. + +Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability. + +```json +curl localhost:3000/generate \ + -X POST \ + -H 'Content-Type: application/json' \ + -d '{ + "inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park", + "parameters": { + "repetition_penalty": 1.3, + "grammar": { + "type": "json", + "value": { + "properties": { + "location": { + "type": "string" + }, + "activity": { + "type": "string" + }, + "animals_seen": { + "type": "integer", + "minimum": 1, + "maximum": 5 + }, + "animals": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": ["location", "activity", "animals_seen", "animals"] + } + } + } +}' +// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"} + +``` + +A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar. + +> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compliation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster. + +### Constrain with Pydantic + +Pydantic is a powerful library for data validation and settings management. It's the perfect tool for crafting the a specific response format. + +Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way. + +```python +import requests +from pydantic import BaseModel, conint +from typing import List + +class Animals(BaseModel): + location: str + activity: str + animals_seen: conint(ge=1, le=5) # Constrained integer type + animals: List[str] + +prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park" + +data = { + "inputs": prompt, + "parameters": { + "repetition_penalty": 1.3, + "grammar": { + "type": "json", + "value": Animals.schema() + } + } +} + +headers = { + "Content-Type": "application/json", +} + +response = requests.post( + 'http://127.0.0.1:3000/generate', + headers=headers, + json=data +) +print(response.json()) +# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'} + +``` + +### JSON Schema Integration + +If Pydantic's not your style, go raw with direct JSON Schema integration. It's like having a conversation with the AI in its own language. This is simliar to the first example but with programmatic control. + +```python +import requests + +json_schema = { + "properties": { + "location": { + "type": "string" + }, + "activity": { + "type": "string" + }, + "animals_seen": { + "type": "integer", + "minimum": 1, + "maximum": 5 + }, + "animals": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": ["location", "activity", "animals_seen", "animals"] +} + +data = { + "inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]", + "parameters": { + "max_new_tokens": 200, + "repetition_penalty": 1.3, + "grammar": { + "type": "json", + "value": json_schema + } + } +} + +headers = { + "Content-Type": "application/json", +} + +response = requests.post( + 'http://127.0.0.1:3000/generate', + headers=headers, + json=data +) +print(response.json()) +# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'} + +``` + +### Using the client + +TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter. + +```python +from text_generation import AsyncClient +from text_generation.types import GrammarType + +# NOTE: tools defined above and removed for brevity + +# Define an async function to encapsulate the async operation +async def main(): + client = AsyncClient(base_url="http://localhost:3000") + + # Use 'await' to wait for the async method 'chat' to complete + response = await client.generate( + "Whats Googles DNS", + max_new_tokens=10, + decoder_input_details=True, + seed=1, + grammar={ + "type": GrammarType.Regex, + "value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)", + }, + ) + + # Once the response is received, you can process it + print(response.generated_text) + +# Ensure the main async function is run in the event loop +if __name__ == "__main__": + import asyncio + asyncio.run(main()) + +# 118.8.0.84 + +``` + +## Tools and Functions 🛠️ + +### The Tools Parameter + +In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API. + +Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the AI's capabilities. You can use these tools to perform a variety of tasks, such as data manipulation, formatting, and more. + +Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API. + +```json +curl localhost:3000/v1/chat/completions \ + -X POST \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "tgi", + "messages": [ + { + "role": "user", + "content": "What is the weather like in New York?" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + } + }, + "required": ["location", "format"] + } + } + } + ], + "tool_choice": "get_current_weather" +}' +// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}} +``` + +
+ Tools used in example below + + ```python + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_n_day_weather_forecast", + "description": "Get an N-day weather forecast", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + "num_days": { + "type": "integer", + "description": "The number of days to forecast", + }, + }, + "required": ["location", "format", "num_days"], + }, + }, + } + ] + ``` + +
+ +### Text Generation Inference Client + +TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions. + +```python +from text_generation import AsyncClient + +# NOTE: tools defined above and removed for brevity + +# Define an async function to encapsulate the async operation +async def main(): + client = AsyncClient(base_url="http://localhost:3000") + + # Use 'await' to wait for the async method 'chat' to complete + response = await client.chat( + max_tokens=100, + seed=1, + tools=tools, + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + + # Once the response is received, you can process it + print(response.choices[0].message.tool_calls) + +# Ensure the main async function is run in the event loop +if __name__ == "__main__": + import asyncio + asyncio.run(main()) + +# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}} + +``` + +### OpenAI integration + +TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. + +However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary. + +```python +from openai import OpenAI + +# Initialize the client, pointing it to one of the available models +client = OpenAI( + base_url="http://localhost:3000/v1", + api_key="_", +) + +# NOTE: tools defined above and removed for brevity + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + { + "role": "system", + "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", + }, + { + "role": "user", + "content": "What's the weather like the next 3 days in San Francisco, CA?", + }, + ], + tools=tools, + tool_choice="auto", # tool selected by model + max_tokens=500, +) + + +called = chat_completion.choices[0].message.tool_calls +print(called) +# { +# "id": 0, +# "type": "function", +# "function": { +# "description": None, +# "name": "tools", +# "parameters": { +# "format": "celsius", +# "location": "San Francisco, CA", +# "num_days": 3, +# }, +# }, +# } +``` diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index e11c7cf9..96cf43ad 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -23,6 +23,8 @@ from text_generation.types import ( Token, BestOfSequence, Grammar, + ChatComplete, + ChatCompletionChunk, ) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) @@ -59,6 +61,15 @@ class ResponseComparator(JSONSnapshotExtension): ) -> bool: def convert_data(data): data = json.loads(data) + if isinstance(data, Dict) and "choices" in data: + choices = data["choices"] + if ( + isinstance(choices, List) + and len(choices) >= 1 + and "delta" in choices[0] + ): + return ChatCompletionChunk(**data) + return ChatComplete(**data) if isinstance(data, Dict): return Response(**data) @@ -144,6 +155,16 @@ class ResponseComparator(JSONSnapshotExtension): ) ) + def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool: + return ( + response.choices[0].message.content == other.choices[0].message.content + ) + + def eq_chat_complete_chunk( + response: ChatCompletionChunk, other: ChatCompletionChunk + ) -> bool: + return response.choices[0].delta.content == other.choices[0].delta.content + def eq_response(response: Response, other: Response) -> bool: return response.generated_text == other.generated_text and eq_details( response.details, other.details @@ -157,6 +178,19 @@ class ResponseComparator(JSONSnapshotExtension): if not isinstance(snapshot_data, List): snapshot_data = [snapshot_data] + if isinstance(serialized_data[0], ChatComplete): + return len(snapshot_data) == len(serialized_data) and all( + [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)] + ) + + if isinstance(serialized_data[0], ChatCompletionChunk): + return len(snapshot_data) == len(serialized_data) and all( + [ + eq_chat_complete_chunk(r, o) + for r, o in zip(serialized_data, snapshot_data) + ] + ) + return len(snapshot_data) == len(serialized_data) and all( [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)] ) diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json new file mode 100644 index 00000000..3c4b4aea --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1708957015, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.2-native", + "usage": { + "completion_tokens": 100, + "prompt_tokens": 60, + "total_tokens": 160 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json new file mode 100644 index 00000000..9b9e33c6 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json @@ -0,0 +1,38 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "name": null, + "role": "assistant", + "tool_calls": { + "function": { + "description": null, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY", + "num_days": 14 + } + }, + "id": 0, + "type": "function" + } + }, + "usage": null + } + ], + "created": 1709079417, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.2-native", + "usage": { + "completion_tokens": 29, + "prompt_tokens": 316, + "total_tokens": 345 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json new file mode 100644 index 00000000..de32c970 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json @@ -0,0 +1,38 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "name": null, + "role": "assistant", + "tool_calls": { + "function": { + "description": null, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY", + "num_days": 14 + } + }, + "id": 0, + "type": "function" + } + }, + "usage": null + } + ], + "created": 1709079492, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.2-native", + "usage": { + "completion_tokens": 29, + "prompt_tokens": 316, + "total_tokens": 345 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json new file mode 100644 index 00000000..3551e205 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json @@ -0,0 +1,37 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "name": null, + "role": "assistant", + "tool_calls": { + "function": { + "description": null, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY" + } + }, + "id": 0, + "type": "function" + } + }, + "usage": null + } + ], + "created": 1709079493, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.2-native", + "usage": { + "completion_tokens": 21, + "prompt_tokens": 187, + "total_tokens": 208 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json new file mode 100644 index 00000000..c367cc6f --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json @@ -0,0 +1,27 @@ +{ + "choices": [ + { + "delta": { + "content": null, + "role": "assistant", + "tool_calls": { + "function": { + "arguments": "", + "name": null + }, + "id": "", + "index": 20, + "type": "function" + } + }, + "finish_reason": "eos_token", + "index": 20, + "logprobs": null + } + ], + "created": 1709087088, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.2-native" +} diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py new file mode 100644 index 00000000..38570c38 --- /dev/null +++ b/integration-tests/models/test_tools_llama.py @@ -0,0 +1,240 @@ +import pytest +import json + +from text_generation.types import GrammarType + + +@pytest.fixture(scope="module") +def flash_llama_grammar_tools_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle): + await flash_llama_grammar_tools_handle.health(300) + return flash_llama_grammar_tools_handle.client + + +# tools to be used in the following tests +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_n_day_weather_forecast", + "description": "Get an N-day weather forecast", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + "num_days": { + "type": "integer", + "description": "The number of days to forecast", + }, + }, + "required": ["location", "format", "num_days"], + }, + }, + }, +] + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_no_tools( + flash_llama_grammar_tools, response_snapshot +): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + + assert ( + response.choices[0].message.content + == "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + assert response.choices[0].message.content == None + assert response.choices[0].message.tool_calls == { + "function": { + "description": None, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY", + "num_days": 14, + }, + }, + "id": 0, + "type": "function", + } + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_auto( + flash_llama_grammar_tools, response_snapshot +): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + tool_choice="auto", + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + assert response.choices[0].message.content == None + assert response.choices[0].message.tool_calls == { + "function": { + "description": None, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY", + "num_days": 14, + }, + }, + "id": 0, + "type": "function", + } + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_choice( + flash_llama_grammar_tools, response_snapshot +): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + tool_choice="get_current_weather", + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + assert response.choices[0].message.content == None + assert response.choices[0].message.tool_calls == { + "id": 0, + "type": "function", + "function": { + "description": None, + "name": "tools", + "parameters": {"format": "celsius", "location": "New York, NY"}, + }, + } + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_stream( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + tool_choice="get_current_weather", + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Paris, France?", + }, + ], + stream=True, + ) + + count = 0 + async for response in responses: + count += 1 + + assert count == 20 + assert response == response_snapshot diff --git a/router/src/infer.rs b/router/src/infer.rs index 472b7d66..42405327 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -812,23 +812,27 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: "Hi!".to_string(), + content: Some("Hi!".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), + content: Some("Hello how can I help?".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "What is Deep Learning?".to_string(), + content: Some("What is Deep Learning?".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "magic!".to_string(), + content: Some("magic!".to_string()), name: None, + tool_calls: None, }, ], bos_token: Some("[BOS]"), @@ -877,28 +881,33 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: "Hi!".to_string(), + content: Some("Hi!".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "Hi again!".to_string(), + content: Some("Hi again!".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), + content: Some("Hello how can I help?".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "What is Deep Learning?".to_string(), + content: Some("What is Deep Learning?".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "magic!".to_string(), + content: Some("magic!".to_string()), name: None, + tool_calls: None, }, ], bos_token: Some("[BOS]"), @@ -952,23 +961,27 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: "Hi!".to_string(), + content: Some("Hi!".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), + content: Some("Hello how can I help?".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "What is Deep Learning?".to_string(), + content: Some("What is Deep Learning?".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "magic!".to_string(), + content: Some("magic!".to_string()), name: None, + tool_calls: None, }, ], bos_token: Some("[BOS]"), @@ -1006,23 +1019,27 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: "Hi!".to_string(), + content: Some("Hi!".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), + content: Some("Hello how can I help?".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "What is Deep Learning?".to_string(), + content: Some("What is Deep Learning?".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "magic!".to_string(), + content: Some("magic!".to_string()), name: None, + tool_calls: None, }, ], bos_token: Some("[BOS]"), diff --git a/router/src/lib.rs b/router/src/lib.rs index 1c06eb8a..d89bacb5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -358,10 +358,11 @@ impl ChatCompletion { pub(crate) fn new( model: String, system_fingerprint: String, - output: String, + output: Option, created: u64, details: Details, return_logprobs: bool, + tool_calls: Option, ) -> Self { Self { id: String::new(), @@ -375,6 +376,7 @@ impl ChatCompletion { role: "assistant".into(), content: output, name: None, + tool_calls, }, logprobs: return_logprobs .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), @@ -413,15 +415,35 @@ pub(crate) struct ChatCompletionChoice { pub(crate) struct ChatCompletionDelta { #[schema(example = "user")] pub role: String, + #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "What is Deep Learning?")] - pub content: String, + pub content: Option, + // default to None + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option, } +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +pub(crate) struct DeltaToolCall { + pub index: u32, + pub id: String, + pub r#type: String, + pub function: Function, +} + +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +pub(crate) struct Function { + pub name: Option, + pub arguments: String, +} + +#[allow(clippy::too_many_arguments)] impl ChatCompletionChunk { pub(crate) fn new( model: String, system_fingerprint: String, - delta: String, + delta: Option, + tool_calls: Option>, created: u64, index: u32, logprobs: Option, @@ -438,6 +460,15 @@ impl ChatCompletionChunk { delta: ChatCompletionDelta { role: "assistant".to_string(), content: delta, + tool_calls: tool_calls.map(|tc| DeltaToolCall { + index, + id: String::new(), + r#type: "function".to_string(), + function: Function { + name: None, + arguments: tc[0].to_string(), + }, + }), }, logprobs, finish_reason, @@ -520,6 +551,125 @@ pub(crate) struct ChatRequest { #[serde(default)] #[schema(nullable = true, example = 0.95)] pub top_p: Option, + + /// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of + /// functions the model may generate JSON inputs for. + #[serde(default)] + #[schema(nullable = true, example = "null")] + pub tools: Option>, + + /// A prompt to be appended before the tools + #[serde(default = "default_tool_prompt")] + #[schema( + nullable = true, + example = "\"Based on the conversation, please choose the most appropriate tool to use: \"" + )] + pub tool_prompt: Option, + + /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. + #[serde(default)] + #[schema(nullable = true, example = "null")] + #[serde(deserialize_with = "deserialize_tool_choice::deserialize")] + pub tool_choice: Option, +} + +fn default_tool_prompt() -> Option { + Some( + "\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(), + ) +} +#[derive(Clone, Deserialize, ToSchema, Serialize)] +enum ToolType { + FunctionName(String), + OneOf, +} + +/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None) +mod deserialize_tool_choice { + use super::*; + use serde::de; + use serde::Deserializer; + use serde_json::Value; + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + match value { + Value::String(s) => match s.as_str() { + "none" => Ok(None), + "auto" => Ok(Some(ToolType::OneOf)), + _ => Ok(Some(ToolType::FunctionName(s))), + }, + Value::Object(map) => { + if let Some(content) = map + .get("function") + .and_then(|v| v.get("name")) + .and_then(|v| v.as_str()) + { + Ok(Some(ToolType::FunctionName(content.to_string()))) + } else { + Err(de::Error::custom("function key not found in tool choice")) + } + } + Value::Null => Ok(Some(ToolType::OneOf)), + _ => Err(de::Error::custom("invalid token format")), + } + } +} + +#[derive(Debug, Deserialize, Serialize, ToSchema)] +pub struct Tools { + #[serde(flatten)] + functions_map: FunctionsMap, + properties: Properties, +} + +#[derive(Debug, Serialize, Deserialize)] +struct FunctionsMap { + #[serde(rename = "$functions")] + functions: std::collections::HashMap, +} + +#[derive(Debug, Serialize, Deserialize)] +struct FunctionRef { + #[serde(rename = "$ref")] + ref_path: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Properties { + #[serde(serialize_with = "serialize_function")] + function: Vec, +} + +fn serialize_function(functions: &Vec, serializer: S) -> Result +where + S: serde::Serializer, +{ + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("Function", 1)?; + state.serialize_field("anyOf", functions)?; + state.end() +} + +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)] +pub(crate) struct FunctionDefinition { + #[serde(default)] + pub description: Option, + pub name: String, + pub parameters: serde_json::Value, +} + +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +pub(crate) struct Tool { + // The type of the tool. Currently, only 'function' is supported. + #[schema(example = "function")] + pub r#type: String, + // Grab the tool as generic JSON for debugging purposes. + pub function: FunctionDefinition, } #[derive(Clone, Serialize, Deserialize)] @@ -530,15 +680,25 @@ pub(crate) struct ChatTemplateInputs<'a> { add_generation_prompt: bool, } +#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] +pub(crate) struct ToolCall { + pub id: u32, + pub r#type: String, + pub function: FunctionDefinition, +} + #[derive(Clone, Deserialize, ToSchema, Serialize)] pub(crate) struct Message { #[schema(example = "user")] pub role: String, + #[serde(skip_serializing_if = "Option::is_none")] #[schema(example = "My name is David and I")] - pub content: String, + pub content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] pub name: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option, } #[derive(Clone, Debug, Deserialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 9fdd66cc..2efa9284 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -10,6 +10,7 @@ use crate::{ HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, }; +use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; @@ -22,6 +23,8 @@ use futures::stream::StreamExt; use futures::Stream; use futures::TryStreamExt; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; +use serde_json::Value; +use std::collections::HashMap; use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; @@ -581,7 +584,7 @@ async fn chat_completions( let seed = req.seed; // apply chat template to flatten the request into a single input - let inputs = match infer.apply_chat_template(req.messages) { + let mut inputs = match infer.apply_chat_template(req.messages) { Ok(inputs) => inputs, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); @@ -596,6 +599,62 @@ async fn chat_completions( } }; + let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) { + let tool_prompt = req.tool_prompt.unwrap_or_default(); + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![req_tools + .iter() + .find(|tool| tool.function.name == *name) + .ok_or_else(|| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Tool choice not found in tool names".to_string(), + error_type: "Tool not found".to_string(), + }), + ) + })? + .clone()] + } + ToolType::OneOf => req_tools.to_owned(), + }; + + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); + (func.name, func.parameters) + }) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .collect(), + }, + }; + + let tools_str = serde_json::to_string(&tools).map_err(|e| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: e.to_string(), + error_type: "Input validation error".to_string(), + }), + ) + })?; + inputs = format!("{inputs}{tool_prompt}{tools_str}"); + Some(GrammarType::Json(serde_json::json!(tools))) + } else { + None + }; + // build the request passing some parameters let generate_request = GenerateRequest { inputs: inputs.to_string(), @@ -617,7 +676,7 @@ async fn chat_completions( decoder_input_details: !stream, seed, top_n_tokens: None, - grammar: None, + grammar: tool_grammar.clone(), }, }; @@ -640,11 +699,19 @@ async fn chat_completions( ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens)) }); + // replace the content with the tool calls if grammar is present + let (content, tool_calls) = if tool_grammar.is_some() { + (None, Some(vec![stream_token.token.text])) + } else { + (Some(stream_token.token.text), None) + }; + event .json_data(ChatCompletionChunk::new( model_id.clone(), system_fingerprint.clone(), - stream_token.token.text, + content, + tool_calls, current_time, stream_token.index, logprobs, @@ -681,14 +748,54 @@ async fn chat_completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); + let (tool_calls, output) = if tool_grammar.is_some() { + // gen_text should be valid json + let gen_text_value: Value = + serde_json::from_str(&generation.generated_text).map_err(|e| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: e.to_string(), + error_type: "Input validation error".to_string(), + }), + ) + })?; + + let tool_call = Some(ToolCall { + id: 0, + r#type: "function".to_string(), + function: FunctionDefinition { + description: None, + name: "tools".to_string(), + parameters: gen_text_value.get("function").map_or_else( + || { + serde_json::from_str(&generation.generated_text).map_err(|e| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: e.to_string(), + error_type: "Input validation error".to_string(), + }), + ) + }) + }, + |f| Ok(f.clone()), + )?, + }, + }); + (tool_call, None) + } else { + (None, Some(generation.generated_text)) + }; // build the complete response object with the full text let response = ChatCompletion::new( model_id, system_fingerprint, - generation.generated_text, + output, current_time, generation.details.unwrap(), logprobs, + tool_calls, ); // wrap generation inside a Vec to match api-inference From 910d0a906288bf116ee8bd57d53dc05fb76cd29c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 28 Feb 2024 11:30:37 +0100 Subject: [PATCH 3/5] Fixing x-compute-time. (#1606) # What does this PR do? It was meant to be in seconds float Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- README.md | 2 ++ docs/source/_toctree.yml | 4 +++ docs/source/conceptual/guidance.md | 1 + docs/source/conceptual/speculation.md | 48 +++++++++++++++++++++++++++ router/src/server.rs | 2 +- 5 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 docs/source/conceptual/guidance.md create mode 100644 docs/source/conceptual/speculation.md diff --git a/README.md b/README.md index 7589a3a6..60fe83cd 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,8 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan - Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) - Stop sequences - Log probabilities +- [Speculation](https://huggingface.co/docs/text-generation-inference/conceptual/speculation) ~2x latency +- [Guidance/JSON](https://huggingface.co/docs/text-generation-inference/conceptual/guidance). Specify output format to speed up inference and make sure the output is valid according to some specs.. - Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output - Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 964a743a..73c88ccc 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -39,4 +39,8 @@ title: Safetensors - local: conceptual/flash_attention title: Flash Attention + - local: conceptual/speculation + title: Speculation (Medusa, ngram) + - local: conceptual/guidance + title: Guidance, JSON, tools (using outlines) title: Conceptual Guides diff --git a/docs/source/conceptual/guidance.md b/docs/source/conceptual/guidance.md new file mode 100644 index 00000000..8fb46466 --- /dev/null +++ b/docs/source/conceptual/guidance.md @@ -0,0 +1 @@ +## Guidance diff --git a/docs/source/conceptual/speculation.md b/docs/source/conceptual/speculation.md new file mode 100644 index 00000000..071b7b68 --- /dev/null +++ b/docs/source/conceptual/speculation.md @@ -0,0 +1,48 @@ +## Speculation + +Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea. +The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid. + +So you are making *more* computations on your LLM, but if you are correct you produce 1, 2, 3 etc.. tokens on a single LLM pass. Since LLMs are usually memory bound (and not compute bound), provided your guesses are correct enough, this is a 2-3x faster inference (It can be much more for code oriented tasks for instance). + +You can check a more [detailed explanation](https://huggingface.co/blog/assisted-generation). + +Text-generation inference supports 2 main speculative methods: + +- Medusa +- N-gram + + +### Medusa + + +Medusa is a [simple method](https://arxiv.org/abs/2401.10774) to create many tokens in a single pass using fine-tuned LM heads in addition to your existing models. + + +You can check a few existing fine-tunes for popular models: + +- [text-generation-inference/gemma-7b-it-medusa](https://huggingface.co/text-generation-inference/gemma-7b-it-medusa) +- [text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa](https://huggingface.co/text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa) +- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa) + + +In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [https://github.com/FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa) + + +In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically. + + +### N-gram + + +If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`. +Ngram works by trying to find in the previous sequence existing tokens that match, and use those as speculation. + +This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much. + + +In order to enable n-gram speculation simply use + +`--speculate 2` in your flags. + +[Details about the flag](https://huggingface.co/docs/text-generation-inference/basic_tutorials/launcher#speculate) diff --git a/router/src/server.rs b/router/src/server.rs index 2efa9284..9c7046d9 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -242,7 +242,7 @@ async fn generate( headers.insert("x-compute-type", compute_type.parse().unwrap()); headers.insert( "x-compute-time", - total_time.as_millis().to_string().parse().unwrap(), + total_time.as_secs_f64().to_string().parse().unwrap(), ); headers.insert( "x-compute-characters", From 97e22369f46fd0a8085856d9798ef2f61946fa6c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 28 Feb 2024 12:05:15 +0100 Subject: [PATCH 4/5] Fixing guidance docs. (#1607) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- docs/source/conceptual/guidance.md | 420 ++++++++++++++++++++++++++++- docs/source/guidance.md | 419 ---------------------------- 2 files changed, 419 insertions(+), 420 deletions(-) delete mode 100644 docs/source/guidance.md diff --git a/docs/source/conceptual/guidance.md b/docs/source/conceptual/guidance.md index 8fb46466..8b9ba094 100644 --- a/docs/source/conceptual/guidance.md +++ b/docs/source/conceptual/guidance.md @@ -1 +1,419 @@ -## Guidance +# Guidance + +Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs. + +These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them! + +## Quick Start + +Before we jump into the deep end, ensure your system is using TGI version `1.4.3` or later to access all the features we're about to explore in this guide. + +If you're not up to date, grab the latest version and let's get started! + +## Table of Contents 📚 + +### Grammar and Constraints + +- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision. +- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models. +- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema. +- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses. + +### Tools and Functions + +- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions. +- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions. +- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. + +## Grammar and Constraints 🛣️ + +### The Grammar Parameter + +In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the AI. This is a game-changer for those who need precise control over the AI's output. + +Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability. + +```json +curl localhost:3000/generate \ + -X POST \ + -H 'Content-Type: application/json' \ + -d '{ + "inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park", + "parameters": { + "repetition_penalty": 1.3, + "grammar": { + "type": "json", + "value": { + "properties": { + "location": { + "type": "string" + }, + "activity": { + "type": "string" + }, + "animals_seen": { + "type": "integer", + "minimum": 1, + "maximum": 5 + }, + "animals": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": ["location", "activity", "animals_seen", "animals"] + } + } + } +}' +// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"} + +``` + +A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar. + +> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compliation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster. + +### Constrain with Pydantic + +Pydantic is a powerful library for data validation and settings management. It's the perfect tool for crafting the a specific response format. + +Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way. + +```python +import requests +from pydantic import BaseModel, conint +from typing import List + +class Animals(BaseModel): + location: str + activity: str + animals_seen: conint(ge=1, le=5) # Constrained integer type + animals: List[str] + +prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park" + +data = { + "inputs": prompt, + "parameters": { + "repetition_penalty": 1.3, + "grammar": { + "type": "json", + "value": Animals.schema() + } + } +} + +headers = { + "Content-Type": "application/json", +} + +response = requests.post( + 'http://127.0.0.1:3000/generate', + headers=headers, + json=data +) +print(response.json()) +# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'} + +``` + +### JSON Schema Integration + +If Pydantic's not your style, go raw with direct JSON Schema integration. It's like having a conversation with the AI in its own language. This is simliar to the first example but with programmatic control. + +```python +import requests + +json_schema = { + "properties": { + "location": { + "type": "string" + }, + "activity": { + "type": "string" + }, + "animals_seen": { + "type": "integer", + "minimum": 1, + "maximum": 5 + }, + "animals": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": ["location", "activity", "animals_seen", "animals"] +} + +data = { + "inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]", + "parameters": { + "max_new_tokens": 200, + "repetition_penalty": 1.3, + "grammar": { + "type": "json", + "value": json_schema + } + } +} + +headers = { + "Content-Type": "application/json", +} + +response = requests.post( + 'http://127.0.0.1:3000/generate', + headers=headers, + json=data +) +print(response.json()) +# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'} + +``` + +### Using the client + +TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter. + +```python +from text_generation import AsyncClient +from text_generation.types import GrammarType + +# NOTE: tools defined above and removed for brevity + +# Define an async function to encapsulate the async operation +async def main(): + client = AsyncClient(base_url="http://localhost:3000") + + # Use 'await' to wait for the async method 'chat' to complete + response = await client.generate( + "Whats Googles DNS", + max_new_tokens=10, + decoder_input_details=True, + seed=1, + grammar={ + "type": GrammarType.Regex, + "value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)", + }, + ) + + # Once the response is received, you can process it + print(response.generated_text) + +# Ensure the main async function is run in the event loop +if __name__ == "__main__": + import asyncio + asyncio.run(main()) + +# 118.8.0.84 + +``` + +## Tools and Functions 🛠️ + +### The Tools Parameter + +In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API. + +Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the AI's capabilities. You can use these tools to perform a variety of tasks, such as data manipulation, formatting, and more. + +Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API. + +```json +curl localhost:3000/v1/chat/completions \ + -X POST \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "tgi", + "messages": [ + { + "role": "user", + "content": "What is the weather like in New York?" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + } + }, + "required": ["location", "format"] + } + } + } + ], + "tool_choice": "get_current_weather" +}' +// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}} +``` + +
+ Tools used in example below + + ```python + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_n_day_weather_forecast", + "description": "Get an N-day weather forecast", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + "num_days": { + "type": "integer", + "description": "The number of days to forecast", + }, + }, + "required": ["location", "format", "num_days"], + }, + }, + } + ] + ``` + +
+ +### Text Generation Inference Client + +TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions. + +```python +from text_generation import AsyncClient + +# NOTE: tools defined above and removed for brevity + +# Define an async function to encapsulate the async operation +async def main(): + client = AsyncClient(base_url="http://localhost:3000") + + # Use 'await' to wait for the async method 'chat' to complete + response = await client.chat( + max_tokens=100, + seed=1, + tools=tools, + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + + # Once the response is received, you can process it + print(response.choices[0].message.tool_calls) + +# Ensure the main async function is run in the event loop +if __name__ == "__main__": + import asyncio + asyncio.run(main()) + +# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}} + +``` + +### OpenAI integration + +TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. + +However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary. + +```python +from openai import OpenAI + +# Initialize the client, pointing it to one of the available models +client = OpenAI( + base_url="http://localhost:3000/v1", + api_key="_", +) + +# NOTE: tools defined above and removed for brevity + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + { + "role": "system", + "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", + }, + { + "role": "user", + "content": "What's the weather like the next 3 days in San Francisco, CA?", + }, + ], + tools=tools, + tool_choice="auto", # tool selected by model + max_tokens=500, +) + + +called = chat_completion.choices[0].message.tool_calls +print(called) +# { +# "id": 0, +# "type": "function", +# "function": { +# "description": None, +# "name": "tools", +# "parameters": { +# "format": "celsius", +# "location": "San Francisco, CA", +# "num_days": 3, +# }, +# }, +# } +``` diff --git a/docs/source/guidance.md b/docs/source/guidance.md deleted file mode 100644 index 8b9ba094..00000000 --- a/docs/source/guidance.md +++ /dev/null @@ -1,419 +0,0 @@ -# Guidance - -Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs. - -These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them! - -## Quick Start - -Before we jump into the deep end, ensure your system is using TGI version `1.4.3` or later to access all the features we're about to explore in this guide. - -If you're not up to date, grab the latest version and let's get started! - -## Table of Contents 📚 - -### Grammar and Constraints - -- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision. -- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models. -- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema. -- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses. - -### Tools and Functions - -- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions. -- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions. -- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. - -## Grammar and Constraints 🛣️ - -### The Grammar Parameter - -In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the AI. This is a game-changer for those who need precise control over the AI's output. - -Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability. - -```json -curl localhost:3000/generate \ - -X POST \ - -H 'Content-Type: application/json' \ - -d '{ - "inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park", - "parameters": { - "repetition_penalty": 1.3, - "grammar": { - "type": "json", - "value": { - "properties": { - "location": { - "type": "string" - }, - "activity": { - "type": "string" - }, - "animals_seen": { - "type": "integer", - "minimum": 1, - "maximum": 5 - }, - "animals": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "required": ["location", "activity", "animals_seen", "animals"] - } - } - } -}' -// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"} - -``` - -A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar. - -> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compliation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster. - -### Constrain with Pydantic - -Pydantic is a powerful library for data validation and settings management. It's the perfect tool for crafting the a specific response format. - -Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way. - -```python -import requests -from pydantic import BaseModel, conint -from typing import List - -class Animals(BaseModel): - location: str - activity: str - animals_seen: conint(ge=1, le=5) # Constrained integer type - animals: List[str] - -prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park" - -data = { - "inputs": prompt, - "parameters": { - "repetition_penalty": 1.3, - "grammar": { - "type": "json", - "value": Animals.schema() - } - } -} - -headers = { - "Content-Type": "application/json", -} - -response = requests.post( - 'http://127.0.0.1:3000/generate', - headers=headers, - json=data -) -print(response.json()) -# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'} - -``` - -### JSON Schema Integration - -If Pydantic's not your style, go raw with direct JSON Schema integration. It's like having a conversation with the AI in its own language. This is simliar to the first example but with programmatic control. - -```python -import requests - -json_schema = { - "properties": { - "location": { - "type": "string" - }, - "activity": { - "type": "string" - }, - "animals_seen": { - "type": "integer", - "minimum": 1, - "maximum": 5 - }, - "animals": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "required": ["location", "activity", "animals_seen", "animals"] -} - -data = { - "inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]", - "parameters": { - "max_new_tokens": 200, - "repetition_penalty": 1.3, - "grammar": { - "type": "json", - "value": json_schema - } - } -} - -headers = { - "Content-Type": "application/json", -} - -response = requests.post( - 'http://127.0.0.1:3000/generate', - headers=headers, - json=data -) -print(response.json()) -# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'} - -``` - -### Using the client - -TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter. - -```python -from text_generation import AsyncClient -from text_generation.types import GrammarType - -# NOTE: tools defined above and removed for brevity - -# Define an async function to encapsulate the async operation -async def main(): - client = AsyncClient(base_url="http://localhost:3000") - - # Use 'await' to wait for the async method 'chat' to complete - response = await client.generate( - "Whats Googles DNS", - max_new_tokens=10, - decoder_input_details=True, - seed=1, - grammar={ - "type": GrammarType.Regex, - "value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)", - }, - ) - - # Once the response is received, you can process it - print(response.generated_text) - -# Ensure the main async function is run in the event loop -if __name__ == "__main__": - import asyncio - asyncio.run(main()) - -# 118.8.0.84 - -``` - -## Tools and Functions 🛠️ - -### The Tools Parameter - -In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API. - -Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the AI's capabilities. You can use these tools to perform a variety of tasks, such as data manipulation, formatting, and more. - -Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API. - -```json -curl localhost:3000/v1/chat/completions \ - -X POST \ - -H 'Content-Type: application/json' \ - -d '{ - "model": "tgi", - "messages": [ - { - "role": "user", - "content": "What is the weather like in New York?" - } - ], - "tools": [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "format": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The temperature unit to use. Infer this from the users location." - } - }, - "required": ["location", "format"] - } - } - } - ], - "tool_choice": "get_current_weather" -}' -// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}} -``` - -
- Tools used in example below - - ```python - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "format": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The temperature unit to use. Infer this from the users location.", - }, - }, - "required": ["location", "format"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "get_n_day_weather_forecast", - "description": "Get an N-day weather forecast", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "format": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The temperature unit to use. Infer this from the users location.", - }, - "num_days": { - "type": "integer", - "description": "The number of days to forecast", - }, - }, - "required": ["location", "format", "num_days"], - }, - }, - } - ] - ``` - -
- -### Text Generation Inference Client - -TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions. - -```python -from text_generation import AsyncClient - -# NOTE: tools defined above and removed for brevity - -# Define an async function to encapsulate the async operation -async def main(): - client = AsyncClient(base_url="http://localhost:3000") - - # Use 'await' to wait for the async method 'chat' to complete - response = await client.chat( - max_tokens=100, - seed=1, - tools=tools, - presence_penalty=-1.1, - messages=[ - { - "role": "system", - "content": "You're a helpful assistant! Answer the users question best you can.", - }, - { - "role": "user", - "content": "What is the weather like in Brooklyn, New York?", - }, - ], - ) - - # Once the response is received, you can process it - print(response.choices[0].message.tool_calls) - -# Ensure the main async function is run in the event loop -if __name__ == "__main__": - import asyncio - asyncio.run(main()) - -# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}} - -``` - -### OpenAI integration - -TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. - -However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary. - -```python -from openai import OpenAI - -# Initialize the client, pointing it to one of the available models -client = OpenAI( - base_url="http://localhost:3000/v1", - api_key="_", -) - -# NOTE: tools defined above and removed for brevity - -chat_completion = client.chat.completions.create( - model="tgi", - messages=[ - { - "role": "system", - "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", - }, - { - "role": "user", - "content": "What's the weather like the next 3 days in San Francisco, CA?", - }, - ], - tools=tools, - tool_choice="auto", # tool selected by model - max_tokens=500, -) - - -called = chat_completion.choices[0].message.tool_calls -print(called) -# { -# "id": 0, -# "type": "function", -# "function": { -# "description": None, -# "name": "tools", -# "parameters": { -# "format": "celsius", -# "location": "San Francisco, CA", -# "num_days": 3, -# }, -# }, -# } -``` From b40e833493808ed80b0bd6d8a68252fff01d307a Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 28 Feb 2024 12:07:08 +0100 Subject: [PATCH 5/5] feat: starcoder2 (#1605) --- .../test_flash_starcoder2.json | 94 +++ .../test_flash_starcoder2_default_params.json | 394 +++++++++++++ .../test_flash_starcoder2_load.json | 378 ++++++++++++ .../models/test_flash_starcoder2.py | 55 ++ proto/generate.proto | 1 - .../text_generation_server/models/__init__.py | 24 + .../flash_starcoder2_modeling.py | 545 ++++++++++++++++++ .../models/flash_mistral.py | 40 +- .../models/flash_starcoder2.py | 86 +++ 9 files changed, 1601 insertions(+), 16 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json create mode 100644 integration-tests/models/test_flash_starcoder2.py create mode 100644 server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py create mode 100644 server/text_generation_server/models/flash_starcoder2.py diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json new file mode 100644 index 00000000..36a2ff4d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json @@ -0,0 +1,94 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40844727, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27905273, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6118164, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68652344, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4619141, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.7993164, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.63134766, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23278809, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2294922, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json new file mode 100644 index 00000000..38117272 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json @@ -0,0 +1,394 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 60, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": 0, + "tokens": [ + { + "id": 2284, + "logprob": -0.296875, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.28125, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -0.79248047, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.61816406, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.0619812, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -0.4091797, + "special": false, + "text": "def" + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7670, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 444, + "logprob": -0.21655273, + "special": false, + "text": "name" + }, + { + "id": 45, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 731, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 303, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 332, + "logprob": -0.034698486, + "special": false, + "text": " \"" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 655, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 494, + "logprob": -0.20141602, + "special": false, + "text": " +" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 16013, + "logprob": 0.0, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": 0.0, + "special": false, + "text": "def" + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7670, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 400, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 45, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 49, + "logprob": 0.0, + "special": false, + "text": "," + }, + { + "id": 11505, + "logprob": 0.0, + "special": false, + "text": " age" + }, + { + "id": 731, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 303, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 655, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 3021, + "logprob": -0.5761719, + "special": false, + "text": " \"," + }, + { + "id": 863, + "logprob": 0.0, + "special": false, + "text": " you" + }, + { + "id": 904, + "logprob": 0.0, + "special": false, + "text": " are" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 615, + "logprob": 0.0, + "special": false, + "text": " str" + }, + { + "id": 45, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 400, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 46, + "logprob": 0.0, + "special": false, + "text": ")" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json new file mode 100644 index 00000000..9e82d4be --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json @@ -0,0 +1,378 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + } +] diff --git a/integration-tests/models/test_flash_starcoder2.py b/integration-tests/models/test_flash_starcoder2.py new file mode 100644 index 00000000..ea665b6c --- /dev/null +++ b/integration-tests/models/test_flash_starcoder2.py @@ -0,0 +1,55 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_starcoder2_handle(launcher): + with launcher("bigcode/starcoder2-3b", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_starcoder2(flash_starcoder2_handle): + await flash_starcoder2_handle.health(300) + return flash_starcoder2_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder2(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "def print_hello", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "def print_hello", + max_new_tokens=60, + temperature=0.2, + top_p=0.95, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 60 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder2_load( + flash_starcoder2, generate_load, response_snapshot +): + responses = await generate_load( + flash_starcoder2, "def print_hello", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/proto/generate.proto b/proto/generate.proto index 0490029f..6351e37f 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -230,7 +230,6 @@ message WarmupRequest { uint32 max_total_tokens = 4; } -/// Empty response message WarmupResponse { /// Maximum number of tokens supported by the model optional uint32 max_supported_total_tokens = 1; diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3208275c..e2edbfa9 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -64,6 +64,7 @@ try: from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mixtral import FlashMixtral from text_generation_server.models.flash_phi import FlashPhi + from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA except ImportError as e: @@ -80,6 +81,7 @@ if FLASH_ATTENTION: __all__.append(FlashMistral) __all__.append(FlashMixtral) __all__.append(FlashPhi) + __all__.append(FlashStarcoder2) MAMBA_AVAILABLE = True try: @@ -184,6 +186,16 @@ def get_model( trust_remote_code=trust_remote_code, ) + if model_id.startswith("facebook/galactica"): + return GalacticaSharded( + model_id, + revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + if ( model_type == "gpt_bigcode" or model_type == "gpt2" @@ -401,6 +413,18 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + if model_type == "starcoder2": + sliding_window = config_dict.get("sliding_window", -1) + if ( + (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION + ) or HAS_FLASH_ATTN_V2_CUDA: + return FlashStarcoder2( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type == "opt": return OPTSharded( diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py new file mode 100644 index 00000000..ed77af78 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -0,0 +1,545 @@ +# coding=utf-8 +# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + SpeculativeHead, + get_linear, + FastRMSNorm, + FastLayerNorm, +) + + +class Starcoder2Config(PretrainedConfig): + model_type = "starcoder2" + + def __init__( + self, + vocab_size=49152, + hidden_size=3072, + intermediate_size=12288, + num_hidden_layers=30, + num_attention_heads=24, + num_key_value_heads=2, + mlp_type="default", + hidden_act="gelu_pytorch_tanh", + max_position_embeddings=4096, + initializer_range=0.018042, + norm_type="layer_norm", + norm_epsilon=1e-5, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + rope_theta=10000.0, + sliding_window=None, + attention_dropout=0.0, + residual_dropout=0.0, + embedding_dropout=0.0, + use_bias: bool = True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.use_bias = use_bias + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.mlp_type = mlp_type + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.norm_type = norm_type + self.norm_epsilon = norm_epsilon + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.residual_dropout = residual_dropout + self.embedding_dropout = embedding_dropout + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=config.use_bias, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + if config.use_bias: + w = [ + weights.get_sharded(f"{p}.bias", dim=0) + for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + ] + bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device) + else: + bias = None + + return TensorParallelColumnLinear( + get_linear(weight, bias=bias, quantize=config.quantize) + ) + + +class Starcoder2Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=config.use_bias, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + + paged_attention.reshape_and_cache( + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + window_size_left=self.max_past, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class Starcoder2MLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.c_fc = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.c_fc", + weights=weights, + bias=config.use_bias, + ) + self.c_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=config.use_bias, + ) + + def forward(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + return self.c_proj(hidden_states) + + +class Starcoder2GatedMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=config.use_bias, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=config.use_bias, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +STARCODER2_NORMALIZATION_CLASSES = { + "layer_norm": FastLayerNorm, + "rms_norm": FastRMSNorm, +} + +STARCODER2_MLP_CLASSES = { + "default": Starcoder2MLP, + "gated": Starcoder2GatedMLP, +} + + +class Starcoder2Layer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = Starcoder2Attention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + + self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.norm_epsilon + ) + self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[ + config.norm_type + ].load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.norm_epsilon, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class Starcoder2Model(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.layers = nn.ModuleList( + [ + Starcoder2Layer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( + prefix="model.norm", weights=weights, eps=config.norm_epsilon + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + true_max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, true_max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashStarcoder2ForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = Starcoder2Model(config, weights) + try: + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head", + weights=weights, + ) + except RuntimeError: + self.lm_head = SpeculativeHead.load( + config, + prefix="model.embed_tokens", + weights=weights, + ) + + self.max_past = config.sliding_window + self.max_past_tensor = ( + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None + else None + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + true_max_s = max_s + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif self.max_past is not None: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + true_max_s, + prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index d3c0da9c..fd5c18e0 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from opentelemetry import trace from transformers import PreTrainedTokenizerBase from transformers.models.llama import LlamaTokenizerFast -from typing import Optional, Tuple, Type, List +from typing import Optional, Tuple, Type from text_generation_server.pb import generate_pb2 from text_generation_server.models import FlashCausalLM @@ -38,6 +38,19 @@ SLIDING_WINDOW_BLOCKS: Optional[int] = None MEM_POOL = torch.cuda.graph_pool_handle() +def set_sliding_window(sliding_window: int, sliding_window_blocks: int): + global SLIDING_WINDOW + global SLIDING_WINDOW_BLOCKS + SLIDING_WINDOW = sliding_window + SLIDING_WINDOW_BLOCKS = sliding_window_blocks + + +def get_sliding_windows() -> Tuple[int, int]: + global SLIDING_WINDOW + global SLIDING_WINDOW_BLOCKS + return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS + + # Adds windowing logic to FlashCausalLMBatch @dataclass class FlashMistralBatch(FlashCausalLMBatch): @@ -53,8 +66,7 @@ class FlashMistralBatch(FlashCausalLMBatch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": - global SLIDING_WINDOW - global SLIDING_WINDOW_BLOCKS + sliding_window, sliding_window_blocks = get_sliding_windows() batch_inputs = [] max_truncation = 0 @@ -139,8 +151,8 @@ class FlashMistralBatch(FlashCausalLMBatch): # Needed blocks can not go over SLIDING_WINDOW_BLOCKS needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) - if SLIDING_WINDOW_BLOCKS is not None: - needed_blocks = min(needed_blocks, SLIDING_WINDOW_BLOCKS) + if sliding_window_blocks is not None: + needed_blocks = min(needed_blocks, sliding_window_blocks) blocks += needed_blocks needed_blocks_slots.append((needed_blocks, total_tokens)) @@ -154,9 +166,9 @@ class FlashMistralBatch(FlashCausalLMBatch): slot_indices.append(request_slot_indices) # Create tensor to slice into the kv tensor in prefill - if SLIDING_WINDOW is not None: + if sliding_window is not None: request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - SLIDING_WINDOW), + cumulative_length + max(0, input_length - sliding_window), cumulative_length + input_length, dtype=torch.int64, ) @@ -212,13 +224,13 @@ class FlashMistralBatch(FlashCausalLMBatch): input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) slot_indices = torch.cat(slot_indices) - if SLIDING_WINDOW is not None: + if sliding_window is not None: prefill_cache_indices = torch.cat(prefill_cache_indices) else: input_ids = all_input_ids[0] position_ids = position_ids[0] slot_indices = slot_indices[0] - if SLIDING_WINDOW is not None: + if sliding_window is not None: prefill_cache_indices = prefill_cache_indices[0] cu_seqlen_prefill = torch.tensor( @@ -228,7 +240,7 @@ class FlashMistralBatch(FlashCausalLMBatch): position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) prefill_cache_indices = ( - prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None + prefill_cache_indices.to(device) if sliding_window is not None else None ) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_lengths_tensor = torch.tensor( @@ -298,9 +310,6 @@ class BaseFlashMistral(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - global SLIDING_WINDOW - global SLIDING_WINDOW_BLOCKS - self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") @@ -324,8 +333,9 @@ class BaseFlashMistral(FlashCausalLM): # Set context windows if config.sliding_window is not None: - SLIDING_WINDOW = config.sliding_window - SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE) + set_sliding_window( + config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) + ) torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py new file mode 100644 index 00000000..2f6ae757 --- /dev/null +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -0,0 +1,86 @@ +import math + +import torch + +from typing import Optional + +from transformers.models.gpt2 import GPT2TokenizerFast + +from text_generation_server.models.cache_manager import BLOCK_SIZE +from text_generation_server.models.flash_mistral import ( + BaseFlashMistral, + set_sliding_window, +) +from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( + Starcoder2Config, + FlashStarcoder2ForCausalLM, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + + +# Starcoder2 has the same base as Mistral +class FlashStarcoder2(BaseFlashMistral): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + use_medusa: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + raise NotImplementedError("FlashLlama is only available on GPU") + + tokenizer = GPT2TokenizerFast.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = Starcoder2Config.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + config.use_medusa = use_medusa + + # Set context windows + if config.sliding_window is not None: + set_sliding_window( + config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) + ) + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["gptq", "awq"]: + weights._set_gptq_params(model_id, revision) + + model = FlashStarcoder2ForCausalLM(config, weights) + + self.cuda_graphs = {} + + torch.distributed.barrier(group=self.process_group) + super(BaseFlashMistral, self).__init__( + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + sliding_window=config.sliding_window, + )