Non local file.

This commit is contained in:
Ubuntu 2023-05-04 13:22:42 +00:00
parent c3d12ae2d4
commit c126ca01d9

View File

@ -36,6 +36,7 @@ import dropout_layer_norm
from flash_attn.layers.rotary import RotaryEmbedding
# from safetensors.torch import load_file
from safetensors import safe_open
from huggingface_hub import hf_hub_download
HAS_BITS_AND_BYTES = True
try:
@ -185,7 +186,8 @@ class FastLinear(nn.Linear):
else:
raise ValueError("Need a specific class of Linear (TensorParallel, or regular Linear)")
with safe_open("/home/ubuntu/src/GPTQ-for-LLaMa/oasst-4bit-128g.safetensors", framework="pt", device=f"cuda:{rank}") as f:
filename = hf_hub_download("Narsil/oasst-gptq", filename="oasst-4bit-128g.safetensors")
with safe_open(filename, framework="pt", device=f"cuda:{rank}") as f:
if name == 'self_attn.query_key_value':
query_name = f'model.layers.{layer}.self_attn'
self.qlinear.qweight[:, : self.out_features // 3] = get_slice(f, f"{query_name}.q_proj.qweight")