From 6b6dec9ea1499f144ae7b450b8cf528f47d5ea40 Mon Sep 17 00:00:00 2001 From: jkaniecki <153085639+jkaniecki@users.noreply.github.com> Date: Wed, 21 Feb 2024 14:24:41 +0100 Subject: [PATCH] Transparent tokenizer uses explicit int32 (#31) (#60) Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com> --- server/text_generation_server/utils/tokens.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index d82a7f80..6d95af94 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -385,8 +385,8 @@ def make_tokenizer_optional(tokenizer): return int(i) all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')] for inner_text in text] - return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length-len(tokens)) + tokens for tokens in all_tokens]), - "attention_mask": torch.tensor([[0] * (max_length-len(tokens)) + [1]*len(tokens) for tokens in all_tokens])} + return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length-len(tokens)) + tokens for tokens in all_tokens], dtype=torch.int32), + "attention_mask": torch.tensor([[0] * (max_length-len(tokens)) + [1]*len(tokens) for tokens in all_tokens], dtype=torch.int32)} def decode( self,