From 4dc99bd5961c1549b835d62035ec3d41a0db2dba Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 2 Dec 2024 09:48:44 -0800 Subject: [PATCH] just tempt some student into trying it --- README.md | 9 +++++++++ alphafold3_pytorch/attention.py | 19 +++++++++++++++++++ pyproject.toml | 2 +- 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a91001ef..0f3db20c 100644 --- a/README.md +++ b/README.md @@ -514,3 +514,12 @@ docker run -v .:/data --gpus all -it af3 url = {https://api.semanticscholar.org/CorpusID:273532030} } ``` + +```bibtex +@inproceedings{Duvvuri2024LASERAW, + title = {LASER: Attention with Exponential Transformation}, + author = {Sai Surya Duvvuri and Inderjit S. Dhillon}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273849947} +} +``` diff --git a/alphafold3_pytorch/attention.py b/alphafold3_pytorch/attention.py index 57921371..d0f15621 100644 --- a/alphafold3_pytorch/attention.py +++ b/alphafold3_pytorch/attention.py @@ -40,6 +40,9 @@ def pack_one(t, pattern): def unpack_one(t, ps, pattern): return unpack(t, ps, pattern)[0] +def log(t, eps = 1e-20): + return t.clamp(min = eps).log() + def softclamp(t, value): return (t / value).tanh() * value @@ -181,6 +184,7 @@ def __init__( query_bias = True, window_size = None, num_memory_kv: int = 0, + laser = False, enable_attn_softclamp = False, attn_softclamp_value = 50., softmax_full_precision = False @@ -222,6 +226,10 @@ def __init__( self.memory_kv = nn.Parameter(torch.zeros(2, heads, num_memory_kv, dim_head)) nn.init.normal_(self.memory_kv, std = 0.02) + # laser attention + + self.laser = laser + # gating of value # allows attention to attend to nothing @@ -262,6 +270,12 @@ def forward( q, k, v = tuple(self.split_heads(t) for t in (q, k, v)) + # maybe laser + + if self.laser: + v_max = v.amax(dim = -2, keepdim = True) + v = (v - v_max).exp() + # attention out = self.attend( @@ -272,6 +286,11 @@ def forward( memory_kv = self.memory_kv ) + # maybe laser + + if self.laser: + out = log(out) + v_max + # merge heads out = self.merge_heads(out) diff --git a/pyproject.toml b/pyproject.toml index 2084cd59..428da8d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.6.8" +version = "0.6.9" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" },