Skip to content

Commit

Permalink
Merge pull request #2 from divyashreepathihalli/clip_refactor_sub
Browse files Browse the repository at this point in the history
Clip refactor sub
  • Loading branch information
divyashreepathihalli authored Feb 9, 2024
2 parents 957b6c8 + 54f02e8 commit 79de15d
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 46 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
python -m pip install --upgrade pip setuptools
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
Expand Down Expand Up @@ -65,7 +65,7 @@ jobs:
python -m pip install --upgrade pip setuptools
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
Expand Down Expand Up @@ -110,7 +110,7 @@ jobs:
python -m pip install --upgrade pip setuptools
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
python -m pip install --upgrade pip setuptools
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
python -m pip install --upgrade pip setuptools
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/scorecard.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ jobs:
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
# format to the repository Actions tab.
- name: "Upload artifact"
uses: actions/upload-artifact@c7d193f32edcb7bfad88892161225aeda64e9392 # v4.0.0
uses: actions/upload-artifact@26f96dfa697d77e81fd5907df203aa23a56210a8 # v4.3.0
with:
name: SARIF file
path: results.sarif
retention-days: 5

# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@012739e5082ff0c22ca6d6ab32e07c36df03c4a4 # v3.22.12
uses: github/codeql-action/upload-sarif@b7bf0a3ed3ecfa44160715d7c442788f65f0f923 # v3.23.2
with:
sarif_file: results.sarif
41 changes: 26 additions & 15 deletions keras_cv/models/feature_extractor/clip/clip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,24 @@ def __init__(
* 0.02
)

def attention(self, x):
self.attn_mask = (
def attention(self, x, attention_mask=None):
mask = (
ops.cast(self.attn_mask, dtype=x.dtype)
if self.attn_mask is not None
else None
)
if attention_mask is not None:
attention_mask = (
ops.cast(attention_mask, dtype=x.dtype)
if attention_mask is not None
else None
)
mask = ops.add(self.attn_mask, attention_mask)

return self.attn(x, attention_mask=self.attn_mask)
return self.attn(
x,
attention_mask=mask,
)

def build(self, input_shape):
super().build(input_shape)
Expand Down Expand Up @@ -93,8 +103,8 @@ def build(self, input_shape):
)
self.ln_2 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_2")

def call(self, x):
x = x + self.attention(self.ln_1(x))
def call(self, x, attention_mask=None):
x = x + self.attention(self.ln_1(x), attention_mask=attention_mask)
x = x + self.mlp(self.ln_2(x))
return x

Expand All @@ -109,20 +119,21 @@ def __init__(self, width, layers, heads, attn_mask=None, **kwargs):
self.layers = layers
self.heads = heads
self.attn_mask = attn_mask
self.resblocks = keras.Sequential(
[
ResidualAttention(
self.width, self.heads, self.layers, self.attn_mask
)
for _ in range(self.layers)
]
)
self.resblocks = [
ResidualAttention(
self.width, self.heads, self.layers, self.attn_mask
)
for _ in range(self.layers)
]

def build(self, input_shape):
super().build(input_shape)
self.resblocks.build()

def call(self, x):
return self.resblocks(x)
def call(self, x, attention_mask=None):
for block in self.resblocks:
x = block(x, attention_mask=attention_mask)
return x

def compute_output_shape(self, inputs_shape):
return inputs_shape
Expand Down
16 changes: 6 additions & 10 deletions keras_cv/models/feature_extractor/clip/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,21 +146,17 @@ def __init__(
self.image_embeddings = None
self.text_embeddings = None

def build_attention_mask(self):
mask = ops.ones((self.context_length, self.context_length))
# Zero out the lower diagonal
mask = ops.triu(mask)
return ops.cast(mask, "float32")

def encode_images(self, image):
return self.image_encoder(image)

def encode_text(self, text):
return self.text_encoder(text)
def encode_text(self, text, attention_mask=None):
return self.text_encoder(text, attention_mask=attention_mask)

def call(self, image, text):
def call(self, image, text, attention_mask=None):
self.image_embeddings = self.encode_images(image)
self.text_embeddings = self.encode_text(text)
self.text_embeddings = self.encode_text(
text, attention_mask=attention_mask
)
normalize_image_features = keras.ops.sqrt(
keras.ops.sum(
keras.ops.power(self.image_embeddings, 2), keepdims=True
Expand Down
3 changes: 1 addition & 2 deletions keras_cv/models/feature_extractor/clip/clip_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,11 @@ def process_texts(self, texts, context_length: int = 77):
texts = [texts]

def pack_tokens(text):
tok, _ = self.packer(
return self.packer(
self.tokenizer(text),
sequence_length=context_length,
add_start_value=True,
add_end_value=True,
)
return tok

return pack_tokens(texts)
37 changes: 25 additions & 12 deletions keras_cv/models/feature_extractor/clip/clip_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,20 @@ def __init__(
)

self.vocab_size = vocab_size
self.positional_embedding = self.add_weight(
shape=[self.context_length, transformer_width],
self.positional_embedding = keras.layers.Embedding(
self.context_length,
transformer_width,
name="positional_embedding",
)
mask = ops.ones((self.context_length, self.context_length))
# Zero out the lower diagonal
mask = ops.triu(mask)
mask = ops.cast(mask, "float32")
self.encoder = CLIPEncoder(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
attn_mask=mask,
name="clip_encoder",
)
self.ln_final = keras.layers.LayerNormalization(name="ln_final")
Expand All @@ -42,24 +47,32 @@ def __init__(
embed_dim, name="text_projector", use_bias=False
)

def call(self, inputs):
def call(self, inputs, attention_mask=None):
token_embedding = self.token_embedding(inputs)
position_ids = ops.expand_dims(
ops.arange(self.context_length, dtype="int32"), 0
)
position_embedding = self.positional_embedding(position_ids)
position_embedding = ops.tile(
position_embedding, repeats=(inputs.shape[0], 1, 1)
)
attention_mask = ops.cast(attention_mask, dtype="float32")
expanded_mask = ops.tile(
attention_mask[:, None, None, :], (1, 1, self.context_length, 1)
)
expanded_mask = (1.0 - expanded_mask) * (-1e8)
encoded_output = self.encoder(
token_embedding + self.positional_embedding
token_embedding + position_embedding, attention_mask=expanded_mask
)
print("encoded_output", encoded_output)
layer_norm = self.ln_final(encoded_output)
indices = ops.expand_dims(
ops.cast(ops.argmax(inputs, axis=1), "int32"), axis=-1
ops.cast(ops.argmax(inputs, axis=-1), "int32"), axis=-1
)
selected_features = ops.take_along_axis(
layer_norm, indices[:, :, None], axis=1
)
print("pooler output", selected_features)
text_features = self.text_projector(selected_features)
output = ops.squeeze(text_features, axis=1)
return output

def build_attention_mask(self):
mask = ops.ones((self.context_length, self.context_length))
# Zero out the lower diagonal
mask = ops.triu(mask)
return ops.cast(mask, "float32")

0 comments on commit 79de15d

Please sign in to comment.