diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 7144542f..9229bcf2 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -486,8 +486,6 @@ def get_model( model_type = config_dict["model_type"] - if model_type == "gpt_bigcode": - return StarCoder(model_id=model_id, revision=revision, dtype=dtype) kv_cache_dtype = dtype if FLASH_ATTENTION: @@ -871,6 +869,8 @@ def get_model( trust_remote_code=trust_remote_code, ) adapt_transformers_to_gaudi() + if model_type == "gpt_bigcode": + return StarCoder(model_id=model_id, revision=revision, dtype=dtype) if model_type == "bloom": return BLOOM( model_id=model_id,