mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Reload model_type when speculator is found.
This commit is contained in:
parent
6009dadee3
commit
dcb727c232
@ -158,6 +158,8 @@ def get_model(
|
|||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
# Reload model type from parent.
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
is_local = Path(medusa_model_id).exists()
|
is_local = Path(medusa_model_id).exists()
|
||||||
if not is_local:
|
if not is_local:
|
||||||
medusa_config = hf_hub_download(
|
medusa_config = hf_hub_download(
|
||||||
@ -198,6 +200,8 @@ def get_model(
|
|||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
# Reload model type from parent.
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
is_local = Path(mlp_model_id).exists()
|
is_local = Path(mlp_model_id).exists()
|
||||||
extension = ".safetensors"
|
extension = ".safetensors"
|
||||||
if not is_local:
|
if not is_local:
|
||||||
|
Loading…
Reference in New Issue
Block a user