Removing more dead code.

This commit is contained in:
Nicolas Patry 2024-07-02 16:46:52 +00:00
parent dbf9292afc
commit 24bbd7b822
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674
3 changed files with 11 additions and 8 deletions

View File

@ -627,10 +627,11 @@ class CausalLM(Model):
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch return CausalLMBatch
def decode(self, generated_ids: List[int]) -> str: # This is not used anymore
return self.tokenizer.decode( # def decode(self, generated_ids: List[int]) -> str:
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False # return self.tokenizer.decode(
) # generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
# )
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None

View File

@ -827,6 +827,7 @@ class FlashCausalLM(Model):
aliases=None, aliases=None,
# Used for Santacoder override of config # Used for Santacoder override of config
num_kv_heads=None, num_kv_heads=None,
skip_special_tokens: bool = True,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():

View File

@ -668,10 +668,11 @@ class Seq2SeqLM(Model):
def batch_type(self) -> Type[Seq2SeqLMBatch]: def batch_type(self) -> Type[Seq2SeqLMBatch]:
return Seq2SeqLMBatch return Seq2SeqLMBatch
def decode(self, decoder_ids: List[int]) -> str: # Not used anymore
return self.tokenizer.decode( # def decode(self, decoder_ids: List[int]) -> str:
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False # return self.tokenizer.decode(
) # decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
# )
def forward( def forward(
self, self,