diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index d82a7f80..6d95af94 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -385,8 +385,8 @@ def make_tokenizer_optional(tokenizer): return int(i) all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')] for inner_text in text] - return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length-len(tokens)) + tokens for tokens in all_tokens]), - "attention_mask": torch.tensor([[0] * (max_length-len(tokens)) + [1]*len(tokens) for tokens in all_tokens])} + return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length-len(tokens)) + tokens for tokens in all_tokens], dtype=torch.int32), + "attention_mask": torch.tensor([[0] * (max_length-len(tokens)) + [1]*len(tokens) for tokens in all_tokens], dtype=torch.int32)} def decode( self,