From c126ca01d9bd4044f65331a02dc994dbfc3cff41 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 4 May 2023 13:22:42 +0000 Subject: [PATCH] Non local file. --- .../models/custom_modeling/flash_llama_modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index b9440052..440cd6a9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -36,6 +36,7 @@ import dropout_layer_norm from flash_attn.layers.rotary import RotaryEmbedding # from safetensors.torch import load_file from safetensors import safe_open +from huggingface_hub import hf_hub_download HAS_BITS_AND_BYTES = True try: @@ -185,7 +186,8 @@ class FastLinear(nn.Linear): else: raise ValueError("Need a specific class of Linear (TensorParallel, or regular Linear)") - with safe_open("/home/ubuntu/src/GPTQ-for-LLaMa/oasst-4bit-128g.safetensors", framework="pt", device=f"cuda:{rank}") as f: + filename = hf_hub_download("Narsil/oasst-gptq", filename="oasst-4bit-128g.safetensors") + with safe_open(filename, framework="pt", device=f"cuda:{rank}") as f: if name == 'self_attn.query_key_value': query_name = f'model.layers.{layer}.self_attn' self.qlinear.qweight[:, : self.out_features // 3] = get_slice(f, f"{query_name}.q_proj.qweight")