From 9b062483953c5dec1f32bb43684bfd88d19b2306 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 11 Apr 2023 18:37:00 +0200 Subject: [PATCH] fix --- server/text_generation/models/__init__.py | 91 ------------------- .../text_generation_server/models/__init__.py | 26 ++++-- server/text_generation_server/models/bloom.py | 2 +- .../models/galactica.py | 10 +- .../text_generation_server/models/gpt_neox.py | 2 +- server/text_generation_server/models/opt.py | 8 +- server/text_generation_server/models/t5.py | 2 +- server/text_generation_server/utils/hub.py | 5 - 8 files changed, 30 insertions(+), 116 deletions(-) delete mode 100644 server/text_generation/models/__init__.py diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py deleted file mode 100644 index 5b89d275..00000000 --- a/server/text_generation/models/__init__.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch - -from transformers import AutoConfig -from typing import Optional - -from text_generation.models.model import Model -from text_generation.models.causal_lm import CausalLM -from text_generation.models.bloom import BLOOM, BLOOMSharded -from text_generation.models.seq2seq_lm import Seq2SeqLM -from text_generation.models.galactica import Galactica, GalacticaSharded -from text_generation.models.santacoder import SantaCoder -from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded -from text_generation.models.opt import OPT, OPTSharded -from text_generation.models.t5 import T5Sharded - -__all__ = [ - "Model", - "BLOOM", - "BLOOMSharded", - "CausalLM", - "Galactica", - "GalacticaSharded", - "GPTNeox", - "GPTNeoxSharded", - "Seq2SeqLM", - "Galactica", - "GalacticaSharded", - "SantaCoder", - "GPTNeox", - "GPTNeoxSharded", - "OPT", - "OPTSharded", - "T5Sharded", - "get_model", -] - -# The flag below controls whether to allow TF32 on matmul. This flag defaults to False -# in PyTorch 1.12 and later. -torch.backends.cuda.matmul.allow_tf32 = True - -# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. -torch.backends.cudnn.allow_tf32 = True - -# Disable gradients -torch.set_grad_enabled(False) - - -def get_model( - model_id: str, revision: Optional[str], sharded: bool, quantize: bool -) -> Model: - if model_id.startswith("facebook/galactica"): - if sharded: - return GalacticaSharded(model_id, revision, quantize=quantize) - else: - return Galactica(model_id, revision, quantize=quantize) - - if "santacoder" in model_id: - return SantaCoder(model_id, revision, quantize) - - config = AutoConfig.from_pretrained(model_id, revision=revision) - - if config.model_type == "bloom": - if sharded: - return BLOOMSharded(model_id, revision, quantize=quantize) - else: - return BLOOM(model_id, revision, quantize=quantize) - - if config.model_type == "gpt_neox": - if sharded: - return GPTNeoxSharded(model_id, revision, quantize=quantize) - else: - return GPTNeox(model_id, revision, quantize=quantize) - - if config.model_type == "t5": - if sharded: - return T5Sharded(model_id, revision, quantize=quantize) - else: - return Seq2SeqLM(model_id, revision, quantize=quantize) - - if config.model_type == "opt": - if sharded: - return OPTSharded(model_id, revision, quantize=quantize) - else: - return OPT(model_id, revision, quantize=quantize) - - if sharded: - raise ValueError("sharded is not supported for AutoModel") - try: - return CausalLM(model_id, revision, quantize=quantize) - except Exception: - return Seq2SeqLM(model_id, revision, quantize=quantize) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 1e06b6dc..c04ae117 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,4 +1,3 @@ -import os import torch from loguru import logger @@ -11,6 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM +from text_generation_server.models.opt import OPT, OPTSharded from text_generation_server.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.gpt_neox import GPTNeoxSharded @@ -36,7 +36,11 @@ __all__ = [ "GalacticaSharded", "GPTNeoxSharded", "Seq2SeqLM", + "Galactica", + "GalacticaSharded", "SantaCoder", + "OPT", + "OPTSharded", "T5Sharded", "get_model", ] @@ -48,9 +52,11 @@ if FLASH_ATTENTION: __all__.append(FlashLlama) __all__.append(FlashLlamaSharded) -FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention CUDA kernels to be installed.\n" \ - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " \ - "or install flash attention with `cd server && make install install-flash-attention`" +FLASH_ATT_ERROR_MESSAGE = ( + "{} requires Flash Attention CUDA kernels to be installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" +) # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -64,7 +70,7 @@ torch.set_grad_enabled(False) def get_model( - model_id: str, revision: Optional[str], sharded: bool, quantize: bool + model_id: str, revision: Optional[str], sharded: bool, quantize: bool ) -> Model: if "facebook/galactica" in model_id: if sharded: @@ -100,13 +106,17 @@ def get_model( if sharded: if FLASH_ATTENTION: return FlashLlamaSharded(model_id, revision, quantize=quantize) - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama") - ) + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama")) else: llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM return llama_cls(model_id, revision, quantize=quantize) + if config.model_type == "opt": + if sharded: + return OPTSharded(model_id, revision, quantize=quantize) + else: + return OPT(model_id, revision, quantize=quantize) + if model_type == "t5": if sharded: return T5Sharded(model_id, revision, quantize=quantize) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 5a6a9c0d..1a961027 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -62,7 +62,7 @@ class BLOOMSharded(BLOOM): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 0022a50d..58daee0b 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import ( TensorParallelRowLinear, ) -from text_generation.models import CausalLMBatch -from text_generation.pb import generate_pb2 -from text_generation.models.opt import OPT, OPTSharded -from text_generation.utils import ( +from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.opt import OPT, OPTSharded +from text_generation_server.utils import ( NextTokenChooser, StoppingCriteria, initialize_torch_distributed, @@ -192,7 +192,7 @@ class GalacticaSharded(OPTSharded): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index b81976da..fb109ed7 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 3569e77c..85f0ac8c 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -16,8 +16,8 @@ from transformers.models.opt.parallel_layers import ( TensorParallelRowLinear, ) -from text_generation.models import CausalLM -from text_generation.utils import ( +from text_generation_server.models import CausalLM +from text_generation_server.utils import ( initialize_torch_distributed, weight_files, ) @@ -54,13 +54,13 @@ class OPTSharded(OPT): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left" + model_id, revision=revision, padding_side="left", truncation_side="left" ) config = AutoConfig.from_pretrained( diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 300b376e..5266eb8d 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index d338fb29..4feec8a1 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -50,7 +50,6 @@ def try_to_load_from_cache( refs_dir = repo_cache / "refs" snapshots_dir = repo_cache / "snapshots" - no_exist_dir = repo_cache / ".no_exist" # Resolve refs (for instance to convert main to the associated commit sha) if refs_dir.is_dir(): @@ -59,10 +58,6 @@ def try_to_load_from_cache( with revision_file.open() as f: revision = f.read() - # Check if file is cached as "no_exist" - if (no_exist_dir / revision / filename).is_file(): - return None - # Check if revision folder exists if not snapshots_dir.exists(): return None