Skip to content

Commit

Permalink
feature(xyy):add HPT model and test_hpt
Browse files Browse the repository at this point in the history
  • Loading branch information
luodi-7 committed Dec 4, 2024
1 parent 188759b commit cbe7dea
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions ding/model/template/hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class HPT(nn.Module):
__init__, forward
.. note::
The model is designed to be flexible and can be adapted for different input dimensions and action spaces.
The model is designed to be flexible and can be adapted
for different input dimensions and action spaces.
"""

def __init__(self, state_dim: int, action_dim: int):
Expand All @@ -33,7 +34,8 @@ def __init__(self, state_dim: int, action_dim: int):
- action_dim (:obj:`int`): The dimension of the action space.
.. note::
The Policy Stem is initialized with cross-attention, and the Dueling Head is set to process the resulting tokens.
The Policy Stem is initialized with cross-attention,
and the Dueling Head is set to process the resulting tokens.
"""
super(HPT, self).__init__()
# Initialise Policy Stem
Expand Down Expand Up @@ -69,13 +71,15 @@ class PolicyStem(nn.Module):
Overview:
The Policy Stem module is responsible for processing input features
and generating latent tokens using a cross-attention mechanism.
It extracts features from the input and then applies cross-attention to generate a set of latent tokens.
It extracts features from the input and then applies cross-attention
to generate a set of latent tokens.
Interfaces:
__init__, init_cross_attn, compute_latent, forward
.. note::
This module is inspired by the implementation in the Perceiver IO model and uses attention mechanisms for feature extraction.
This module is inspired by the implementation in the Perceiver IO model
and uses attention mechanisms for feature extraction.
"""

def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs):
Expand All @@ -85,7 +89,8 @@ def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs):
Arguments:
- feature_dim (:obj:`int`): The dimension of the input features.
- token_dim (:obj:`int`): The dimension of the latent tokens generated by the attention mechanism.
- token_dim (:obj:`int`): The dimension of the latent tokens generated
by the attention mechanism.
"""
super().__init__()
# Initialise the feature extraction module
Expand All @@ -97,12 +102,14 @@ def init_cross_attn(self):
"""Initialize cross-attention module and learnable tokens."""
token_num = 16
self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST)
self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1)
self.cross_attention = CrossAttention(
128, heads=8, dim_head=64, dropout=0.1)

def compute_latent(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Compute latent representations of the input data using the feature extractor and cross-attention.
Compute latent representations of the input data using
the feature extractor and cross-attention.
Arguments:
- x (:obj:`torch.Tensor`): Input tensor with shape [B, T, ..., F].
Expand All @@ -112,10 +119,12 @@ def compute_latent(self, x: torch.Tensor) -> torch.Tensor:
"""
# Using the Feature Extractor
stem_feat = self.feature_extractor(x)
stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128)
stem_feat = stem_feat.reshape(
stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128)
# Calculating latent tokens using CrossAttention
stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128)
stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128)
stem_tokens = self.cross_attention(
stem_tokens, stem_feat) # (B, 16, 128)
return stem_tokens

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -165,7 +174,9 @@ class CrossAttention(nn.Module):
dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0.
"""

def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0):
def __init__(self, query_dim: int,
heads: int = 8,
dim_head: int = 64, dropout: float = 0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = query_dim
Expand All @@ -178,7 +189,9 @@ def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout:

self.dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, x: torch.Tensor,
context: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Overview:
Forward pass of the CrossAttention module.
Expand All @@ -195,7 +208,8 @@ def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.T
h = self.heads
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q, k, v = map(lambda t: rearrange(
t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

if mask is not None:
Expand Down

0 comments on commit cbe7dea

Please sign in to comment.