fix imports after rebase

This commit is contained in:
Mohit Sharma 2024-09-27 15:52:43 +00:00
parent 473d9a892d
commit b2cd1b66ed
15 changed files with 18 additions and 21 deletions

View File

@ -18,7 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch
import torch.distributed
@ -40,6 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch
import torch.distributed
@ -31,6 +30,7 @@ from text_generation_server.layers.attention import (
attention,
reshape_and_cache,
Seqlen,
PREFILL_IN_KV_CACHE,
)
from text_generation_server.layers import (
FastLinear,

View File

@ -15,9 +15,6 @@
from typing import List, Optional, Tuple, Type
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.utils.import_utils import SYSTEM
import torch
import torch.distributed
from torch import nn
@ -38,9 +35,11 @@ from text_generation_server.layers.attention import (
paged_attention,
reshape_and_cache,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights
if SYSTEM == "rocm":
@ -390,8 +389,8 @@ class DeepseekV2MLP(nn.Module):
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
if (
SYSTEM == "rocm"
and hidden_states.dtype == torch.float16
and self.hidden_act == "silu"
and hidden_states.dtype == torch.float16
and hidden_states.shape[0] == 1
and not self.quantize
):

View File

@ -18,7 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch
import torch.distributed
@ -41,6 +40,7 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,

View File

@ -18,7 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch
import torch.distributed
@ -31,6 +30,7 @@ from text_generation_server.layers.attention import (
attention,
reshape_and_cache,
Seqlen,
PREFILL_IN_KV_CACHE,
)
from text_generation_server.layers import (
TensorParallelRowLinear,

View File

@ -18,7 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch
import torch.distributed
@ -39,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)

View File

@ -321,12 +321,12 @@ class LlamaMLP(nn.Module):
def forward(self, hidden_states, adapter_data):
if (
SYSTEM == "rocm"
and hidden_states.dtype == torch.float16
and self.hidden_act == "silu"
and hidden_states.dtype == torch.float16
and hidden_states.shape[0] == 1
and not self.quantize
and self.hidden_size
!= 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.
and not self.quantize
):
out = torch.empty(
hidden_states.shape[0],
@ -561,7 +561,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
inputs_embeds,
position_ids,

View File

@ -18,7 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch
import torch.distributed
@ -42,6 +41,7 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -302,7 +302,6 @@ class MistralMLP(nn.Module):
def forward(self, hidden_states, adapter_data):
if (
SYSTEM == "rocm"
and hidden_states.dtype == torch.float16
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
and not self.quantize

View File

@ -18,7 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from typing import List, Optional, Tuple, Type
import torch
@ -40,6 +39,7 @@ from text_generation_server.layers.attention import (
paged_attention,
reshape_and_cache,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding

View File

@ -18,7 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch
import torch.distributed
@ -40,6 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)

View File

@ -1,4 +1,3 @@
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch
import torch.distributed
@ -20,6 +19,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)

View File

@ -1,4 +1,3 @@
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch
import torch.distributed
@ -18,6 +17,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
SpeculativeHead,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,

View File

@ -1,6 +1,5 @@
from typing import List, Optional, Tuple
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch
import torch.distributed
from torch import nn
@ -13,6 +12,7 @@ from text_generation_server.layers import (
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import (

View File

@ -1,4 +1,3 @@
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch
import torch.distributed
@ -19,6 +18,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import (
FastLayerNorm,

View File

@ -18,7 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch
import torch.distributed
@ -40,6 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
FastRMSNorm,