fix truncation

This commit is contained in:
OlivierDehaene 2023-04-09 09:55:05 +02:00
parent 146e0e27ce
commit 82464709d3
11 changed files with 12 additions and 10 deletions

View File

@ -68,7 +68,7 @@ class BLOOMSharded(BLOOM):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(

View File

@ -303,7 +303,7 @@ class CausalLM(Model):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,

View File

@ -224,7 +224,7 @@ class FlashCausalLM(Model):
raise NotImplementedError("FlashCausalLM does not support quantization") raise NotImplementedError("FlashCausalLM does not support quantization")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
self.model = ( self.model = (
model_cls.from_pretrained( model_cls.from_pretrained(

View File

@ -42,6 +42,7 @@ class FlashLlama(FlashCausalLM):
model_id, model_id,
revision=revision, revision=revision,
padding_side="left", padding_side="left",
truncation_side="left",
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
@ -160,6 +161,7 @@ class FlashLlamaSharded(FlashLlama):
model_id, model_id,
revision=revision, revision=revision,
padding_side="left", padding_side="left",
truncation_side="left",
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(

View File

@ -45,7 +45,7 @@ class FlashNeoXSharded(FlashNeoX):
raise NotImplementedError("FlashNeoX does not support quantization") raise NotImplementedError("FlashNeoX does not support quantization")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(

View File

@ -33,7 +33,7 @@ class FlashSantacoder(FlashCausalLM):
raise NotImplementedError("FlashSantacoder does not support quantization") raise NotImplementedError("FlashSantacoder does not support quantization")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(

View File

@ -198,7 +198,7 @@ class GalacticaSharded(Galactica):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(

View File

@ -44,7 +44,7 @@ class GPTNeoxSharded(CausalLM):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token

View File

@ -26,7 +26,7 @@ class SantaCoder(CausalLM):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{ {

View File

@ -349,7 +349,7 @@ class Seq2SeqLM(Model):
load_in_8bit=quantize, load_in_8bit=quantize,
).eval() ).eval()
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
tokenizer.bos_token_id = self.model.config.decoder_start_token_id tokenizer.bos_token_id = self.model.config.decoder_start_token_id

View File

@ -44,7 +44,7 @@ class T5Sharded(Seq2SeqLM):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(