fix: revert non snapshot changes

This commit is contained in:
drbh 2024-07-29 14:05:55 +00:00
parent 68854d11ef
commit b5f61e92b5
3 changed files with 10 additions and 26 deletions

View File

@ -26,6 +26,3 @@ install-selective-scan: install-causal-conv1d build-selective-scan
cd mamba && pip install . cd mamba && pip install .
build-all: build-causal-conv1d build-selective-scan build-all: build-causal-conv1d build-selective-scan
install-ssm: install-causal-conv1d install-selective-scan
@echo "Selective scan model installed"

View File

@ -609,16 +609,15 @@ class CausalLM(Model):
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
device_map = (
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map=device_map, device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -629,11 +628,6 @@ class CausalLM(Model):
): ):
model = model.cuda() model = model.cuda()
# if device_map is "auto", it's unclear which device the model is on
# therefore, we need to get the device the model is on after loading
if device_map is not None:
device = next(model.parameters()).device
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None: if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id tokenizer.pad_token_id = model.config.pad_token_id

View File

@ -639,28 +639,21 @@ class Seq2SeqLM(Model):
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.float32 if dtype is None else dtype
device_map = (
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
)
model = AutoModelForSeq2SeqLM.from_pretrained( model = AutoModelForSeq2SeqLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map=device_map, device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1: if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda() model = model.cuda()
# if device_map is "auto", it's unclear which device the model is on
# therefore, we need to get the device the model is on after loading
if device_map is not None:
device = next(model.parameters()).device
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,