mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:02:13 +00:00
48 lines
1.3 KiB
Python
48 lines
1.3 KiB
Python
|
import torch
|
||
|
from dataclasses import dataclass
|
||
|
from typing import List, Optional, Type
|
||
|
|
||
|
from text_generation_server.models import CausalLM
|
||
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class StarCoderCausalLMBatch(CausalLMBatch):
|
||
|
past_key_values: Optional[List[torch.Tensor]]
|
||
|
|
||
|
def detach_kv_cache(self):
|
||
|
past_keys = []
|
||
|
past_values = []
|
||
|
last_dim = int(self.past_key_values[0].size(dim=-1) / 2)
|
||
|
for key_value in self.past_key_values:
|
||
|
past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0])
|
||
|
past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1])
|
||
|
del self.past_key_values
|
||
|
|
||
|
return past_keys, past_values
|
||
|
|
||
|
def attach_kv_cache(self, past_keys, past_values):
|
||
|
self.past_key_values = [
|
||
|
torch.cat((key, value), dim=-1)
|
||
|
for key, value in zip(past_keys, past_values)
|
||
|
]
|
||
|
|
||
|
|
||
|
class StarCoder(CausalLM):
|
||
|
def __init__(
|
||
|
self,
|
||
|
model_id: str,
|
||
|
revision: Optional[str] = None,
|
||
|
dtype: Optional[torch.dtype] = None,
|
||
|
):
|
||
|
|
||
|
super(StarCoder, self).__init__(
|
||
|
model_id=model_id,
|
||
|
revision=revision,
|
||
|
dtype=dtype,
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||
|
return StarCoderCausalLMBatch
|