diff --git a/Dockerfile.neuron b/Dockerfile.neuron
index d22ca222..6228dbb7 100644
--- a/Dockerfile.neuron
+++ b/Dockerfile.neuron
@@ -5,7 +5,7 @@ RUN mkdir -p /tgi
 # Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
 FROM alpine AS optimum-neuron
 RUN mkdir -p /optimum-neuron
-ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz
+ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.2.0.tar.gz /optimum-neuron/sources.tar.gz
 RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
 
 # Build cargo components (adapted from TGI original Dockerfile)
@@ -108,10 +108,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU
 # Install neuronx packages
 RUN apt-get update -y \
     && apt-get install -y --no-install-recommends \
-    aws-neuronx-dkms=2.19.64.0 \
-    aws-neuronx-collectives=2.23.135.0-3e70920f2 \
-    aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \
-    aws-neuronx-tools=2.20.204.0 \
+    aws-neuronx-dkms=2.20.28.0 \
+    aws-neuronx-collectives=2.24.59.0-838c7fc8b \
+    aws-neuronx-runtime-lib=2.24.53.0-f239092cc \
+    aws-neuronx-tools=2.22.61.0 \
     libxml2 \
     && rm -rf /var/lib/apt/lists/* \
     && apt-get clean
@@ -125,11 +125,10 @@ RUN pip3 install \
     --index-url https://download.pytorch.org/whl/cpu
 
 RUN pip3 install \
-    neuronx-cc==2.16.372.0 \
-    torch-neuronx==2.5.1.2.4.0 \
-    transformers-neuronx==0.13.322 \
-    neuronx-distributed==0.10.1 \
-    libneuronxla==2.1.681.0 \
+    neuronx-cc==2.17.194.0 \
+    torch-neuronx==2.5.1.2.6.0 \
+    neuronx-distributed==0.11.0 \
+    libneuronxla==2.2.1630.0 \
     --extra-index-url=https://pip.repos.neuron.amazonaws.com
 
 # Install HuggingFace packages
@@ -160,7 +159,7 @@ RUN pip install dist/text_generation_server*.tar.gz
 # Final image
 FROM neuron
 
-COPY backends/neuron/tgi_env.py /tgi_env.py
+COPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py
 COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
 RUN chmod +x /tgi-entrypoint.sh
 
diff --git a/backends/neuron/server/text_generation_server/generator.py b/backends/neuron/server/text_generation_server/generator.py
index b3887e14..10a4d7a2 100644
--- a/backends/neuron/server/text_generation_server/generator.py
+++ b/backends/neuron/server/text_generation_server/generator.py
@@ -7,7 +7,8 @@ from typing import List, Optional, Tuple
 
 import torch
 from loguru import logger
-from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
+from transformers import AutoTokenizer, PreTrainedTokenizerBase
+from optimum.neuron.configuration_utils import NeuronConfig
 from transformers.generation import GenerationConfig
 
 from optimum.neuron import NeuronModelForCausalLM
@@ -175,6 +176,12 @@ class Slot:
                 self._generation_config.top_p = request.parameters.top_p
             if request.parameters.typical_p != 0:
                 self._generation_config.typical_p = request.parameters.typical_p
+        else:
+            # Set the sampling parameters to emulate greedy decoding when using on-device sampling
+            self._generation_config.temperature = 1.0
+            self._generation_config.top_k = 1
+            self._generation_config.top_p = 1.0
+            self._generation_config.typical_p = 1.0
         if request.parameters.repetition_penalty != 0:
             self._generation_config.repetition_penalty = (
                 request.parameters.repetition_penalty
@@ -211,19 +218,11 @@ class Slot:
         self._mask = attention_mask.clone()
         self._selector = selector
 
-    def pause(self, reset_on_pause: bool):
+    def pause(self):
         """Mark the current slot as paused for generation.
 
         Note that the KV cache for this slot will still be filled.
         """
-        if reset_on_pause:
-            # Drop the last token as it will be added back when resuming the slot
-            self._generated_tokens -= 1
-            # Since generated tokens are now part of the prefill, we need to reevaluate
-            # max_new_tokens for the next generation
-            self._generation_config.max_new_tokens = (
-                self._max_new_tokens - self._generated_tokens
-            )
         self._state = Slot.State.PAUSE
 
     def resume(self):
@@ -340,16 +339,27 @@ class NeuronGenerator(Generator):
         tokenizer: PreTrainedTokenizerBase,
     ):
         self.model = model
-        self.rebuild_cache_on_prefill = not self.model.continuous_batching
+        if not isinstance(self.model, NeuronModelForCausalLM):
+            raise ValueError("The model must be a NeuronModelForCausalLM.")
+        if not model.neuron_config.continuous_batching:
+            raise ValueError(
+                "The neuron model must be compiled with continuous_batching=True."
+            )
         # Specify padding and truncation options for decoder-only architecture
         tokenizer.pad_token_id = tokenizer.eos_token_id
         tokenizer.padding_side = "left"
         tokenizer.truncation_side = "left"
         self.tokenizer = tokenizer
         self.special_tokens = self.tokenizer.all_special_ids
-        self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
+        self.slots = [
+            Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size)
+        ]
         self.batch_id = 0
 
+    @property
+    def on_device_sampling(self) -> bool:
+        return getattr(self.model.neuron_config, "on_device_sampling", False)
+
     @property
     def info(self) -> InfoResponse:
         """Returns the expected InfoResponse."""
@@ -371,14 +381,22 @@ class NeuronGenerator(Generator):
             The maximum number of tokens the model supports.
         """
         # Just check that the warmup request parameters match the model capacity
-        batch_size = self.model.batch_size
+        batch_size = self.model.neuron_config.batch_size
         if len(batch.requests) > batch_size:
             raise ValueError(
-                f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI.  The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process.  The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
+                f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI.  The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process.  The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
             )
         self.prefill(batch)
         self.clear()
-        return self.model.batch_size * self.model.max_length
+        return (
+            self.model.neuron_config.batch_size
+            * self.model.neuron_config.sequence_length
+        )
+
+    def max_prefill_length(self) -> int:
+        if hasattr(self.model.neuron_config, "max_context_length"):
+            return self.model.neuron_config.max_context_length
+        return self.model.neuron_config.sequence_length
 
     def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
         """Prefill new requests.
@@ -398,7 +416,7 @@ class NeuronGenerator(Generator):
         if len(empty_slots) < len(batch.requests):
             raise ValueError(
                 f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
-                f" Please align max_batch_size with the static batch size: {self.model.batch_size}."
+                f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}."
             )
         # Assign each request to an empty slot
         logger.debug(
@@ -412,14 +430,8 @@ class NeuronGenerator(Generator):
             logger.debug(
                 f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
             )
-        if self.rebuild_cache_on_prefill:
-            # We will clear pending slots and prefill all slots
-            prefill_slots = self.slots
-            seq_ids = None
-        else:
-            # We only need to pass inputs for the new requests
-            prefill_slots = new_slots
-            seq_ids = torch.tensor([slot.id for slot in prefill_slots])
+        prefill_slots = new_slots
+        seq_ids = torch.tensor([slot.id for slot in prefill_slots])
         # Reconstruct the full inputs (without padding) as seen by the model.
         # This comprises:
         # - the inputs for new requests,
@@ -431,8 +443,10 @@ class NeuronGenerator(Generator):
             inputs.append(slot.cached_text)
             # Apply truncation, making sure we fit into static dimensions
             if slot.truncate == 0:
-                max_length = self.model.max_length
-            elif slot.truncate > max_length and slot.truncate < self.model.max_length:
+                max_length = self.max_prefill_length()
+            elif (
+                slot.truncate > max_length and slot.truncate < self.max_prefill_length()
+            ):
                 max_length = slot.truncate
         # Tokenize with padding and truncation
         padded_inputs = self.tokenizer(
@@ -444,13 +458,12 @@ class NeuronGenerator(Generator):
         )
         input_ids = padded_inputs.input_ids
         attention_mask = padded_inputs.attention_mask
+        sampling_params = (
+            torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None
+        )
         # Pause previously active slots during generation
-        next_tokens = []
         for slot in active_slots:
-            slot.pause(reset_on_pause=self.rebuild_cache_on_prefill)
-            if self.rebuild_cache_on_prefill:
-                # The slot will be reset, so we need to store its next token
-                next_tokens.append(slot.next_token)
+            slot.pause()
         # Each slot must be reset with the padded inputs and masks
         for i, slot in enumerate(prefill_slots):
             if slot.state != slot.state.EMPTY:
@@ -464,29 +477,33 @@ class NeuronGenerator(Generator):
                     slot_input_ids,
                     slot.generation_config,
                     self.model,
-                    self.model.max_length,
+                    self.model.neuron_config.sequence_length,
                     tokenizer=self.tokenizer,
                     seed=slot.seed,
                 )
                 slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
                 slot_attention_mask = attention_mask[i]
                 slot.reset(slot_input_ids, slot_attention_mask, selector)
+                if sampling_params is not None:
+                    sampling_params[i, 0] = slot.generation_config.top_k
+                    sampling_params[i, 1] = slot.generation_config.top_p
+                    sampling_params[i, 2] = slot.generation_config.temperature
         # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
         # as they have already been generated and sent back in the last decode.
         model_inputs = self.model.prepare_inputs_for_prefill(
-            input_ids, attention_mask, seq_ids
+            input_ids,
+            attention_mask=attention_mask,
+            seq_ids=seq_ids,
+            sampling_params=sampling_params,
         )
-        logits = self.model(**model_inputs)[0]
+        tokens_or_logits = self.model(**model_inputs)[0]
         generation, next_batch = self._generate_token(
-            prefill_slots, self.batch_id, logits, input_ids
+            prefill_slots, self.batch_id, tokens_or_logits, input_ids
         )
         self.batch_id += 1
         # Reactivate previously active slots for the next decode
         for i, slot in enumerate(active_slots):
             slot.resume()
-            if self.rebuild_cache_on_prefill:
-                # Append back the next token
-                slot.append(next_tokens[i])
         logger.debug("Model ready for decoding")
         if next_batch is not None:
             logger.debug(
@@ -530,12 +547,8 @@ class NeuronGenerator(Generator):
             raise ValueError(
                 "Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
             )
-        if self.model.continuous_batching:
-            decode_slots = active_slots
-            seq_ids = torch.tensor([slot.id for slot in decode_slots])
-        else:
-            decode_slots = self.slots
-            seq_ids = None
+        decode_slots = active_slots
+        seq_ids = torch.tensor([slot.id for slot in decode_slots])
         # Reconstruct input_ids and attention_mask from decode slots
         n_slots = len(decode_slots)
         input_ids = torch.full(
@@ -545,22 +558,32 @@ class NeuronGenerator(Generator):
         for slot in decode_slots:
             max_length = max(max_length, slot.attention_mask.size(-1))
         attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
+        sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None
         for i, slot in enumerate(decode_slots):
             if slot.state != Slot.State.EMPTY:
                 # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
                 input_ids[i, 0] = slot.next_token
                 attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
+                if sampling_params is not None:
+                    sampling_params[i, 0] = slot.generation_config.top_k
+                    sampling_params[i, 1] = slot.generation_config.top_p
+                    sampling_params[i, 2] = slot.generation_config.temperature
         model_inputs = self.model.prepare_inputs_for_decode(
-            input_ids, attention_mask, seq_ids
+            input_ids,
+            attention_mask=attention_mask,
+            seq_ids=seq_ids,
+            sampling_params=sampling_params,
+        )
+        tokens_or_logits = self.model(**model_inputs)[0]
+        return self._generate_token(
+            decode_slots, next_batch_id, tokens_or_logits, input_ids
         )
-        logits = self.model(**model_inputs)[0]
-        return self._generate_token(decode_slots, next_batch_id, logits, input_ids)
 
     def _generate_token(
         self,
         slots: List[Slot],
         next_batch_id: int,
-        logits: torch.Tensor,
+        tokens_or_logits: torch.Tensor,
         input_ids: torch.LongTensor,
     ) -> Tuple[List[Generation], CachedBatch]:
         generations = []
@@ -569,9 +592,12 @@ class NeuronGenerator(Generator):
             if slot.state != Slot.State.READY:
                 continue
             request_id = slot.request_id
-            next_token_logits = logits[i : i + 1, -1, :]
             slot_input_ids = input_ids[i : i + 1, :]
-            next_token = slot.select(slot_input_ids, next_token_logits)
+            if self.on_device_sampling:
+                next_token = tokens_or_logits[i]
+            else:
+                next_token_logits = tokens_or_logits[i : i + 1, -1, :]
+                next_token = slot.select(slot_input_ids, next_token_logits)
             next_token_text = slot.append(next_token)
             generated_text = None
             finish_reason = None
@@ -622,7 +648,7 @@ class NeuronGenerator(Generator):
 
     def _cached_batch(self, batch_id: int, request_ids: List):
         size = len(request_ids)
-        max_tokens = size * self.model.max_length
+        max_tokens = size * self.model.neuron_config.sequence_length
         return CachedBatch(
             id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
         )
@@ -671,8 +697,16 @@ class NeuronGenerator(Generator):
         Returns:
             A NeuronGenerator.
         """
-        config = AutoConfig.from_pretrained(model_id)
-        neuron_config = getattr(config, "neuron", None)
+        try:
+            neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
+        except Exception as e:
+            logger.debug(
+                "NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
+                model_id,
+                revision,
+                e,
+            )
+            neuron_config = None
         start = time.time()
         if neuron_config is None:
             export_kwargs = get_export_kwargs_from_env()
diff --git a/backends/neuron/server/text_generation_server/model.py b/backends/neuron/server/text_generation_server/model.py
index 2151a921..d281b175 100644
--- a/backends/neuron/server/text_generation_server/model.py
+++ b/backends/neuron/server/text_generation_server/model.py
@@ -6,10 +6,12 @@ from typing import Optional
 from huggingface_hub import snapshot_download
 from huggingface_hub.constants import HF_HUB_CACHE
 from loguru import logger
-from transformers import AutoConfig
 
-from optimum.neuron import NeuronModelForCausalLM
-from optimum.neuron.utils import get_hub_cached_entries
+from optimum.neuron.cache import get_hub_cached_entries
+from optimum.neuron.configuration_utils import NeuronConfig
+
+
+from .tgi_env import check_env_and_neuron_config_compatibility
 
 
 def get_export_kwargs_from_env():
@@ -24,7 +26,6 @@ def get_export_kwargs_from_env():
         num_cores = int(num_cores)
     auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
     return {
-        "task": "text-generation",
         "batch_size": batch_size,
         "sequence_length": sequence_length,
         "num_cores": num_cores,
@@ -32,20 +33,15 @@ def get_export_kwargs_from_env():
     }
 
 
-def is_cached(model_id, neuron_config):
+def is_cached(model_id):
     # Look for cached entries for the specified model
     in_cache = False
-    entries = get_hub_cached_entries(model_id, "inference")
+    entries = get_hub_cached_entries(model_id)
     # Look for compatible entries
     for entry in entries:
-        compatible = True
-        for key, value in neuron_config.items():
-            # Only weights can be different
-            if key in ["checkpoint_id", "checkpoint_revision"]:
-                continue
-            if entry[key] != value:
-                compatible = False
-        if compatible:
+        if check_env_and_neuron_config_compatibility(
+            entry, check_compiler_version=True
+        ):
             in_cache = True
             break
     return in_cache
@@ -87,8 +83,16 @@ def fetch_model(
         revision = None
     # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model)
     # Note that the model may already be present in the cache.
-    config = AutoConfig.from_pretrained(model_id, revision=revision)
-    neuron_config = getattr(config, "neuron", None)
+    try:
+        neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
+    except Exception as e:
+        logger.debug(
+            "NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
+            model_id,
+            revision,
+            e,
+        )
+        neuron_config = None
     if neuron_config is not None:
         if os.path.isdir(model_id):
             return model_id
@@ -99,16 +103,11 @@ def fetch_model(
         log_cache_size()
         return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
     # Model needs to be exported: look for compatible cached entries on the hub
-    export_kwargs = get_export_kwargs_from_env()
-    export_config = NeuronModelForCausalLM.get_export_config(
-        model_id, config, revision=revision, **export_kwargs
-    )
-    neuron_config = export_config.neuron
-    if not is_cached(model_id, neuron_config):
+    if not is_cached(model_id):
         hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
         neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
         error_msg = (
-            f"No cached version found for {model_id} with {neuron_config}."
+            f"No cached version found for {model_id} with {get_export_kwargs_from_env()}."
             f"You can start a discussion to request it on {hub_cache_url}"
             f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
         )
@@ -121,8 +120,10 @@ def fetch_model(
     # Prefetch weights, tokenizer and generation config so that they are in cache
     log_cache_size()
     start = time.time()
-    snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
+    snapshot_path = snapshot_download(
+        model_id, revision=revision, ignore_patterns="*.bin"
+    )
     end = time.time()
     logger.info(f"Model weights fetched in {end - start:.2f} s.")
     log_cache_size()
-    return model_id
+    return snapshot_path
diff --git a/backends/neuron/tgi_env.py b/backends/neuron/server/text_generation_server/tgi_env.py
old mode 100755
new mode 100644
similarity index 63%
rename from backends/neuron/tgi_env.py
rename to backends/neuron/server/text_generation_server/tgi_env.py
index a7042130..ee97f180
--- a/backends/neuron/tgi_env.py
+++ b/backends/neuron/server/text_generation_server/tgi_env.py
@@ -6,12 +6,11 @@ import os
 import sys
 from typing import Any, Dict, List, Optional
 
-from huggingface_hub import constants
-from transformers import AutoConfig
-
 from optimum.neuron.modeling_decoder import get_available_cores
-from optimum.neuron.utils import get_hub_cached_entries
+from optimum.neuron.cache import get_hub_cached_entries
+from optimum.neuron.configuration_utils import NeuronConfig
 from optimum.neuron.utils.version_utils import get_neuronxcc_version
+from optimum.neuron.utils import map_torch_dtype
 
 
 logger = logging.getLogger(__name__)
@@ -24,15 +23,9 @@ tgi_router_env_vars = [
 ]
 tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"]
 
-env_config_peering = [
-    ("MAX_BATCH_SIZE", "batch_size"),
-    ("MAX_TOTAL_TOKENS", "sequence_length"),
-    ("HF_AUTO_CAST_TYPE", "auto_cast_type"),
-    ("HF_NUM_CORES", "num_cores"),
-]
 
 # By the end of this script all env var should be specified properly
-env_vars = tgi_server_env_vars + tgi_router_env_vars
+tgi_env_vars = tgi_server_env_vars + tgi_router_env_vars
 
 available_cores = get_available_cores()
 neuronxcc_version = get_neuronxcc_version()
@@ -93,9 +86,17 @@ def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
 
 
 def neuron_config_to_env(neuron_config):
+    if isinstance(neuron_config, NeuronConfig):
+        neuron_config = neuron_config.to_dict()
     with open(os.environ["ENV_FILEPATH"], "w") as f:
-        for env_var, config_key in env_config_peering:
-            f.write("export {}={}\n".format(env_var, neuron_config[config_key]))
+        f.write("export MAX_BATCH_SIZE={}\n".format(neuron_config["batch_size"]))
+        f.write("export MAX_TOTAL_TOKENS={}\n".format(neuron_config["sequence_length"]))
+        f.write("export HF_NUM_CORES={}\n".format(neuron_config["tp_degree"]))
+        config_key = (
+            "auto_cast_type" if "auto_cast_type" in neuron_config else "torch_dtype"
+        )
+        auto_cast_type = neuron_config[config_key]
+        f.write("export HF_AUTO_CAST_TYPE={}\n".format(auto_cast_type))
         max_input_tokens = os.getenv("MAX_INPUT_TOKENS")
         if not max_input_tokens:
             max_input_tokens = int(neuron_config["sequence_length"]) // 2
@@ -111,7 +112,7 @@ def neuron_config_to_env(neuron_config):
 
 
 def sort_neuron_configs(dictionary):
-    return -dictionary["num_cores"], -dictionary["batch_size"]
+    return -dictionary["tp_degree"], -dictionary["batch_size"]
 
 
 def lookup_compatible_cached_model(
@@ -119,7 +120,7 @@ def lookup_compatible_cached_model(
 ) -> Optional[Dict[str, Any]]:
     # Reuse the same mechanic as the one in use to configure the tgi server part
     # The only difference here is that we stay as flexible as possible on the compatibility part
-    entries = get_hub_cached_entries(model_id, "inference")
+    entries = get_hub_cached_entries(model_id)
 
     logger.debug(
         "Found %d cached entries for model %s, revision %s",
@@ -155,15 +156,15 @@ def lookup_compatible_cached_model(
 
 
 def check_env_and_neuron_config_compatibility(
-    neuron_config: Dict[str, Any], check_compiler_version: bool
+    neuron_config_dict: Dict[str, Any], check_compiler_version: bool
 ) -> bool:
     logger.debug(
         "Checking the provided neuron config %s is compatible with the local setup and provided environment",
-        neuron_config,
+        neuron_config_dict,
     )
 
     # Local setup compat checks
-    if neuron_config["num_cores"] > available_cores:
+    if neuron_config_dict["tp_degree"] > available_cores:
         logger.debug(
             "Not enough neuron cores available to run the provided neuron config"
         )
@@ -171,33 +172,65 @@ def check_env_and_neuron_config_compatibility(
 
     if (
         check_compiler_version
-        and neuron_config["compiler_version"] != neuronxcc_version
+        and neuron_config_dict["neuronxcc_version"] != neuronxcc_version
     ):
         logger.debug(
             "Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)",
             neuronxcc_version,
-            neuron_config["compiler_version"],
+            neuron_config_dict["neuronxcc_version"],
         )
         return False
 
-    for env_var, config_key in env_config_peering:
-        neuron_config_value = str(neuron_config[config_key])
-        env_value = os.getenv(env_var, str(neuron_config_value))
+    batch_size = os.getenv("MAX_BATCH_SIZE", None)
+    if batch_size is not None and neuron_config_dict["batch_size"] < int(batch_size):
+        logger.debug(
+            "The provided MAX_BATCH_SIZE (%s) is higher than the neuron config batch size (%s)",
+            os.getenv("MAX_BATCH_SIZE"),
+            neuron_config_dict["batch_size"],
+        )
+        return False
+    max_total_tokens = os.getenv("MAX_TOTAL_TOKENS", None)
+    if max_total_tokens is not None and neuron_config_dict["sequence_length"] < int(
+        max_total_tokens
+    ):
+        logger.debug(
+            "The provided MAX_TOTAL_TOKENS (%s) is higher than the neuron config sequence length (%s)",
+            max_total_tokens,
+            neuron_config_dict["sequence_length"],
+        )
+        return False
+    num_cores = os.getenv("HF_NUM_CORES", None)
+    if num_cores is not None and neuron_config_dict["tp_degree"] < int(num_cores):
+        logger.debug(
+            "The provided HF_NUM_CORES (%s) is higher than the neuron config tp degree (%s)",
+            num_cores,
+            neuron_config_dict["tp_degree"],
+        )
+        return False
+    auto_cast_type = os.getenv("HF_AUTO_CAST_TYPE", None)
+    if auto_cast_type is not None:
+        config_key = (
+            "auto_cast_type"
+            if "auto_cast_type" in neuron_config_dict
+            else "torch_dtype"
+        )
+        neuron_config_value = map_torch_dtype(str(neuron_config_dict[config_key]))
+        env_value = map_torch_dtype(auto_cast_type)
         if env_value != neuron_config_value:
             logger.debug(
-                "The provided env var '%s' and the neuron config '%s' param differ (%s != %s)",
-                env_var,
-                config_key,
+                "The provided auto cast type and the neuron config param differ (%s != %s)",
                 env_value,
                 neuron_config_value,
             )
             return False
-
     max_input_tokens = int(
         os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))
     )
     if max_input_tokens > 0:
-        sequence_length = neuron_config["sequence_length"]
+        if hasattr(neuron_config_dict, "max_context_length"):
+            sequence_length = neuron_config_dict["max_context_length"]
+        else:
+            sequence_length = neuron_config_dict["sequence_length"]
         if max_input_tokens >= sequence_length:
             logger.debug(
                 "Specified max input tokens is not compatible with config sequence length ( %s >= %s)",
@@ -211,38 +244,29 @@ def check_env_and_neuron_config_compatibility(
 
 def get_env_dict() -> Dict[str, str]:
     d = {}
-    for k in env_vars:
+    for k in tgi_env_vars:
         d[k] = os.getenv(k)
     return d
 
 
-def main():
-    """
-    This script determines proper default TGI env variables for the neuron precompiled models to
-    work properly
-    :return:
-    """
-    args = parse_cmdline_and_set_env()
-
-    for env_var in env_vars:
-        if not os.getenv(env_var):
-            break
-    else:
-        logger.info(
-            "All env vars %s already set, skipping, user know what they are doing",
-            env_vars,
+def get_neuron_config_for_model(
+    model_name_or_path: str, revision: Optional[str] = None
+) -> NeuronConfig:
+    try:
+        neuron_config = NeuronConfig.from_pretrained(
+            model_name_or_path, revision=revision
         )
-        sys.exit(0)
-
-    cache_dir = constants.HF_HUB_CACHE
-
-    logger.info("Cache dir %s, model %s", cache_dir, args.model_id)
-
-    config = AutoConfig.from_pretrained(args.model_id, revision=args.revision)
-    neuron_config = getattr(config, "neuron", None)
+    except Exception as e:
+        logger.debug(
+            "NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
+            model_name_or_path,
+            revision,
+            e,
+        )
+        neuron_config = None
     if neuron_config is not None:
         compatible = check_env_and_neuron_config_compatibility(
-            neuron_config, check_compiler_version=False
+            neuron_config.to_dict(), check_compiler_version=False
         )
         if not compatible:
             env_dict = get_env_dict()
@@ -252,17 +276,6 @@ def main():
             logger.error(msg)
             raise Exception(msg)
     else:
-        neuron_config = lookup_compatible_cached_model(args.model_id, args.revision)
+        neuron_config = lookup_compatible_cached_model(model_name_or_path, revision)
 
-    if not neuron_config:
-        msg = (
-            "No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
-        ).format(get_env_dict(), available_cores, neuronxcc_version)
-        logger.error(msg)
-        raise Exception(msg)
-
-    neuron_config_to_env(neuron_config)
-
-
-if __name__ == "__main__":
-    main()
+    return neuron_config
diff --git a/backends/neuron/tests/fixtures/model.py b/backends/neuron/tests/fixtures/model.py
index 4b6a1375..ad41fd10 100644
--- a/backends/neuron/tests/fixtures/model.py
+++ b/backends/neuron/tests/fixtures/model.py
@@ -4,14 +4,12 @@ import subprocess
 import sys
 from tempfile import TemporaryDirectory
 
-import huggingface_hub
+import os
 import pytest
 from transformers import AutoTokenizer
 
-from optimum.neuron import NeuronModelForCausalLM
-from optimum.neuron.utils import synchronize_hub_cache
-from optimum.neuron.version import __sdk_version__ as sdk_version
-from optimum.neuron.version import __version__ as version
+
+from optimum.neuron.cache import synchronize_hub_cache
 
 
 logging.basicConfig(
@@ -21,30 +19,14 @@ logging.basicConfig(
 )
 logger = logging.getLogger(__file__)
 
+
 OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache"
 
+
 # All model configurations below will be added to the neuron_model_config fixture
 MODEL_CONFIGURATIONS = {
-    "gpt2": {
-        "model_id": "gpt2",
-        "export_kwargs": {
-            "batch_size": 4,
-            "sequence_length": 1024,
-            "num_cores": 2,
-            "auto_cast_type": "fp16",
-        },
-    },
     "llama": {
-        "model_id": "NousResearch/Hermes-2-Theta-Llama-3-8B",
-        "export_kwargs": {
-            "batch_size": 4,
-            "sequence_length": 2048,
-            "num_cores": 2,
-            "auto_cast_type": "fp16",
-        },
-    },
-    "mistral": {
-        "model_id": "optimum/mistral-1.1b-testing",
+        "model_id": "unsloth/Llama-3.2-1B-Instruct",
         "export_kwargs": {
             "batch_size": 4,
             "sequence_length": 4096,
@@ -58,7 +40,7 @@ MODEL_CONFIGURATIONS = {
             "batch_size": 4,
             "sequence_length": 4096,
             "num_cores": 2,
-            "auto_cast_type": "fp16",
+            "auto_cast_type": "bf16",
         },
     },
     "granite": {
@@ -73,12 +55,6 @@ MODEL_CONFIGURATIONS = {
 }
 
 
-def get_hub_neuron_model_id(config_name: str):
-    return (
-        f"optimum-internal-testing/neuron-testing-{version}-{sdk_version}-{config_name}"
-    )
-
-
 def export_model(model_id, export_kwargs, neuron_model_path):
     export_command = [
         "optimum-cli",
@@ -104,57 +80,35 @@ def export_model(model_id, export_kwargs, neuron_model_path):
 def neuron_model_config(request):
     """Expose a pre-trained neuron model
 
-    The fixture first makes sure the following model artifacts are present on the hub:
-    - exported neuron model under optimum-internal-testing/neuron-testing-<version>-<name>,
-    - cached artifacts under optimum-internal-testing/neuron-testing-cache.
-    If not, it will export the model and push it to the hub.
-
-    It then fetches the model locally and return a dictionary containing:
+    The fixture exports a model locally and returns a dictionary containing:
     - a configuration name,
     - the original model id,
     - the export parameters,
-    - the neuron model id,
     - the neuron model local path.
 
     For each exposed model, the local directory is maintained for the duration of the
     test session and cleaned up afterwards.
-    The hub model artifacts are never cleaned up and persist accross sessions.
-    They must be cleaned up manually when the optimum-neuron version changes.
 
     """
     config_name = request.param
     model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
     model_id = model_config["model_id"]
     export_kwargs = model_config["export_kwargs"]
-    neuron_model_id = get_hub_neuron_model_id(config_name)
     with TemporaryDirectory() as neuron_model_path:
-        hub = huggingface_hub.HfApi()
-        if hub.repo_exists(neuron_model_id):
-            logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub")
-            hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path)
-        else:
-            export_model(model_id, export_kwargs, neuron_model_path)
-            tokenizer = AutoTokenizer.from_pretrained(model_id)
-            tokenizer.save_pretrained(neuron_model_path)
-            del tokenizer
-            # Create the test model on the hub
-            hub.create_repo(neuron_model_id, private=True)
-            hub.upload_folder(
-                folder_path=neuron_model_path,
-                repo_id=neuron_model_id,
-                ignore_patterns=[NeuronModelForCausalLM.CHECKPOINT_DIR + "/*"],
-            )
-            # Make sure it is cached
-            synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
+        export_model(model_id, export_kwargs, neuron_model_path)
+        synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
+        tokenizer = AutoTokenizer.from_pretrained(model_id)
+        tokenizer.save_pretrained(neuron_model_path)
+        del tokenizer
         # Add dynamic parameters to the model configuration
         model_config["neuron_model_path"] = neuron_model_path
-        model_config["neuron_model_id"] = neuron_model_id
         # Also add model configuration name to allow tests to adapt their expectations
         model_config["name"] = config_name
         # Yield instead of returning to keep a reference to the temporary directory.
         # It will go out of scope and be released only once all tests needing the fixture
         # have been completed.
         logger.info(f"{config_name} ready for testing ...")
+        os.environ["CUSTOM_CACHE_REPO"] = OPTIMUM_CACHE_REPO_ID
         yield model_config
         logger.info(f"Done with {config_name}")
 
diff --git a/backends/neuron/tests/server/test_cached_model.py b/backends/neuron/tests/server/test_cached_model.py
new file mode 100644
index 00000000..73622578
--- /dev/null
+++ b/backends/neuron/tests/server/test_cached_model.py
@@ -0,0 +1,42 @@
+import os
+import pytest
+
+from text_generation_server.generator import NeuronGenerator
+from text_generation_server.model import fetch_model, is_cached
+
+
+@pytest.fixture(scope="module")
+def cached_model_id(neuron_model_config) -> str:
+    """
+    Fixture to provide a cached model ID for testing.
+    This assumes the model is already cached in the local environment.
+    """
+    export_kwargs = neuron_model_config["export_kwargs"]
+    os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
+    os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
+    os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
+    os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
+    yield neuron_model_config["model_id"]
+    os.environ.pop("MAX_BATCH_SIZE", None)
+    os.environ.pop("MAX_TOTAL_TOKENS", None)
+    os.environ.pop("HF_AUTO_CAST_TYPE", None)
+    os.environ.pop("HF_NUM_CORES", None)
+
+
+def test_model_is_cached(cached_model_id):
+    assert is_cached(cached_model_id), f"Model {cached_model_id} is not cached"
+
+
+def test_fetch_cached_model(cached_model_id: str):
+    model_path = fetch_model(cached_model_id)
+    assert os.path.exists(
+        model_path
+    ), f"Model {cached_model_id} was not fetched successfully"
+    assert os.path.isdir(model_path), f"Model {cached_model_id} is not a directory"
+
+
+def test_generator_from_cached_model(cached_model_id: str):
+    generator = NeuronGenerator.from_pretrained(model_id=cached_model_id)
+    assert generator is not None, "Generator could not be created from cached model"
+    assert generator.model is not None, "Generator model is not initialized"
+    assert generator.tokenizer is not None, "Generator tokenizer is not initialized"
diff --git a/backends/neuron/tests/server/test_continuous_batching.py b/backends/neuron/tests/server/test_continuous_batching.py
index 48bb70cc..3d9ab509 100644
--- a/backends/neuron/tests/server/test_continuous_batching.py
+++ b/backends/neuron/tests/server/test_continuous_batching.py
@@ -9,13 +9,13 @@ def test_continuous_batching_two_requests(neuron_model_config):
     """
     neuron_model_path = neuron_model_config["neuron_model_path"]
     generator = NeuronGenerator.from_pretrained(neuron_model_path)
-    assert generator.model.batch_size > 1
+    assert generator.model.neuron_config.batch_size > 1
     input_text = "Once upon a time"
     max_new_tokens = 20
     # Prefill a single request, remembering the generated token
     tokens = {0: [], 1: []}
     request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
-    max_length = generator.model.max_length
+    max_length = generator.model.neuron_config.sequence_length
     batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
     generations, next_batch = generator.prefill(batch)
     assert next_batch.size == 1
diff --git a/backends/neuron/tests/server/test_decode.py b/backends/neuron/tests/server/test_decode.py
index 9db5e3ab..b864e3ec 100644
--- a/backends/neuron/tests/server/test_decode.py
+++ b/backends/neuron/tests/server/test_decode.py
@@ -23,7 +23,7 @@ def _test_decode(config_name, generator, do_sample):
     request = create_request(
         id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
     )
-    max_length = generator.model.max_length
+    max_length = generator.model.neuron_config.sequence_length
     batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
     generations, next_batch = generator.prefill(batch)
     # We already generated one token: call decode max_new_tokens - 1 times
@@ -40,19 +40,15 @@ def _test_decode(config_name, generator, do_sample):
     assert output.finish_reason == 0
     if do_sample:
         expected_text = {
-            "gpt2": " The sun was set",
-            "llama": "George Orwell, 1984",
-            "mistral": "The sky was",
-            "qwen2": " A young woman with",
+            "llama": " I sat alone in the café",
+            "qwen2": " The air was so still",
             "granite": "1984, George Orwell",
         }[config_name]
         assert expected_text in output.text
     else:
         print(output.text)
         expected_text = {
-            "gpt2": '\n\n"I\'m going to go to bed," I said.\n\n"I\'m going',
-            "llama": " George Orwell’s classic dystopian novel, 1984, begins with this ominous sentence. The story",
-            "mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.",
+            "llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility",
             "qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
             "granite": "\n\nThis opening line from George Orwell's dystopian novel \"198",
         }[config_name]
diff --git a/backends/neuron/tests/server/test_prefill.py b/backends/neuron/tests/server/test_prefill.py
index c0155b1a..c9ecd1c8 100644
--- a/backends/neuron/tests/server/test_prefill.py
+++ b/backends/neuron/tests/server/test_prefill.py
@@ -9,7 +9,7 @@ def test_prefill(neuron_model_config):
     neuron_model_path = neuron_model_config["neuron_model_path"]
     generator = NeuronGenerator.from_pretrained(neuron_model_path)
     max_batch_size = 4
-    assert generator.model.batch_size >= max_batch_size
+    assert generator.model.neuron_config.batch_size >= max_batch_size
     for num_requests in [1, max_batch_size]:
         for do_sample in [True, False]:
             mode = "sample" if do_sample else "greedy"
@@ -34,7 +34,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
             )
         )
     # Let's be pessimistic when estimating max_tokens
-    max_length = generator.model.max_length
+    max_length = generator.max_prefill_length()
     batch = Batch(
         id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
     )
@@ -46,17 +46,13 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
     assert len(generations) == batch_size
     if do_sample:
         expectations = {
-            "gpt2": [383, " The"],
-            "llama": [10058, " George"],
-            "mistral": [450, " The"],
-            "qwen2": [362, " A"],
+            "llama": [358, " I"],
+            "qwen2": [576, " The"],
             "granite": [308, " ("],
         }[config_name]
     else:
         expectations = {
-            "gpt2": [198, "\n"],
-            "llama": [10058, " George"],
-            "mistral": [13, "\n"],
+            "llama": [578, " The"],
             "qwen2": [358, " I"],
             "granite": [203, "\n"],
         }[config_name]
@@ -70,7 +66,7 @@ def test_prefill_truncate(neuron_model_config):
     config_name = neuron_model_config["name"]
     neuron_model_path = neuron_model_config["neuron_model_path"]
     generator = NeuronGenerator.from_pretrained(neuron_model_path)
-    batch_size = generator.model.batch_size
+    batch_size = generator.model.neuron_config.batch_size
     # We apply truncation to all requests but the first one
     truncate = [
         None,
@@ -83,7 +79,7 @@ def test_prefill_truncate(neuron_model_config):
     requests = []
     for i in range(batch_size):
         requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
-    max_length = generator.model.max_length
+    max_length = generator.max_prefill_length()
     batch = Batch(
         id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
     )
@@ -91,12 +87,12 @@ def test_prefill_truncate(neuron_model_config):
     # Even if the input text is identical for all requests, the first generated token might
     # be different because of the truncation
     expectations = {
-        "gpt2": [" He", " He", "\n", " He"],
-        "llama": [" —", " The", " He", " He"],
-        "mistral": [" He", "\n", " He", " He"],
+        "llama": [" He", "iens", "\x08", " He"],
         "qwen2": [" He", " The", " He", " He"],
         "granite": ["\n", "\n", " I", " He"],
     }[config_name]
     for i, g in enumerate(generations):
         tokens = g.tokens
-        assert tokens.texts[0] == expectations[i]
+        assert (
+            tokens.texts[0] == expectations[i]
+        ), f"Request {i} expected [{expectations[i]}], got [{tokens.texts[0]}]"
diff --git a/backends/neuron/tests/test_entry_point.py b/backends/neuron/tests/test_entry_point.py
new file mode 100644
index 00000000..d4ddf338
--- /dev/null
+++ b/backends/neuron/tests/test_entry_point.py
@@ -0,0 +1,63 @@
+import os
+import pytest
+from tempfile import TemporaryDirectory
+
+from optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig
+from optimum.neuron.utils import map_torch_dtype
+
+from text_generation_server.tgi_env import (
+    get_neuron_config_for_model,
+    lookup_compatible_cached_model,
+    neuron_config_to_env,
+)
+
+
+def test_get_neuron_config_for_model(neuron_model_config):
+    neuron_model_path = neuron_model_config["neuron_model_path"]
+    export_kwargs = neuron_model_config["export_kwargs"]
+    os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
+    os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
+    os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
+    os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
+    neuron_config = get_neuron_config_for_model(neuron_model_path)
+    assert neuron_config is not None
+    assert neuron_config.batch_size == export_kwargs["batch_size"]
+    assert neuron_config.sequence_length == export_kwargs["sequence_length"]
+    assert neuron_config.tp_degree == export_kwargs["num_cores"]
+    if isinstance(neuron_config, NxDNeuronConfig):
+        assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype(
+            export_kwargs["auto_cast_type"]
+        )
+    else:
+        assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype(
+            export_kwargs["auto_cast_type"]
+        )
+
+
+@pytest.mark.parametrize("model_id", ["unsloth/Llama-3.2-1B-Instruct"])
+def test_lookup_compatible_cached_model(model_id: str):
+    neuron_config = lookup_compatible_cached_model(model_id, None)
+    assert neuron_config is not None
+
+
+def test_neuron_config_to_env(neuron_model_config) -> None:
+    neuron_model_path = neuron_model_config["neuron_model_path"]
+    neuron_config = get_neuron_config_for_model(neuron_model_path)
+    with TemporaryDirectory() as temp_dir:
+        os.environ["ENV_FILEPATH"] = os.path.join(temp_dir, "env.sh")
+        neuron_config_to_env(neuron_config)
+        with open(os.environ["ENV_FILEPATH"], "r") as env_file:
+            env_content = env_file.read()
+            assert f"export MAX_BATCH_SIZE={neuron_config.batch_size}" in env_content
+            assert (
+                f"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}"
+                in env_content
+            )
+            assert f"export HF_NUM_CORES={neuron_config.tp_degree}" in env_content
+            if hasattr(neuron_config, "torch_dtype"):
+                auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split(
+                    "."
+                )[-1]
+            else:
+                auto_cast_type = neuron_config.auto_cast_type
+            assert f"export HF_AUTO_CAST_TYPE={auto_cast_type}" in env_content
diff --git a/backends/neuron/tgi-entrypoint.sh b/backends/neuron/tgi-entrypoint.sh
index b959a795..7965d1da 100755
--- a/backends/neuron/tgi-entrypoint.sh
+++ b/backends/neuron/tgi-entrypoint.sh
@@ -9,7 +9,7 @@ touch $ENV_FILEPATH
 
 SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
 
-${SCRIPT_DIR}/tgi_env.py $@
+${SCRIPT_DIR}/tgi_entry_point.py $@
 
 source $ENV_FILEPATH
 
diff --git a/backends/neuron/tgi_entry_point.py b/backends/neuron/tgi_entry_point.py
new file mode 100755
index 00000000..7e81d0e4
--- /dev/null
+++ b/backends/neuron/tgi_entry_point.py
@@ -0,0 +1,53 @@
+#!/usr/bin/env python
+
+import logging
+import os
+import sys
+
+
+from text_generation_server.tgi_env import (
+    available_cores,
+    get_env_dict,
+    get_neuron_config_for_model,
+    neuron_config_to_env,
+    neuronxcc_version,
+    parse_cmdline_and_set_env,
+    tgi_env_vars,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+def main():
+    """
+    This script determines proper default TGI env variables for the neuron precompiled models to
+    work properly
+    :return:
+    """
+    args = parse_cmdline_and_set_env()
+
+    for env_var in tgi_env_vars:
+        if not os.getenv(env_var):
+            break
+    else:
+        logger.info(
+            "All env vars %s already set, skipping, user know what they are doing",
+            tgi_env_vars,
+        )
+        sys.exit(0)
+
+    neuron_config = get_neuron_config_for_model(args.model_id, args.revision)
+
+    if not neuron_config:
+        msg = (
+            "No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
+        ).format(get_env_dict(), available_cores, neuronxcc_version)
+        logger.error(msg)
+        raise Exception(msg)
+
+    neuron_config_to_env(neuron_config)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/integration-tests/fixtures/neuron/export_models.py b/integration-tests/fixtures/neuron/export_models.py
index 836402ec..d4d0f01c 100644
--- a/integration-tests/fixtures/neuron/export_models.py
+++ b/integration-tests/fixtures/neuron/export_models.py
@@ -28,15 +28,6 @@ logger = logging.getLogger(__file__)
 
 # All model configurations below will be added to the neuron_model_config fixture
 MODEL_CONFIGURATIONS = {
-    "gpt2": {
-        "model_id": "gpt2",
-        "export_kwargs": {
-            "batch_size": 4,
-            "sequence_length": 1024,
-            "num_cores": 2,
-            "auto_cast_type": "fp16",
-        },
-    },
     "llama": {
         "model_id": "unsloth/Llama-3.2-1B-Instruct",
         "export_kwargs": {
@@ -46,15 +37,6 @@ MODEL_CONFIGURATIONS = {
             "auto_cast_type": "fp16",
         },
     },
-    "mistral": {
-        "model_id": "optimum/mistral-1.1b-testing",
-        "export_kwargs": {
-            "batch_size": 4,
-            "sequence_length": 4096,
-            "num_cores": 2,
-            "auto_cast_type": "bf16",
-        },
-    },
     "qwen2": {
         "model_id": "Qwen/Qwen2.5-0.5B",
         "export_kwargs": {
diff --git a/integration-tests/neuron/test_generate.py b/integration-tests/neuron/test_generate.py
index f0804356..9108ce0e 100644
--- a/integration-tests/neuron/test_generate.py
+++ b/integration-tests/neuron/test_generate.py
@@ -20,9 +20,7 @@ async def test_model_single_request(tgi_service):
     )
     assert response.details.generated_tokens == 17
     greedy_expectations = {
-        "gpt2": "\n\nDeep learning is a new field of research that has been around for a while",
-        "llama": " and How Does it Work?\nDeep learning is a subset of machine learning that uses artificial",
-        "mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that",
+        "llama": " and how does it work?\nDeep learning is a subset of machine learning that uses artificial",
         "qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on",
         "granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art",
     }
@@ -79,9 +77,7 @@ async def test_model_multiple_requests(tgi_service, neuron_generate_load):
 
     assert len(responses) == 4
     expectations = {
-        "gpt2": "Deep learning is a new field of research that has been around for a while",
         "llama": "Deep learning is a subset of machine learning that uses artificial",
-        "mistral": "Deep Learning is a type of machine learning that",
         "qwen2": "Deep Learning is a subset of Machine Learning that is based on",
         "granite": "Deep Learning is a subset of Machine Learning, which is a branch of Art",
     }