From d794d048d66ce804aa8b63fbdd476ac450b494b4 Mon Sep 17 00:00:00 2001 From: dujing Date: Tue, 17 Dec 2024 11:52:10 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0lora=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosyvoice/loralib/__init__.py | 4 + cosyvoice/loralib/layers.py | 655 ++++++++++++++++++ cosyvoice/loralib/utils.py | 342 +++++++++ cosyvoice/tokenizer/phoneme_tokenizer.py | 2 +- .../cosyvoice/conf/cosyvoice_phoneme.yaml | 2 +- 5 files changed, 1003 insertions(+), 2 deletions(-) create mode 100644 cosyvoice/loralib/__init__.py create mode 100644 cosyvoice/loralib/layers.py create mode 100644 cosyvoice/loralib/utils.py diff --git a/cosyvoice/loralib/__init__.py b/cosyvoice/loralib/__init__.py new file mode 100644 index 0000000..7fd9ad2 --- /dev/null +++ b/cosyvoice/loralib/__init__.py @@ -0,0 +1,4 @@ +name = "lora" + +from .layers import * +from .utils import * \ No newline at end of file diff --git a/cosyvoice/loralib/layers.py b/cosyvoice/loralib/layers.py new file mode 100644 index 0000000..f8bc083 --- /dev/null +++ b/cosyvoice/loralib/layers.py @@ -0,0 +1,655 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math +from typing import Optional, List + +class LoRALayer(): + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + + +class Embedding(nn.Embedding, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + r: int = 0, + lora_alpha: int = 1, + merge_weights: bool = True, + lora_init_weights: str = "normal", + **kwargs + ): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0, + merge_weights=merge_weights) + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings))) + self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) + self.scaling = self.lora_alpha / self.r if "noscale" not in lora_init_weights else 1.0 + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + + # init + nn.Embedding.reset_parameters(self) + self.lora_init_weights = lora_init_weights + if "pissa" not in lora_init_weights: + self.init_parameters() + + def init_parameters(self): + if hasattr(self, 'lora_A'): + if self.lora_init_weights=="normal": + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.zeros_(self.lora_B) + nn.init.normal_(self.lora_A) + elif "pissa" in self.lora_init_weights: + if self.lora_init_weights[:5]=="pissa": + V, S, Uh = torch.linalg.svd(self.weight.data, full_matrices=False) + Vr = V[:, : self.r] + Sr = S[: self.r] + Sr /= self.scaling + Uhr = Uh[: self.r] + elif len(self.lora_init_weights.split("_niter_")) == 2: + # print("debug embedding:", self.weight.data.size(), self.r, int(self.lora_init_weights.split("_niter_")[-1])) + Vr, Sr, Ur = torch.svd_lowrank(self.weight.data, self.r, niter=int(self.lora_init_weights.split("_niter_")[-1])) + Sr /= self.scaling + Uhr = Ur.t() + else: + assert(0) + + lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr + lora_B = Vr @ torch.diag(torch.sqrt(Sr)) + self.lora_A.data = lora_A + self.lora_B.data = lora_B + dtype = self.weight.dtype + weight = self.weight.data - self.scaling * lora_B @ lora_A + weight = weight.to(dtype) + self.weight.data = weight + + if self.lora_init_weights[:5]=="pissa": + del V, S, Uh, Vr, Sr, Uhr + elif len(self.lora_init_weights.split("_niter_")) == 2: + del Vr, Sr, Ur, Uhr + else: + assert(0) + else: + assert(0) + + def unmerge_parameters(self): + if self.merged: + # Make sure that the weights are not merged + if self.r > 0: + self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling + self.merged = False + return + + def merge_parameters(self): + if not self.merged: + # Merge the weights and mark it + if self.r > 0: + self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling + self.merged = True + return + + def train(self, mode: bool = True): + nn.Embedding.train(self, mode) + if mode: + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0: + self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling + self.merged = False + else: + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0: + self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + if self.r > 0 and not self.merged: + result = nn.Embedding.forward(self, x) + after_A = F.embedding( + x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse + ) + result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling + return result + else: + return nn.Embedding.forward(self, x) + + +class Linear(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + merge_weights: bool = True, + lora_init_weights: str = "normal", + **kwargs + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, + merge_weights=merge_weights) + + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + assert(r>0) + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r if "noscale" not in lora_init_weights else 1.0 + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + if "cachewnorm" in lora_init_weights: + self.weight_cache = nn.Parameter((self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling) + # init + nn.Linear.reset_parameters(self) + self.lora_init_weights = lora_init_weights + if "pissa" not in lora_init_weights: + self.init_parameters() + + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def init_parameters(self): + if hasattr(self, 'lora_A'): + if self.lora_init_weights=="normal": + # initialize B the same way as the default for nn.Linear and A to zero + # this is different than what is described in the paper but should not affect performance + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + elif "pissa" in self.lora_init_weights: + if self.lora_init_weights[:5]=="pissa": + V, S, Uh = torch.linalg.svd(self.weight.data, full_matrices=False) + Vr = V[:, : self.r] + Sr = S[: self.r] + Sr /= self.scaling + Uhr = Uh[: self.r] + elif len(self.lora_init_weights.split("_niter_")) == 2: + # print("debug linear:", self.weight.data.size(), self.r, int(self.lora_init_weights.split("_niter_")[-1])) + Vr, Sr, Ur = torch.svd_lowrank(self.weight.data, self.r, niter=int(self.lora_init_weights.split("_niter_")[-1])) + Sr /= self.scaling + Uhr = Ur.t() + else: + assert(0) + + lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr + lora_B = Vr @ torch.diag(torch.sqrt(Sr)) + self.lora_A.data = lora_A + self.lora_B.data = lora_B + dtype = self.weight.dtype + if "cachewnorm" in self.lora_init_weights: + self.weight_cache.data = (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + if "rmwnorm" in self.lora_init_weights: ### NOTE 去除weight_norm后就可以减去svd分解 + weight = self.weight.data - self.scaling * lora_B @ lora_A + weight = weight.to(dtype) + self.weight.data = weight + + if self.lora_init_weights[:5]=="pissa": + del V, S, Uh, Vr, Sr, Uhr + elif len(self.lora_init_weights.split("_niter_")) == 2: + del Vr, Sr, Ur, Uhr + else: + assert(0) + + def T(self, w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + + def unmerge_parameters(self): + if self.merged: + # Make sure that the weights are not merged + if self.r > 0: + self.weight.data -= self.T(self.lora_B @ self.lora_A) * self.scaling + self.merged = False + return + + def merge_parameters(self): + if not self.merged: + # Merge the weights and mark it + if self.r > 0: + self.weight.data += self.T(self.lora_B @ self.lora_A) * self.scaling + self.merged = True + return + + def train(self, mode: bool = True): + nn.Linear.train(self, mode) + if mode: + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0: + self.weight.data -= self.T(self.lora_B @ self.lora_A) * self.scaling + self.merged = False + else: + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0: + self.weight.data += self.T(self.lora_B @ self.lora_A) * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + def T(w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + if self.r > 0 and not self.merged: + if "cachewnorm" in self.lora_init_weights and not "rmwnorm" in self.lora_init_weights: + result = F.linear(x, T(self.weight), bias=self.bias) + result -= F.linear(x, T(self.weight_cache), bias=self.bias) + result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling + else: ### NOTE 常规的或去除wnorm的pissa + result = F.linear(x, T(self.weight), bias=self.bias) + result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling + return result + else: + return F.linear(x, T(self.weight), bias=self.bias) + +class ConvLoRA(nn.Module, LoRALayer): + def __init__( + self, + in_channels=0, + out_channels=0, + kernel_size=0, + r=0, + lora_alpha=1, + lora_dropout=0., + merge_weights=True, + lora_init_weights="normal", + **kwargs + ): + # print("?super(ConvLoRA, self).__init__()开始") + # super(ConvLoRA, self).__init__() + # nn.Module.__init__(self) + # print("?super(ConvLoRA, self).__init__()完毕") + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) + assert isinstance(kernel_size, int), kernel_size + # Actual trainable parameters + if r > 0: + self.r = self.r * kernel_size if "fixconvk" not in lora_init_weights else self.r + if "pissa" in lora_init_weights: + weight = self.weight.data + n,c,h = weight.size(0), weight.size(1), weight.size(2) + w = weight.size(3) if len(weight.size())>=4 else 1 + self.r = min(n, self.r, c*h*w) + lora_init_weights = "normal" if self.r==1 else lora_init_weights + + # if "fullconv" in lora_init_weights: + # self.lora_A = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False) + # self.lora_B = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False) + # else: + self.lora_A = nn.Parameter(self.weight.new_zeros((self.r, in_channels*kernel_size**(self.weight.dim()-2)))) + self.lora_B = nn.Parameter(self.weight.new_zeros((out_channels//self.groups, self.r))) + + self.scaling = self.lora_alpha / self.r if "noscale" not in lora_init_weights else 1.0 + if "cachewnorm" in lora_init_weights: + self.weight_cache = nn.Parameter((self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling) + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + + # init + self.reset_parameters() + self.lora_init_weights = lora_init_weights + if "pissa" not in lora_init_weights: + self.init_parameters() + + self.merged = False + + def init_parameters(self): + if hasattr(self, 'lora_A'): + if self.lora_init_weights=="normal": + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + elif "pissa" in self.lora_init_weights: + weight = self.weight.data + n,c,h = weight.size(0), weight.size(1), weight.size(2) + w = weight.size(3) if len(weight.size())>=4 else -1 + weight = weight.view(n, -1) + if self.lora_init_weights[:5]=="pissa": + V, S, Uh = torch.linalg.svd(weight, full_matrices=False) + Vr = V[:, : self.r] + Sr = S[: self.r] + Sr /= self.scaling + Uhr = Uh[: self.r] + elif len(self.lora_init_weights.split("_niter_")) == 2: + Vr, Sr, Ur = torch.svd_lowrank(weight, self.r, niter=int(self.lora_init_weights.split("_niter_")[-1])) + Sr /= self.scaling + Uhr = Ur.t() + else: + assert(0) + + lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr + lora_B = Vr @ torch.diag(torch.sqrt(Sr)) + if self.lora_A.data.size() != lora_A.size() or self.lora_B.data.size() != lora_B.size(): + print("self.weight.data.size(), weight.size()", self.weight.data.size(), weight.size()) + print("V.size(), S.size(), Uh.size()", V.size(), S.size(), Uh.size()) + print("self.lora_A={},self.lora_B={},lora_A={},lora_B={}".format(self.lora_A.data.size(), self.lora_B.data.size(), lora_A.size(), lora_B.size())) + print("self.r,c,h,w", self.r,c,h,w) + assert(0) + # else: + # print("lora_A={},lora_B={}".format(lora_A.size(), lora_B.size())) + + self.lora_A.data = lora_A#.view(self.r,c,h,w) if w>=1 else lora_A.view(self.r,c,h) + self.lora_B.data = lora_B#[:,:,None,None] + dtype = self.weight.dtype + if "cachewnorm" in self.lora_init_weights: + self.weight_cache.data = (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + if "rmwnorm" in self.lora_init_weights: + weight = self.weight.data - (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + weight = weight.to(dtype) + self.weight.data = weight + + if self.lora_init_weights[:5]=="pissa": + del V, S, Uh, Vr, Sr, Uhr + elif len(self.lora_init_weights.split("_niter_")) == 2: + del Vr, Sr, Ur, Uhr + else: + assert(0) + else: + assert(0) + + def unmerge_parameters(self): + if self.merged: + # Make sure that the weights are not merged + if self.r > 0: + # Make sure that the weights are not merged + self.weight = self.weight.to(self.lora_B.device) + self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + self.merged = False + return + + def merge_parameters(self): + if not self.merged: + # Merge the weights and mark it + if self.r > 0: + # Merge the weights and mark it + self.weight = self.weight.to(self.lora_B.device) + self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + self.merged = True + return + + def train(self, mode=True): + super(ConvLoRA, self).train(mode) + if mode: + if self.merge_weights and self.merged: + if self.r > 0: + # Make sure that the weights are not merged + self.weight = self.weight.to(self.lora_B.device) + self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + self.merged = False + else: + if self.merge_weights and not self.merged: + if self.r > 0: + # Merge the weights and mark it + self.weight = self.weight.to(self.lora_B.device) + self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + self.merged = True + + def forward(self, x): + if self.r > 0 and not self.merged: + if "cachewnorm" in self.lora_init_weights and not "rmwnorm" in self.lora_init_weights: + return self._conv_forward( + x, + self.weight - self.weight_cache + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling, + self.bias + ) + else: + return self._conv_forward( + x, + self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling, + self.bias + ) + return self._conv_forward(x, self.weight, self.bias) + +# Can Extend to other ones like this +class Conv1d(ConvLoRA, nn.Conv1d): + def __init__(self, + *args, + r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, lora_init_weights="normal", + **kwargs + ): + # super(Conv1d, self).__init__(*args, **kwargs) + # print("nn.Conv1d.__init__开始") + nn.Conv1d.__init__(self, *args, **kwargs) + # print("nn.Conv1d.__init__完毕") + # print("ConvLoRA.__init__开始") + ConvLoRA.__init__(self, + *args, + r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights, lora_init_weights=lora_init_weights, + **kwargs + ) + # print("ConvLoRA.__init__完毕") + +class ConvTransposeLoRA(nn.Module, LoRALayer): + def __init__( + self, + in_channels=0, + out_channels=0, + kernel_size=0, + r=0, + lora_alpha=1, + lora_dropout=0., + merge_weights=True, + lora_init_weights="normal", + **kwargs + ): + # print("?super(ConvLoRA, self).__init__()开始") + # super(ConvLoRA, self).__init__() + # nn.Module.__init__(self) + # print("?super(ConvLoRA, self).__init__()完毕") + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) + assert isinstance(kernel_size, int), kernel_size + # Actual trainable parameters + if r > 0: + self.r = self.r * kernel_size if "fixconvk" not in lora_init_weights else self.r + if "pissa" in lora_init_weights: + weight = self.weight.data + n,c,h = weight.size(0), weight.size(1), weight.size(2) + w = weight.size(3) if len(weight.size())>=4 else 1 + self.r = min(n, self.r, c*h*w) + lora_init_weights = "normal" if self.r==1 else lora_init_weights + + self.lora_A = nn.Parameter( + self.weight.new_zeros((self.r, out_channels//self.groups*kernel_size**(self.weight.dim()-2))) + ) + self.lora_B = nn.Parameter( + self.weight.new_zeros((in_channels, self.r)) + ) + self.scaling = self.lora_alpha / self.r if "noscale" not in lora_init_weights else 1.0 + if "cachewnorm" in lora_init_weights: + self.weight_cache = nn.Parameter((self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling) + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + + self.reset_parameters() + self.lora_init_weights = lora_init_weights + if "pissa" not in lora_init_weights: + self.init_parameters() + + self.merged = False + + def init_parameters(self): + if hasattr(self, 'lora_A'): + if self.lora_init_weights=="normal": + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + elif "pissa" in self.lora_init_weights: + weight = self.weight.data + n,c,h = weight.size(0), weight.size(1), weight.size(2) + w = weight.size(3) if len(weight.size())>=4 else -1 + weight = weight.view(n, -1) + if self.lora_init_weights[:5]=="pissa": + V, S, Uh = torch.linalg.svd(weight, full_matrices=False) + Vr = V[:, : self.r] + Sr = S[: self.r] + Sr /= self.scaling + Uhr = Uh[: self.r] + elif len(self.lora_init_weights.split("_niter_")) == 2: + # print("debug conv:", self.weight.data.size(), self.r, int(self.lora_init_weights.split("_niter_")[-1])) + Vr, Sr, Ur = torch.svd_lowrank(weight, self.r, niter=int(self.lora_init_weights.split("_niter_")[-1])) + Sr /= self.scaling + Uhr = Ur.t() + else: + assert(0) + + lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr + lora_B = Vr @ torch.diag(torch.sqrt(Sr)) + if self.lora_A.data.size() != lora_A.size() or self.lora_B.data.size() != lora_B.size(): + print("self.weight.data.size(), weight.size()", self.weight.data.size(), weight.size()) + print("V.size(), S.size(), Uh.size()", V.size(), S.size(), Uh.size()) + print("self.lora_A={},self.lora_B={},lora_A={},lora_B={}".format(self.lora_A.data.size(), self.lora_B.data.size(), lora_A.size(), lora_B.size())) + print("self.r,c,h,w", self.r,c,h,w) + assert(0) + # else: + # print("lora_A={},lora_B={}".format(lora_A.size(), lora_B.size())) + + self.lora_A.data = lora_A#.view(self.r,c,h,w) if w>=1 else lora_A.view(self.r,c,h) + self.lora_B.data = lora_B#[:,:,None,None] + dtype = self.weight.dtype + if "cachewnorm" in self.lora_init_weights: + self.weight_cache.data = (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + if "rmwnorm" in self.lora_init_weights: + weight = self.weight.data - (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + weight = weight.to(dtype) + self.weight.data = weight + + if self.lora_init_weights[:5]=="pissa": + del V, S, Uh, Vr, Sr, Uhr + elif len(self.lora_init_weights.split("_niter_")) == 2: + del Vr, Sr, Ur, Uhr + else: + assert(0) + else: + assert(0) + + def unmerge_parameters(self): + if self.merged: + # Make sure that the weights are not merged + if self.r > 0: + # Make sure that the weights are not merged + self.weight = self.weight.to(self.lora_B.device) + self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + self.merged = False + return + + def merge_parameters(self): + if not self.merged: + # Merge the weights and mark it + if self.r > 0: + # Merge the weights and mark it + self.weight = self.weight.to(self.lora_B.device) + self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + self.merged = True + return + + def train(self, mode=True): + super(ConvTransposeLoRA, self).train(mode) + if mode: + if self.merge_weights and self.merged: + if self.r > 0: + # Make sure that the weights are not merged + self.weight = self.weight.to(self.lora_B.device) + self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + self.merged = False + else: + if self.merge_weights and not self.merged: + if self.r > 0: + # Merge the weights and mark it + self.weight = self.weight.to(self.lora_B.device) + self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + self.merged = True + + def forward(self, x, output_size = None): + output_padding = self._output_padding( + input = x, + output_size = output_size, + stride=self.stride, + padding=self.padding, + kernel_size=self.kernel_size, + num_spatial_dims=self.num_spatial_dims, ### NOTE 非常坑!torch 1.12这里多了个参数!!!! + dilation=self.dilation + ) + + if self.r > 0 and not self.merged: + + if self.num_spatial_dims==1: + if "cachewnorm" in self.lora_init_weights and not "rmwnorm" in self.lora_init_weights: + return F.conv_transpose1d( + x, + self.weight - self.weight_cache + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation + ) + else: + return F.conv_transpose1d( + x, + self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation + ) + else: + assert(0), "还没写好" + else: + return F.conv_transpose1d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation + ) + +class ConvTranspose1d(ConvTransposeLoRA, nn.ConvTranspose1d): + def __init__(self, + *args, + r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, lora_init_weights="normal", + **kwargs + ): + self.num_spatial_dims = 1 + # super(ConvTranspose1d, self).__init__(*args, **kwargs) + # print("nn.ConvTranspose1d.__init__开始") + nn.ConvTranspose1d.__init__(self, *args, **kwargs) + # print("nn.ConvTranspose1d.__init__完毕") + # print("ConvLoRA.__init__开始") + ConvTransposeLoRA.__init__(self, + *args, + r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights, lora_init_weights=lora_init_weights, + **kwargs + ) + # print("ConvLoRA.__init__完毕") \ No newline at end of file diff --git a/cosyvoice/loralib/utils.py b/cosyvoice/loralib/utils.py new file mode 100644 index 0000000..36605c1 --- /dev/null +++ b/cosyvoice/loralib/utils.py @@ -0,0 +1,342 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ +import torch +import torch.nn as nn + +from typing import Dict + +from torch.nn.utils import weight_norm, remove_weight_norm +from .layers import LoRALayer, Linear, Embedding, Conv1d, ConvTranspose1d + +def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none', debug=False) -> None: + for n, p in model.named_parameters(): + if 'lora_' not in n: + p.requires_grad = False + if debug: + print("检查梯度:", n, p.requires_grad) + if bias == 'none': + return + elif bias == 'all': + for n, p in model.named_parameters(): + if 'bias' in n: + p.requires_grad = True + elif bias == 'lora_only': + for m in model.modules(): + if isinstance(m, LoRALayer) and \ + hasattr(m, 'bias') and \ + m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError + + +def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]: + my_state_dict = model.state_dict() + if bias == 'none': + return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or "weight_cache" in k} + elif bias == 'all': + return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k or "weight_cache" in k } + elif bias == 'lora_only': + to_return = {} + for k in my_state_dict: + if 'lora_' in k or "weight_cache" in k: + to_return[k] = my_state_dict[k] + bias_name = k.split('lora_')[0]+'bias' + if bias_name in my_state_dict: + to_return[bias_name] = my_state_dict[bias_name] + return to_return + else: + raise NotImplementedError + +def getModelSize_lora(model, hps): + param_size_dict = {} + for name, module in model.named_modules(): + if not isinstance(module, LoRALayer): + continue + + parent_name = name.split(".")[0] + if parent_name not in param_size_dict: + param_size_dict[parent_name] = 0 + + if hasattr(module, 'lora_A'): + param = module.lora_A + param_size_dict[parent_name] += param.nelement() * param.element_size() / 1024 / 1024 + param = module.lora_B + param_size_dict[parent_name] += param.nelement() * param.element_size() / 1024 / 1024 + if hps.lora_bias == "lora_only": + if hasattr(module, 'bias'): + if module.bias is not None: + param = module.bias + param_size_dict[parent_name] += param.nelement() * param.element_size() / 1024 / 1024 + + buffer_size = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + # buffer_sum += buffer.nelement() + + # all_size = (param_size + buffer_size) / 1024 / 1024 + all_size = 0 + for key, param_size in param_size_dict.items(): + print("lora节点{}大小为:{:.3f}MB".format(key, param_size)) + all_size += param_size + print("lora模型总大小为:{:.3f}MB".format(all_size)) + if "enc_q" in param_size_dict: + print("lora模型去除enc_q总大小为:{:.3f}MB".format(all_size-param_size_dict["enc_q"])) + return + +def adjust_r(name, hps, lora_r): + lora_max = hps.lora_max #8 + lora_min = hps.lora_min #2 + # lora_lambda_flow = getattr(hps, "lora_lambda_flow", 1.0) + # lora_lambda_dec = getattr(hps, "lora_lambda_dec", 1.0) + father_type = name.split(".")[0] + lora_lambda = getattr(hps, "lora_lambda_{}".format(father_type), 1.0) + + if father_type in ["flow", "dec", "dur", "enc_p", "enc_q"]: + lora_r = int(lora_r * lora_lambda) + + return lora_r + +def replace_specific_layer_4lora(model, hps): + lora_r = hps.lora_r #16 + lora_alpha = hps.lora_alpha #32 + lora_dropout = hps.lora_dropout #0.01 + lora_max = hps.lora_max #8 + lora_mid_scale = hps.lora_mid_scale #16 + lora_min = hps.lora_min #2 + lora_init_weights = hps.lora_init_weights + + unique_dim = 4 + + # Recursively visit all modules and submodules + for name, module in model.named_modules(): + # Check if the module is an instance of the specified layers + if isinstance(module, torch.nn.Linear): + out_features, in_features = module.weight.shape + device = module.weight.device + dtype = module.weight.dtype + + ### NOTE ### NOTE ### NOTE ### NOTE ### NOTE + lora_init_weights_local = "normal" if out_features<=unique_dim or in_features<=unique_dim else lora_init_weights + lora_r_new = adjust_r(name, hps, lora_r) + # print("{} {} r:{} in_feat:{} out_feat:{}".format(name, lora_init_weights_local.upper() if lora_init_weights_local!='normal' else "", lora_r_new, in_features, out_features), device) + localnet = Linear(in_features, out_features, bias=False if module.bias is None else True, + r=lora_r_new, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=hps.merge_weights, + lora_init_weights=lora_init_weights_local) + localnet.to(dtype) + localnet.to(device) + set_layer_from_name(model, name, localnet) + elif isinstance(module, torch.nn.Embedding): + num_embeddings, embedding_dim = module.weight.shape + device = module.weight.device + dtype = module.weight.dtype + + ### NOTE ### NOTE ### NOTE ### NOTE ### NOTE + # lora_init_weights_local = "normal" if num_embeddings<=unique_dim or embedding_dim<=unique_dim else lora_init_weights + lora_init_weights_local = "normal" + lora_r_new = adjust_r(name, hps, lora_r) + # print("{} {} r:{} num_emb:{} emb_dim:{}".format(name, lora_init_weights_local.upper() if lora_init_weights_local!='normal' else "", lora_r_new, num_embeddings, embedding_dim), device) + localnet = Embedding(num_embeddings, embedding_dim, + r=lora_r_new, lora_alpha=lora_alpha, merge_weights=hps.merge_weights, + lora_init_weights=lora_init_weights_local) + localnet.to(dtype) + localnet.to(device) + set_layer_from_name(model, name, localnet) + elif isinstance(module, torch.nn.Conv1d): + # device = "cuda:0" + device = module.weight.device ### NOTE 很恶心 + dtype = module.weight.dtype + lora_r_new = min(lora_max, lora_r, max(((module.in_channels + module.out_channels)//2)//lora_mid_scale, lora_min)) + lora_r_new = adjust_r(name, hps, lora_r_new) + + ### NOTE ### NOTE ### NOTE ### NOTE ### NOTE + lora_init_weights_local = lora_init_weights + if "noconvk" in hps.lora_init_weights: + lora_init_weights_local = "normal" if module.kernel_size[0]>1 else lora_init_weights + if getattr(module, "weight_g", None) is not None and "nownorm" in hps.lora_init_weights: + lora_init_weights_local = "normal" + # print("{} {} r:{} in_c:{} out_c:{} ker:{}".format(name, lora_init_weights_local.upper() if lora_init_weights_local!='normal' else "", lora_r_new, module.in_channels, module.out_channels, module.kernel_size[0]), device) + localnet = Conv1d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size[0], + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + device=module.weight.device, + dtype=module.weight.dtype, + bias=False if module.bias is None else True, + r=lora_r_new, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=hps.merge_weights, + lora_init_weights=lora_init_weights_local) + ######## NOTE 判断是否有weight_norm 会修改weight!!! ######## + if getattr(module, "weight_g", None) is not None: + localnet = weight_norm(localnet) + ########################################### + localnet.to(dtype) + localnet.to(device) + set_layer_from_name(model, name, localnet) + elif isinstance(module, torch.nn.ConvTranspose1d): + # print(name, "反卷积还没写好") + # device = "cuda:0" + device = module.weight.device ### NOTE 很恶心 + dtype = module.weight.dtype + lora_r_new = min(lora_max, lora_r, max(((module.in_channels + module.out_channels)//2)//lora_mid_scale, lora_min)) + lora_r_new = adjust_r(name, hps, lora_r_new) + + ### NOTE ### NOTE ### NOTE ### NOTE ### NOTE + lora_init_weights_local = lora_init_weights + if "noconvk" in hps.lora_init_weights: + lora_init_weights_local = "normal" if module.kernel_size[0]>1 else lora_init_weights + # print("{} type:{} init:{} r:{} in_c:{} out_c:{} ker:{} s:{} d:{} g:{}".format(name, type(module), lora_init_weights_local, lora_r_new, module.in_channels, module.out_channels, module.kernel_size[0], module.stride, module.dilation, module.groups)) + if getattr(module, "weight_g", None) is not None and "nownorm" in hps.lora_init_weights: + lora_init_weights_local = "normal" + # print("{} {} r:{} in_c:{} out_c:{} trans ker:{}".format(name, lora_init_weights_local.upper() if lora_init_weights_local!='normal' else "", lora_r_new, module.in_channels, module.out_channels, module.kernel_size[0]), device) + localnet = ConvTranspose1d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size[0], + stride=module.stride, + padding=module.padding, + output_padding=module.output_padding, + dilation=module.dilation, + groups=module.groups, + device=module.weight.device, + dtype=module.weight.dtype, + bias=False if module.bias is None else True, + r=lora_r_new, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=hps.merge_weights, + lora_init_weights=lora_init_weights_local) + ######## NOTE 判断是否有weight_norm 会修改weight!!! ######## + if getattr(module, "weight_g", None) is not None: + localnet = weight_norm(localnet) + ########################################### + localnet.to(dtype) + localnet.to(device) + set_layer_from_name(model, name, localnet) + elif isinstance(module, torch.nn.Conv2d): + assert(0) + pass + + ### module_list=(torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, torch.nn.Conv1d, torch.nn.ConvTranspose1d) + return model + +def handle_pissa_weight(model, hps, action="init"): + lora_r = hps.lora_r #16 + lora_alpha = hps.lora_alpha #32 + lora_dropout = hps.lora_dropout #0.01 + lora_max = hps.lora_max #8 + lora_mid_scale = hps.lora_mid_scale #16 + lora_min = hps.lora_min #2 + lora_init_weights = hps.lora_init_weights + + if "pissa" not in lora_init_weights: + print("WARNING:使用的不是pissa lora?") + return model + + unique_dim = 4 + + # Recursively visit all modules and submodules + for name, module in model.named_modules(): + # Check if the module is an instance of the specified layers + if isinstance(module, torch.nn.Linear): + if action=="init": + out_features, in_features = module.weight.shape + device = module.weight.device + dtype = module.weight.dtype + + ### NOTE ### NOTE ### NOTE ### NOTE ### NOTE + lora_init_weights_local = "normal" if out_features<=unique_dim or in_features<=unique_dim else lora_init_weights + lora_r_new = adjust_r(name, hps, lora_r) + # print("{} type:{} init:{} r:{} in_feat:{} out_feat:{}".format(name, type(module), lora_init_weights_local, lora_r_new, in_features, out_features)) + # print("{} {} r:{} in_feat:{} out_feat:{}".format(name, lora_init_weights_local.upper() if lora_init_weights_local!='normal' else "", lora_r_new, in_features, out_features)) + print("init {} {}".format(name, lora_init_weights_local.upper() if lora_init_weights_local!='normal' else "")) + module.init_parameters() + elif action=="unmerge": + module.unmerge_parameters() + elif action=="merge": + module.merge_parameters() + elif isinstance(module, torch.nn.Embedding): + if action=="init": + num_embeddings, embedding_dim = module.weight.shape + device = module.weight.device + dtype = module.weight.dtype + + ### NOTE ### NOTE ### NOTE ### NOTE ### NOTE + # lora_init_weights_local = "normal" if num_embeddings<=unique_dim or embedding_dim<=unique_dim else lora_init_weights + lora_init_weights_local = "normal" + lora_r_new = adjust_r(name, hps, lora_r) + # print("{} {} r:{} num_emb:{} emb_dim:{}".format(name, lora_init_weights_local.upper() if lora_init_weights_local!='normal' else "", lora_r_new, num_embeddings, embedding_dim)) + print("init {} {}".format(name, lora_init_weights_local.upper() if lora_init_weights_local!='normal' else "")) + module.init_parameters() + elif action=="unmerge": + module.unmerge_parameters() + elif action=="merge": + module.merge_parameters() + elif isinstance(module, torch.nn.Conv1d): + if action=="init": + device = module.weight.device ### NOTE 很恶心 + dtype = module.weight.dtype + lora_r_new = min(lora_max, lora_r, max(((module.in_channels + module.out_channels)//2)//lora_mid_scale, lora_min)) + lora_r_new = adjust_r(name, hps, lora_r_new) + + ### NOTE ### NOTE ### NOTE ### NOTE ### NOTE + lora_init_weights_local = lora_init_weights + if "noconvk" in hps.lora_init_weights: + lora_init_weights_local = "normal" if module.kernel_size[0]>1 else lora_init_weights + if getattr(module, "weight_g", None) is not None and "nownorm" in hps.lora_init_weights: + lora_init_weights_local = "normal" + # print("init {} {} {}".format(name, lora_init_weights_local.upper() if lora_init_weights_local!='normal' else "", device)) + if getattr(module, "weight_g", None) is not None and "rmwnorm" in lora_init_weights: + module = remove_weight_norm(module) + module.init_parameters() + elif action=="unmerge": + module.unmerge_parameters() + elif action=="merge": + module.merge_parameters() + elif action=="rmwnorm": + if getattr(module, "weight_g", None) is not None and "rmwnorm" in lora_init_weights: + module = remove_weight_norm(module) + elif isinstance(module, torch.nn.ConvTranspose1d): + if action=="init": + device = module.weight.device ### NOTE 很恶心 + dtype = module.weight.dtype + lora_r_new = min(lora_max, lora_r, max(((module.in_channels + module.out_channels)//2)//lora_mid_scale, lora_min)) + lora_r_new = adjust_r(name, hps, lora_r_new) + + ### NOTE ### NOTE ### NOTE ### NOTE ### NOTE + lora_init_weights_local = lora_init_weights + if "noconvk" in hps.lora_init_weights: + lora_init_weights_local = "normal" if module.kernel_size[0]>1 else lora_init_weights + # print("{} type:{} init:{} r:{} in_c:{} out_c:{} ker:{} s:{} d:{} g:{}".format(name, type(module), lora_init_weights_local, lora_r_new, module.in_channels, module.out_channels, module.kernel_size[0], module.stride, module.dilation, module.groups)) + if getattr(module, "weight_g", None) is not None and "nownorm" in hps.lora_init_weights: + lora_init_weights_local = "normal" + # print("{} {} r:{} in_c:{} out_c:{} trans ker:{}".format(name, lora_init_weights_local.upper() if lora_init_weights_local!='normal' else "", lora_r_new, module.in_channels, module.out_channels, module.kernel_size[0])) + print("init {} {} {}".format(name, lora_init_weights_local.upper() if lora_init_weights_local!='normal' else "", device)) + if getattr(module, "weight_g", None) is not None and "rmwnorm" in lora_init_weights: + module = remove_weight_norm(module) + module.init_parameters() + elif action=="unmerge": + module.unmerge_parameters() + elif action=="merge": + module.merge_parameters() + elif action=="rmwnorm": + if getattr(module, "weight_g", None) is not None and "rmwnorm" in lora_init_weights: + module = remove_weight_norm(module) + elif isinstance(module, torch.nn.Conv2d): + assert(0) + pass + + ### module_list=(torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, torch.nn.Conv1d, torch.nn.ConvTranspose1d) + return model + +def set_layer_from_name(net, name, target_layer): + tokens = name.strip().split('.') + layer = net + for t in tokens[:-1]: + if not t.isnumeric(): + layer = getattr(layer, t) + else: + layer = layer[int(t)] + setattr(layer, tokens[-1], target_layer) \ No newline at end of file diff --git a/cosyvoice/tokenizer/phoneme_tokenizer.py b/cosyvoice/tokenizer/phoneme_tokenizer.py index 925dab2..a306bab 100644 --- a/cosyvoice/tokenizer/phoneme_tokenizer.py +++ b/cosyvoice/tokenizer/phoneme_tokenizer.py @@ -29,7 +29,7 @@ def __init__(self, if mode == 'inference': # hard code frontend model, import from local path TTS_root = "/data/megastore/Projects/DuJing/code/TTS" sys.path.append(TTS_root) - from tts.init_text_frontend import init_text_frontend + from tts.frontend.init_text_frontend import init_text_frontend if self.cn_frontend_model is None: self.cn_frontend_model = init_text_frontend('hntts') if self.en_frontend_model is None: diff --git a/examples/tts_vc/cosyvoice/conf/cosyvoice_phoneme.yaml b/examples/tts_vc/cosyvoice/conf/cosyvoice_phoneme.yaml index 8c1b3ae..af3da6c 100644 --- a/examples/tts_vc/cosyvoice/conf/cosyvoice_phoneme.yaml +++ b/examples/tts_vc/cosyvoice/conf/cosyvoice_phoneme.yaml @@ -157,7 +157,7 @@ get_tokenizer: !name:cosyvoice.tokenizer.phoneme_tokenizer.get_tokenizer tokenize: !name:cosyvoice.dataset.processor_kaldidata.tokenize_phoneme get_tokenizer: !ref filter: !name:cosyvoice.dataset.processor_kaldidata.filter - max_length: 4000 # 100 frame per second + max_length: 3600 # 100 frame per second min_length: 0 token_max_length: 512 # phoneme length token_min_length: 1