mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-21 01:32:08 +00:00
Fix HF_HUB_OFFLINE=1
for Gaudi backend (#3193)
* Fix `HF_HUB_OFFLINE=1` for Gaudi backend * Fix HF cache default value in server.rs * Format
This commit is contained in:
parent
7253be349a
commit
f208ba6afc
@ -8,7 +8,7 @@ PYTORCH_VERSION := 2.6.0
|
|||||||
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
||||||
|
|
||||||
image:
|
image:
|
||||||
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
|
docker build --ulimit nofile=4096 -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
|
||||||
|
|
||||||
run-local-dev-container:
|
run-local-dev-container:
|
||||||
docker run -it \
|
docker run -it \
|
||||||
|
@ -4,6 +4,7 @@ import bisect
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import itertools
|
import itertools
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -17,15 +18,12 @@ from loguru import logger
|
|||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
|
|
||||||
import text_generation_server.habana_quantization_env as hq_env
|
import text_generation_server.habana_quantization_env as hq_env
|
||||||
|
from text_generation_server.utils import weight_files
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
from optimum.habana.utils import HabanaProfile
|
from optimum.habana.utils import HabanaProfile
|
||||||
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||||
from text_generation_server.utils.chunks import concat_text_chunks
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from optimum.habana.checkpoint_utils import (
|
from optimum.habana.checkpoint_utils import model_on_meta
|
||||||
get_repo_root,
|
|
||||||
model_on_meta,
|
|
||||||
write_checkpoints_json,
|
|
||||||
)
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@ -708,6 +706,9 @@ class CausalLM(Model):
|
|||||||
if hq_env.is_quantization_enabled:
|
if hq_env.is_quantization_enabled:
|
||||||
htorch.core.hpu_set_env()
|
htorch.core.hpu_set_env()
|
||||||
|
|
||||||
|
# Get weight files
|
||||||
|
weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
os.environ.setdefault(
|
os.environ.setdefault(
|
||||||
"DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
|
"DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
|
||||||
@ -715,8 +716,6 @@ class CausalLM(Model):
|
|||||||
model = self.get_deepspeed_model(model_id, dtype, revision)
|
model = self.get_deepspeed_model(model_id, dtype, revision)
|
||||||
model = hq_env.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
else:
|
else:
|
||||||
get_repo_root(model_id)
|
|
||||||
|
|
||||||
# Check support for rope scaling
|
# Check support for rope scaling
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
config = AutoConfig.from_pretrained(model_id)
|
config = AutoConfig.from_pretrained(model_id)
|
||||||
@ -868,7 +867,6 @@ class CausalLM(Model):
|
|||||||
with deepspeed.OnDevice(dtype=dtype, device="meta"):
|
with deepspeed.OnDevice(dtype=dtype, device="meta"):
|
||||||
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
|
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
|
||||||
else:
|
else:
|
||||||
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
|
|
||||||
# TODO: revisit placement on CPU when auto-injection is possible
|
# TODO: revisit placement on CPU when auto-injection is possible
|
||||||
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
@ -884,7 +882,16 @@ class CausalLM(Model):
|
|||||||
if load_to_meta:
|
if load_to_meta:
|
||||||
# model loaded to meta is managed differently
|
# model loaded to meta is managed differently
|
||||||
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
|
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
|
||||||
write_checkpoints_json(model_id, local_rank, checkpoints_json)
|
checkpoint_files = [
|
||||||
|
str(f)
|
||||||
|
for f in weight_files(
|
||||||
|
model_id, revision=revision, extension=".safetensors"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0}
|
||||||
|
json.dump(data, checkpoints_json)
|
||||||
|
checkpoints_json.flush()
|
||||||
|
|
||||||
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
||||||
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
@ -12,6 +13,7 @@ import tempfile
|
|||||||
import copy
|
import copy
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from text_generation_server.utils import weight_files
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.causal_lm import (
|
from text_generation_server.models.causal_lm import (
|
||||||
@ -43,11 +45,7 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
)
|
)
|
||||||
from optimum.habana.checkpoint_utils import (
|
from optimum.habana.checkpoint_utils import model_on_meta
|
||||||
get_repo_root,
|
|
||||||
model_on_meta,
|
|
||||||
write_checkpoints_json,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
@ -840,6 +838,9 @@ class VlmCausalLM(Model):
|
|||||||
if hq_env.is_quantization_enabled:
|
if hq_env.is_quantization_enabled:
|
||||||
htorch.core.hpu_set_env()
|
htorch.core.hpu_set_env()
|
||||||
|
|
||||||
|
# Get weight files
|
||||||
|
weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
os.environ.setdefault(
|
os.environ.setdefault(
|
||||||
"DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
|
"DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
|
||||||
@ -847,8 +848,6 @@ class VlmCausalLM(Model):
|
|||||||
model = self.get_deepspeed_model(model_class, model_id, dtype, revision)
|
model = self.get_deepspeed_model(model_class, model_id, dtype, revision)
|
||||||
model = hq_env.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
else:
|
else:
|
||||||
get_repo_root(model_id)
|
|
||||||
|
|
||||||
# Check support for rope scaling
|
# Check support for rope scaling
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
config = AutoConfig.from_pretrained(model_id)
|
config = AutoConfig.from_pretrained(model_id)
|
||||||
@ -1000,7 +999,6 @@ class VlmCausalLM(Model):
|
|||||||
with deepspeed.OnDevice(dtype=dtype, device="meta"):
|
with deepspeed.OnDevice(dtype=dtype, device="meta"):
|
||||||
model = model_class.from_config(config, torch_dtype=dtype)
|
model = model_class.from_config(config, torch_dtype=dtype)
|
||||||
else:
|
else:
|
||||||
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
|
|
||||||
# TODO: revisit placement on CPU when auto-injection is possible
|
# TODO: revisit placement on CPU when auto-injection is possible
|
||||||
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(
|
||||||
@ -1019,7 +1017,15 @@ class VlmCausalLM(Model):
|
|||||||
if load_to_meta:
|
if load_to_meta:
|
||||||
# model loaded to meta is managed differently
|
# model loaded to meta is managed differently
|
||||||
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
|
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
|
||||||
write_checkpoints_json(model_id, local_rank, checkpoints_json)
|
checkpoint_files = [
|
||||||
|
str(f)
|
||||||
|
for f in weight_files(
|
||||||
|
model_id, revision=revision, extension=".safetensors"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0}
|
||||||
|
json.dump(data, checkpoints_json)
|
||||||
|
checkpoints_json.flush()
|
||||||
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
||||||
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
||||||
|
|
||||||
|
@ -1578,7 +1578,7 @@ pub async fn run(
|
|||||||
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
|
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
|
||||||
.map_err(|_| ())
|
.map_err(|_| ())
|
||||||
.map(|cache_dir| Cache::new(cache_dir.into()))
|
.map(|cache_dir| Cache::new(cache_dir.into()))
|
||||||
.unwrap_or_else(|_| Cache::default());
|
.unwrap_or_else(|_| Cache::from_env());
|
||||||
tracing::warn!("Offline mode active using cache defaults");
|
tracing::warn!("Offline mode active using cache defaults");
|
||||||
Type::Cache(cache)
|
Type::Cache(cache)
|
||||||
} else {
|
} else {
|
||||||
|
Loading…
Reference in New Issue
Block a user