Skip to content

Commit

Permalink
return singluar values
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Nov 20, 2023
1 parent 0c6d13d commit 6d5e888
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
16 changes: 10 additions & 6 deletions analog/lora/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,15 @@ def init_weight(self, init_strategy: str = "random", hessian=None):
nn.init.kaiming_uniform_(self.analog_lora_A.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.analog_lora_C.weight, a=math.sqrt(5))
elif init_strategy == "pca":
top_r_singular_vector_forward = compute_top_k_singular_vectors(
hessian[FORWARD], self.rank
)
top_r_singular_vector_backward = compute_top_k_singular_vectors(
hessian[BACKWARD], self.rank
)
(
top_r_singular_vector_forward,
top_r_singular_value_forward,
) = compute_top_k_singular_vectors(hessian[FORWARD], self.rank)
(
top_r_singular_vector_backward,
top_r_singular_value_backward,
) = compute_top_k_singular_vectors(hessian[BACKWARD], self.rank)
# top_r_singular_vector_forward /= top_r_singular_value_forward.unsqueeze(0)
# top_r_singular_vector_backward /= top_r_singular_value_backward.unsqueeze(0)
self.analog_lora_A.weight.data.copy_(top_r_singular_vector_forward.T)
self.analog_lora_C.weight.data.copy_(top_r_singular_vector_backward)
3 changes: 2 additions & 1 deletion analog/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ def compute_top_k_singular_vectors(matrix, k):
"""
U, S, Vh = torch.linalg.svd(matrix)
top_k_singular_vectors = U[:, :k]
return top_k_singular_vectors
top_k_singular_values = S[:k]
return top_k_singular_vectors, top_k_singular_values

0 comments on commit 6d5e888

Please sign in to comment.