mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
* feat: add ruff and resolve issue * fix: update client exports and adjust after rebase * fix: adjust syntax to avoid circular import * fix: adjust client ruff settings * fix: lint and refactor import check and avoid model enum as global names * fix: improve fbgemm_gpu check and lints * fix: update lints * fix: prefer comparing model enum over str * fix: adjust lints and ignore specific rules * fix: avoid unneeded quantize check
277 lines
11 KiB
Python
277 lines
11 KiB
Python
# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License.
|
|
#
|
|
# MIT License
|
|
#
|
|
# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
|
|
"""
|
|
|
|
Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially
|
|
time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note
|
|
that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to
|
|
prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that
|
|
to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore.
|
|
|
|
References:
|
|
- DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model
|
|
- Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch
|
|
|
|
"""
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from text_generation_server.layers import (
|
|
TensorParallelColumnLinear,
|
|
TensorParallelRowLinear,
|
|
)
|
|
|
|
EPS = 1e-5
|
|
|
|
|
|
class IdeficsPerceiverResampler(nn.Module):
|
|
def __init__(
|
|
self,
|
|
prefix,
|
|
config,
|
|
embed_dim: int,
|
|
depth: int,
|
|
n_heads: int,
|
|
head_dim: int,
|
|
n_latents: int,
|
|
weights,
|
|
) -> None:
|
|
"""
|
|
Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
|
|
MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
|
|
returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed
|
|
to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler.
|
|
Could be e.g., VIT embed_dim, ResNet pool dim, and so on.
|
|
|
|
Args:
|
|
config (`IdeficsConfig`): config object
|
|
embed_dim (`int`): The size of each embedding vector
|
|
depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
|
|
n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention).
|
|
head_dim (`int`): Dimensionality of each head projection in the Transformer block.
|
|
n_latents (`int`):
|
|
Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
|
|
|
|
"""
|
|
super().__init__()
|
|
self.embed_dim, self.n_heads, self.head_dim, self.n_latents = (
|
|
embed_dim,
|
|
n_heads,
|
|
head_dim,
|
|
n_latents,
|
|
)
|
|
self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver
|
|
|
|
# Create Latents for Perceiver
|
|
self.latents = nn.Parameter(weights.get_tensor(f"{prefix}.latents"))
|
|
|
|
self.intermediate_dim = (
|
|
self.embed_dim * 4
|
|
if not hasattr(config.vision_config, "embed_dim")
|
|
else config.vision_config.embed_dim * 4
|
|
)
|
|
# Create Transformer Blocks
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
nn.ModuleList(
|
|
[
|
|
IdeficsPerceiverAttention(
|
|
prefix=f"{prefix}.blocks.{layer_id}.0",
|
|
config=config,
|
|
embed_dim=self.embed_dim,
|
|
n_heads=self.n_heads,
|
|
head_dim=self.head_dim,
|
|
qk_layer_norms=self.qk_layer_norms,
|
|
weights=weights,
|
|
),
|
|
IdeficsMLP(
|
|
prefix=f"{prefix}.blocks.{layer_id}.1",
|
|
intermediate_size=self.intermediate_dim,
|
|
config=config,
|
|
weights=weights,
|
|
),
|
|
]
|
|
)
|
|
for layer_id in range(depth)
|
|
]
|
|
)
|
|
self.layer_norm = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS
|
|
)
|
|
|
|
def forward(self, context: torch.Tensor) -> torch.Tensor:
|
|
"""Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
|
|
# einsum.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])
|
|
latents = self.latents.repeat(context.shape[0], 1, 1)
|
|
|
|
# Feed through Perceiver Attention blocks...
|
|
for attn, ff in self.blocks:
|
|
latents = attn(context, latents) + latents
|
|
latents = ff(latents) + latents
|
|
|
|
return self.layer_norm(latents)
|
|
|
|
|
|
class IdeficsPerceiverAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
prefix,
|
|
config,
|
|
embed_dim: int,
|
|
n_heads: int,
|
|
head_dim: int,
|
|
qk_layer_norms: bool,
|
|
weights,
|
|
) -> None:
|
|
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
|
|
super().__init__()
|
|
self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
|
|
self.qk_layer_norms = qk_layer_norms
|
|
# Normalization & Scaling
|
|
self.context_layer_norm = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS
|
|
)
|
|
self.latents_layer_norm = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS
|
|
)
|
|
if self.qk_layer_norms:
|
|
self.q_layer_norm = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS
|
|
)
|
|
self.k_layer_norm = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS
|
|
)
|
|
|
|
self.qk_scale = self.head_dim**-0.5
|
|
|
|
if n_heads % weights.process_group.size() != 0:
|
|
raise ValueError(
|
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} "
|
|
f"and `num_shards`: {weights.process_group.size()}"
|
|
)
|
|
self.n_heads //= weights.process_group.size()
|
|
|
|
# Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
|
|
self.q_proj = TensorParallelColumnLinear.load(
|
|
config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
|
|
)
|
|
self.k_proj = TensorParallelColumnLinear.load(
|
|
config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
|
|
)
|
|
self.v_proj = TensorParallelColumnLinear.load(
|
|
config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
|
|
)
|
|
|
|
self.output_proj = TensorParallelRowLinear.load(
|
|
config=config, prefix=f"{prefix}.output_proj", weights=weights, bias=False
|
|
)
|
|
|
|
def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
|
|
|
|
Args:
|
|
context (`torch.Tensor`):
|
|
Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample.
|
|
latents (`torch.Tensor`):
|
|
Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to.
|
|
|
|
Returns:
|
|
`torch.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross
|
|
from context.
|
|
"""
|
|
context = self.context_layer_norm(context)
|
|
latents = self.latents_layer_norm(latents)
|
|
batch_size, seq_length, embed_dim = context.shape[:3]
|
|
|
|
# Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
|
|
# Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
|
|
q = self.q_proj(latents)
|
|
k = self.k_proj(torch.cat([context, latents], dim=-2))
|
|
v = self.v_proj(torch.cat([context, latents], dim=-2))
|
|
|
|
# Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
|
|
# =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
|
|
# einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads)
|
|
q, k, v = [
|
|
x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(
|
|
1, 2
|
|
)
|
|
for x in (q, k, v)
|
|
]
|
|
|
|
if self.qk_layer_norms:
|
|
q = self.q_layer_norm(q)
|
|
k = self.k_layer_norm(k)
|
|
|
|
scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
|
|
stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach())
|
|
attn = stabilized_scores.softmax(dim=-1)
|
|
|
|
# Attend & project back to output...
|
|
resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v)
|
|
# einsum.rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads)
|
|
return self.output_proj(resampled.transpose(1, 2).flatten(-2))
|
|
|
|
|
|
class IdeficsMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
prefix,
|
|
intermediate_size,
|
|
config,
|
|
weights,
|
|
):
|
|
"""Simple MLP block with intermediate_size and embedding size"""
|
|
super().__init__()
|
|
self.embed_dim = config.vision_config.embed_dim
|
|
self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS)
|
|
self.fc = TensorParallelColumnLinear.load(
|
|
config=config,
|
|
prefix=f"{prefix}.fc",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
self.act = nn.ReLU()
|
|
self.c_proj = TensorParallelRowLinear.load(
|
|
config=config,
|
|
prefix=f"{prefix}.c_proj",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
|
|
def forward(
|
|
self, hidden_states: Optional[Tuple[torch.FloatTensor]]
|
|
) -> torch.FloatTensor:
|
|
hidden_states = self.ln(hidden_states)
|
|
hidden_states = self.fc(hidden_states)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states = self.c_proj(hidden_states)
|
|
|
|
return hidden_states
|