mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
fix: revert non snapshot changes
This commit is contained in:
parent
68854d11ef
commit
b5f61e92b5
@ -26,6 +26,3 @@ install-selective-scan: install-causal-conv1d build-selective-scan
|
|||||||
cd mamba && pip install .
|
cd mamba && pip install .
|
||||||
|
|
||||||
build-all: build-causal-conv1d build-selective-scan
|
build-all: build-causal-conv1d build-selective-scan
|
||||||
|
|
||||||
install-ssm: install-causal-conv1d install-selective-scan
|
|
||||||
@echo "Selective scan model installed"
|
|
||||||
|
@ -609,16 +609,15 @@ class CausalLM(Model):
|
|||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
device_map = (
|
|
||||||
"auto"
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map=device_map,
|
device_map=(
|
||||||
|
"auto"
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
else None
|
||||||
|
),
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -629,11 +628,6 @@ class CausalLM(Model):
|
|||||||
):
|
):
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
# if device_map is "auto", it's unclear which device the model is on
|
|
||||||
# therefore, we need to get the device the model is on after loading
|
|
||||||
if device_map is not None:
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
if model.config.pad_token_id is not None:
|
if model.config.pad_token_id is not None:
|
||||||
tokenizer.pad_token_id = model.config.pad_token_id
|
tokenizer.pad_token_id = model.config.pad_token_id
|
||||||
|
@ -639,28 +639,21 @@ class Seq2SeqLM(Model):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
device_map = (
|
|
||||||
"auto"
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map=device_map,
|
device_map=(
|
||||||
|
"auto"
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
else None
|
||||||
|
),
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
# if device_map is "auto", it's unclear which device the model is on
|
|
||||||
# therefore, we need to get the device the model is on after loading
|
|
||||||
if device_map is not None:
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
Loading…
Reference in New Issue
Block a user