Skip to content

Commit

Permalink
support separate lora & state loading
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed May 29, 2024
1 parent 689908f commit 240b79f
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 23 deletions.
8 changes: 6 additions & 2 deletions logix/batch_info.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Iterable, Optional, Any

import torch

from logix.utils import nested_dict


class BatchInfo:
def __init__(self):
self.data_id = None
self.mask = None
self.data_id: Optional[Iterable[Any]] = None
self.mask: Optional[torch.Tensor] = None
self.log = nested_dict()

def clear(self):
Expand Down
2 changes: 1 addition & 1 deletion logix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from logix.utils import get_rank


def init_config_from_yaml(project: str, logix_config: str):
def init_config_from_yaml(project: str, logix_config: Optional[str] = None):
config_dict = {}
if logix_config is not None:
assert os.path.exists(logix_config), f"{logix_config} doesn't exist!"
Expand Down
54 changes: 40 additions & 14 deletions logix/logix.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from typing import Optional, Iterable, Dict, Any, List, Union
from typing import Optional, Iterable, Dict, Any, List, Union, Tuple
from dataclasses import asdict
from functools import reduce
from copy import deepcopy
Expand Down Expand Up @@ -86,8 +86,8 @@ def __init__(
def watch(
self,
model: nn.Module,
type_filter: List[nn.Module] = None,
name_filter: List[str] = None,
type_filter: Optional[List[nn.Module]] = None,
name_filter: Optional[List[str]] = None,
) -> None:
"""
Sets up modules in the model to be watched based on optional type and name filters.
Expand Down Expand Up @@ -151,6 +151,7 @@ def add_lora(
clear: bool = True,
type_filter: Optional[List[nn.Module]] = None,
name_filter: Optional[List[str]] = None,
lora_path: Optional[str] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
"""
Expand All @@ -171,6 +172,13 @@ def add_lora(
clear (bool, optional):
Whether to clear the internal states after adding LoRA. Defaults to
`True`.
type_filter (Optional[List[nn.Module]], optional):
A list of module types to be watched.
name_filter (Optional[List[str]], optional):
A list of module names to be watched.
lora_path (Optional[str], optional):
The path to the LoRA state file. If None, the LoRA state is not loaded.
lora_config (Optional[LoRAConfig], optional): LoRA configuration.
"""
if lora_config is not None:
self.set_lora_config(lora_config)
Expand All @@ -187,12 +195,22 @@ def add_lora(
name_filter=name_filter or self.name_filter,
)

# If lora_path is not none, load lora weights from this path
if lora_path is not None:
lora_dir = os.path.join(os.path.join(lora_path, "lora"))
lora_state = torch.load(os.path.join(lora_dir, "lora_state_dict.pt"))
for name in lora_state:
assert name in self.model.state_dict(), f"{name} not in model!"
model.load_state_dict(lora_state, strict=False)

# Clear state and logger
if clear:
msg = "LogIX will clear the previous Hessian, Storage, and Logging "
msg += "handlers after adding LoRA for gradient compression.\n"
get_logger().info(msg)
self.clear()

# (Re-)watch lora-added model
if watch:
self.watch(model)

Expand All @@ -210,7 +228,7 @@ def log(self, data_id: Any, mask: Optional[torch.Tensor] = None) -> None:

def __call__(
self,
data_id: Iterable[Any] = None,
data_id: Iterable[Any],
mask: Optional[torch.Tensor] = None,
):
"""
Expand Down Expand Up @@ -289,7 +307,7 @@ def build_log_dataloader(
)
return self.log_dataloader

def get_log(self, copy=False) -> Dict[str, Dict[str, torch.Tensor]]:
def get_log(self, copy=False) -> Tuple[str, Dict[str, Dict[str, torch.Tensor]]]:
"""
Returns the current log, including data identifiers and logged information.
Expand All @@ -305,16 +323,18 @@ def get_covariance_state(self) -> Dict[str, Dict[str, torch.Tensor]]:
"""
return self.state.get_covariance_state()

def get_covariance_svd_state(self) -> Dict[str, Dict[str, torch.Tensor]]:
def get_covariance_svd_state(
self,
) -> Tuple[Dict[str, Dict[str, torch.Tensor]], Dict[str, Dict[str, torch.Tensor]]]:
"""
Returns the SVD of the covariance from the Hessian handler.
"""
return self.state.get_covariance_svd_state()

def compute_influence(
self,
src_log: Dict[str, Dict[str, torch.Tensor]],
tgt_log: Dict[str, Dict[str, torch.Tensor]],
src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
tgt_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
mode: Optional[str] = "dot",
precondition: Optional[bool] = True,
) -> Dict[str, Union[List[str], torch.Tensor, Dict[str, torch.Tensor]]]:
Expand All @@ -334,10 +354,10 @@ def compute_influence(

def compute_influence_all(
self,
src_log: Optional[Dict[str, Dict[str, torch.Tensor]]] = None,
src_log: Optional[Tuple[str, Dict[str, Dict[str, torch.Tensor]]]] = None,
loader: Optional[torch.utils.data.DataLoader] = None,
mode: Optional[str] = "dot",
precondition: Optional[str] = True,
precondition: Optional[bool] = True,
) -> Dict[str, Union[List[str], torch.Tensor, Dict[str, torch.Tensor]]]:
"""
Front-end interface for computing influence scores against all train data in the log.
Expand All @@ -357,7 +377,7 @@ def compute_influence_all(

def compute_self_influence(
self,
src_log: Optional[Dict[str, Dict[str, torch.Tensor]]] = None,
src_log: Optional[Tuple[str, Dict[str, Dict[str, torch.Tensor]]]] = None,
precondition: Optional[bool] = True,
) -> Dict[str, Union[List[str], torch.Tensor, Dict[str, torch.Tensor]]]:
"""
Expand Down Expand Up @@ -400,9 +420,15 @@ def save_lora(self) -> None:
os.makedirs(log_dir)
torch.save(lora_state_dict, os.path.join(log_dir, "lora_state_dict.pt"))

def initialize_from_log(self) -> None:
def initialize_from_log(
self, state_path: Optional[str] = None, lora_path: Optional[str] = None
) -> None:
"""
Load all states from disk.
Args:
state_path (str, optional): Path to the state file.
lora_path (str, optional): Path to the LoRA state file.
"""
# Load logix config
assert os.path.exists(self.log_dir), f"{self.log_dir} does not exist!"
Expand All @@ -411,7 +437,7 @@ def initialize_from_log(self) -> None:
self.config.load_config(config_file)

# Load LoRA state
lora_dir = os.path.join(self.log_dir, "lora")
lora_dir = os.path.join(lora_path or self.log_dir, "lora")
if os.path.exists(lora_dir) and self.model is not None:
if not is_lora(self.model):
self.add_lora()
Expand All @@ -421,7 +447,7 @@ def initialize_from_log(self) -> None:
self.model.load_state_dict(lora_state, strict=False)

# Load state
self.state.load_state(self.log_dir)
self.state.load_state(state_path or self.log_dir)

def finalize(
self,
Expand Down
9 changes: 5 additions & 4 deletions logix/lora/lora.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional

import torch.nn as nn

from logix.config import LoRAConfig
from logix.state import LogIXState
from logix.lora.modules import LoraLinear, LoraConv2d, LoraEmbedding
from logix.lora.utils import find_parameter_sharing_group, _get_submodules
Expand All @@ -15,7 +16,7 @@ class LoRAHandler:

_SUPPORTED_MODULES = {nn.Linear, nn.Conv1d, nn.Conv2d}

def __init__(self, config: Dict[str, Any], state: LogIXState):
def __init__(self, config: LoRAConfig, state: LogIXState):
self._state = state

self.init_strategy = config.init
Expand All @@ -26,8 +27,8 @@ def __init__(self, config: Dict[str, Any], state: LogIXState):
def add_lora(
self,
model: nn.Module,
type_filter: List[nn.Module],
name_filter: List[str],
type_filter: Optional[List[nn.Module]],
name_filter: Optional[List[str]],
):
"""
Add LoRA modules to a model.
Expand Down
1 change: 0 additions & 1 deletion logix/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from logix import LogIX
from logix.statistic import Covariance
from logix.utils import get_logger


class LogIXScheduler:
Expand Down
2 changes: 1 addition & 1 deletion logix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def generate_hash_id(self, data: Any) -> List[str]:
data_id.append(hashlib.sha256(ndarray.tobytes()).hexdigest())
return data_id

def generate_index_id(self, data: Any) -> List[int]:
def generate_index_id(self, data: Any) -> List[str]:
"""
Given data, generate id based on the index.
"""
Expand Down

0 comments on commit 240b79f

Please sign in to comment.