Skip to content

Commit

Permalink
Update at 20240926_23h52m36s
Browse files Browse the repository at this point in the history
  • Loading branch information
heqin-zhu committed Sep 26, 2024
0 parents commit 30e397a
Show file tree
Hide file tree
Showing 30 changed files with 100,804 additions and 0 deletions.
20 changes: 20 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
.*
!.gitignore
log*
*log
__pycache__/
*.pyc
lightning_logs/
MNIST/
weights/
/backup
*.*seq
BPPM_generater
!BPPM_generater/README.md
*.yaml
model_checkpoints
*results
examples
*txt
*png
*jpg
Empty file added BPfold/__init__.py
Empty file.
290 changes: 290 additions & 0 deletions BPfold/dataset/RNAseq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
import os
import yaml
import random
from itertools import product

import numpy as np
import torch
import torch.utils.data as data
from sklearn.model_selection import KFold

from ..util.misc import get_file_name, str_localtime
from ..util.RNA_kit import read_SS, connects2arr, mut_seq
from ..util.postprocess import get_base_index
from ..util.base_pair_motif_energy import BPM_energy
from ..util.base_pair_probability import read_BPPM, gen_BPPM


CANONICAL_PAIRS = {'AU', 'UA', 'GC', 'CG', 'GU', 'UG'}


class RNAseq_data(data.Dataset):
def __init__(self,
data_dir,
index_name='data_index.yaml',
phase='train',
Lmax=600,
Lmin=0,
fold=0,
nfolds=4,
seed=42,
cache_dir='.cache_data',
mask_only=False,
method='CDPfold',
trainall=False,
predict_files=None,
training_set=None,
test_set=None,
use_BPE=True,
use_BPP=True,
normalize_energy=False,
verbose=False,
para_dir='paras/',
*args,
**kargs,
):
# Set all input args as attributes
self.__dict__.update(locals())
self.use_BPE = use_BPE
self.use_BPP = use_BPP
self.phase = phase.lower()
self.verbose = verbose
if self.phase == 'predict':
self.cache_dir = os.path.join(cache_dir, method)
else:
self.cache_dir = os.path.join(data_dir, cache_dir, method)
self.method = method
self.data_dir = data_dir
self.mask_only = mask_only
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)

# data filter
index_file = os.path.join(data_dir, index_name)
if self.phase == 'predict':
self.Lmax = Lmax
self.file_list = predict_files
else:
if self.phase in {'train', 'validate'}:
with open(index_file) as f:
index_dic = yaml.load(f.read(), Loader=yaml.FullLoader)
all_files = index_dic['train'][:]
if 'validate' in index_dic:
all_files += index_dic['validate']
if trainall:
all_files += index_dic['test']
# train with specific datasets
if self.phase == 'train' and training_set:
all_files = [f for f in all_files if f['dataset'] in training_set]
print(f'training sets: {training_set}')

all_files.sort(key=lambda dic: dic['path'])

# Kfold
if nfolds <= 1: # no kfold
if self.phase == 'validate':
self.file_list = index_dic['test'][:]
else: # train
self.file_list = all_files[:]
random.shuffle(self.file_list)
else: # kfold training
split = list(KFold(n_splits=nfolds, random_state=seed, shuffle=True).split(range(len(all_files))))[fold][0 if self.phase == 'train' else 1]
self.file_list = [all_files[i] for i in split]
self.Lmax = Lmax
# limit length
self.file_list = [f for f in self.file_list if Lmin<=f['length']<=self.Lmax]

elif self.phase in ['test']:
with open(index_file) as f:
index_dic = yaml.load(f.read(), Loader=yaml.FullLoader)
self.file_list = index_dic['test'][:]
if self.phase == 'test' and test_set:
print(f'test sets: {test_set}')
self.file_list = [f for f in self.file_list if f['dataset'] in test_set]
self.Lmax = max([f['length'] for f in self.file_list])
else:
raise NotImplementedError
print(f'phase={self.phase}, num={len(self.file_list)}, nfolds={nfolds}: {index_file}')
print(f'use_BPP={use_BPP}, use_BPE={use_BPE}')
for dic in self.file_list:
dic['path'] = os.path.join(self.data_dir, dic['path'])

if self.use_BPE:
self.normalize_energy = normalize_energy
self.BPM_ene = BPM_energy(path=os.path.join(para_dir, 'key.energy'))
self.base_index = get_base_index() # for matrix embed
self.num_base = len(self.base_index)
self.index_base = {v: k for k, v in self.base_index.items()}
self.token_index = {k: v for k, v in self.base_index.items()}
self.token_index.update({tok: len(self.base_index)+i for i, tok in enumerate(['START', 'END', 'EMPTY'])}) # for sequence embed
self.noncanonical = [self.index_base[i]+self.index_base[j] not in CANONICAL_PAIRS for i, j in product(range(self.num_base), range(self.num_base))]
self.noncanonical_flag = np.array(self.noncanonical, dtype=bool)
self.to_device_keywords = {'input', 'input_seqmat', 'mask', 'forward_mask', 'BPPM', 'BPEM', 'seq_onehot', 'nc_map',}
if self.phase !='predict':
self.to_device_keywords.add('gt')
print(self.token_index)


def __len__(self):
return len(self.file_list)

def prepare_data(self, name, seq, connects=None):
ret = {}
L = len(seq)
ret['ori_seq'] = seq.upper() # AUGC and others
ret['seq'] = mut_seq(ret['ori_seq'], connects) # AUGC
ret['length'] = L

# mask: (Lmax+2)x(Lmax+2)
mask = torch.zeros(self.Lmax + 2, self.Lmax+2, dtype=torch.bool)
for row in range(1, L+1): # not including START and END
mask[row, 1:L+1] = True
# forward_mask: Lmax+2
forward_mask = torch.zeros(self.Lmax + 2, dtype=torch.bool) # START, seq, END
forward_mask[0:L+2] = True # including START and END
ret['mask'] = mask
ret['forward_mask'] = forward_mask

if self.mask_only:
return ret, None

## pad to uniform size=(Lmax+2) when batch-loading
rside_pad = self.Lmax + 1 - L

# seq_embed seq: Lmax+2
ret['input'] = self.seq_embed_sequence(ret['seq'])

# seq_embed outer product
seqmat, seq_onehot = self.seq_embed_matrix(ret['seq'], return_onehot=True) # N**2xLxL
seqmat_pad = np.pad(seqmat, ((0, 0), (1,rside_pad), (1, rside_pad)), constant_values=0)
ret['input_seqmat'] = torch.FloatTensor(seqmat_pad) # NUM_BASE**2 x (Lmax+2) x (Lmax+2)
seq_onehot_pad = np.pad(seq_onehot, ((1, rside_pad), (0, 0)), constant_values=0)
ret['seq_onehot'] = torch.FloatTensor(seq_onehot_pad) # LxNUM_BASE

# nc_map: noncanonical : (Lmax+2) x (Lmax+2)
nc_map = seqmat[self.noncanonical_flag].sum(axis=0).astype(bool) # LxL
nc_map_pad = np.pad(nc_map, ((1, rside_pad), (1, rside_pad)), constant_values=0)
ret['nc_map'] = torch.FloatTensor(nc_map_pad)

# BPPM: 1x(Lmax+2)x(Lmax+2)
if self.use_BPP:
bppm = self.load_BPPM(seq=seq, name=name, use_cache=(self.phase!='predict'))
bppm_pad = np.pad(bppm, ((1, rside_pad), (1, rside_pad)), constant_values=0)
ret['BPPM'] = torch.FloatTensor(bppm_pad).unsqueeze(0)
# ret['BPPM'] = torch.log(ret['BPPM']+1e-5) # Note: Energy ~ log(p)
# ret['BPPM'] = - ret['BPPM']
# BPEM: 1x(Lmax+2)x(Lmax+2)
if self.use_BPE:
bpem = self.BPM_ene.get_energy(ret['seq'], normalize_energy=self.normalize_energy)
if self.normalize_energy:
bpem_pad = np.pad(bpem, ((0, 0), (1, rside_pad), (1, rside_pad)), constant_values=0)
ret['BPEM'] = torch.FloatTensor(bpem_pad)
else:
bpem_pad = np.pad(bpem, ((1, rside_pad), (1, rside_pad)), constant_values=0)
ret['BPEM'] = torch.FloatTensor(bpem_pad).unsqueeze(0)

y = {k: ret[k] for k in ['mask', 'forward_mask', 'nc_map', 'seq_onehot']}
# gt, contact map: (Lmax+2)x(Lmax+2)
if self.phase != 'predict':
gt = connects2arr(connects)
gt_pad = np.pad(gt, ((1, rside_pad), (1, rside_pad)), constant_values=0)
ret['gt'] = torch.FloatTensor(gt_pad)
y['gt'] = ret['gt']
return ret, y

def __getitem__(self, idx):
info_dic = self.file_list[idx]

dataset = info_dic['dataset'] if 'dataset' in info_dic else 'RNAseq'
seq = name = connects = None

if 'path' in info_dic:
path = info_dic['path']
name, suf = get_file_name(path, return_suf=True)
if suf.lower() in {'.bpseq', '.ct', '.dbn'}:
seq, connects = read_SS(path)
else:
with open(path) as fp:
fp.readline()
seq = fp.readline().strip(' \n')
else:
if 'seq' in info_dic:
seq = info_dic['seq']
else:
raise Exception(f'[Error] seq or path needed: {info_dic}')
name = info_dic['name'] if 'name' in info_dic else str_localtime()
if connects is None and self.phase != 'predict':
raise Exception(f'[Error] Invalid input: {info_dic} at {self.phase} stage, gt SS needed.')
# # load_data
# ret_data = {}
# cache_path = os.path.join(self.cache_dir, name+'.pth')
# if os.path.exists(cache_path):
# ret_data = torch.load(cache_path)
# else:
# ret_data, y = self.prepare_data(name=name, seq=seq, connects=connects)
# if not self.mask_only:
# torch.save(ret_data, cache_path)
ret_data, y = self.prepare_data(name=name, seq=seq, connects=connects)

if self.mask_only:
return {k: ret_data[k] for k in ['mask', 'forward_mask', 'length']}
# update ret dic
ret_data.update({'name': name, 'idx': idx, 'dataset': dataset})
return ret_data, y

def seq_embed_matrix(self, seq, return_onehot=False):
'''
seq: str, len=L, 'AUGC...'
ret: tensor, NUM_BASE**2 x L x L, 0, 1 val
'''
L = len(seq)
# seq onehot L x NUM_BASE
seq_onehot = np.zeros((L, self.num_base), dtype=float)
for i in range(L): # should be consistent to function `postprocess` in ..util.postprocess
seq_onehot[i][self.base_index[seq[i]]] = 1

# seq embeding: NUM_BASE*NUM_BASE x L x L
seq_embed = np.zeros((self.num_base**2, L, L)) # AUGC
for n, (i, j) in enumerate(product(range(self.num_base), range(self.num_base))):
seq_embed[n] = np.matmul(seq_onehot[:, i].reshape(-1, 1), seq_onehot[:, j].reshape(1, -1))
if return_onehot:
return seq_embed, seq_onehot
else:
return seq_embed

def seq_embed_sequence(self, seq):
'''
seq: str, len=L, 'AUGC...'
ret: ndarray, Lmax+2, 0-6 val, repr 'start AUGC... end empty...'
'''
ret = [self.token_index['START']]
ret.extend(self.token_index[s] for s in seq)
ret.append(self.token_index['END'])

# Lmax + 2, final length
for i in range(self.Lmax - len(seq)):
ret.append(self.token_index['EMPTY'])
ret = torch.Tensor(ret).int() # float? TODO
return ret


def load_BPPM(self, seq, name, use_cache=True):

txt_path = os.path.join(self.cache_dir, name+'.txt')
npy_path = os.path.join(self.cache_dir, name+'.npy')

if use_cache and os.path.exists(npy_path):
return np.load(npy_path, allow_pickle=True)
else:
if not os.path.exists(txt_path):
try:
if self.phase == 'predict' and self.verbose:
print(f'[Info] Using "{self.method}" to generate BPPM, saving at "{txt_path}"')
gen_BPPM(txt_path, mut_seq(seq), name, self.method)
except Exception as e:
if self.phase == 'predict' and self.verbose:
print(f'[Warning] {e}, using CDPfold instead')
gen_BPPM(txt_path, mut_seq(seq), name, 'CDPfold')
BPPM = read_BPPM(txt_path, len(seq))
np.save(npy_path, BPPM)
return BPPM
7 changes: 7 additions & 0 deletions BPfold/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .RNAseq import RNAseq_data


def get_dataset(s):
return {
'rnaseq': RNAseq_data,
}[s.lower()]
Loading

0 comments on commit 30e397a

Please sign in to comment.