from typing import Optional, List import torch from text_generation_server.models import CausalLM FIM_PREFIX = "" FIM_MIDDLE = "" FIM_SUFFIX = "" FIM_PAD = "" EOD = "<|endoftext|>" class SantaCoder(CausalLM): def __init__( self, model_id: str, revision: Optional[str] = None, use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): super().__init__( model_id=model_id, revision=revision, use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) self.tokenizer.add_special_tokens( { "additional_special_tokens": [ EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, ], "pad_token": EOD, } ) def decode(self, generated_ids: List[int]) -> str: # Do not skip special tokens as they are used for custom parsing rules of the generated text return self.tokenizer.decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)