This commit is contained in:
OlivierDehaene 2023-03-23 17:51:30 +01:00
parent d199c71a32
commit a87468ad86
3 changed files with 18 additions and 21 deletions

View File

@ -14,6 +14,7 @@ from text_generation_server.models.t5 import T5Sharded
try: try:
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
FLASH_NEOX = torch.cuda.is_available() FLASH_NEOX = torch.cuda.is_available()
except ImportError: except ImportError:
FLASH_NEOX = False FLASH_NEOX = False
@ -48,7 +49,7 @@ torch.set_grad_enabled(False)
def get_model( def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model: ) -> Model:
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:
if sharded: if sharded:

View File

@ -84,10 +84,9 @@ class FlashNeoXBatch(Batch):
# Parse batch # Parse batch
for r in pb.requests: for r in pb.requests:
tokenized_input = ( tokenized_input = tokenizer(r.inputs, return_tensors="pt")[
tokenizer(r.inputs, return_tensors="pt")["input_ids"] "input_ids"
.squeeze(0) ].squeeze(0)
)
input_ids.append(tokenized_input) input_ids.append(tokenized_input)
all_input_ids.append(tokenized_input.tolist()) all_input_ids.append(tokenized_input.tolist())
@ -96,9 +95,7 @@ class FlashNeoXBatch(Batch):
input_lengths.append(input_length) input_lengths.append(input_length)
# Position ids # Position ids
position_ids.append( position_ids.append(torch.arange(0, input_length, dtype=torch.int32))
torch.arange(0, input_length, dtype=torch.int32)
)
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length) cu_seqlens.append(cumulative_length + input_length)
@ -188,7 +185,6 @@ class FlashNeoXBatch(Batch):
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
) )
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
@ -319,9 +315,7 @@ class FlashNeoX(Model):
logits = out[i].unsqueeze(0) logits = out[i].unsqueeze(0)
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(all_input_ids, logits)
all_input_ids, logits
)
# Copy to cpu to avoid other copies when indexing and calling .item() # Copy to cpu to avoid other copies when indexing and calling .item()
next_token_id = next_token_id.to("cpu", non_blocking=True) next_token_id = next_token_id.to("cpu", non_blocking=True)
logprobs = logprobs.to("cpu") logprobs = logprobs.to("cpu")
@ -435,9 +429,7 @@ class FlashNeoX(Model):
next_batch_position_ids = torch.tensor( next_batch_position_ids = torch.tensor(
next_batch_position_ids, dtype=torch.int32 next_batch_position_ids, dtype=torch.int32
) )
next_batch_cu_seqlens = torch.tensor( next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32)
next_batch_cu_seqlens, dtype=torch.int32
)
if len(next_batch_keep_indices) > 1: if len(next_batch_keep_indices) > 1:
next_batch_input_ids = torch.concat(next_batch_input_ids) next_batch_input_ids = torch.concat(next_batch_input_ids)
next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1) next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1)

View File

@ -39,7 +39,8 @@ class WatermarkLogitsProcessor(LogitsProcessor):
def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]): def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):
if isinstance(input_ids, list): if isinstance(input_ids, list):
assert (len(input_ids) >= 1 assert (
len(input_ids) >= 1
), "requires at least a 1 token prefix sequence to seed rng" ), "requires at least a 1 token prefix sequence to seed rng"
prev_token = input_ids[-1] prev_token = input_ids[-1]
else: else:
@ -52,15 +53,16 @@ class WatermarkLogitsProcessor(LogitsProcessor):
self.rng.manual_seed(self.hash_key * prev_token) self.rng.manual_seed(self.hash_key * prev_token)
def _get_greenlist_ids( def _get_greenlist_ids(
self, input_ids: Union[List[int], torch.LongTensor], max_value: int, device: torch.device self,
input_ids: Union[List[int], torch.LongTensor],
max_value: int,
device: torch.device,
) -> List[int]: ) -> List[int]:
# seed the rng using the previous tokens/prefix # seed the rng using the previous tokens/prefix
self._seed_rng(input_ids) self._seed_rng(input_ids)
greenlist_size = int(max_value * self.gamma) greenlist_size = int(max_value * self.gamma)
vocab_permutation = torch.randperm( vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng)
max_value, device=device, generator=self.rng
)
greenlist_ids = vocab_permutation[:greenlist_size] greenlist_ids = vocab_permutation[:greenlist_size]
return greenlist_ids return greenlist_ids
@ -83,7 +85,9 @@ class WatermarkLogitsProcessor(LogitsProcessor):
def __call__( def __call__(
self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
greenlist_ids = self._get_greenlist_ids(input_ids, scores.shape[-1], scores.device) greenlist_ids = self._get_greenlist_ids(
input_ids, scores.shape[-1], scores.device
)
green_tokens_mask = self._calc_greenlist_mask( green_tokens_mask = self._calc_greenlist_mask(
scores=scores, greenlist_token_ids=greenlist_ids scores=scores, greenlist_token_ids=greenlist_ids
) )