diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 0dfd8078..368060a0 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -19,8 +19,10 @@ from text_generation_server.models.t5 import T5Sharded try: from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded - from text_generation_server.models.flash_santacoder import FlashSantacoder, FlashSantacoderSharded - + from text_generation_server.models.flash_santacoder import ( + FlashSantacoder, + FlashSantacoderSharded, + ) FLASH_ATTENTION = torch.cuda.is_available() except ImportError: @@ -83,7 +85,9 @@ def get_model( if "bigcode" in model_id: if sharded: if not FLASH_ATTENTION: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")) + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") + ) return FlashSantacoderSharded(model_id, revision=revision) else: santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 1a961027..00c91ffe 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -88,10 +88,11 @@ class BLOOMSharded(BLOOM): filenames, quantize=quantize, device=device, + dtype=dtype, rank=self.rank, world_size=self.world_size, ) - self.model = model.eval().to(dtype) + self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, @@ -104,6 +105,7 @@ class BLOOMSharded(BLOOM): filenames: List[str], quantize: bool, device: torch.device, + dtype: torch.dtype, rank: int, world_size: int, ): @@ -153,7 +155,7 @@ class BLOOMSharded(BLOOM): f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" ) - tensor = tensor.contiguous() + tensor = tensor.contiguous().to(dtype) if quantize: if not HAS_BITS_AND_BYTES: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index aba538ea..8679826b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -209,7 +209,9 @@ class FlashMQAttention(torch.nn.Module): self.num_heads = self.num_heads // process_group.size() self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2)) self.c_proj = TensorParallelRowLinear( - hidden_size, hidden_size, process_group=process_group, + hidden_size, + hidden_size, + process_group=process_group, ) def forward( diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index a5f688c0..a8b38465 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -64,11 +64,12 @@ class FlashNeoXSharded(FlashNeoX): model, filenames, device=device, + dtype=dtype, rank=self.rank, world_size=self.world_size, ) model.post_load_weights() - self.model = model.eval().to(dtype) + self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( tokenizer=tokenizer, @@ -80,6 +81,7 @@ class FlashNeoXSharded(FlashNeoX): model, filenames: List[str], device: torch.device, + dtype: torch.dtype, rank: int, world_size: int, ): @@ -138,7 +140,7 @@ class FlashNeoXSharded(FlashNeoX): f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" ) - tensor = tensor.contiguous() + tensor = tensor.contiguous().to(dtype) if current_parameter_tensor is not None: module._parameters[param_name] = tensor diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index fd9de934..1f16f37f 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -58,11 +58,7 @@ class FlashSantacoder(FlashCausalLM): model = FlashSantacoderForCausalLM(config) self.load_weights( - model, - filenames, - device, - dtype, - config.architectures[0].startswith("GPT2") + model, filenames, device, dtype, config.architectures[0].startswith("GPT2") ) self.model = model.eval() @@ -77,7 +73,7 @@ class FlashSantacoder(FlashCausalLM): filenames: List[Path], device: torch.device, dtype: torch.dtype, - transpose: bool + transpose: bool, ): for filename in filenames: state_dict = torch.load(filename, map_location="cpu") @@ -179,7 +175,9 @@ class FlashSantacoderSharded(FlashSantacoder): raise NotImplementedError("FlashSantacoderSharded is only available on GPU") if quantize: - raise NotImplementedError("FlashSantacoderSharded does not support quantization") + raise NotImplementedError( + "FlashSantacoderSharded does not support quantization" + ) tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" @@ -247,7 +245,9 @@ class FlashSantacoderSharded(FlashSantacoder): block_size = size // world_size start = rank * block_size stop = (rank + 1) * block_size - tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop] + tensor = ( + slice_[start:stop] if dim == 0 else slice_[:, start:stop] + ) elif isinstance(module, TensorParallelRowLinear): if param_name == "weight": dim = 0 if transpose else 1 @@ -255,7 +255,11 @@ class FlashSantacoderSharded(FlashSantacoder): block_size = size // world_size start = rank * block_size stop = (rank + 1) * block_size - tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop] + tensor = ( + slice_[start:stop] + if dim == 0 + else slice_[:, start:stop] + ) else: tensor = slice_[:] # XXX: Hack for Rowlinear to add the bias only once. diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 396cc4f6..dc78aa8b 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -219,10 +219,11 @@ class GalacticaSharded(Galactica): filenames, quantize=quantize, device=device, + dtype=dtype, rank=self.rank, world_size=self.world_size, ) - self.model = model.eval().to(dtype) + self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, @@ -235,6 +236,7 @@ class GalacticaSharded(Galactica): filenames: List[str], quantize: bool, device: torch.device, + dtype: torch.dtype, rank: int, world_size: int, ): @@ -285,7 +287,7 @@ class GalacticaSharded(Galactica): f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" ) - tensor = tensor.contiguous() + tensor = tensor.contiguous().to(dtype) if quantize: if not HAS_BITS_AND_BYTES: diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index fb109ed7..489615e1 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -64,10 +64,11 @@ class GPTNeoxSharded(CausalLM): filenames, quantize=quantize, device=device, + dtype=dtype, rank=self.rank, world_size=self.world_size, ) - self.model = model.eval().to(dtype) + self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, @@ -80,6 +81,7 @@ class GPTNeoxSharded(CausalLM): filenames: List[str], quantize: bool, device: torch.device, + dtype: torch.dtype, rank: int, world_size: int, ): @@ -140,7 +142,7 @@ class GPTNeoxSharded(CausalLM): f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" ) - tensor = tensor.contiguous() + tensor = tensor.contiguous().to(dtype) if quantize: if not HAS_BITS_AND_BYTES: diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 85f0ac8c..8e5527c0 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -80,10 +80,11 @@ class OPTSharded(OPT): filenames, quantize=quantize, device=device, + dtype=dtype, rank=self.rank, world_size=self.world_size, ) - self.model = model.eval().to(dtype) + self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, @@ -96,6 +97,7 @@ class OPTSharded(OPT): filenames: List[str], quantize: bool, device: torch.device, + dtype: torch.dtype, rank: int, world_size: int, ): @@ -146,7 +148,7 @@ class OPTSharded(OPT): f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" ) - tensor = tensor.contiguous() + tensor = tensor.contiguous().to(dtype) if quantize: if not HAS_BITS_AND_BYTES: diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 5266eb8d..b9f77015 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -64,10 +64,11 @@ class T5Sharded(Seq2SeqLM): filenames, quantize=quantize, device=device, + dtype=dtype, rank=self.rank, world_size=self.world_size, ) - self.model = model.eval().to(dtype) + self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(Seq2SeqLM, self).__init__( tokenizer=tokenizer, @@ -80,6 +81,7 @@ class T5Sharded(Seq2SeqLM): filenames: List[str], quantize: bool, device: torch.device, + dtype: torch.dtype, rank: int, world_size: int, ): @@ -146,7 +148,7 @@ class T5Sharded(Seq2SeqLM): f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" ) - tensor = tensor.contiguous() + tensor = tensor.contiguous().to(dtype) if quantize: if not HAS_BITS_AND_BYTES: