mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Non local file.
This commit is contained in:
parent
c3d12ae2d4
commit
c126ca01d9
@ -36,6 +36,7 @@ import dropout_layer_norm
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
# from safetensors.torch import load_file
|
||||
from safetensors import safe_open
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
@ -185,7 +186,8 @@ class FastLinear(nn.Linear):
|
||||
else:
|
||||
raise ValueError("Need a specific class of Linear (TensorParallel, or regular Linear)")
|
||||
|
||||
with safe_open("/home/ubuntu/src/GPTQ-for-LLaMa/oasst-4bit-128g.safetensors", framework="pt", device=f"cuda:{rank}") as f:
|
||||
filename = hf_hub_download("Narsil/oasst-gptq", filename="oasst-4bit-128g.safetensors")
|
||||
with safe_open(filename, framework="pt", device=f"cuda:{rank}") as f:
|
||||
if name == 'self_attn.query_key_value':
|
||||
query_name = f'model.layers.{layer}.self_attn'
|
||||
self.qlinear.qweight[:, : self.out_features // 3] = get_slice(f, f"{query_name}.q_proj.qweight")
|
||||
|
Loading…
Reference in New Issue
Block a user