Fix HF_HUB_OFFLINE=1 for Gaudi backend (#3193)

* Fix `HF_HUB_OFFLINE=1` for Gaudi backend

* Fix HF cache default value in server.rs

* Format
This commit is contained in:
regisss 2025-05-06 02:47:53 -06:00 committed by GitHub
parent 7253be349a
commit f208ba6afc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 33 additions and 20 deletions

View File

@ -8,7 +8,7 @@ PYTORCH_VERSION := 2.6.0
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install .PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
image: image:
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) docker build --ulimit nofile=4096 -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
run-local-dev-container: run-local-dev-container:
docker run -it \ docker run -it \

View File

@ -4,6 +4,7 @@ import bisect
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
import itertools import itertools
import json
import math import math
import os import os
import tempfile import tempfile
@ -17,15 +18,12 @@ from loguru import logger
from opentelemetry import trace from opentelemetry import trace
import text_generation_server.habana_quantization_env as hq_env import text_generation_server.habana_quantization_env as hq_env
from text_generation_server.utils import weight_files
import habana_frameworks.torch as htorch import habana_frameworks.torch as htorch
from optimum.habana.utils import HabanaProfile from optimum.habana.utils import HabanaProfile
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from optimum.habana.checkpoint_utils import ( from optimum.habana.checkpoint_utils import model_on_meta
get_repo_root,
model_on_meta,
write_checkpoints_json,
)
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
AutoModelForCausalLM, AutoModelForCausalLM,
@ -708,6 +706,9 @@ class CausalLM(Model):
if hq_env.is_quantization_enabled: if hq_env.is_quantization_enabled:
htorch.core.hpu_set_env() htorch.core.hpu_set_env()
# Get weight files
weight_files(model_id, revision=revision, extension=".safetensors")
if world_size > 1: if world_size > 1:
os.environ.setdefault( os.environ.setdefault(
"DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1" "DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
@ -715,8 +716,6 @@ class CausalLM(Model):
model = self.get_deepspeed_model(model_id, dtype, revision) model = self.get_deepspeed_model(model_id, dtype, revision)
model = hq_env.prepare_model_for_quantization(model) model = hq_env.prepare_model_for_quantization(model)
else: else:
get_repo_root(model_id)
# Check support for rope scaling # Check support for rope scaling
model_kwargs = {} model_kwargs = {}
config = AutoConfig.from_pretrained(model_id) config = AutoConfig.from_pretrained(model_id)
@ -868,7 +867,6 @@ class CausalLM(Model):
with deepspeed.OnDevice(dtype=dtype, device="meta"): with deepspeed.OnDevice(dtype=dtype, device="meta"):
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
else: else:
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
# TODO: revisit placement on CPU when auto-injection is possible # TODO: revisit placement on CPU when auto-injection is possible
with deepspeed.OnDevice(dtype=dtype, device="cpu"): with deepspeed.OnDevice(dtype=dtype, device="cpu"):
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
@ -884,7 +882,16 @@ class CausalLM(Model):
if load_to_meta: if load_to_meta:
# model loaded to meta is managed differently # model loaded to meta is managed differently
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
write_checkpoints_json(model_id, local_rank, checkpoints_json) checkpoint_files = [
str(f)
for f in weight_files(
model_id, revision=revision, extension=".safetensors"
)
]
data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0}
json.dump(data, checkpoints_json)
checkpoints_json.flush()
ds_inference_kwargs["checkpoint"] = checkpoints_json.name ds_inference_kwargs["checkpoint"] = checkpoints_json.name
model = deepspeed.init_inference(model, **ds_inference_kwargs) model = deepspeed.init_inference(model, **ds_inference_kwargs)

View File

@ -1,3 +1,4 @@
import json
import re import re
import torch import torch
import os import os
@ -12,6 +13,7 @@ import tempfile
import copy import copy
from text_generation_server.models import Model from text_generation_server.models import Model
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from text_generation_server.utils import weight_files
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import ( from text_generation_server.models.causal_lm import (
@ -43,11 +45,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
AutoConfig, AutoConfig,
) )
from optimum.habana.checkpoint_utils import ( from optimum.habana.checkpoint_utils import model_on_meta
get_repo_root,
model_on_meta,
write_checkpoints_json,
)
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -840,6 +838,9 @@ class VlmCausalLM(Model):
if hq_env.is_quantization_enabled: if hq_env.is_quantization_enabled:
htorch.core.hpu_set_env() htorch.core.hpu_set_env()
# Get weight files
weight_files(model_id, revision=revision, extension=".safetensors")
if world_size > 1: if world_size > 1:
os.environ.setdefault( os.environ.setdefault(
"DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1" "DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
@ -847,8 +848,6 @@ class VlmCausalLM(Model):
model = self.get_deepspeed_model(model_class, model_id, dtype, revision) model = self.get_deepspeed_model(model_class, model_id, dtype, revision)
model = hq_env.prepare_model_for_quantization(model) model = hq_env.prepare_model_for_quantization(model)
else: else:
get_repo_root(model_id)
# Check support for rope scaling # Check support for rope scaling
model_kwargs = {} model_kwargs = {}
config = AutoConfig.from_pretrained(model_id) config = AutoConfig.from_pretrained(model_id)
@ -1000,7 +999,6 @@ class VlmCausalLM(Model):
with deepspeed.OnDevice(dtype=dtype, device="meta"): with deepspeed.OnDevice(dtype=dtype, device="meta"):
model = model_class.from_config(config, torch_dtype=dtype) model = model_class.from_config(config, torch_dtype=dtype)
else: else:
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
# TODO: revisit placement on CPU when auto-injection is possible # TODO: revisit placement on CPU when auto-injection is possible
with deepspeed.OnDevice(dtype=dtype, device="cpu"): with deepspeed.OnDevice(dtype=dtype, device="cpu"):
model = model_class.from_pretrained( model = model_class.from_pretrained(
@ -1019,7 +1017,15 @@ class VlmCausalLM(Model):
if load_to_meta: if load_to_meta:
# model loaded to meta is managed differently # model loaded to meta is managed differently
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
write_checkpoints_json(model_id, local_rank, checkpoints_json) checkpoint_files = [
str(f)
for f in weight_files(
model_id, revision=revision, extension=".safetensors"
)
]
data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0}
json.dump(data, checkpoints_json)
checkpoints_json.flush()
ds_inference_kwargs["checkpoint"] = checkpoints_json.name ds_inference_kwargs["checkpoint"] = checkpoints_json.name
model = deepspeed.init_inference(model, **ds_inference_kwargs) model = deepspeed.init_inference(model, **ds_inference_kwargs)

View File

@ -1578,7 +1578,7 @@ pub async fn run(
let cache = std::env::var("HUGGINGFACE_HUB_CACHE") let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ()) .map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into())) .map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default()); .unwrap_or_else(|_| Cache::from_env());
tracing::warn!("Offline mode active using cache defaults"); tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache) Type::Cache(cache)
} else { } else {