fix: fix style

This commit is contained in:
baptiste 2025-02-25 13:16:11 +00:00
parent c08005a4cd
commit 31535bcde2
11 changed files with 18 additions and 40 deletions

View File

@ -7,9 +7,8 @@ from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi from huggingface_hub import hf_hub_download, HfApi
from typing import Optional from typing import Optional
from pathlib import Path from pathlib import Path
from typing import Optional, List, Dict from typing import List, Dict
# Needed to properly setup habana_frameworks # Needed to properly setup habana_frameworks
import text_generation_server.habana_quantization_env as hq_env
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
@ -31,6 +30,7 @@ from text_generation_server.utils.adapter import (
load_and_merge_adapters, load_and_merge_adapters,
AdapterInfo, AdapterInfo,
) )
from text_generation_server.adapters.lora import LoraWeights
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

View File

@ -737,7 +737,7 @@ class CausalLM(Model):
else: else:
if LAZY_MODE == 0: if LAZY_MODE == 0:
# It is said that "keep_input_mutations" is safe for inference to be done # It is said that "keep_input_mutations" is safe for inference to be done
dbg_trace("TORCH COMPILE", f"Torch compiling of model") dbg_trace("TORCH COMPILE", "Torch compiling of model")
model.model = torch.compile( model.model = torch.compile(
model.model, model.model,
backend="hpu_backend", backend="hpu_backend",
@ -932,7 +932,7 @@ class CausalLM(Model):
if self.has_position_ids: if self.has_position_ids:
kwargs["position_ids"] = position_ids kwargs["position_ids"] = position_ids
if bypass_hpu_graph != None: if bypass_hpu_graph is not None:
kwargs["bypass_hpu_graphs"] = bypass_hpu_graph kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
kwargs.update(self.kwargs) kwargs.update(self.kwargs)
@ -1303,7 +1303,7 @@ class CausalLM(Model):
try: try:
# max prefill batch size warmup # max prefill batch size warmup
_, prefill_batch, _ = self.generate_token([batch]) _, prefill_batch, _ = self.generate_token([batch])
except: except Exception:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
@ -1331,7 +1331,7 @@ class CausalLM(Model):
for seq_len in prefill_seqlen_list: for seq_len in prefill_seqlen_list:
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size) batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
_, prefill_batch, _ = self.generate_token([batch]) _, prefill_batch, _ = self.generate_token([batch])
except: except Exception:
prefill_batch_size_list.sort() prefill_batch_size_list.sort()
prefill_seqlen_list.sort() prefill_seqlen_list.sort()
raise RuntimeError( raise RuntimeError(
@ -1384,7 +1384,7 @@ class CausalLM(Model):
del decode_batch del decode_batch
batches.clear() batches.clear()
except: except Exception:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})." f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
f"You need to decrease `--max-batch-total-tokens`" f"You need to decrease `--max-batch-total-tokens`"

View File

@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
""" PyTorch Llava-NeXT model.""" """ PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple, Union from typing import List, Optional
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.models.llava_next.modeling_llava_next import ( from transformers.models.llava_next.modeling_llava_next import (
unpad_image, unpad_image,
) )

View File

@ -1,5 +1,4 @@
import inspect import inspect
from loguru import logger
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -13,7 +12,6 @@ from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights from text_generation_server.adapters.weights import LayerAdapterWeights
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
import time
BASE_MODEL_ADAPTER_ID = "__base_model__" BASE_MODEL_ADAPTER_ID = "__base_model__"

View File

@ -1,7 +1,5 @@
from loguru import logger
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
import os
from typing import List, Optional, Type from typing import List, Optional, Type
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM

View File

@ -5,8 +5,6 @@ import time
import math import math
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
import base64
import numpy
from opentelemetry import trace from opentelemetry import trace
from loguru import logger from loguru import logger
from typing import Iterable, Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
@ -15,7 +13,6 @@ import tempfile
import copy import copy
from text_generation_server.models import Model from text_generation_server.models import Model
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import ( from text_generation_server.models.causal_lm import (
@ -34,7 +31,6 @@ import text_generation_server.habana_quantization_env as hq_env
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
from text_generation_server.utils import ( from text_generation_server.utils import (
HeterogeneousNextTokenChooser, HeterogeneousNextTokenChooser,
StoppingCriteria,
make_tokenizer_optional, make_tokenizer_optional,
is_tokenizer_transparent, is_tokenizer_transparent,
pad_next_token_chooser_parameters, pad_next_token_chooser_parameters,
@ -47,8 +43,6 @@ from optimum.habana.checkpoint_utils import get_ds_injection_policy
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
AutoModel,
PreTrainedTokenizerBase,
AutoConfig, AutoConfig,
) )
from optimum.habana.checkpoint_utils import ( from optimum.habana.checkpoint_utils import (
@ -59,7 +53,6 @@ from optimum.habana.checkpoint_utils import (
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch,
Tokens, Tokens,
Generation, Generation,
GeneratedText, GeneratedText,
@ -116,7 +109,6 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
elif config.model_type == "llava_next": elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id] height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
from loguru import logger
return "<image>" * num_features return "<image>" * num_features
elif config.model_type == "paligemma": elif config.model_type == "paligemma":
@ -604,7 +596,7 @@ class VlmCausalLM(Model):
if LAZY_MODE == 0: if LAZY_MODE == 0:
# It is said that "keep_input_mutations" is safe for inference to be done # It is said that "keep_input_mutations" is safe for inference to be done
dbg_trace( dbg_trace(
"TORCH COMPILE", f'Torch compiling of model') "TORCH COMPILE", 'Torch compiling of model')
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
model = hq_env.setup_quantization(model) model = hq_env.setup_quantization(model)
@ -790,7 +782,7 @@ class VlmCausalLM(Model):
if self.has_position_ids: if self.has_position_ids:
kwargs["position_ids"] = position_ids kwargs["position_ids"] = position_ids
if bypass_hpu_graph != None: if bypass_hpu_graph is not None:
hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
kwargs.update(self.kwargs) kwargs.update(self.kwargs)
@ -1118,7 +1110,7 @@ class VlmCausalLM(Model):
try: try:
# max prefill batch size warmup # max prefill batch size warmup
_, prefill_batch, _ = self.generate_token([batch], is_warmup) _, prefill_batch, _ = self.generate_token([batch], is_warmup)
except: except Exception:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
@ -1158,7 +1150,7 @@ class VlmCausalLM(Model):
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
except: except Exception:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle following prefill and decode warmup." f"Not enough memory to handle following prefill and decode warmup."
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}" f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
@ -1209,7 +1201,7 @@ class VlmCausalLM(Model):
DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size) DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size)
max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS
MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens
except : except Exception:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle batch_size({batch_size}) decode warmup." f"Not enough memory to handle batch_size({batch_size}) decode warmup."
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"

View File

@ -2,7 +2,6 @@
import asyncio import asyncio
import os import os
import sys
import torch import torch
import time import time
import signal import signal

View File

@ -1,10 +1,8 @@
import os import os
from pathlib import Path from pathlib import Path
from loguru import logger from loguru import logger
import sys
from text_generation_server import server from text_generation_server import server
import argparse import argparse
from typing import List
from text_generation_server.utils.adapter import parse_lora_adapters from text_generation_server.utils.adapter import parse_lora_adapters

View File

@ -1,6 +1,5 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import text_generation_server.habana_quantization_env
from text_generation_server.utils.convert import convert_file, convert_files from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
@ -21,9 +20,6 @@ from text_generation_server.utils.tokens import (
FinishReason, FinishReason,
Sampling, Sampling,
Greedy, Greedy,
make_tokenizer_optional,
is_tokenizer_transparent,
pad_next_token_chooser_parameters,
) )
__all__ = [ __all__ = [

View File

@ -44,9 +44,7 @@ class FakeGroup:
def initialize_torch_distributed(): def initialize_torch_distributed():
import habana_frameworks.torch.core as htcore
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1")) world_size = int(os.getenv("WORLD_SIZE", "1"))
options = None options = None
@ -69,7 +67,7 @@ def initialize_torch_distributed():
raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).") raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).")
else: else:
try: try:
import oneccl_bindings_for_pytorch import oneccl_bindings_for_pytorch # noqa: F401
backend = "ccl" backend = "ccl"
if os.getenv("CCL_WORKER_COUNT", None) is None: if os.getenv("CCL_WORKER_COUNT", None) is None:

View File

@ -705,8 +705,8 @@ def make_tokenizer_optional(tokenizer):
): ):
assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer" assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer"
assert padding == "max_length" or padding == "longest", "inccorrect input arguments when calling TransparentTokenizer" assert padding == "max_length" or padding == "longest", "inccorrect input arguments when calling TransparentTokenizer"
assert return_token_type_ids == False, "inccorrect input arguments when calling TransparentTokenizer" assert not return_token_type_ids, "inccorrect input arguments when calling TransparentTokenizer"
assert truncation == True, "inccorrect input arguments when calling TransparentTokenizer" assert truncation, "inccorrect input arguments when calling TransparentTokenizer"
def str_token_to_int(i): def str_token_to_int(i):
if i == '?': if i == '?':
@ -727,7 +727,8 @@ def make_tokenizer_optional(tokenizer):
clean_up_tokenization_spaces: bool = None, clean_up_tokenization_spaces: bool = None,
**kwargs, **kwargs,
) -> str: ) -> str:
return ','.join(str(i) for i in to_py_obj(token_ids)) # I don't think this method is used anywhere and should be removed when doing refactoring
return ','.join(str(i) for i in to_py_obj(token_ids)) # noqa: F821
import os import os
if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true": if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":