diff --git a/logix/batch_info.py b/logix/batch_info.py index a08b15cd..98a16547 100644 --- a/logix/batch_info.py +++ b/logix/batch_info.py @@ -1,10 +1,14 @@ +from typing import Iterable, Optional, Any + +import torch + from logix.utils import nested_dict class BatchInfo: def __init__(self): - self.data_id = None - self.mask = None + self.data_id: Optional[Iterable[Any]] = None + self.mask: Optional[torch.Tensor] = None self.log = nested_dict() def clear(self): diff --git a/logix/config.py b/logix/config.py index a19ce692..75fd675c 100644 --- a/logix/config.py +++ b/logix/config.py @@ -9,7 +9,7 @@ from logix.utils import get_rank -def init_config_from_yaml(project: str, logix_config: str): +def init_config_from_yaml(project: str, logix_config: Optional[str] = None): config_dict = {} if logix_config is not None: assert os.path.exists(logix_config), f"{logix_config} doesn't exist!" diff --git a/logix/logix.py b/logix/logix.py index 8d750e2d..9c9853d6 100644 --- a/logix/logix.py +++ b/logix/logix.py @@ -1,6 +1,6 @@ import os -from typing import Optional, Iterable, Dict, Any, List, Union +from typing import Optional, Iterable, Dict, Any, List, Union, Tuple from dataclasses import asdict from functools import reduce from copy import deepcopy @@ -86,8 +86,8 @@ def __init__( def watch( self, model: nn.Module, - type_filter: List[nn.Module] = None, - name_filter: List[str] = None, + type_filter: Optional[List[nn.Module]] = None, + name_filter: Optional[List[str]] = None, ) -> None: """ Sets up modules in the model to be watched based on optional type and name filters. @@ -151,6 +151,7 @@ def add_lora( clear: bool = True, type_filter: Optional[List[nn.Module]] = None, name_filter: Optional[List[str]] = None, + lora_path: Optional[str] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: """ @@ -171,6 +172,13 @@ def add_lora( clear (bool, optional): Whether to clear the internal states after adding LoRA. Defaults to `True`. + type_filter (Optional[List[nn.Module]], optional): + A list of module types to be watched. + name_filter (Optional[List[str]], optional): + A list of module names to be watched. + lora_path (Optional[str], optional): + The path to the LoRA state file. If None, the LoRA state is not loaded. + lora_config (Optional[LoRAConfig], optional): LoRA configuration. """ if lora_config is not None: self.set_lora_config(lora_config) @@ -187,12 +195,22 @@ def add_lora( name_filter=name_filter or self.name_filter, ) + # If lora_path is not none, load lora weights from this path + if lora_path is not None: + lora_dir = os.path.join(os.path.join(lora_path, "lora")) + lora_state = torch.load(os.path.join(lora_dir, "lora_state_dict.pt")) + for name in lora_state: + assert name in self.model.state_dict(), f"{name} not in model!" + model.load_state_dict(lora_state, strict=False) + # Clear state and logger if clear: msg = "LogIX will clear the previous Hessian, Storage, and Logging " msg += "handlers after adding LoRA for gradient compression.\n" get_logger().info(msg) self.clear() + + # (Re-)watch lora-added model if watch: self.watch(model) @@ -210,7 +228,7 @@ def log(self, data_id: Any, mask: Optional[torch.Tensor] = None) -> None: def __call__( self, - data_id: Iterable[Any] = None, + data_id: Iterable[Any], mask: Optional[torch.Tensor] = None, ): """ @@ -289,7 +307,7 @@ def build_log_dataloader( ) return self.log_dataloader - def get_log(self, copy=False) -> Dict[str, Dict[str, torch.Tensor]]: + def get_log(self, copy=False) -> Tuple[str, Dict[str, Dict[str, torch.Tensor]]]: """ Returns the current log, including data identifiers and logged information. @@ -305,7 +323,9 @@ def get_covariance_state(self) -> Dict[str, Dict[str, torch.Tensor]]: """ return self.state.get_covariance_state() - def get_covariance_svd_state(self) -> Dict[str, Dict[str, torch.Tensor]]: + def get_covariance_svd_state( + self, + ) -> Tuple[Dict[str, Dict[str, torch.Tensor]], Dict[str, Dict[str, torch.Tensor]]]: """ Returns the SVD of the covariance from the Hessian handler. """ @@ -313,8 +333,8 @@ def get_covariance_svd_state(self) -> Dict[str, Dict[str, torch.Tensor]]: def compute_influence( self, - src_log: Dict[str, Dict[str, torch.Tensor]], - tgt_log: Dict[str, Dict[str, torch.Tensor]], + src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]], + tgt_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]], mode: Optional[str] = "dot", precondition: Optional[bool] = True, ) -> Dict[str, Union[List[str], torch.Tensor, Dict[str, torch.Tensor]]]: @@ -334,10 +354,10 @@ def compute_influence( def compute_influence_all( self, - src_log: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, + src_log: Optional[Tuple[str, Dict[str, Dict[str, torch.Tensor]]]] = None, loader: Optional[torch.utils.data.DataLoader] = None, mode: Optional[str] = "dot", - precondition: Optional[str] = True, + precondition: Optional[bool] = True, ) -> Dict[str, Union[List[str], torch.Tensor, Dict[str, torch.Tensor]]]: """ Front-end interface for computing influence scores against all train data in the log. @@ -357,7 +377,7 @@ def compute_influence_all( def compute_self_influence( self, - src_log: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, + src_log: Optional[Tuple[str, Dict[str, Dict[str, torch.Tensor]]]] = None, precondition: Optional[bool] = True, ) -> Dict[str, Union[List[str], torch.Tensor, Dict[str, torch.Tensor]]]: """ @@ -400,9 +420,15 @@ def save_lora(self) -> None: os.makedirs(log_dir) torch.save(lora_state_dict, os.path.join(log_dir, "lora_state_dict.pt")) - def initialize_from_log(self) -> None: + def initialize_from_log( + self, state_path: Optional[str] = None, lora_path: Optional[str] = None + ) -> None: """ Load all states from disk. + + Args: + state_path (str, optional): Path to the state file. + lora_path (str, optional): Path to the LoRA state file. """ # Load logix config assert os.path.exists(self.log_dir), f"{self.log_dir} does not exist!" @@ -411,7 +437,7 @@ def initialize_from_log(self) -> None: self.config.load_config(config_file) # Load LoRA state - lora_dir = os.path.join(self.log_dir, "lora") + lora_dir = os.path.join(lora_path or self.log_dir, "lora") if os.path.exists(lora_dir) and self.model is not None: if not is_lora(self.model): self.add_lora() @@ -421,7 +447,7 @@ def initialize_from_log(self) -> None: self.model.load_state_dict(lora_state, strict=False) # Load state - self.state.load_state(self.log_dir) + self.state.load_state(state_path or self.log_dir) def finalize( self, diff --git a/logix/lora/lora.py b/logix/lora/lora.py index f9fc974f..13c44c36 100644 --- a/logix/lora/lora.py +++ b/logix/lora/lora.py @@ -1,7 +1,8 @@ -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional import torch.nn as nn +from logix.config import LoRAConfig from logix.state import LogIXState from logix.lora.modules import LoraLinear, LoraConv2d, LoraEmbedding from logix.lora.utils import find_parameter_sharing_group, _get_submodules @@ -15,7 +16,7 @@ class LoRAHandler: _SUPPORTED_MODULES = {nn.Linear, nn.Conv1d, nn.Conv2d} - def __init__(self, config: Dict[str, Any], state: LogIXState): + def __init__(self, config: LoRAConfig, state: LogIXState): self._state = state self.init_strategy = config.init @@ -26,8 +27,8 @@ def __init__(self, config: Dict[str, Any], state: LogIXState): def add_lora( self, model: nn.Module, - type_filter: List[nn.Module], - name_filter: List[str], + type_filter: Optional[List[nn.Module]], + name_filter: Optional[List[str]], ): """ Add LoRA modules to a model. diff --git a/logix/scheduler.py b/logix/scheduler.py index cb94a3bd..43f38a5a 100644 --- a/logix/scheduler.py +++ b/logix/scheduler.py @@ -1,6 +1,5 @@ from logix import LogIX from logix.statistic import Covariance -from logix.utils import get_logger class LogIXScheduler: diff --git a/logix/utils.py b/logix/utils.py index 44177123..07131000 100644 --- a/logix/utils.py +++ b/logix/utils.py @@ -236,7 +236,7 @@ def generate_hash_id(self, data: Any) -> List[str]: data_id.append(hashlib.sha256(ndarray.tobytes()).hexdigest()) return data_id - def generate_index_id(self, data: Any) -> List[int]: + def generate_index_id(self, data: Any) -> List[str]: """ Given data, generate id based on the index. """