fix truncation

This commit is contained in:
OlivierDehaene 2023-04-09 09:55:05 +02:00
parent 146e0e27ce
commit 82464709d3
11 changed files with 12 additions and 10 deletions

View File

@ -68,7 +68,7 @@ class BLOOMSharded(BLOOM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
config = AutoConfig.from_pretrained(

View File

@ -303,7 +303,7 @@ class CausalLM(Model):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,

View File

@ -224,7 +224,7 @@ class FlashCausalLM(Model):
raise NotImplementedError("FlashCausalLM does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
self.model = (
model_cls.from_pretrained(

View File

@ -42,6 +42,7 @@ class FlashLlama(FlashCausalLM):
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
)
config = AutoConfig.from_pretrained(
@ -160,6 +161,7 @@ class FlashLlamaSharded(FlashLlama):
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
)
config = AutoConfig.from_pretrained(

View File

@ -45,7 +45,7 @@ class FlashNeoXSharded(FlashNeoX):
raise NotImplementedError("FlashNeoX does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
config = AutoConfig.from_pretrained(

View File

@ -33,7 +33,7 @@ class FlashSantacoder(FlashCausalLM):
raise NotImplementedError("FlashSantacoder does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
config = AutoConfig.from_pretrained(

View File

@ -198,7 +198,7 @@ class GalacticaSharded(Galactica):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
config = AutoConfig.from_pretrained(

View File

@ -44,7 +44,7 @@ class GPTNeoxSharded(CausalLM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
tokenizer.pad_token = tokenizer.eos_token

View File

@ -26,7 +26,7 @@ class SantaCoder(CausalLM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
tokenizer.add_special_tokens(
{

View File

@ -349,7 +349,7 @@ class Seq2SeqLM(Model):
load_in_8bit=quantize,
).eval()
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
tokenizer.bos_token_id = self.model.config.decoder_start_token_id

View File

@ -44,7 +44,7 @@ class T5Sharded(Seq2SeqLM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
config = AutoConfig.from_pretrained(