From e2fa96a91c4f60ad9938660b26a623739445e2f9 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 23 May 2025 09:48:05 +0000 Subject: [PATCH] fix(neuron): neuron config is not stored in config anymore --- .../server/text_generation_server/generator.py | 15 ++++++++++++--- .../server/text_generation_server/model.py | 14 ++++++++++++-- backends/neuron/tgi_env.py | 16 +++++++++++++--- 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/backends/neuron/server/text_generation_server/generator.py b/backends/neuron/server/text_generation_server/generator.py index 77746512..0878c1bd 100644 --- a/backends/neuron/server/text_generation_server/generator.py +++ b/backends/neuron/server/text_generation_server/generator.py @@ -7,7 +7,8 @@ from typing import List, Optional, Tuple import torch from loguru import logger -from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from optimum.neuron.configuration_utils import NeuronConfig from transformers.generation import GenerationConfig from optimum.neuron import NeuronModelForCausalLM @@ -663,8 +664,16 @@ class NeuronGenerator(Generator): Returns: A NeuronGenerator. """ - config = AutoConfig.from_pretrained(model_id) - neuron_config = getattr(config, "neuron", None) + try: + neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision) + except Exception as e: + logger.debug( + "NeuronConfig.from_pretrained failed for model %s, revision %s: %s", + model_id, + revision, + e, + ) + neuron_config = None start = time.time() if neuron_config is None: export_kwargs = get_export_kwargs_from_env() diff --git a/backends/neuron/server/text_generation_server/model.py b/backends/neuron/server/text_generation_server/model.py index fe6a00eb..25173ec2 100644 --- a/backends/neuron/server/text_generation_server/model.py +++ b/backends/neuron/server/text_generation_server/model.py @@ -10,6 +10,7 @@ from transformers import AutoConfig from optimum.neuron import NeuronModelForCausalLM from optimum.neuron.cache import get_hub_cached_entries +from optimum.neuron.configuration_utils import NeuronConfig def get_export_kwargs_from_env(): @@ -87,8 +88,16 @@ def fetch_model( revision = None # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model) # Note that the model may already be present in the cache. - config = AutoConfig.from_pretrained(model_id, revision=revision) - neuron_config = getattr(config, "neuron", None) + try: + neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision) + except Exception as e: + logger.debug( + "NeuronConfig.from_pretrained failed for model %s, revision %s: %s", + model_id, + revision, + e, + ) + neuron_config = None if neuron_config is not None: if os.path.isdir(model_id): return model_id @@ -100,6 +109,7 @@ def fetch_model( return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") # Model needs to be exported: look for compatible cached entries on the hub export_kwargs = get_export_kwargs_from_env() + config = AutoConfig.from_pretrained(model_id, revision=revision) export_config = NeuronModelForCausalLM.get_export_config( model_id, config, revision=revision, **export_kwargs ) diff --git a/backends/neuron/tgi_env.py b/backends/neuron/tgi_env.py index 12ede1e9..f7a89269 100755 --- a/backends/neuron/tgi_env.py +++ b/backends/neuron/tgi_env.py @@ -7,10 +7,10 @@ import sys from typing import Any, Dict, List, Optional from huggingface_hub import constants -from transformers import AutoConfig from optimum.neuron.modeling_decoder import get_available_cores from optimum.neuron.cache import get_hub_cached_entries +from optimum.neuron.configuration_utils import NeuronConfig from optimum.neuron.utils.version_utils import get_neuronxcc_version @@ -238,8 +238,18 @@ def main(): logger.info("Cache dir %s, model %s", cache_dir, args.model_id) - config = AutoConfig.from_pretrained(args.model_id, revision=args.revision) - neuron_config = getattr(config, "neuron", None) + try: + neuron_config = NeuronConfig.from_pretrained( + args.model_id, revision=args.revision + ) + except Exception as e: + logger.debug( + "NeuronConfig.from_pretrained failed for model %s, revision %s: %s", + args.model_id, + args.revision, + e, + ) + neuron_config = None if neuron_config is not None: compatible = check_env_and_neuron_config_compatibility( neuron_config, check_compiler_version=False