From 79183d164728f080e1a571b7ff1f58bd0ed840b0 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Tue, 10 Jun 2025 17:56:25 +0200 Subject: [PATCH] Bump neuron SDK version (#3260) * chore(neuron): bump version to 0.2.0 * refactor(neuron): use named parameters in inputs helpers This allows to hide the differences between the two backends in terms of input parameters. * refactor(neuron): remove obsolete code paths * fix(neuron): use neuron_config whenever possible * fix(neuron): use new cache import path * fix(neuron): neuron config is not stored in config anymore * fix(nxd): adapt model retrieval to new APIs * fix(generator): emulate greedy in sampling parameters When on-device sampling is enabled, we need to emulate the greedy behaviour using top-k=1, top-p=1, temperature=1. * test(neuron): update models and expectations * feat(neuron): support on-device sampling * fix(neuron): adapt entrypoint * tests(neuron): remove obsolete models * fix(neuron): adjust test expectations for llama on nxd --- Dockerfile.neuron | 21 ++- .../text_generation_server/generator.py | 140 ++++++++++------- .../server/text_generation_server/model.py | 51 +++--- .../text_generation_server}/tgi_env.py | 145 ++++++++++-------- backends/neuron/tests/fixtures/model.py | 74 ++------- .../neuron/tests/server/test_cached_model.py | 42 +++++ .../tests/server/test_continuous_batching.py | 4 +- backends/neuron/tests/server/test_decode.py | 12 +- backends/neuron/tests/server/test_prefill.py | 26 ++-- backends/neuron/tests/test_entry_point.py | 63 ++++++++ backends/neuron/tgi-entrypoint.sh | 2 +- backends/neuron/tgi_entry_point.py | 53 +++++++ .../fixtures/neuron/export_models.py | 18 --- integration-tests/neuron/test_generate.py | 6 +- 14 files changed, 393 insertions(+), 264 deletions(-) rename backends/neuron/{ => server/text_generation_server}/tgi_env.py (63%) mode change 100755 => 100644 create mode 100644 backends/neuron/tests/server/test_cached_model.py create mode 100644 backends/neuron/tests/test_entry_point.py create mode 100755 backends/neuron/tgi_entry_point.py diff --git a/Dockerfile.neuron b/Dockerfile.neuron index d22ca222..6228dbb7 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -5,7 +5,7 @@ RUN mkdir -p /tgi # Fetch the optimum-neuron sources directly to avoid relying on pypi deployments FROM alpine AS optimum-neuron RUN mkdir -p /optimum-neuron -ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz +ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.2.0.tar.gz /optimum-neuron/sources.tar.gz RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1 # Build cargo components (adapted from TGI original Dockerfile) @@ -108,10 +108,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU # Install neuronx packages RUN apt-get update -y \ && apt-get install -y --no-install-recommends \ - aws-neuronx-dkms=2.19.64.0 \ - aws-neuronx-collectives=2.23.135.0-3e70920f2 \ - aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \ - aws-neuronx-tools=2.20.204.0 \ + aws-neuronx-dkms=2.20.28.0 \ + aws-neuronx-collectives=2.24.59.0-838c7fc8b \ + aws-neuronx-runtime-lib=2.24.53.0-f239092cc \ + aws-neuronx-tools=2.22.61.0 \ libxml2 \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean @@ -125,11 +125,10 @@ RUN pip3 install \ --index-url https://download.pytorch.org/whl/cpu RUN pip3 install \ - neuronx-cc==2.16.372.0 \ - torch-neuronx==2.5.1.2.4.0 \ - transformers-neuronx==0.13.322 \ - neuronx-distributed==0.10.1 \ - libneuronxla==2.1.681.0 \ + neuronx-cc==2.17.194.0 \ + torch-neuronx==2.5.1.2.6.0 \ + neuronx-distributed==0.11.0 \ + libneuronxla==2.2.1630.0 \ --extra-index-url=https://pip.repos.neuron.amazonaws.com # Install HuggingFace packages @@ -160,7 +159,7 @@ RUN pip install dist/text_generation_server*.tar.gz # Final image FROM neuron -COPY backends/neuron/tgi_env.py /tgi_env.py +COPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh diff --git a/backends/neuron/server/text_generation_server/generator.py b/backends/neuron/server/text_generation_server/generator.py index b3887e14..10a4d7a2 100644 --- a/backends/neuron/server/text_generation_server/generator.py +++ b/backends/neuron/server/text_generation_server/generator.py @@ -7,7 +7,8 @@ from typing import List, Optional, Tuple import torch from loguru import logger -from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from optimum.neuron.configuration_utils import NeuronConfig from transformers.generation import GenerationConfig from optimum.neuron import NeuronModelForCausalLM @@ -175,6 +176,12 @@ class Slot: self._generation_config.top_p = request.parameters.top_p if request.parameters.typical_p != 0: self._generation_config.typical_p = request.parameters.typical_p + else: + # Set the sampling parameters to emulate greedy decoding when using on-device sampling + self._generation_config.temperature = 1.0 + self._generation_config.top_k = 1 + self._generation_config.top_p = 1.0 + self._generation_config.typical_p = 1.0 if request.parameters.repetition_penalty != 0: self._generation_config.repetition_penalty = ( request.parameters.repetition_penalty @@ -211,19 +218,11 @@ class Slot: self._mask = attention_mask.clone() self._selector = selector - def pause(self, reset_on_pause: bool): + def pause(self): """Mark the current slot as paused for generation. Note that the KV cache for this slot will still be filled. """ - if reset_on_pause: - # Drop the last token as it will be added back when resuming the slot - self._generated_tokens -= 1 - # Since generated tokens are now part of the prefill, we need to reevaluate - # max_new_tokens for the next generation - self._generation_config.max_new_tokens = ( - self._max_new_tokens - self._generated_tokens - ) self._state = Slot.State.PAUSE def resume(self): @@ -340,16 +339,27 @@ class NeuronGenerator(Generator): tokenizer: PreTrainedTokenizerBase, ): self.model = model - self.rebuild_cache_on_prefill = not self.model.continuous_batching + if not isinstance(self.model, NeuronModelForCausalLM): + raise ValueError("The model must be a NeuronModelForCausalLM.") + if not model.neuron_config.continuous_batching: + raise ValueError( + "The neuron model must be compiled with continuous_batching=True." + ) # Specify padding and truncation options for decoder-only architecture tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" tokenizer.truncation_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids - self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)] + self.slots = [ + Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size) + ] self.batch_id = 0 + @property + def on_device_sampling(self) -> bool: + return getattr(self.model.neuron_config, "on_device_sampling", False) + @property def info(self) -> InfoResponse: """Returns the expected InfoResponse.""" @@ -371,14 +381,22 @@ class NeuronGenerator(Generator): The maximum number of tokens the model supports. """ # Just check that the warmup request parameters match the model capacity - batch_size = self.model.batch_size + batch_size = self.model.neuron_config.batch_size if len(batch.requests) > batch_size: raise ValueError( - f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE." + f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE." ) self.prefill(batch) self.clear() - return self.model.batch_size * self.model.max_length + return ( + self.model.neuron_config.batch_size + * self.model.neuron_config.sequence_length + ) + + def max_prefill_length(self) -> int: + if hasattr(self.model.neuron_config, "max_context_length"): + return self.model.neuron_config.max_context_length + return self.model.neuron_config.sequence_length def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: """Prefill new requests. @@ -398,7 +416,7 @@ class NeuronGenerator(Generator): if len(empty_slots) < len(batch.requests): raise ValueError( f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots." - f" Please align max_batch_size with the static batch size: {self.model.batch_size}." + f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}." ) # Assign each request to an empty slot logger.debug( @@ -412,14 +430,8 @@ class NeuronGenerator(Generator): logger.debug( f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}" ) - if self.rebuild_cache_on_prefill: - # We will clear pending slots and prefill all slots - prefill_slots = self.slots - seq_ids = None - else: - # We only need to pass inputs for the new requests - prefill_slots = new_slots - seq_ids = torch.tensor([slot.id for slot in prefill_slots]) + prefill_slots = new_slots + seq_ids = torch.tensor([slot.id for slot in prefill_slots]) # Reconstruct the full inputs (without padding) as seen by the model. # This comprises: # - the inputs for new requests, @@ -431,8 +443,10 @@ class NeuronGenerator(Generator): inputs.append(slot.cached_text) # Apply truncation, making sure we fit into static dimensions if slot.truncate == 0: - max_length = self.model.max_length - elif slot.truncate > max_length and slot.truncate < self.model.max_length: + max_length = self.max_prefill_length() + elif ( + slot.truncate > max_length and slot.truncate < self.max_prefill_length() + ): max_length = slot.truncate # Tokenize with padding and truncation padded_inputs = self.tokenizer( @@ -444,13 +458,12 @@ class NeuronGenerator(Generator): ) input_ids = padded_inputs.input_ids attention_mask = padded_inputs.attention_mask + sampling_params = ( + torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None + ) # Pause previously active slots during generation - next_tokens = [] for slot in active_slots: - slot.pause(reset_on_pause=self.rebuild_cache_on_prefill) - if self.rebuild_cache_on_prefill: - # The slot will be reset, so we need to store its next token - next_tokens.append(slot.next_token) + slot.pause() # Each slot must be reset with the padded inputs and masks for i, slot in enumerate(prefill_slots): if slot.state != slot.state.EMPTY: @@ -464,29 +477,33 @@ class NeuronGenerator(Generator): slot_input_ids, slot.generation_config, self.model, - self.model.max_length, + self.model.neuron_config.sequence_length, tokenizer=self.tokenizer, seed=slot.seed, ) slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64) slot_attention_mask = attention_mask[i] slot.reset(slot_input_ids, slot_attention_mask, selector) + if sampling_params is not None: + sampling_params[i, 0] = slot.generation_config.top_k + sampling_params[i, 1] = slot.generation_config.top_p + sampling_params[i, 2] = slot.generation_config.temperature # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored, # as they have already been generated and sent back in the last decode. model_inputs = self.model.prepare_inputs_for_prefill( - input_ids, attention_mask, seq_ids + input_ids, + attention_mask=attention_mask, + seq_ids=seq_ids, + sampling_params=sampling_params, ) - logits = self.model(**model_inputs)[0] + tokens_or_logits = self.model(**model_inputs)[0] generation, next_batch = self._generate_token( - prefill_slots, self.batch_id, logits, input_ids + prefill_slots, self.batch_id, tokens_or_logits, input_ids ) self.batch_id += 1 # Reactivate previously active slots for the next decode for i, slot in enumerate(active_slots): slot.resume() - if self.rebuild_cache_on_prefill: - # Append back the next token - slot.append(next_tokens[i]) logger.debug("Model ready for decoding") if next_batch is not None: logger.debug( @@ -530,12 +547,8 @@ class NeuronGenerator(Generator): raise ValueError( "Unable to decode tokens for non-prefilled batches (probably due to a previous failure)" ) - if self.model.continuous_batching: - decode_slots = active_slots - seq_ids = torch.tensor([slot.id for slot in decode_slots]) - else: - decode_slots = self.slots - seq_ids = None + decode_slots = active_slots + seq_ids = torch.tensor([slot.id for slot in decode_slots]) # Reconstruct input_ids and attention_mask from decode slots n_slots = len(decode_slots) input_ids = torch.full( @@ -545,22 +558,32 @@ class NeuronGenerator(Generator): for slot in decode_slots: max_length = max(max_length, slot.attention_mask.size(-1)) attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64) + sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None for i, slot in enumerate(decode_slots): if slot.state != Slot.State.EMPTY: # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached) input_ids[i, 0] = slot.next_token attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask + if sampling_params is not None: + sampling_params[i, 0] = slot.generation_config.top_k + sampling_params[i, 1] = slot.generation_config.top_p + sampling_params[i, 2] = slot.generation_config.temperature model_inputs = self.model.prepare_inputs_for_decode( - input_ids, attention_mask, seq_ids + input_ids, + attention_mask=attention_mask, + seq_ids=seq_ids, + sampling_params=sampling_params, + ) + tokens_or_logits = self.model(**model_inputs)[0] + return self._generate_token( + decode_slots, next_batch_id, tokens_or_logits, input_ids ) - logits = self.model(**model_inputs)[0] - return self._generate_token(decode_slots, next_batch_id, logits, input_ids) def _generate_token( self, slots: List[Slot], next_batch_id: int, - logits: torch.Tensor, + tokens_or_logits: torch.Tensor, input_ids: torch.LongTensor, ) -> Tuple[List[Generation], CachedBatch]: generations = [] @@ -569,9 +592,12 @@ class NeuronGenerator(Generator): if slot.state != Slot.State.READY: continue request_id = slot.request_id - next_token_logits = logits[i : i + 1, -1, :] slot_input_ids = input_ids[i : i + 1, :] - next_token = slot.select(slot_input_ids, next_token_logits) + if self.on_device_sampling: + next_token = tokens_or_logits[i] + else: + next_token_logits = tokens_or_logits[i : i + 1, -1, :] + next_token = slot.select(slot_input_ids, next_token_logits) next_token_text = slot.append(next_token) generated_text = None finish_reason = None @@ -622,7 +648,7 @@ class NeuronGenerator(Generator): def _cached_batch(self, batch_id: int, request_ids: List): size = len(request_ids) - max_tokens = size * self.model.max_length + max_tokens = size * self.model.neuron_config.sequence_length return CachedBatch( id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens ) @@ -671,8 +697,16 @@ class NeuronGenerator(Generator): Returns: A NeuronGenerator. """ - config = AutoConfig.from_pretrained(model_id) - neuron_config = getattr(config, "neuron", None) + try: + neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision) + except Exception as e: + logger.debug( + "NeuronConfig.from_pretrained failed for model %s, revision %s: %s", + model_id, + revision, + e, + ) + neuron_config = None start = time.time() if neuron_config is None: export_kwargs = get_export_kwargs_from_env() diff --git a/backends/neuron/server/text_generation_server/model.py b/backends/neuron/server/text_generation_server/model.py index 2151a921..d281b175 100644 --- a/backends/neuron/server/text_generation_server/model.py +++ b/backends/neuron/server/text_generation_server/model.py @@ -6,10 +6,12 @@ from typing import Optional from huggingface_hub import snapshot_download from huggingface_hub.constants import HF_HUB_CACHE from loguru import logger -from transformers import AutoConfig -from optimum.neuron import NeuronModelForCausalLM -from optimum.neuron.utils import get_hub_cached_entries +from optimum.neuron.cache import get_hub_cached_entries +from optimum.neuron.configuration_utils import NeuronConfig + + +from .tgi_env import check_env_and_neuron_config_compatibility def get_export_kwargs_from_env(): @@ -24,7 +26,6 @@ def get_export_kwargs_from_env(): num_cores = int(num_cores) auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None) return { - "task": "text-generation", "batch_size": batch_size, "sequence_length": sequence_length, "num_cores": num_cores, @@ -32,20 +33,15 @@ def get_export_kwargs_from_env(): } -def is_cached(model_id, neuron_config): +def is_cached(model_id): # Look for cached entries for the specified model in_cache = False - entries = get_hub_cached_entries(model_id, "inference") + entries = get_hub_cached_entries(model_id) # Look for compatible entries for entry in entries: - compatible = True - for key, value in neuron_config.items(): - # Only weights can be different - if key in ["checkpoint_id", "checkpoint_revision"]: - continue - if entry[key] != value: - compatible = False - if compatible: + if check_env_and_neuron_config_compatibility( + entry, check_compiler_version=True + ): in_cache = True break return in_cache @@ -87,8 +83,16 @@ def fetch_model( revision = None # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model) # Note that the model may already be present in the cache. - config = AutoConfig.from_pretrained(model_id, revision=revision) - neuron_config = getattr(config, "neuron", None) + try: + neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision) + except Exception as e: + logger.debug( + "NeuronConfig.from_pretrained failed for model %s, revision %s: %s", + model_id, + revision, + e, + ) + neuron_config = None if neuron_config is not None: if os.path.isdir(model_id): return model_id @@ -99,16 +103,11 @@ def fetch_model( log_cache_size() return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") # Model needs to be exported: look for compatible cached entries on the hub - export_kwargs = get_export_kwargs_from_env() - export_config = NeuronModelForCausalLM.get_export_config( - model_id, config, revision=revision, **export_kwargs - ) - neuron_config = export_config.neuron - if not is_cached(model_id, neuron_config): + if not is_cached(model_id): hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache" neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi" error_msg = ( - f"No cached version found for {model_id} with {neuron_config}." + f"No cached version found for {model_id} with {get_export_kwargs_from_env()}." f"You can start a discussion to request it on {hub_cache_url}" f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}" ) @@ -121,8 +120,10 @@ def fetch_model( # Prefetch weights, tokenizer and generation config so that they are in cache log_cache_size() start = time.time() - snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") + snapshot_path = snapshot_download( + model_id, revision=revision, ignore_patterns="*.bin" + ) end = time.time() logger.info(f"Model weights fetched in {end - start:.2f} s.") log_cache_size() - return model_id + return snapshot_path diff --git a/backends/neuron/tgi_env.py b/backends/neuron/server/text_generation_server/tgi_env.py old mode 100755 new mode 100644 similarity index 63% rename from backends/neuron/tgi_env.py rename to backends/neuron/server/text_generation_server/tgi_env.py index a7042130..ee97f180 --- a/backends/neuron/tgi_env.py +++ b/backends/neuron/server/text_generation_server/tgi_env.py @@ -6,12 +6,11 @@ import os import sys from typing import Any, Dict, List, Optional -from huggingface_hub import constants -from transformers import AutoConfig - from optimum.neuron.modeling_decoder import get_available_cores -from optimum.neuron.utils import get_hub_cached_entries +from optimum.neuron.cache import get_hub_cached_entries +from optimum.neuron.configuration_utils import NeuronConfig from optimum.neuron.utils.version_utils import get_neuronxcc_version +from optimum.neuron.utils import map_torch_dtype logger = logging.getLogger(__name__) @@ -24,15 +23,9 @@ tgi_router_env_vars = [ ] tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"] -env_config_peering = [ - ("MAX_BATCH_SIZE", "batch_size"), - ("MAX_TOTAL_TOKENS", "sequence_length"), - ("HF_AUTO_CAST_TYPE", "auto_cast_type"), - ("HF_NUM_CORES", "num_cores"), -] # By the end of this script all env var should be specified properly -env_vars = tgi_server_env_vars + tgi_router_env_vars +tgi_env_vars = tgi_server_env_vars + tgi_router_env_vars available_cores = get_available_cores() neuronxcc_version = get_neuronxcc_version() @@ -93,9 +86,17 @@ def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace: def neuron_config_to_env(neuron_config): + if isinstance(neuron_config, NeuronConfig): + neuron_config = neuron_config.to_dict() with open(os.environ["ENV_FILEPATH"], "w") as f: - for env_var, config_key in env_config_peering: - f.write("export {}={}\n".format(env_var, neuron_config[config_key])) + f.write("export MAX_BATCH_SIZE={}\n".format(neuron_config["batch_size"])) + f.write("export MAX_TOTAL_TOKENS={}\n".format(neuron_config["sequence_length"])) + f.write("export HF_NUM_CORES={}\n".format(neuron_config["tp_degree"])) + config_key = ( + "auto_cast_type" if "auto_cast_type" in neuron_config else "torch_dtype" + ) + auto_cast_type = neuron_config[config_key] + f.write("export HF_AUTO_CAST_TYPE={}\n".format(auto_cast_type)) max_input_tokens = os.getenv("MAX_INPUT_TOKENS") if not max_input_tokens: max_input_tokens = int(neuron_config["sequence_length"]) // 2 @@ -111,7 +112,7 @@ def neuron_config_to_env(neuron_config): def sort_neuron_configs(dictionary): - return -dictionary["num_cores"], -dictionary["batch_size"] + return -dictionary["tp_degree"], -dictionary["batch_size"] def lookup_compatible_cached_model( @@ -119,7 +120,7 @@ def lookup_compatible_cached_model( ) -> Optional[Dict[str, Any]]: # Reuse the same mechanic as the one in use to configure the tgi server part # The only difference here is that we stay as flexible as possible on the compatibility part - entries = get_hub_cached_entries(model_id, "inference") + entries = get_hub_cached_entries(model_id) logger.debug( "Found %d cached entries for model %s, revision %s", @@ -155,15 +156,15 @@ def lookup_compatible_cached_model( def check_env_and_neuron_config_compatibility( - neuron_config: Dict[str, Any], check_compiler_version: bool + neuron_config_dict: Dict[str, Any], check_compiler_version: bool ) -> bool: logger.debug( "Checking the provided neuron config %s is compatible with the local setup and provided environment", - neuron_config, + neuron_config_dict, ) # Local setup compat checks - if neuron_config["num_cores"] > available_cores: + if neuron_config_dict["tp_degree"] > available_cores: logger.debug( "Not enough neuron cores available to run the provided neuron config" ) @@ -171,33 +172,65 @@ def check_env_and_neuron_config_compatibility( if ( check_compiler_version - and neuron_config["compiler_version"] != neuronxcc_version + and neuron_config_dict["neuronxcc_version"] != neuronxcc_version ): logger.debug( "Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)", neuronxcc_version, - neuron_config["compiler_version"], + neuron_config_dict["neuronxcc_version"], ) return False - for env_var, config_key in env_config_peering: - neuron_config_value = str(neuron_config[config_key]) - env_value = os.getenv(env_var, str(neuron_config_value)) + batch_size = os.getenv("MAX_BATCH_SIZE", None) + if batch_size is not None and neuron_config_dict["batch_size"] < int(batch_size): + logger.debug( + "The provided MAX_BATCH_SIZE (%s) is higher than the neuron config batch size (%s)", + os.getenv("MAX_BATCH_SIZE"), + neuron_config_dict["batch_size"], + ) + return False + max_total_tokens = os.getenv("MAX_TOTAL_TOKENS", None) + if max_total_tokens is not None and neuron_config_dict["sequence_length"] < int( + max_total_tokens + ): + logger.debug( + "The provided MAX_TOTAL_TOKENS (%s) is higher than the neuron config sequence length (%s)", + max_total_tokens, + neuron_config_dict["sequence_length"], + ) + return False + num_cores = os.getenv("HF_NUM_CORES", None) + if num_cores is not None and neuron_config_dict["tp_degree"] < int(num_cores): + logger.debug( + "The provided HF_NUM_CORES (%s) is higher than the neuron config tp degree (%s)", + num_cores, + neuron_config_dict["tp_degree"], + ) + return False + auto_cast_type = os.getenv("HF_AUTO_CAST_TYPE", None) + if auto_cast_type is not None: + config_key = ( + "auto_cast_type" + if "auto_cast_type" in neuron_config_dict + else "torch_dtype" + ) + neuron_config_value = map_torch_dtype(str(neuron_config_dict[config_key])) + env_value = map_torch_dtype(auto_cast_type) if env_value != neuron_config_value: logger.debug( - "The provided env var '%s' and the neuron config '%s' param differ (%s != %s)", - env_var, - config_key, + "The provided auto cast type and the neuron config param differ (%s != %s)", env_value, neuron_config_value, ) return False - max_input_tokens = int( os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0)) ) if max_input_tokens > 0: - sequence_length = neuron_config["sequence_length"] + if hasattr(neuron_config_dict, "max_context_length"): + sequence_length = neuron_config_dict["max_context_length"] + else: + sequence_length = neuron_config_dict["sequence_length"] if max_input_tokens >= sequence_length: logger.debug( "Specified max input tokens is not compatible with config sequence length ( %s >= %s)", @@ -211,38 +244,29 @@ def check_env_and_neuron_config_compatibility( def get_env_dict() -> Dict[str, str]: d = {} - for k in env_vars: + for k in tgi_env_vars: d[k] = os.getenv(k) return d -def main(): - """ - This script determines proper default TGI env variables for the neuron precompiled models to - work properly - :return: - """ - args = parse_cmdline_and_set_env() - - for env_var in env_vars: - if not os.getenv(env_var): - break - else: - logger.info( - "All env vars %s already set, skipping, user know what they are doing", - env_vars, +def get_neuron_config_for_model( + model_name_or_path: str, revision: Optional[str] = None +) -> NeuronConfig: + try: + neuron_config = NeuronConfig.from_pretrained( + model_name_or_path, revision=revision ) - sys.exit(0) - - cache_dir = constants.HF_HUB_CACHE - - logger.info("Cache dir %s, model %s", cache_dir, args.model_id) - - config = AutoConfig.from_pretrained(args.model_id, revision=args.revision) - neuron_config = getattr(config, "neuron", None) + except Exception as e: + logger.debug( + "NeuronConfig.from_pretrained failed for model %s, revision %s: %s", + model_name_or_path, + revision, + e, + ) + neuron_config = None if neuron_config is not None: compatible = check_env_and_neuron_config_compatibility( - neuron_config, check_compiler_version=False + neuron_config.to_dict(), check_compiler_version=False ) if not compatible: env_dict = get_env_dict() @@ -252,17 +276,6 @@ def main(): logger.error(msg) raise Exception(msg) else: - neuron_config = lookup_compatible_cached_model(args.model_id, args.revision) + neuron_config = lookup_compatible_cached_model(model_name_or_path, revision) - if not neuron_config: - msg = ( - "No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}" - ).format(get_env_dict(), available_cores, neuronxcc_version) - logger.error(msg) - raise Exception(msg) - - neuron_config_to_env(neuron_config) - - -if __name__ == "__main__": - main() + return neuron_config diff --git a/backends/neuron/tests/fixtures/model.py b/backends/neuron/tests/fixtures/model.py index 4b6a1375..ad41fd10 100644 --- a/backends/neuron/tests/fixtures/model.py +++ b/backends/neuron/tests/fixtures/model.py @@ -4,14 +4,12 @@ import subprocess import sys from tempfile import TemporaryDirectory -import huggingface_hub +import os import pytest from transformers import AutoTokenizer -from optimum.neuron import NeuronModelForCausalLM -from optimum.neuron.utils import synchronize_hub_cache -from optimum.neuron.version import __sdk_version__ as sdk_version -from optimum.neuron.version import __version__ as version + +from optimum.neuron.cache import synchronize_hub_cache logging.basicConfig( @@ -21,30 +19,14 @@ logging.basicConfig( ) logger = logging.getLogger(__file__) + OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache" + # All model configurations below will be added to the neuron_model_config fixture MODEL_CONFIGURATIONS = { - "gpt2": { - "model_id": "gpt2", - "export_kwargs": { - "batch_size": 4, - "sequence_length": 1024, - "num_cores": 2, - "auto_cast_type": "fp16", - }, - }, "llama": { - "model_id": "NousResearch/Hermes-2-Theta-Llama-3-8B", - "export_kwargs": { - "batch_size": 4, - "sequence_length": 2048, - "num_cores": 2, - "auto_cast_type": "fp16", - }, - }, - "mistral": { - "model_id": "optimum/mistral-1.1b-testing", + "model_id": "unsloth/Llama-3.2-1B-Instruct", "export_kwargs": { "batch_size": 4, "sequence_length": 4096, @@ -58,7 +40,7 @@ MODEL_CONFIGURATIONS = { "batch_size": 4, "sequence_length": 4096, "num_cores": 2, - "auto_cast_type": "fp16", + "auto_cast_type": "bf16", }, }, "granite": { @@ -73,12 +55,6 @@ MODEL_CONFIGURATIONS = { } -def get_hub_neuron_model_id(config_name: str): - return ( - f"optimum-internal-testing/neuron-testing-{version}-{sdk_version}-{config_name}" - ) - - def export_model(model_id, export_kwargs, neuron_model_path): export_command = [ "optimum-cli", @@ -104,57 +80,35 @@ def export_model(model_id, export_kwargs, neuron_model_path): def neuron_model_config(request): """Expose a pre-trained neuron model - The fixture first makes sure the following model artifacts are present on the hub: - - exported neuron model under optimum-internal-testing/neuron-testing--, - - cached artifacts under optimum-internal-testing/neuron-testing-cache. - If not, it will export the model and push it to the hub. - - It then fetches the model locally and return a dictionary containing: + The fixture exports a model locally and returns a dictionary containing: - a configuration name, - the original model id, - the export parameters, - - the neuron model id, - the neuron model local path. For each exposed model, the local directory is maintained for the duration of the test session and cleaned up afterwards. - The hub model artifacts are never cleaned up and persist accross sessions. - They must be cleaned up manually when the optimum-neuron version changes. """ config_name = request.param model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param]) model_id = model_config["model_id"] export_kwargs = model_config["export_kwargs"] - neuron_model_id = get_hub_neuron_model_id(config_name) with TemporaryDirectory() as neuron_model_path: - hub = huggingface_hub.HfApi() - if hub.repo_exists(neuron_model_id): - logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub") - hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path) - else: - export_model(model_id, export_kwargs, neuron_model_path) - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.save_pretrained(neuron_model_path) - del tokenizer - # Create the test model on the hub - hub.create_repo(neuron_model_id, private=True) - hub.upload_folder( - folder_path=neuron_model_path, - repo_id=neuron_model_id, - ignore_patterns=[NeuronModelForCausalLM.CHECKPOINT_DIR + "/*"], - ) - # Make sure it is cached - synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID) + export_model(model_id, export_kwargs, neuron_model_path) + synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.save_pretrained(neuron_model_path) + del tokenizer # Add dynamic parameters to the model configuration model_config["neuron_model_path"] = neuron_model_path - model_config["neuron_model_id"] = neuron_model_id # Also add model configuration name to allow tests to adapt their expectations model_config["name"] = config_name # Yield instead of returning to keep a reference to the temporary directory. # It will go out of scope and be released only once all tests needing the fixture # have been completed. logger.info(f"{config_name} ready for testing ...") + os.environ["CUSTOM_CACHE_REPO"] = OPTIMUM_CACHE_REPO_ID yield model_config logger.info(f"Done with {config_name}") diff --git a/backends/neuron/tests/server/test_cached_model.py b/backends/neuron/tests/server/test_cached_model.py new file mode 100644 index 00000000..73622578 --- /dev/null +++ b/backends/neuron/tests/server/test_cached_model.py @@ -0,0 +1,42 @@ +import os +import pytest + +from text_generation_server.generator import NeuronGenerator +from text_generation_server.model import fetch_model, is_cached + + +@pytest.fixture(scope="module") +def cached_model_id(neuron_model_config) -> str: + """ + Fixture to provide a cached model ID for testing. + This assumes the model is already cached in the local environment. + """ + export_kwargs = neuron_model_config["export_kwargs"] + os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"]) + os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"]) + os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"] + os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"]) + yield neuron_model_config["model_id"] + os.environ.pop("MAX_BATCH_SIZE", None) + os.environ.pop("MAX_TOTAL_TOKENS", None) + os.environ.pop("HF_AUTO_CAST_TYPE", None) + os.environ.pop("HF_NUM_CORES", None) + + +def test_model_is_cached(cached_model_id): + assert is_cached(cached_model_id), f"Model {cached_model_id} is not cached" + + +def test_fetch_cached_model(cached_model_id: str): + model_path = fetch_model(cached_model_id) + assert os.path.exists( + model_path + ), f"Model {cached_model_id} was not fetched successfully" + assert os.path.isdir(model_path), f"Model {cached_model_id} is not a directory" + + +def test_generator_from_cached_model(cached_model_id: str): + generator = NeuronGenerator.from_pretrained(model_id=cached_model_id) + assert generator is not None, "Generator could not be created from cached model" + assert generator.model is not None, "Generator model is not initialized" + assert generator.tokenizer is not None, "Generator tokenizer is not initialized" diff --git a/backends/neuron/tests/server/test_continuous_batching.py b/backends/neuron/tests/server/test_continuous_batching.py index 48bb70cc..3d9ab509 100644 --- a/backends/neuron/tests/server/test_continuous_batching.py +++ b/backends/neuron/tests/server/test_continuous_batching.py @@ -9,13 +9,13 @@ def test_continuous_batching_two_requests(neuron_model_config): """ neuron_model_path = neuron_model_config["neuron_model_path"] generator = NeuronGenerator.from_pretrained(neuron_model_path) - assert generator.model.batch_size > 1 + assert generator.model.neuron_config.batch_size > 1 input_text = "Once upon a time" max_new_tokens = 20 # Prefill a single request, remembering the generated token tokens = {0: [], 1: []} request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens) - max_length = generator.model.max_length + max_length = generator.model.neuron_config.sequence_length batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length) generations, next_batch = generator.prefill(batch) assert next_batch.size == 1 diff --git a/backends/neuron/tests/server/test_decode.py b/backends/neuron/tests/server/test_decode.py index 9db5e3ab..b864e3ec 100644 --- a/backends/neuron/tests/server/test_decode.py +++ b/backends/neuron/tests/server/test_decode.py @@ -23,7 +23,7 @@ def _test_decode(config_name, generator, do_sample): request = create_request( id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample ) - max_length = generator.model.max_length + max_length = generator.model.neuron_config.sequence_length batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length) generations, next_batch = generator.prefill(batch) # We already generated one token: call decode max_new_tokens - 1 times @@ -40,19 +40,15 @@ def _test_decode(config_name, generator, do_sample): assert output.finish_reason == 0 if do_sample: expected_text = { - "gpt2": " The sun was set", - "llama": "George Orwell, 1984", - "mistral": "The sky was", - "qwen2": " A young woman with", + "llama": " I sat alone in the café", + "qwen2": " The air was so still", "granite": "1984, George Orwell", }[config_name] assert expected_text in output.text else: print(output.text) expected_text = { - "gpt2": '\n\n"I\'m going to go to bed," I said.\n\n"I\'m going', - "llama": " George Orwell’s classic dystopian novel, 1984, begins with this ominous sentence. The story", - "mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.", + "llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility", "qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a", "granite": "\n\nThis opening line from George Orwell's dystopian novel \"198", }[config_name] diff --git a/backends/neuron/tests/server/test_prefill.py b/backends/neuron/tests/server/test_prefill.py index c0155b1a..c9ecd1c8 100644 --- a/backends/neuron/tests/server/test_prefill.py +++ b/backends/neuron/tests/server/test_prefill.py @@ -9,7 +9,7 @@ def test_prefill(neuron_model_config): neuron_model_path = neuron_model_config["neuron_model_path"] generator = NeuronGenerator.from_pretrained(neuron_model_path) max_batch_size = 4 - assert generator.model.batch_size >= max_batch_size + assert generator.model.neuron_config.batch_size >= max_batch_size for num_requests in [1, max_batch_size]: for do_sample in [True, False]: mode = "sample" if do_sample else "greedy" @@ -34,7 +34,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample): ) ) # Let's be pessimistic when estimating max_tokens - max_length = generator.model.max_length + max_length = generator.max_prefill_length() batch = Batch( id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length ) @@ -46,17 +46,13 @@ def _test_prefill(config_name, generator, batch_size, do_sample): assert len(generations) == batch_size if do_sample: expectations = { - "gpt2": [383, " The"], - "llama": [10058, " George"], - "mistral": [450, " The"], - "qwen2": [362, " A"], + "llama": [358, " I"], + "qwen2": [576, " The"], "granite": [308, " ("], }[config_name] else: expectations = { - "gpt2": [198, "\n"], - "llama": [10058, " George"], - "mistral": [13, "\n"], + "llama": [578, " The"], "qwen2": [358, " I"], "granite": [203, "\n"], }[config_name] @@ -70,7 +66,7 @@ def test_prefill_truncate(neuron_model_config): config_name = neuron_model_config["name"] neuron_model_path = neuron_model_config["neuron_model_path"] generator = NeuronGenerator.from_pretrained(neuron_model_path) - batch_size = generator.model.batch_size + batch_size = generator.model.neuron_config.batch_size # We apply truncation to all requests but the first one truncate = [ None, @@ -83,7 +79,7 @@ def test_prefill_truncate(neuron_model_config): requests = [] for i in range(batch_size): requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i])) - max_length = generator.model.max_length + max_length = generator.max_prefill_length() batch = Batch( id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length ) @@ -91,12 +87,12 @@ def test_prefill_truncate(neuron_model_config): # Even if the input text is identical for all requests, the first generated token might # be different because of the truncation expectations = { - "gpt2": [" He", " He", "\n", " He"], - "llama": [" —", " The", " He", " He"], - "mistral": [" He", "\n", " He", " He"], + "llama": [" He", "iens", "\x08", " He"], "qwen2": [" He", " The", " He", " He"], "granite": ["\n", "\n", " I", " He"], }[config_name] for i, g in enumerate(generations): tokens = g.tokens - assert tokens.texts[0] == expectations[i] + assert ( + tokens.texts[0] == expectations[i] + ), f"Request {i} expected [{expectations[i]}], got [{tokens.texts[0]}]" diff --git a/backends/neuron/tests/test_entry_point.py b/backends/neuron/tests/test_entry_point.py new file mode 100644 index 00000000..d4ddf338 --- /dev/null +++ b/backends/neuron/tests/test_entry_point.py @@ -0,0 +1,63 @@ +import os +import pytest +from tempfile import TemporaryDirectory + +from optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig +from optimum.neuron.utils import map_torch_dtype + +from text_generation_server.tgi_env import ( + get_neuron_config_for_model, + lookup_compatible_cached_model, + neuron_config_to_env, +) + + +def test_get_neuron_config_for_model(neuron_model_config): + neuron_model_path = neuron_model_config["neuron_model_path"] + export_kwargs = neuron_model_config["export_kwargs"] + os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"]) + os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"]) + os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"] + os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"]) + neuron_config = get_neuron_config_for_model(neuron_model_path) + assert neuron_config is not None + assert neuron_config.batch_size == export_kwargs["batch_size"] + assert neuron_config.sequence_length == export_kwargs["sequence_length"] + assert neuron_config.tp_degree == export_kwargs["num_cores"] + if isinstance(neuron_config, NxDNeuronConfig): + assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype( + export_kwargs["auto_cast_type"] + ) + else: + assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype( + export_kwargs["auto_cast_type"] + ) + + +@pytest.mark.parametrize("model_id", ["unsloth/Llama-3.2-1B-Instruct"]) +def test_lookup_compatible_cached_model(model_id: str): + neuron_config = lookup_compatible_cached_model(model_id, None) + assert neuron_config is not None + + +def test_neuron_config_to_env(neuron_model_config) -> None: + neuron_model_path = neuron_model_config["neuron_model_path"] + neuron_config = get_neuron_config_for_model(neuron_model_path) + with TemporaryDirectory() as temp_dir: + os.environ["ENV_FILEPATH"] = os.path.join(temp_dir, "env.sh") + neuron_config_to_env(neuron_config) + with open(os.environ["ENV_FILEPATH"], "r") as env_file: + env_content = env_file.read() + assert f"export MAX_BATCH_SIZE={neuron_config.batch_size}" in env_content + assert ( + f"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}" + in env_content + ) + assert f"export HF_NUM_CORES={neuron_config.tp_degree}" in env_content + if hasattr(neuron_config, "torch_dtype"): + auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split( + "." + )[-1] + else: + auto_cast_type = neuron_config.auto_cast_type + assert f"export HF_AUTO_CAST_TYPE={auto_cast_type}" in env_content diff --git a/backends/neuron/tgi-entrypoint.sh b/backends/neuron/tgi-entrypoint.sh index b959a795..7965d1da 100755 --- a/backends/neuron/tgi-entrypoint.sh +++ b/backends/neuron/tgi-entrypoint.sh @@ -9,7 +9,7 @@ touch $ENV_FILEPATH SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -${SCRIPT_DIR}/tgi_env.py $@ +${SCRIPT_DIR}/tgi_entry_point.py $@ source $ENV_FILEPATH diff --git a/backends/neuron/tgi_entry_point.py b/backends/neuron/tgi_entry_point.py new file mode 100755 index 00000000..7e81d0e4 --- /dev/null +++ b/backends/neuron/tgi_entry_point.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python + +import logging +import os +import sys + + +from text_generation_server.tgi_env import ( + available_cores, + get_env_dict, + get_neuron_config_for_model, + neuron_config_to_env, + neuronxcc_version, + parse_cmdline_and_set_env, + tgi_env_vars, +) + + +logger = logging.getLogger(__name__) + + +def main(): + """ + This script determines proper default TGI env variables for the neuron precompiled models to + work properly + :return: + """ + args = parse_cmdline_and_set_env() + + for env_var in tgi_env_vars: + if not os.getenv(env_var): + break + else: + logger.info( + "All env vars %s already set, skipping, user know what they are doing", + tgi_env_vars, + ) + sys.exit(0) + + neuron_config = get_neuron_config_for_model(args.model_id, args.revision) + + if not neuron_config: + msg = ( + "No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}" + ).format(get_env_dict(), available_cores, neuronxcc_version) + logger.error(msg) + raise Exception(msg) + + neuron_config_to_env(neuron_config) + + +if __name__ == "__main__": + main() diff --git a/integration-tests/fixtures/neuron/export_models.py b/integration-tests/fixtures/neuron/export_models.py index 836402ec..d4d0f01c 100644 --- a/integration-tests/fixtures/neuron/export_models.py +++ b/integration-tests/fixtures/neuron/export_models.py @@ -28,15 +28,6 @@ logger = logging.getLogger(__file__) # All model configurations below will be added to the neuron_model_config fixture MODEL_CONFIGURATIONS = { - "gpt2": { - "model_id": "gpt2", - "export_kwargs": { - "batch_size": 4, - "sequence_length": 1024, - "num_cores": 2, - "auto_cast_type": "fp16", - }, - }, "llama": { "model_id": "unsloth/Llama-3.2-1B-Instruct", "export_kwargs": { @@ -46,15 +37,6 @@ MODEL_CONFIGURATIONS = { "auto_cast_type": "fp16", }, }, - "mistral": { - "model_id": "optimum/mistral-1.1b-testing", - "export_kwargs": { - "batch_size": 4, - "sequence_length": 4096, - "num_cores": 2, - "auto_cast_type": "bf16", - }, - }, "qwen2": { "model_id": "Qwen/Qwen2.5-0.5B", "export_kwargs": { diff --git a/integration-tests/neuron/test_generate.py b/integration-tests/neuron/test_generate.py index f0804356..9108ce0e 100644 --- a/integration-tests/neuron/test_generate.py +++ b/integration-tests/neuron/test_generate.py @@ -20,9 +20,7 @@ async def test_model_single_request(tgi_service): ) assert response.details.generated_tokens == 17 greedy_expectations = { - "gpt2": "\n\nDeep learning is a new field of research that has been around for a while", - "llama": " and How Does it Work?\nDeep learning is a subset of machine learning that uses artificial", - "mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that", + "llama": " and how does it work?\nDeep learning is a subset of machine learning that uses artificial", "qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on", "granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art", } @@ -79,9 +77,7 @@ async def test_model_multiple_requests(tgi_service, neuron_generate_load): assert len(responses) == 4 expectations = { - "gpt2": "Deep learning is a new field of research that has been around for a while", "llama": "Deep learning is a subset of machine learning that uses artificial", - "mistral": "Deep Learning is a type of machine learning that", "qwen2": "Deep Learning is a subset of Machine Learning that is based on", "granite": "Deep Learning is a subset of Machine Learning, which is a branch of Art", }