mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
black
This commit is contained in:
parent
d199c71a32
commit
a87468ad86
@ -14,6 +14,7 @@ from text_generation_server.models.t5 import T5Sharded
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
||||
|
||||
FLASH_NEOX = torch.cuda.is_available()
|
||||
except ImportError:
|
||||
FLASH_NEOX = False
|
||||
@ -48,7 +49,7 @@ torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def get_model(
|
||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||
) -> Model:
|
||||
if "facebook/galactica" in model_id:
|
||||
if sharded:
|
||||
|
@ -84,10 +84,9 @@ class FlashNeoXBatch(Batch):
|
||||
|
||||
# Parse batch
|
||||
for r in pb.requests:
|
||||
tokenized_input = (
|
||||
tokenizer(r.inputs, return_tensors="pt")["input_ids"]
|
||||
.squeeze(0)
|
||||
)
|
||||
tokenized_input = tokenizer(r.inputs, return_tensors="pt")[
|
||||
"input_ids"
|
||||
].squeeze(0)
|
||||
input_ids.append(tokenized_input)
|
||||
all_input_ids.append(tokenized_input.tolist())
|
||||
|
||||
@ -96,9 +95,7 @@ class FlashNeoXBatch(Batch):
|
||||
input_lengths.append(input_length)
|
||||
|
||||
# Position ids
|
||||
position_ids.append(
|
||||
torch.arange(0, input_length, dtype=torch.int32)
|
||||
)
|
||||
position_ids.append(torch.arange(0, input_length, dtype=torch.int32))
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlens.append(cumulative_length + input_length)
|
||||
@ -188,7 +185,6 @@ class FlashNeoXBatch(Batch):
|
||||
stopping_criterias=stopping_criterias,
|
||||
)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
@ -319,9 +315,7 @@ class FlashNeoX(Model):
|
||||
logits = out[i].unsqueeze(0)
|
||||
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
all_input_ids, logits
|
||||
)
|
||||
next_token_id, logprobs = next_token_chooser(all_input_ids, logits)
|
||||
# Copy to cpu to avoid other copies when indexing and calling .item()
|
||||
next_token_id = next_token_id.to("cpu", non_blocking=True)
|
||||
logprobs = logprobs.to("cpu")
|
||||
@ -435,9 +429,7 @@ class FlashNeoX(Model):
|
||||
next_batch_position_ids = torch.tensor(
|
||||
next_batch_position_ids, dtype=torch.int32
|
||||
)
|
||||
next_batch_cu_seqlens = torch.tensor(
|
||||
next_batch_cu_seqlens, dtype=torch.int32
|
||||
)
|
||||
next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32)
|
||||
if len(next_batch_keep_indices) > 1:
|
||||
next_batch_input_ids = torch.concat(next_batch_input_ids)
|
||||
next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1)
|
||||
|
@ -39,7 +39,8 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):
|
||||
if isinstance(input_ids, list):
|
||||
assert (len(input_ids) >= 1
|
||||
assert (
|
||||
len(input_ids) >= 1
|
||||
), "requires at least a 1 token prefix sequence to seed rng"
|
||||
prev_token = input_ids[-1]
|
||||
else:
|
||||
@ -52,15 +53,16 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
||||
self.rng.manual_seed(self.hash_key * prev_token)
|
||||
|
||||
def _get_greenlist_ids(
|
||||
self, input_ids: Union[List[int], torch.LongTensor], max_value: int, device: torch.device
|
||||
self,
|
||||
input_ids: Union[List[int], torch.LongTensor],
|
||||
max_value: int,
|
||||
device: torch.device,
|
||||
) -> List[int]:
|
||||
# seed the rng using the previous tokens/prefix
|
||||
self._seed_rng(input_ids)
|
||||
|
||||
greenlist_size = int(max_value * self.gamma)
|
||||
vocab_permutation = torch.randperm(
|
||||
max_value, device=device, generator=self.rng
|
||||
)
|
||||
vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng)
|
||||
greenlist_ids = vocab_permutation[:greenlist_size]
|
||||
return greenlist_ids
|
||||
|
||||
@ -83,7 +85,9 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
||||
def __call__(
|
||||
self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
greenlist_ids = self._get_greenlist_ids(input_ids, scores.shape[-1], scores.device)
|
||||
greenlist_ids = self._get_greenlist_ids(
|
||||
input_ids, scores.shape[-1], scores.device
|
||||
)
|
||||
green_tokens_mask = self._calc_greenlist_mask(
|
||||
scores=scores, greenlist_token_ids=greenlist_ids
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user