mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
black
This commit is contained in:
parent
d199c71a32
commit
a87468ad86
@ -14,6 +14,7 @@ from text_generation_server.models.t5 import T5Sharded
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
||||||
|
|
||||||
FLASH_NEOX = torch.cuda.is_available()
|
FLASH_NEOX = torch.cuda.is_available()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
FLASH_NEOX = False
|
FLASH_NEOX = False
|
||||||
@ -48,7 +49,7 @@ torch.set_grad_enabled(False)
|
|||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if "facebook/galactica" in model_id:
|
if "facebook/galactica" in model_id:
|
||||||
if sharded:
|
if sharded:
|
||||||
|
@ -84,10 +84,9 @@ class FlashNeoXBatch(Batch):
|
|||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
tokenized_input = (
|
tokenized_input = tokenizer(r.inputs, return_tensors="pt")[
|
||||||
tokenizer(r.inputs, return_tensors="pt")["input_ids"]
|
"input_ids"
|
||||||
.squeeze(0)
|
].squeeze(0)
|
||||||
)
|
|
||||||
input_ids.append(tokenized_input)
|
input_ids.append(tokenized_input)
|
||||||
all_input_ids.append(tokenized_input.tolist())
|
all_input_ids.append(tokenized_input.tolist())
|
||||||
|
|
||||||
@ -96,9 +95,7 @@ class FlashNeoXBatch(Batch):
|
|||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
position_ids.append(
|
position_ids.append(torch.arange(0, input_length, dtype=torch.int32))
|
||||||
torch.arange(0, input_length, dtype=torch.int32)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
cu_seqlens.append(cumulative_length + input_length)
|
cu_seqlens.append(cumulative_length + input_length)
|
||||||
@ -188,7 +185,6 @@ class FlashNeoXBatch(Batch):
|
|||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
@ -319,9 +315,7 @@ class FlashNeoX(Model):
|
|||||||
logits = out[i].unsqueeze(0)
|
logits = out[i].unsqueeze(0)
|
||||||
|
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(
|
next_token_id, logprobs = next_token_chooser(all_input_ids, logits)
|
||||||
all_input_ids, logits
|
|
||||||
)
|
|
||||||
# Copy to cpu to avoid other copies when indexing and calling .item()
|
# Copy to cpu to avoid other copies when indexing and calling .item()
|
||||||
next_token_id = next_token_id.to("cpu", non_blocking=True)
|
next_token_id = next_token_id.to("cpu", non_blocking=True)
|
||||||
logprobs = logprobs.to("cpu")
|
logprobs = logprobs.to("cpu")
|
||||||
@ -435,9 +429,7 @@ class FlashNeoX(Model):
|
|||||||
next_batch_position_ids = torch.tensor(
|
next_batch_position_ids = torch.tensor(
|
||||||
next_batch_position_ids, dtype=torch.int32
|
next_batch_position_ids, dtype=torch.int32
|
||||||
)
|
)
|
||||||
next_batch_cu_seqlens = torch.tensor(
|
next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32)
|
||||||
next_batch_cu_seqlens, dtype=torch.int32
|
|
||||||
)
|
|
||||||
if len(next_batch_keep_indices) > 1:
|
if len(next_batch_keep_indices) > 1:
|
||||||
next_batch_input_ids = torch.concat(next_batch_input_ids)
|
next_batch_input_ids = torch.concat(next_batch_input_ids)
|
||||||
next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1)
|
next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1)
|
||||||
|
@ -39,7 +39,8 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):
|
def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):
|
||||||
if isinstance(input_ids, list):
|
if isinstance(input_ids, list):
|
||||||
assert (len(input_ids) >= 1
|
assert (
|
||||||
|
len(input_ids) >= 1
|
||||||
), "requires at least a 1 token prefix sequence to seed rng"
|
), "requires at least a 1 token prefix sequence to seed rng"
|
||||||
prev_token = input_ids[-1]
|
prev_token = input_ids[-1]
|
||||||
else:
|
else:
|
||||||
@ -52,15 +53,16 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
self.rng.manual_seed(self.hash_key * prev_token)
|
self.rng.manual_seed(self.hash_key * prev_token)
|
||||||
|
|
||||||
def _get_greenlist_ids(
|
def _get_greenlist_ids(
|
||||||
self, input_ids: Union[List[int], torch.LongTensor], max_value: int, device: torch.device
|
self,
|
||||||
|
input_ids: Union[List[int], torch.LongTensor],
|
||||||
|
max_value: int,
|
||||||
|
device: torch.device,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
# seed the rng using the previous tokens/prefix
|
# seed the rng using the previous tokens/prefix
|
||||||
self._seed_rng(input_ids)
|
self._seed_rng(input_ids)
|
||||||
|
|
||||||
greenlist_size = int(max_value * self.gamma)
|
greenlist_size = int(max_value * self.gamma)
|
||||||
vocab_permutation = torch.randperm(
|
vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng)
|
||||||
max_value, device=device, generator=self.rng
|
|
||||||
)
|
|
||||||
greenlist_ids = vocab_permutation[:greenlist_size]
|
greenlist_ids = vocab_permutation[:greenlist_size]
|
||||||
return greenlist_ids
|
return greenlist_ids
|
||||||
|
|
||||||
@ -83,7 +85,9 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor
|
self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
greenlist_ids = self._get_greenlist_ids(input_ids, scores.shape[-1], scores.device)
|
greenlist_ids = self._get_greenlist_ids(
|
||||||
|
input_ids, scores.shape[-1], scores.device
|
||||||
|
)
|
||||||
green_tokens_mask = self._calc_greenlist_mask(
|
green_tokens_mask = self._calc_greenlist_mask(
|
||||||
scores=scores, greenlist_token_ids=greenlist_ids
|
scores=scores, greenlist_token_ids=greenlist_ids
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user