diff --git a/.gitignore b/.gitignore index 19604d426..a4f9823a3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea target router/tokenizer.json +server/flash-attention diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 440cd6a9b..2f259b674 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -1,4 +1,5 @@ # coding=utf-8 + # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -112,6 +113,9 @@ class FastLinear(nn.Linear): ) -> None: super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) self.quantized = False + self.weight = self.weight.to(device="meta") + if bias: + self.bias = self.bias.to(device="meta") self.qlinear = None def prepare_weights(self, layer=None, name=None, quantize: Optional[str] = None): @@ -154,8 +158,12 @@ class FastLinear(nn.Linear): outfeatures=self.out_features, bias=bool(self.bias), ) - rank = int(os.getenv("RANK")) - world_size = int(os.getenv("WORLD_SIZE")) + try: + rank = int(os.getenv("RANK")) + world_size = int(os.getenv("WORLD_SIZE")) + except: + rank = 0 + world_size = 1 def get_row_slice(f, name): slice_ = f.get_slice(name) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 50c311a82..3ac458e1c 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -49,12 +49,13 @@ class FlashLlama(FlashCausalLM): ) # We do not use from_pretrained as we modified the model internal module layout - try: - filenames = weight_files(model_id, revision, ".bin") - # Local files not found - except LocalEntryNotFoundError: - hub_files = weight_hub_files(model_id, revision, ".bin") - filenames = download_weights(hub_files, model_id, revision) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + # try: + # filenames = weight_files(model_id, revision, ".bin") + # # Local files not found + # except LocalEntryNotFoundError: + # hub_files = weight_hub_files(model_id, revision, ".bin") + # filenames = download_weights(hub_files, model_id, revision) with init_empty_weights(): model = FlashLlamaForCausalLM(config) @@ -78,66 +79,77 @@ class FlashLlama(FlashCausalLM): dtype: torch.dtype, ): for filename in filenames: - state_dict = torch.load(filename, map_location="cpu") - for key, value in state_dict.items(): - value = value.to(device if not quantize else "cpu").to(dtype) + with safe_open( + filename, framework="pt", device=str(device) + ) as f: - layer_name = ".".join(key.split(".")[:4]) + for key in f.keys(): + # tmp + if "_proj" in key: + continue + value = f.get_tensor(key) + value = value.to(device if not quantize else "cpu").to(dtype) - # Fused qkv - if "q_proj" in key or "k_proj" in key or "v_proj" in key: - final_key = layer_name + ".query_key_value.weight" + layer_name = ".".join(key.split(".")[:4]) - # Fused gate and up projs - elif "gate_proj" in key or "up_proj" in key: - final_key = layer_name + ".gate_up_proj.weight" - else: - final_key = key + # Fused qkv + if "q_proj" in key or "k_proj" in key or "v_proj" in key: + final_key = layer_name + ".query_key_value.weight" - module_name, param_name = final_key.rsplit(".", 1) - module = model.get_submodule(module_name) - - try: - current_parameter_tensor = module._parameters[param_name] - except KeyError: - current_parameter_tensor = None - - if current_parameter_tensor is not None: - if current_parameter_tensor.device == torch.device("meta"): - # Init qkv - if "query_key_value" in final_key: - module._parameters[param_name] = value.new_empty( - (value.shape[0] * 3, value.shape[1]) - ) - # Init gate and up proj - elif "gate_up_proj" in final_key: - module._parameters[param_name] = value.new_empty( - (value.shape[0] * 2, value.shape[1]) - ) - - # Copy to correct slice - if "q_proj" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "k_proj" in key: - module._parameters[param_name][ - value.shape[0] : value.shape[0] * 2 - ] = value - elif "v_proj" in key: - module._parameters[param_name][value.shape[0] * 2 :] = value - elif "gate_proj" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "up_proj" in key: - module._parameters[param_name][value.shape[0] :] = value + # Fused gate and up projs + elif "gate_proj" in key or "up_proj" in key: + final_key = layer_name + ".gate_up_proj.weight" else: - if current_parameter_tensor.shape != value.shape: - raise ValueError( - f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" - ) - module._parameters[param_name] = value - else: - module._buffers[param_name] = value + final_key = key - del value + module_name, param_name = final_key.rsplit(".", 1) + module = model.get_submodule(module_name) + + try: + current_parameter_tensor = module._parameters[param_name] + except KeyError: + current_parameter_tensor = None + + if current_parameter_tensor is not None: + if current_parameter_tensor.device == torch.device("meta"): + # Init qkv + if "query_key_value" in final_key: + module._parameters[param_name] = value.new_empty( + (value.shape[0] * 3, value.shape[1]) + ) + # Init gate and up proj + elif "gate_up_proj" in final_key: + module._parameters[param_name] = value.new_empty( + (value.shape[0] * 2, value.shape[1]) + ) + + # Copy to correct slice + if "q_proj" in key: + module._parameters[param_name][: value.shape[0]] = value + elif "k_proj" in key: + module._parameters[param_name][ + value.shape[0] : value.shape[0] * 2 + ] = value + elif "v_proj" in key: + module._parameters[param_name][value.shape[0] * 2 :] = value + elif "gate_proj" in key: + module._parameters[param_name][: value.shape[0]] = value + elif "up_proj" in key: + module._parameters[param_name][value.shape[0] :] = value + else: + if current_parameter_tensor.shape != value.shape: + raise ValueError( + f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" + ) + module._parameters[param_name] = value + else: + module._buffers[param_name] = value + + del value + + + torch.cuda.empty_cache() + model.post_load_weights(quantize) uninitialized_parameters = [] for n, p in model.named_parameters(): @@ -148,9 +160,6 @@ class FlashLlama(FlashCausalLM): f"found uninitialized parameters in model: {uninitialized_parameters}" ) - torch.cuda.empty_cache() - model.post_load_weights(quantize) - class FlashLlamaSharded(FlashLlama): def __init__( @@ -214,11 +223,14 @@ class FlashLlamaSharded(FlashLlama): rank: int, world_size: int, ): + for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) ) as f: for name in f.keys(): + if "_proj" in name: + continue slice_ = f.get_slice(name) layer_name = ".".join(name.split(".")[:4]) @@ -312,6 +324,10 @@ class FlashLlamaSharded(FlashLlama): else: module._buffers[param_name] = tensor + + torch.cuda.empty_cache() + model.post_load_weights(quantize) + uninitialized_parameters = [] for n, p in model.named_parameters(): if p.data.device == torch.device("meta"): @@ -320,6 +336,3 @@ class FlashLlamaSharded(FlashLlama): raise RuntimeError( f"found uninitialized parameters in model: {uninitialized_parameters}" ) - - torch.cuda.empty_cache() - model.post_load_weights(quantize) diff --git a/server/text_generation_server/quant/quant_linear.py b/server/text_generation_server/quant/quant_linear.py index cdfe010f3..714c9296a 100644 --- a/server/text_generation_server/quant/quant_linear.py +++ b/server/text_generation_server/quant/quant_linear.py @@ -318,7 +318,7 @@ class QuantLinear(nn.Module): self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16).cuda()) self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32).cuda()) if bias: - self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) + self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16).cuda()) else: self.bias = None