From 495be6cd6aa588e1c70f86ae6cc5377b965ac813 Mon Sep 17 00:00:00 2001 From: Andrew Leaver-Fay Date: Thu, 30 May 2024 16:46:25 -0400 Subject: [PATCH 01/52] Sketch the creation of KinForests from precomputed restype stencils --- tmol/chemical/patched_chemdb.py | 1 + tmol/chemical/restypes.py | 6 + tmol/database/chemical/__init__.py | 4 + tmol/database/default/chemical/chemical.yaml | 23 ++ tmol/kinematics/builder.py | 2 + tmol/kinematics/check_fold_forest.py | 15 +- tmol/pose/packed_block_types.py | 17 + tmol/pose/pose_kinematics.py | 68 +++- ...st_create_scan_orering_from_block_types.py | 383 ++++++++++++++++++ tmol/tests/pose/test_pose_stack_kinematics.py | 54 +++ 10 files changed, 562 insertions(+), 11 deletions(-) create mode 100644 tmol/tests/kinematics/test_create_scan_orering_from_block_types.py diff --git a/tmol/chemical/patched_chemdb.py b/tmol/chemical/patched_chemdb.py index 2f5039e34..1276cbd3e 100644 --- a/tmol/chemical/patched_chemdb.py +++ b/tmol/chemical/patched_chemdb.py @@ -510,6 +510,7 @@ def do_patch(res, variant, resgraph, patchgraph, marked): icoors=res.icoors, properties=res.properties, chi_samples=res.chi_samples, + default_jump_connection_atom=res.default_jump_connection_atom, ) # 1. remove atoms diff --git a/tmol/chemical/restypes.py b/tmol/chemical/restypes.py index b9570ed11..2e410c0ba 100644 --- a/tmol/chemical/restypes.py +++ b/tmol/chemical/restypes.py @@ -448,6 +448,12 @@ def _setup_icoors_geom(self): def compute_ideal_coords(self): return build_coords_from_icoors(self.icoors_ancestors, self.icoors_geom) + default_jump_connection_atom_index: int = attr.ib() + + @default_jump_connection_atom_index.default + def get_default_jump_connection_atom_index(self): + return self.atom_to_idx[self.default_jump_connection_atom] + @attr.s(auto_attribs=True) class ResidueTypeSet: diff --git a/tmol/database/chemical/__init__.py b/tmol/database/chemical/__init__.py index 11bd40cff..d9501a938 100644 --- a/tmol/database/chemical/__init__.py +++ b/tmol/database/chemical/__init__.py @@ -136,6 +136,10 @@ class RawResidueType: icoors: Tuple[Icoor, ...] properties: ChemicalProperties chi_samples: Tuple[ChiSamples, ...] + default_jump_connection_atom: str + + def atom_name(self, index): + return self.atoms[index].name @attr.s(auto_attribs=True, frozen=True, slots=True) diff --git a/tmol/database/default/chemical/chemical.yaml b/tmol/database/default/chemical/chemical.yaml index ef45a840f..88ccd5b17 100644 --- a/tmol/database/default/chemical/chemical.yaml +++ b/tmol/database/default/chemical/chemical.yaml @@ -143,6 +143,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: ARG base_name: ARG name3: ARG @@ -277,6 +278,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: ASN base_name: ASN name3: ASN @@ -373,6 +375,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: ASP base_name: ASP name3: ASP @@ -461,6 +464,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: CYS base_name: CYS name3: CYS @@ -549,6 +553,7 @@ residues: - chi_dihedral: chi2 samples: [60, 180, 300] expansions: [] + default_jump_connection_atom: CA - name: CYD base_name: CYD name3: CYS @@ -631,6 +636,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: GLN base_name: GLN name3: GLN @@ -739,6 +745,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: GLU base_name: GLU name3: GLU @@ -839,6 +846,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: GLY base_name: GLY name3: GLY @@ -910,6 +918,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: HIS base_name: HIS name3: HIS @@ -1016,6 +1025,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: HIS_D base_name: HIS_D name3: HIS @@ -1121,6 +1131,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: ILE base_name: ILE name3: ILE @@ -1236,6 +1247,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: LEU base_name: LEU name3: LEU @@ -1351,6 +1363,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: LYS base_name: LYS name3: LYS @@ -1480,6 +1493,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: MET base_name: MET name3: MET @@ -1589,6 +1603,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: PHE base_name: PHE name3: PHE @@ -1702,6 +1717,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: PRO base_name: PRO name3: PRO @@ -1802,6 +1818,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: SER base_name: SER name3: SER @@ -1890,6 +1907,7 @@ residues: - chi_dihedral: chi2 samples: [0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200, 220, 240, 260, 280, 300, 320, 340] expansions: [] + default_jump_connection_atom: CA - name: THR base_name: THR name3: THR @@ -1988,6 +2006,7 @@ residues: - chi_dihedral: chi2 samples: [0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200, 220, 240, 260, 280, 300, 320, 340] expansions: [] + default_jump_connection_atom: CA - name: TRP base_name: TRP name3: TRP @@ -2114,6 +2133,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: TYR base_name: TYR name3: TYR @@ -2235,6 +2255,7 @@ residues: - chi_dihedral: chi3 samples: [0, 180] expansions: [20] + default_jump_connection_atom: CA - name: VAL base_name: VAL name3: VAL @@ -2339,6 +2360,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: CA - name: HOH base_name: HOH name3: HOH @@ -2374,6 +2396,7 @@ residues: pH: 7 virtual: [] chi_samples: [] + default_jump_connection_atom: O variants: - name: CarboxyTerminus diff --git a/tmol/kinematics/builder.py b/tmol/kinematics/builder.py index 66ada1dee..d304285a7 100644 --- a/tmol/kinematics/builder.py +++ b/tmol/kinematics/builder.py @@ -218,6 +218,8 @@ def bonds_to_forest( kfo_2_to, preds = csgraph.breadth_first_order( bond_graph, roots[0], directed=False, return_predecessors=True ) + print("kfo_2_to", kfo_2_to) + print("preds", preds) to_parents_in_kfo = preds[kfo_2_to] n_target_atoms = numpy.max(kfo_2_to) + 1 diff --git a/tmol/kinematics/check_fold_forest.py b/tmol/kinematics/check_fold_forest.py index 62bda20bc..352c1f12a 100644 --- a/tmol/kinematics/check_fold_forest.py +++ b/tmol/kinematics/check_fold_forest.py @@ -9,6 +9,15 @@ def mark_polymeric_bonds_in_foldforest_edges( n_poses: int, max_n_blocks: int, edges: NDArray[int][:, :, 4] ): + """Make each implicit i-to-i+1 or i-to-(i-1) polymer bond explicit + + Notes + ----- + This code does not ensure that the polymeric bonds between + these two residues are present in the PoseStack; this means + that if there are missing loops, e.g., that we can still + "fold through" them. + """ polymeric_connection_in_edge = numpy.zeros( (n_poses, max_n_blocks, max_n_blocks), dtype=numpy.int64 ) @@ -85,11 +94,7 @@ def validate_fold_forest_jit( # ok, let's get the other edges incorporated for i in range(n_poses): for j in range(max_n_edges): - if edges[i, j, 0] == EdgeType.jump: - r1 = edges[i, j, 1] - r2 = edges[i, j, 2] - connections[i, r1, r2] += 1 - if edges[i, j, 0] == EdgeType.chemical: + if edges[i, j, 0] == EdgeType.jump or edges[i, j, 0] == EdgeType.chemical: r1 = edges[i, j, 1] r2 = edges[i, j, 2] connections[i, r1, r2] += 1 diff --git a/tmol/pose/packed_block_types.py b/tmol/pose/packed_block_types.py index 11aed6ead..3b3cf2629 100644 --- a/tmol/pose/packed_block_types.py +++ b/tmol/pose/packed_block_types.py @@ -97,6 +97,8 @@ class PackedBlockTypes: down_conn_inds: Tensor[torch.int32][:] up_conn_inds: Tensor[torch.int32][:] + default_jump_connection_atom_inds: Tensor[torch.int32][:] + device: torch.device @property @@ -133,6 +135,9 @@ def from_restype_list( down_conn_inds, up_conn_inds = cls.join_polymeric_connections( active_block_types, device ) + def_jumpconn_inds = cls.join_default_jump_connection_atom_inds( + active_block_types, device + ) return cls( chem_db=chem_db, @@ -158,6 +163,7 @@ def from_restype_list( conn_atom=conn_atom, down_conn_inds=down_conn_inds, up_conn_inds=up_conn_inds, + default_jump_connection_atom_inds=def_jumpconn_inds, device=device, ) @@ -293,6 +299,14 @@ def join_polymeric_connections(cls, active_block_types, device): ) return down_conn_inds, up_conn_inds + @classmethod + def join_default_jump_connection_atom_inds(cls, active_block_types, device): + return torch.tensor( + [bt.default_jump_connection_atom_index for bt in active_block_types], + dtype=torch.int32, + device=device, + ) + def inds_for_res(self, residues: Sequence[Residue]): return self.restype_index.get_indexer( [res.residue_type.name for res in residues] @@ -331,6 +345,9 @@ def cpu_equiv(x): conn_atom=cpu_equiv(self.conn_atom), down_conn_inds=cpu_equiv(self.down_conn_inds), up_conn_inds=cpu_equiv(self.up_conn_inds), + default_jump_connection_atom_inds=cpu_equiv( + self.default_jump_connection_atom_inds + ), device=cpu_equiv(self.device), ) for self_key in self.__dict__: diff --git a/tmol/pose/pose_kinematics.py b/tmol/pose/pose_kinematics.py index b17a30b83..b4db49fec 100644 --- a/tmol/pose/pose_kinematics.py +++ b/tmol/pose/pose_kinematics.py @@ -1,15 +1,20 @@ import torch import numpy +import numba from tmol.types.array import NDArray from tmol.types.torch import Tensor from tmol.types.functional import validate_args +from tmol.pose.packed_block_types import PackedBlockTypes from tmol.pose.pose_stack import PoseStack from tmol.kinematics.builder import KinematicBuilder from tmol.kinematics.datatypes import KinForest -from tmol.kinematics.fold_forest import FoldForest +from tmol.kinematics.fold_forest import FoldForest, EdgeType from tmol.kinematics.check_fold_forest import mark_polymeric_bonds_in_foldforest_edges +import scipy.sparse as sparse +import scipy.sparse.csgraph as csgraph + def get_bonds_for_named_torsions(pose_stack: PoseStack): pbt = pose_stack.packed_block_types @@ -195,6 +200,16 @@ def get_atom_inds_for_interblock_connections( out whether the polymeric connections in a pose should be included in its fold tree; the logic for handling up-to-down connections (i.e. N->C) is identical to the logic for handling down-to-up connections (i.e. C->N). + + Notes + ----- + This code will not include a connection between residues i and i+1 if + there is not a bond listed between those two residues in the + pose_stack.inter_residue_connections64 tensor, EVEN IF these residues + are listed as connected by the kinematic_connections tensor. + So, whereas "validate_fold_forest" is happy to "fold through" a break + in the chain, this code is not, and the inconsistency is surely + going to be a problem at some point. """ pbt = pose_stack.packed_block_types @@ -219,6 +234,9 @@ def get_atom_inds_for_interblock_connections( # on the other side of the connection point and, having found the complete # connections, go back and refine the list of pose-inds and block-inds that # we will work with + # NOTE: it is here that we throw away possibly-desired kinematic connections + # between residues that are not chemically bonded. We need different + # logic to differentiate between incomplete inter-residue connections that src_conn_complete = src_conn_other_block_prelim != -1 src_conn_other_block = src_conn_other_block_prelim[src_conn_complete] @@ -399,6 +417,37 @@ def get_all_bonds(pose_stack: PoseStack): return bonds +def get_jump_bonds_in_fold_forest(pose_stack, fold_forest) -> Tensor[int][:, 2]: + pbt = pose_stack.packed_block_types + t_edges = torch.tensor( + fold_forest.edges, dtype=torch.int64, device=pose_stack.device + ) + is_jump_edge = t_edges[:, :, 0] == EdgeType.jump + jump_pose_ind, jump_edge_ind = torch.nonzero(is_jump_edge, as_tuple=True) + start_block = t_edges[jump_pose_ind, jump_edge_ind, 1] + stop_block = t_edges[jump_pose_ind, jump_edge_ind, 2] + start_block_offset = pose_stack.block_coord_offset64[jump_pose_ind, start_block] + stop_block_offset = pose_stack.block_coord_offset64[jump_pose_ind, stop_block] + start_jump_atom = pbt.default_jump_connection_atom_inds[ + pose_stack.block_type_ind64[jump_pose_ind, start_block] + ].to(torch.int64) + stop_jump_atom = pbt.default_jump_connection_atom_inds[ + pose_stack.block_type_ind64[jump_pose_ind, stop_block] + ].to(torch.int64) + pose_offset = pose_stack.max_n_pose_atoms * jump_pose_ind + + def _u1(x): + return torch.unsqueeze(x, dim=1) + + return torch.cat( + ( + _u1(pose_offset + start_block_offset + start_jump_atom), + _u1(pose_offset + stop_block_offset + stop_jump_atom), + ), + dim=1, + ) + + def get_root_atom_indices( pose_stack: PoseStack, fold_tree_roots: NDArray[int][:] ) -> Tensor[torch.int32][:]: @@ -432,18 +481,25 @@ def construct_pose_stack_kinforest( # connect to. Logic in R3: take the central "mainchain" atom # which is only ok for polymers, but perverse for anything else. # What's the mainchain of a ligand?! - # jump_atom_pairs = get_jump_bonds_in_fold_forest(pose_stack, fold_forest) + jump_atom_pairs = get_jump_bonds_in_fold_forest(pose_stack, fold_forest) - all_bonds = torch.cat((intra_block_bonds, kin_polymeric_bonds), dim=0).cpu().numpy() - tor_bonds = get_bonds_for_named_torsions(pose_stack).cpu().numpy() + all_bonds = ( + torch.cat((intra_block_bonds, kin_polymeric_bonds, jump_atom_pairs), dim=0) + .cpu() + .numpy() + ) + tor_bonds = get_bonds_for_named_torsions(pose_stack) + prioritized_bonds = torch.cat((tor_bonds, jump_atom_pairs), dim=0).cpu().numpy() root_atoms = get_root_atom_indices(pose_stack, fold_forest.roots).cpu().numpy() return ( KinematicBuilder().append_connected_components( root_atoms, *KinematicBuilder.define_trees_with_prioritized_bonds( - roots=root_atoms, potential_bonds=all_bonds, prioritized_bonds=tor_bonds + roots=root_atoms, + potential_bonds=all_bonds, + prioritized_bonds=prioritized_bonds, ), - # to do: to_jump_nodes=jump_atom_pairs[0,:] + to_jump_nodes=jump_atom_pairs[:, 1], ) ).kinforest diff --git a/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py b/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py new file mode 100644 index 000000000..988578243 --- /dev/null +++ b/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py @@ -0,0 +1,383 @@ +import torch +import numpy +import attrs + +from collections import defaultdict +from numba import jit + +import scipy.sparse as sparse +import scipy.sparse.csgraph as csgraph +from tmol.types.array import NDArray + +from tmol.io.canonical_ordering import ( + default_canonical_ordering, + default_packed_block_types, + canonical_form_from_pdb, +) +from tmol.io.pose_stack_construction import pose_stack_from_canonical_form +from tmol.kinematics.fold_forest import EdgeType +from tmol.kinematics.scan_ordering import get_children + + +@jit +def get_branch_depth(parents): + # modeled off get_children + nelts = parents.shape[0] + + n_immediate_children = numpy.full(nelts, 0, dtype=numpy.int32) + for i in range(nelts): + p = parents[i] + assert p <= i + if p == i: + continue + n_immediate_children[p] += 1 + + child_list = numpy.full(nelts, -1, dtype=numpy.int32) + child_list_span = numpy.empty((nelts, 2), dtype=numpy.int32) + + child_list_span[0, 0] = 0 + child_list_span[0, 1] = n_immediate_children[0] + for i in range(1, nelts): + child_list_span[i, 0] = child_list_span[i - 1, 1] + child_list_span[i, 1] = child_list_span[i, 0] + n_immediate_children[i] + + # Pass 3, fill the child list for each parent. + # As we do this, + + +def jump_bt_atom(bt, spanning_tree): + # CA! TEMP!!! Replace with code that connects up conn atom to down conn atom + # in the spanning tree and chooses the midpoing along that path, but for now, + # CA is atom 1. + return 1 + + +@attrs.define +class GenSegScanPaths: + n_gens: NDArray[numpy.int64][:, :] # n-input x n-output + nodes_for_generation: NDArray[numpy.int64][ + :, :, :, : + ] # n-input x n-output x max-n-gen x max-n-ats-per-gen + n_scans: NDArray[numpy.int64][:, :, :] + scan_starts: NDArray[numpy.int64][:, :, :, :] + scan_is_inter_block: NDArray[bool][:, :, :, :] + scan_lengths: NDArray[numpy.int64][:, :, :, :] + + +def test_kin_tree_construction(ubq_pdb): + torch_device = torch.device("cpu") + + co = default_canonical_ordering() + pbt = default_packed_block_types(torch_device) + canonical_form = canonical_form_from_pdb(co, ubq_pdb, torch_device) + pose_stack = pose_stack_from_canonical_form(co, pbt, **canonical_form) + + # okay! + # 1. let's create some annotations of the packed block types + bt_list = [bt for bt in pbt.active_block_types if bt.name == "LEU"] + + # for bt in pbt.active_block_types: + for bt in bt_list: + n_conn = len(bt.connections) + + n_input_types = n_conn + 2 # n_conn + jump input + root "input" + n_output_types = n_conn + 1 # n_conn + jump output + + n_gens = numpy.zeros((n_input_types, n_output_types), dtype=numpy.int64) + nodes_for_generation = [ + [[] for _ in range(n_output_types)] for _2 in range(n_input_types) + ] + n_scans = [[[] for _ in range(n_output_types)] for _2 in range(n_input_types)] + scan_starts = [ + [[] for _ in range(n_output_types)] for _2 in range(n_input_types) + ] + scan_is_inter_block = [ + [[] for _ in range(n_output_types)] for _2 in range(n_input_types) + ] + scan_lengths = [ + [[] for _ in range(n_output_types)] for _2 in range(n_input_types) + ] + + def _bonds_to_csgraph( + bonds: NDArray[int][:, 2], edge_weight: float + ) -> sparse.csr_matrix: + weights_array = numpy.full((1,), edge_weight, dtype=numpy.float32) + weights = numpy.broadcast_to(weights_array, bonds[:, 0].shape) + + bonds_csr = sparse.csr_matrix( + (weights, (bonds[:, 0], bonds[:, 1])), + shape=(bt.n_atoms, bt.n_atoms), + ) + return bonds_csr + + # create a bond graph and then we will create the prioritized edges + # and all edges + potential_bonds = _bonds_to_csgraph(bt.bond_indices, -1) + print("potential bonds", potential_bonds) + tor_atoms = [ + (uaids[1][0], uaids[2][0]) + for tor, uaids in bt.torsion_to_uaids.items() + if uaids[1][0] >= 0 and uaids[2][0] >= 0 + ] + if len(tor_atoms) == 0: + tor_atoms = numpy.zeros((0, 2), dtype=numpy.int64) + else: + tor_atoms = numpy.array(tor_atoms) + print("tor atoms:", tor_atoms) + + prioritized_bonds = _bonds_to_csgraph(tor_atoms, -0.125) + print("prioritized bonds", prioritized_bonds) + bond_graph = potential_bonds + prioritized_bonds + bond_graph_spanning_tree = csgraph.minimum_spanning_tree(bond_graph.tocsr()) + + mid_bt_atom = jump_bt_atom(bt, bond_graph_spanning_tree) + + is_conn_atom = numpy.zeros((bt.n_atoms,), dtype=bool) + for i in range(n_conn): + is_conn_atom[bt.ordered_connection_atoms[i]] = True + + for i in range(n_input_types): + + i_conn_atom = bt.ordered_connection_atoms[i] if i < n_conn else mid_bt_atom + bfto_2_orig, preds = csgraph.breadth_first_order( + bond_graph_spanning_tree, + i_conn_atom, + directed=False, + return_predecessors=True, + ) + print(bt.name, i, bfto_2_orig, preds) + print([bt.atom_name(bfto_2_orig[bfs_ind]) for bfs_ind in range(bt.n_atoms)]) + for j in range(n_output_types): + + if i == j and i < n_conn: + # we cannot enter from one inter-residue connection point and then + # leave by that same inter-residue connection point unless we are + # building a jump + continue + + # now we start at the j_conn_atom and work backwards toward the root + # which marks the first scan path for this block type: the "primary exit path" + gen_scan_paths = defaultdict(list) + + j_conn_atom = ( + bt.ordered_connection_atoms[j] if j < n_conn else mid_bt_atom + ) + + first_descendant = numpy.full((bt.n_atoms,), -9999, dtype=numpy.int64) + is_on_primary_exit_path = numpy.zeros((bt.n_atoms,), dtype=bool) + is_on_primary_exit_path[i_conn_atom] = True + + focused_atom = j_conn_atom + primary_exit_scan_path = [] + while focused_atom != i_conn_atom: + print("exit path:", bt.atom_name(focused_atom)) + is_on_primary_exit_path[focused_atom] = True + primary_exit_scan_path.append(focused_atom) + pred = preds[focused_atom] + first_descendant[pred] = focused_atom + focused_atom = pred + primary_exit_scan_path.append(i_conn_atom) + primary_exit_scan_path.reverse() + # we need to prioritize exit paths of all stripes + # in constructing the trees + is_on_exit_path = is_on_primary_exit_path.copy() + for k in range(n_conn): + if k == i or k == j: + continue # truly unnecessary; nothing changes if I remove these two lines + is_on_exit_path[bt.ordered_connection_atoms[k]] = True + + print("primary_exit_scan_path:", primary_exit_scan_path) + gen_scan_paths[0].append(primary_exit_scan_path) + + # Create a list of children for each atom. + n_kids = numpy.zeros((bt.n_atoms,), dtype=numpy.int64) + atom_kids = [[] for _ in range(bt.n_atoms)] + for k in range(bt.n_atoms): + if preds[k] < 0: + assert ( + k == i_conn_atom + ), f"bad predecesor for atom {k} in {bt.name}, {preds[k]}" + continue # the root + n_kids[preds[k]] += 1 + atom_kids[preds[k]].append(k) + + # now we label each node with its "generation depth" using a + # leaf-to-root traversal perscribed by the original DFS, taking + # into account the fact that priority must be given to + # exit paths + gen_depth = numpy.ones((bt.n_atoms,), dtype=numpy.int64) + on_path_from_conn_to_i_conn_atom = numpy.zeros( + (bt.n_atoms,), dtype=bool + ) + for k in range(bt.n_atoms - 1, -1, -1): + k_atom_ind = bfto_2_orig[k] + # print("recursing upwards", i, "i_conn atom", i_conn_atom, j, "j_conn_atom", j_conn_atom, k, k_atom_ind) + k_kids = atom_kids[k_atom_ind] + # print("kids:", k_kids) + if len(k_kids) == 0: + continue + # from here forward, we know that k_atom_ind has > 0 children + + def gen_depth_given_first_descendant(): + # first set the first_descendant for k_atom_ind + # then the logic is: we have to add one to the + # gen-depth of every child but the first descendant + # which we get "for free" + # print(f"atom {bt.atom_name(k_atom_ind)} with first descendant {bt.atom_name(first_descendant[k_atom_ind]) if first_descendant[k_atom_ind] >= 0 else 'None'} and depth {gen_depth[first_descendant[k_atom_ind]] if first_descendant[k_atom_ind] >= 0 else -9999}") + return max( + [ + ( + gen_depth[k_kid] + 1 + if k_kid != first_descendant[k_atom_ind] + else gen_depth[k_kid] + ) + for k_kid in k_kids + ] + ) + + if is_on_primary_exit_path[k_atom_ind]: + # in this case, the first_descendant for this atom + # has already been decided + # print("on exit path:", bt.atom_name(k_atom_ind), first_descendant[k_atom_ind], is_conn_atom[k_atom_ind]) + if k_atom_ind == j_conn_atom: + # the first descendent is the atom on the next residue to which + # this residue is connected + gen_depth[k_atom_ind] = ( + max([gen_depth[l] for l in k_kids]) + 1 + ) + else: + # first_descendant is already determined for this atom + gen_depth[k_atom_ind] = gen_depth_given_first_descendant() + else: + + if is_conn_atom[k_atom_ind]: + # in this case, "the" connection (there can possibly be more than one!) + # will be the first child and the other descendants will be second children + # we save the gen depth, but when calculating the gen depth of the + # fold-forest, if this residue is at the upstream end of an edge, then + # its depth will have to be calculated as the min gen-depth of the + # intra-residue bits and the gen-depth of the nodes downstream of it. + gen_depth[k_atom_ind] = ( + max([gen_depth[l] for l in k_kids]) + 1 + ) + else: + # most-common case: an atom not on the primary-exit path, and that isn't + # itself a conn atom. + # First we ask: are we on one or more exit paths? + # NOTE: this just chooses the first exit path atom it encounters + # as the first descendant and so I pause and think: if we have + # a block type with 4 inter-residue connections where the fold + # forest branches at this residue, then the algorithm for constructing + # the most number-of-generations-efficient KinForest here is going + # will fail: we are treating all exit paths out of this residue + # as interchangable and we might say connection c vs c' should + # be first in a case where c' leads to more generations than c. + # The case I am designing for here is: there's a jump that has + # landed at a beta-amino acid's CA atom and there are exit paths + # through the N- and C-terminal ends of the residue and if the + # primary exit path is the C-term, then the N-term exit path should + # still have priority over the side-chain path. + # + # R + # | + # ... CB C + # \ / \ / \ + # N CA ... + # + # The path starting at CB should go towards N and not towards R. + # If we are only dealing with polymeric residues that have an + # up- and a down connection that that's it (e.g. nucleic acids), + # then this algorithm will still produce optimal KinForests. + + for kid in k_kids: + if is_on_exit_path[kid]: + first_descendant[k_atom_ind] = kid + is_on_exit_path[k_atom_ind] = True + + if not is_on_exit_path[k_atom_ind]: + # which should be the first descendant? the one with the greatest gen depth + first_descendant[k_atom_ind] = k_kids[ + numpy.argmax( + numpy.array([gen_depth[kid] for kid in k_kids]) + ) + ] + gen_depth[k_atom_ind] = gen_depth_given_first_descendant() + # print("gen_depth", bt.atom_name(k_atom_ind), "d:", gen_depth[k_atom_ind]) + # print("gen_depth", gen_depth) + + # OKAY! + # now we have paths rooted at each node up to the root + # we need to turn these paths into scan paths + processed_node_into_scan_path = is_on_primary_exit_path.copy() + gen_to_build_atom = numpy.full((bt.n_atoms,), -1, dtype=numpy.int64) + gen_to_build_atom[processed_node_into_scan_path] = 0 + print("gen depth", gen_depth) + print("starting bfs:", processed_node_into_scan_path) + for k in range(bt.n_atoms): + k_atom_ind = bfto_2_orig[k] + if processed_node_into_scan_path[k_atom_ind]: + continue + + # if we arrive here, that means k_atom_ind is the root of a + # new scan path + path = [] + # we have already processed the first scan path + # from the entrace-point atom to the first exit-point atom + assert k_atom_ind != i_conn_atom + # put the parent of this new root at the beginning of + # the scan path + path.append(preds[k_atom_ind]) + focused_atom = k_atom_ind + + gen_to_build_atom[focused_atom] = ( + gen_to_build_atom[preds[focused_atom]] + 1 + ) + print( + f"gen to build {bt.atom_name(focused_atom)} from {bt.atom_name(preds[focused_atom])}", + f"with gen {gen_to_build_atom[focused_atom]}", + ) + while focused_atom >= 0: + path.append(focused_atom) + processed_node_into_scan_path[focused_atom] = True + focused_atom = first_descendant[focused_atom] + if focused_atom >= 0: + gen_to_build_atom[focused_atom] = gen_to_build_atom[ + preds[focused_atom] + ] + if is_on_exit_path[k_atom_ind]: + gen_scan_paths[gen_to_build_atom[k_atom_ind]].insert(0, path) + else: + gen_scan_paths[gen_to_build_atom[k_atom_ind]].append(path) + # Now we need to assemble the scan paths in a compact way: + print("gen scan paths", gen_scan_paths) + + ij_n_gens = gen_depth[i_conn_atom] + print("ij_n_gens", i, j, ij_n_gens) + ij_n_scans = [len(gen_scan_paths[k]) for k in range(ij_n_gens)] + print("ij_n_scans", i, j, ij_n_scans) + ij_scan_starts = [[0] * ij_n_scans[k] for k in range(ij_n_gens)] + print("ij_scan_starts", i, j, ij_scan_starts) + ij_scan_lengths = [ + [len(gen_scan_paths[k][l]) for l in range(len(gen_scan_paths[k]))] + for k in range(ij_n_gens) + ] + print("ij_scan_lengths", i, j, ij_scan_lengths) + # ij_n_nodes_for_gen = + ij_n_nodes_for_gen = [ + sum(len(path) for path in gen_scan_paths[k]) + for k in range(ij_n_gens) + ] + print("ij_n_nodes_for_gen", ij_n_nodes_for_gen) + + +def test_decide_scan_paths_for_foldforest(ubq_pdb): + torch_device = torch.device("cpu") + + co = default_canonical_ordering() + pbt = default_packed_block_types(torch_device) + canonical_form = canonical_form_from_pdb( + co, ubq_pdb, torch_device, residue_start=0, residue_end=10 + ) + pose_stack = pose_stack_from_canonical_form(co, pbt, **canonical_form) + + fold diff --git a/tmol/tests/pose/test_pose_stack_kinematics.py b/tmol/tests/pose/test_pose_stack_kinematics.py index 0c8870397..aa81ec5b8 100644 --- a/tmol/tests/pose/test_pose_stack_kinematics.py +++ b/tmol/tests/pose/test_pose_stack_kinematics.py @@ -8,6 +8,12 @@ get_polymeric_bonds_in_fold_forest, construct_pose_stack_kinforest, ) +from tmol.io.canonical_ordering import ( + default_canonical_ordering, + default_packed_block_types, + canonical_form_from_pdb, +) +from tmol.io.pose_stack_construction import pose_stack_from_canonical_form from tmol.kinematics.check_fold_forest import mark_polymeric_bonds_in_foldforest_edges from tmol.kinematics.fold_forest import FoldForest, EdgeType @@ -320,3 +326,51 @@ def test_construct_pose_stack_kinforest(ubq_res, default_database): # TO DO: make sure kinforest is properly constructed assert kinforest is not None + + +def test_decide_scan_paths_for_foldforest(ubq_pdb): + torch_device = torch.device("cpu") + + co = default_canonical_ordering() + pbt = default_packed_block_types(torch_device) + canonical_form = canonical_form_from_pdb( + co, ubq_pdb, torch_device, residue_start=0, residue_end=10 + ) + pose_stack = pose_stack_from_canonical_form(co, pbt, **canonical_form) + + # let's make a FF with a jump: + # rooted at residue 2 + # 0 5 + # ^ ^ + # | | + # 2 - - > 7 + # | | + # v v + # 4 9 + + edges = numpy.full((1, 5, 4), -1, dtype=int) + edges[0, 0, 0] = EdgeType.jump + edges[0, 0, 1] = 2 + edges[0, 0, 2] = 7 + edges[0, 0, 3] = 0 + edges[0, 1, 0] = EdgeType.polymer + edges[0, 1, 1] = 2 + edges[0, 1, 2] = 0 + edges[0, 2, 0] = EdgeType.polymer + edges[0, 2, 1] = 2 + edges[0, 2, 2] = 4 + edges[0, 3, 0] = EdgeType.polymer + edges[0, 3, 1] = 7 + edges[0, 3, 2] = 5 + edges[0, 4, 0] = EdgeType.polymer + edges[0, 4, 1] = 7 + edges[0, 4, 2] = 9 + + ff = FoldForest( + max_n_edges=5, + n_edges=numpy.full((1,), 5, dtype=int), + edges=edges, + roots=numpy.full((1,), 2, dtype=int), + ) + + kf = construct_pose_stack_kinforest(pose_stack, ff) From 515958ed471aca2d5f33a83c50959607cecd07a9 Mon Sep 17 00:00:00 2001 From: Andrew Leaver-Fay Date: Mon, 3 Jun 2024 19:15:33 -0400 Subject: [PATCH 02/52] Save progress --- tmol/pose/pose_kinematics.py | 2 + ...st_create_scan_orering_from_block_types.py | 103 ++++++++++++++++-- tmol/tests/pose/test_pose_stack_kinematics.py | 74 ++++++++++--- 3 files changed, 156 insertions(+), 23 deletions(-) diff --git a/tmol/pose/pose_kinematics.py b/tmol/pose/pose_kinematics.py index b4db49fec..564d2f744 100644 --- a/tmol/pose/pose_kinematics.py +++ b/tmol/pose/pose_kinematics.py @@ -491,6 +491,8 @@ def construct_pose_stack_kinforest( tor_bonds = get_bonds_for_named_torsions(pose_stack) prioritized_bonds = torch.cat((tor_bonds, jump_atom_pairs), dim=0).cpu().numpy() root_atoms = get_root_atom_indices(pose_stack, fold_forest.roots).cpu().numpy() + print("root atoms", root_atoms) + print(pose_stack.block_coord_offset) return ( KinematicBuilder().append_connected_components( diff --git a/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py b/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py index 988578243..55e24c247 100644 --- a/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py +++ b/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py @@ -55,14 +55,34 @@ def jump_bt_atom(bt, spanning_tree): @attrs.define class GenSegScanPaths: n_gens: NDArray[numpy.int64][:, :] # n-input x n-output + n_nodes_for_gen: NDArray[numpy.int64][:, :, :] nodes_for_generation: NDArray[numpy.int64][ :, :, :, : - ] # n-input x n-output x max-n-gen x max-n-ats-per-gen + ] # n-input x n-output x max-n-gen x max-n-nodes-per-gen n_scans: NDArray[numpy.int64][:, :, :] scan_starts: NDArray[numpy.int64][:, :, :, :] + scan_is_real: NDArray[bool][:, :, :, :] scan_is_inter_block: NDArray[bool][:, :, :, :] scan_lengths: NDArray[numpy.int64][:, :, :, :] + @classmethod + def empty( + cls, n_input_types, n_output_types, max_n_gens, max_n_scans, max_n_nodes_per_gen + ): + io = (n_input_types, n_output_types) + return cls( + n_gens=numpy.zeros(io, dtype=int), + n_nodes_for_gen=numpy.zeros(io + (max_n_gens,), dtype=int), + nodes_for_generation=numpy.zeros( + io + (max_n_gens, max_n_nodes_per_gen), dtype=int + ), + n_scans=numpy.zeros(io + (max_n_gens,), dtype=int), + scan_starts=numpy.full(io + (max_n_gens, max_n_scans), -1, dtype=int), + scan_is_real=numpy.zeros(io + (max_n_gens, max_n_scans), dtype=bool), + scan_is_inter_block=numpy.zeros(io + (max_n_gens, max_n_scans), dtype=bool), + scan_lengths=numpy.zeros(io + (max_n_gens, max_n_scans), dtype=int), + ) + def test_kin_tree_construction(ubq_pdb): torch_device = torch.device("cpu") @@ -136,6 +156,7 @@ def _bonds_to_csgraph( for i in range(n_conn): is_conn_atom[bt.ordered_connection_atoms[i]] = True + scan_path_data = {} for i in range(n_input_types): i_conn_atom = bt.ordered_connection_atoms[i] if i < n_conn else mid_bt_atom @@ -148,7 +169,6 @@ def _bonds_to_csgraph( print(bt.name, i, bfto_2_orig, preds) print([bt.atom_name(bfto_2_orig[bfs_ind]) for bfs_ind in range(bt.n_atoms)]) for j in range(n_output_types): - if i == j and i < n_conn: # we cannot enter from one inter-residue connection point and then # leave by that same inter-residue connection point unless we are @@ -268,10 +288,12 @@ def gen_depth_given_first_descendant(): # as the first descendant and so I pause and think: if we have # a block type with 4 inter-residue connections where the fold # forest branches at this residue, then the algorithm for constructing - # the most number-of-generations-efficient KinForest here is going + # the fewest-number-of-generations KinForest here is going # will fail: we are treating all exit paths out of this residue - # as interchangable and we might say connection c vs c' should - # be first in a case where c' leads to more generations than c. + # as interchangable and we might say connection c should be + # ahead of connection c' in a case where c' has a greater gen_depth + # than c. + # # The case I am designing for here is: there's a jump that has # landed at a beta-amino acid's CA atom and there are exit paths # through the N- and C-terminal ends of the residue and if the @@ -288,7 +310,16 @@ def gen_depth_given_first_descendant(): # If we are only dealing with polymeric residues that have an # up- and a down connection that that's it (e.g. nucleic acids), # then this algorithm will still produce optimal KinForests. - + # + # A case that this would fail to deliver the optimally-efficient + # (fewest number of generations) KinForest would be if this R group + # also contained an inter-residue connection and there were an + # edge in the FoldForest (a "chemical edge") leaving from that + # connection to some further chain, e.g., it could be a sugar + # group attached to a beta-ASN. Now if the path (CA->CB->N) takes + # precedence over the path (CA->CB->R), then everything down- + # stream of the R would have a generation-delay one greater than + # it would otherwise. for kid in k_kids: if is_on_exit_path[kid]: first_descendant[k_atom_ind] = kid @@ -362,12 +393,70 @@ def gen_depth_given_first_descendant(): for k in range(ij_n_gens) ] print("ij_scan_lengths", i, j, ij_scan_lengths) + for k in range(ij_n_gens): + offset = 0 + for l in range(ij_n_scans[k]): + ij_scan_starts[k][l] = offset + offset += ij_scan_lengths[k][l] + ij_scan_is_inter_block = [ + [False] * ij_n_scans[k] for k in range(ij_n_gens) + ] + for k in range(ij_n_gens): + for l in range(ij_n_scans[k]): + l_first_at = gen_scan_paths[k][l][0 if k == 0 else 1] + ij_scan_is_inter_block[k][l] = is_on_exit_path[l_first_at] + + print("ij_scan_is_inter_block", ij_scan_is_inter_block) # ij_n_nodes_for_gen = ij_n_nodes_for_gen = [ sum(len(path) for path in gen_scan_paths[k]) for k in range(ij_n_gens) ] print("ij_n_nodes_for_gen", ij_n_nodes_for_gen) + scan_path_data[(i, j)] = dict( + n_gens=ij_n_gens, + n_nodes_for_gen=ij_n_nodes_for_gen, + nodes_for_generation=gen_scan_paths, + n_scans=ij_n_scans, + scan_starts=ij_scan_starts, + scan_is_inter_block=is_on_exit_path, + scan_lengths=ij_scan_lengths, + ) + # end for j + # end for i + max_n_gens = max( + scan_path_data[(i, j)]["n_gens"] + for i in range(n_input_types) + for j in range(n_output_types) + if (i, j) in scan_path_data + ) + max_n_scans = max( + max( + scan_path_data[(i, j)]["n_scans"][k] + for k in range(scan_path_data[(i, j)]["n_gens"]) + ) + for i in range(n_input_types) + for j in range(n_output_types) + if (i, j) in scan_path_data + ) + max_n_nodes_per_gen = max( + max( + scan_path_data[(i, j)]["n_nodes_for_gen"][k] + for k in range(scan_path_data[(i, j)]["n_gens"]) + ) + for i in range(n_input_types) + for j in range(n_output_types) + if (i, j) in scan_path_data + ) + bt_gen_seg_scan_paths = GenSegScanPaths.empty( + n_input_types, n_output_types, max_n_gens, max_n_scans, max_n_nodes_per_gen + ) + for i in range(n_input_types): + for j in range(n_output_types): + if (i, j) not in scan_path_data: + continue + ij_n_gens = scan_path_data[(i, j)]["n_gens"] + bt_gen_seg_scan_paths.n_gens[i, j] = ij_n_gens def test_decide_scan_paths_for_foldforest(ubq_pdb): @@ -379,5 +468,3 @@ def test_decide_scan_paths_for_foldforest(ubq_pdb): co, ubq_pdb, torch_device, residue_start=0, residue_end=10 ) pose_stack = pose_stack_from_canonical_form(co, pbt, **canonical_form) - - fold diff --git a/tmol/tests/pose/test_pose_stack_kinematics.py b/tmol/tests/pose/test_pose_stack_kinematics.py index aa81ec5b8..276f79f9b 100644 --- a/tmol/tests/pose/test_pose_stack_kinematics.py +++ b/tmol/tests/pose/test_pose_stack_kinematics.py @@ -1,5 +1,6 @@ import torch import numpy +import attrs from tmol.pose.pose_stack_builder import PoseStackBuilder from tmol.pose.pose_kinematics import ( @@ -13,9 +14,11 @@ default_packed_block_types, canonical_form_from_pdb, ) +from tmol.io.write_pose_stack_pdb import write_pose_stack_pdb from tmol.io.pose_stack_construction import pose_stack_from_canonical_form from tmol.kinematics.check_fold_forest import mark_polymeric_bonds_in_foldforest_edges from tmol.kinematics.fold_forest import FoldForest, EdgeType +from tmol.kinematics.operations import inverseKin, forwardKin def test_get_bonds_for_named_torsions(ubq_res, default_database, torch_device): @@ -334,43 +337,84 @@ def test_decide_scan_paths_for_foldforest(ubq_pdb): co = default_canonical_ordering() pbt = default_packed_block_types(torch_device) canonical_form = canonical_form_from_pdb( - co, ubq_pdb, torch_device, residue_start=0, residue_end=10 + co, ubq_pdb, torch_device, residue_start=0, residue_end=20 ) pose_stack = pose_stack_from_canonical_form(co, pbt, **canonical_form) + write_pose_stack_pdb(pose_stack, "ubq20_orig.pdb") # let's make a FF with a jump: # rooted at residue 2 - # 0 5 + # 0 10 # ^ ^ # | | - # 2 - - > 7 + # 5 - - > 15 # | | # v v - # 4 9 + # 9 19 edges = numpy.full((1, 5, 4), -1, dtype=int) edges[0, 0, 0] = EdgeType.jump - edges[0, 0, 1] = 2 - edges[0, 0, 2] = 7 + edges[0, 0, 1] = 5 + edges[0, 0, 2] = 15 edges[0, 0, 3] = 0 edges[0, 1, 0] = EdgeType.polymer - edges[0, 1, 1] = 2 + edges[0, 1, 1] = 5 edges[0, 1, 2] = 0 edges[0, 2, 0] = EdgeType.polymer - edges[0, 2, 1] = 2 - edges[0, 2, 2] = 4 + edges[0, 2, 1] = 5 + edges[0, 2, 2] = 9 edges[0, 3, 0] = EdgeType.polymer - edges[0, 3, 1] = 7 - edges[0, 3, 2] = 5 + edges[0, 3, 1] = 15 + edges[0, 3, 2] = 10 edges[0, 4, 0] = EdgeType.polymer - edges[0, 4, 1] = 7 - edges[0, 4, 2] = 9 + edges[0, 4, 1] = 15 + edges[0, 4, 2] = 19 ff = FoldForest( max_n_edges=5, n_edges=numpy.full((1,), 5, dtype=int), edges=edges, - roots=numpy.full((1,), 2, dtype=int), + roots=numpy.full((1,), 5, dtype=int), + ) + + kinforest = construct_pose_stack_kinforest(pose_stack, ff) + print(kinforest) + # nodes, scanStarts, genStarts = get_scans(kf. + + ps_coords_shape = pose_stack.coords.shape + kincoords_shape = ( + (ps_coords_shape[0] * ps_coords_shape[1]) + 1, + ps_coords_shape[2], + ) + print("kincoords_shape", kincoords_shape) + kincoords = torch.zeros( + kincoords_shape, dtype=torch.float64, device=pose_stack.device ) - kf = construct_pose_stack_kinforest(pose_stack, ff) + kincoords[1:] = pose_stack.coords.view(-1, 3).to(torch.float64)[ + kinforest.id[1:].to(torch.int64) + ] + + dofs = inverseKin(kinforest, kincoords) + pcoords = forwardKin(kinforest, dofs) + + rd_dofs = dofs.clone() + + print("dofs", dofs.shape) + print(dofs.jump[5:15]) + rd_dofs.jump.RBx[10] += 5.1 + rd_dofs.jump.RBy[10] += 5.2 + rd_dofs.jump.RBz[10] += 5.3 + print("rd_dofs", rd_dofs.shape) + print(rd_dofs.jump[5:15]) + + pert_coords = forwardKin(kinforest, rd_dofs) + pert_coords_shape = (ps_coords_shape[0] * ps_coords_shape[1], 3) + pert_coords_for_ps = torch.zeros( + pert_coords_shape, dtype=torch.float32, device=pose_stack.device + ) + pert_coords_for_ps[kinforest.id[1:].to(torch.int64)] = pert_coords[1:].to( + torch.float32 + ) + ps2 = attrs.evolve(pose_stack, coords=pert_coords_for_ps.view(ps_coords_shape)) + write_pose_stack_pdb(ps2, "ubq20_w_pert.pdb") From 1f777a9b49881afc7dc9cf685a39267defe4ab8b Mon Sep 17 00:00:00 2001 From: Andrew Leaver-Fay Date: Mon, 12 Aug 2024 15:06:47 -0400 Subject: [PATCH 03/52] Add a working ground-truth kin-forest that will be the target we try to build programmatically --- ...st_create_scan_orering_from_block_types.py | 1153 +++++++++++------ 1 file changed, 769 insertions(+), 384 deletions(-) diff --git a/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py b/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py index 55e24c247..75af9a331 100644 --- a/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py +++ b/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py @@ -15,34 +15,35 @@ canonical_form_from_pdb, ) from tmol.io.pose_stack_construction import pose_stack_from_canonical_form +from tmol.kinematics.datatypes import NodeType from tmol.kinematics.fold_forest import EdgeType from tmol.kinematics.scan_ordering import get_children +from tmol.kinematics.compiled import inverse_kin, forward_kin_op +# @jit +# def get_branch_depth(parents): +# # modeled off get_children +# nelts = parents.shape[0] -@jit -def get_branch_depth(parents): - # modeled off get_children - nelts = parents.shape[0] +# n_immediate_children = numpy.full(nelts, 0, dtype=numpy.int32) +# for i in range(nelts): +# p = parents[i] +# assert p <= i +# if p == i: +# continue +# n_immediate_children[p] += 1 - n_immediate_children = numpy.full(nelts, 0, dtype=numpy.int32) - for i in range(nelts): - p = parents[i] - assert p <= i - if p == i: - continue - n_immediate_children[p] += 1 +# child_list = numpy.full(nelts, -1, dtype=numpy.int32) +# child_list_span = numpy.empty((nelts, 2), dtype=numpy.int32) - child_list = numpy.full(nelts, -1, dtype=numpy.int32) - child_list_span = numpy.empty((nelts, 2), dtype=numpy.int32) +# child_list_span[0, 0] = 0 +# child_list_span[0, 1] = n_immediate_children[0] +# for i in range(1, nelts): +# child_list_span[i, 0] = child_list_span[i - 1, 1] +# child_list_span[i, 1] = child_list_span[i, 0] + n_immediate_children[i] - child_list_span[0, 0] = 0 - child_list_span[0, 1] = n_immediate_children[0] - for i in range(1, nelts): - child_list_span[i, 0] = child_list_span[i - 1, 1] - child_list_span[i, 1] = child_list_span[i, 0] + n_immediate_children[i] - - # Pass 3, fill the child list for each parent. - # As we do this, +# # Pass 3, fill the child list for each parent. +# # As we do this, def jump_bt_atom(bt, spanning_tree): @@ -53,10 +54,12 @@ def jump_bt_atom(bt, spanning_tree): @attrs.define -class GenSegScanPaths: +class GenerationalSegScanPaths: + parents: NDArray[numpy.int64][:, :] # n-input x n-atoms + input_conn_atom: NDArray[numpy.int64][:] # n-input n_gens: NDArray[numpy.int64][:, :] # n-input x n-output n_nodes_for_gen: NDArray[numpy.int64][:, :, :] - nodes_for_generation: NDArray[numpy.int64][ + nodes_for_gen: NDArray[numpy.int64][ :, :, :, : ] # n-input x n-output x max-n-gen x max-n-nodes-per-gen n_scans: NDArray[numpy.int64][:, :, :] @@ -67,14 +70,24 @@ class GenSegScanPaths: @classmethod def empty( - cls, n_input_types, n_output_types, max_n_gens, max_n_scans, max_n_nodes_per_gen + cls, + n_input_types, + n_output_types, + n_atoms, + max_n_gens, + max_n_scans, + max_n_nodes_per_gen, ): io = (n_input_types, n_output_types) return cls( + parents=numpy.full( + (n_input_types, n_atoms), -1, dtype=int + ), # independent of primary output + input_conn_atom=numpy.full(n_input_types, -1, dtype=int), n_gens=numpy.zeros(io, dtype=int), n_nodes_for_gen=numpy.zeros(io + (max_n_gens,), dtype=int), - nodes_for_generation=numpy.zeros( - io + (max_n_gens, max_n_nodes_per_gen), dtype=int + nodes_for_gen=numpy.full( + io + (max_n_gens, max_n_nodes_per_gen), -1, dtype=int ), n_scans=numpy.zeros(io + (max_n_gens,), dtype=int), scan_starts=numpy.full(io + (max_n_gens, max_n_scans), -1, dtype=int), @@ -84,379 +97,751 @@ def empty( ) -def test_kin_tree_construction(ubq_pdb): +def _annotate_block_type_with_gen_scan_paths(bt): + n_conn = len(bt.connections) + + n_input_types = n_conn + 2 # n_conn + jump input + root "input" + n_output_types = n_conn + 1 # n_conn + jump output + + n_gens = numpy.zeros((n_input_types, n_output_types), dtype=numpy.int64) + nodes_for_generation = [ + [[] for _ in range(n_output_types)] for _2 in range(n_input_types) + ] + n_scans = [[[] for _ in range(n_output_types)] for _2 in range(n_input_types)] + scan_starts = [[[] for _ in range(n_output_types)] for _2 in range(n_input_types)] + scan_is_inter_block = [ + [[] for _ in range(n_output_types)] for _2 in range(n_input_types) + ] + scan_lengths = [[[] for _ in range(n_output_types)] for _2 in range(n_input_types)] + + def _bonds_to_csgraph( + bonds: NDArray[int][:, 2], edge_weight: float + ) -> sparse.csr_matrix: + weights_array = numpy.full((1,), edge_weight, dtype=numpy.float32) + weights = numpy.broadcast_to(weights_array, bonds[:, 0].shape) + + bonds_csr = sparse.csr_matrix( + (weights, (bonds[:, 0], bonds[:, 1])), + shape=(bt.n_atoms, bt.n_atoms), + ) + return bonds_csr + + # create a bond graph and then we will create the prioritized edges + # and all edges + potential_bonds = _bonds_to_csgraph(bt.bond_indices, -1) + # print("potential bonds", potential_bonds) + tor_atoms = [ + (uaids[1][0], uaids[2][0]) + for tor, uaids in bt.torsion_to_uaids.items() + if uaids[1][0] >= 0 and uaids[2][0] >= 0 + ] + if len(tor_atoms) == 0: + tor_atoms = numpy.zeros((0, 2), dtype=numpy.int64) + else: + tor_atoms = numpy.array(tor_atoms) + # print("tor atoms:", tor_atoms) + + prioritized_bonds = _bonds_to_csgraph(tor_atoms, -0.125) + # print("prioritized bonds", prioritized_bonds) + bond_graph = potential_bonds + prioritized_bonds + bond_graph_spanning_tree = csgraph.minimum_spanning_tree(bond_graph.tocsr()) + + mid_bt_atom = jump_bt_atom(bt, bond_graph_spanning_tree) + + is_conn_atom = numpy.zeros((bt.n_atoms,), dtype=bool) + for i in range(n_conn): + is_conn_atom[bt.ordered_connection_atoms[i]] = True + + scan_path_data = {} + parents = numpy.full((n_input_types, bt.n_atoms), -1, dtype=numpy.int64) + input_conn_atom = numpy.zeros((n_input_types,), dtype=numpy.int64) + for i in range(n_input_types): + + i_conn_atom = bt.ordered_connection_atoms[i] if i < n_conn else mid_bt_atom + input_conn_atom[i] = i_conn_atom + bfto_2_orig, preds = csgraph.breadth_first_order( + bond_graph_spanning_tree, + i_conn_atom, + directed=False, + return_predecessors=True, + ) + parents[i, :] = preds + # Now, the parent of the i_conn_atom comes from the previous residue, so we will + # need to fix this atom when we are hooking the blocks together. For now, leave + # it as -9999 (which is what csgraph labels it as) so that we can tell if we have + # not corrected this parent index later on. + # print(bt.name, i, bfto_2_orig, preds) + # print([bt.atom_name(bfto_2_orig[bfs_ind]) for bfs_ind in range(bt.n_atoms)]) + for j in range(n_output_types): + if i == j and i < n_conn: + # we cannot enter from one inter-residue connection point and then + # leave by that same inter-residue connection point unless we are + # building a jump + continue + + # now we start at the j_conn_atom and work backwards toward the root + # which marks the first scan path for this block type: the "primary exit path" + gen_scan_paths = defaultdict(list) + + j_conn_atom = bt.ordered_connection_atoms[j] if j < n_conn else mid_bt_atom + + first_descendant = numpy.full((bt.n_atoms,), -9999, dtype=numpy.int64) + is_on_primary_exit_path = numpy.zeros((bt.n_atoms,), dtype=bool) + is_on_primary_exit_path[i_conn_atom] = True + + focused_atom = j_conn_atom + primary_exit_scan_path = [] + while focused_atom != i_conn_atom: + # print("exit path:", bt.atom_name(focused_atom)) + is_on_primary_exit_path[focused_atom] = True + primary_exit_scan_path.append(focused_atom) + pred = preds[focused_atom] + first_descendant[pred] = focused_atom + focused_atom = pred + primary_exit_scan_path.append(i_conn_atom) + primary_exit_scan_path.reverse() + # we need to prioritize exit paths of all stripes + # in constructing the trees + is_on_exit_path = is_on_primary_exit_path.copy() + for k in range(n_conn): + if k == i or k == j: + continue # truly unnecessary; nothing changes if I remove these two lines + is_on_exit_path[bt.ordered_connection_atoms[k]] = True + + # print("primary_exit_scan_path:", primary_exit_scan_path) + gen_scan_paths[0].append(primary_exit_scan_path) + + # Create a list of children for each atom. + n_kids = numpy.zeros((bt.n_atoms,), dtype=numpy.int64) + atom_kids = [[] for _ in range(bt.n_atoms)] + for k in range(bt.n_atoms): + if preds[k] < 0: + assert ( + k == i_conn_atom + ), f"bad predecesor for atom {k} in {bt.name}, {preds[k]}" + continue # the root + n_kids[preds[k]] += 1 + atom_kids[preds[k]].append(k) + + # now we label each node with its "generation depth" using a + # leaf-to-root traversal perscribed by the original DFS, taking + # into account the fact that priority must be given to + # exit paths + gen_depth = numpy.ones((bt.n_atoms,), dtype=numpy.int64) + on_path_from_conn_to_i_conn_atom = numpy.zeros((bt.n_atoms,), dtype=bool) + for k in range(bt.n_atoms - 1, -1, -1): + k_atom_ind = bfto_2_orig[k] + # print("recursing upwards", i, "i_conn atom", i_conn_atom, j, "j_conn_atom", j_conn_atom, k, k_atom_ind) + k_kids = atom_kids[k_atom_ind] + # print("kids:", k_kids) + if len(k_kids) == 0: + continue + # from here forward, we know that k_atom_ind has > 0 children + + def gen_depth_given_first_descendant(): + # first set the first_descendant for k_atom_ind + # then the logic is: we have to add one to the + # gen-depth of every child but the first descendant + # which we get "for free" + # print(f"atom {bt.atom_name(k_atom_ind)} with first descendant {bt.atom_name(first_descendant[k_atom_ind]) if first_descendant[k_atom_ind] >= 0 else 'None'} and depth {gen_depth[first_descendant[k_atom_ind]] if first_descendant[k_atom_ind] >= 0 else -9999}") + return max( + [ + ( + gen_depth[k_kid] + 1 + if k_kid != first_descendant[k_atom_ind] + else gen_depth[k_kid] + ) + for k_kid in k_kids + ] + ) + + if is_on_primary_exit_path[k_atom_ind]: + # in this case, the first_descendant for this atom + # has already been decided + # print("on exit path:", bt.atom_name(k_atom_ind), first_descendant[k_atom_ind], is_conn_atom[k_atom_ind]) + if k_atom_ind == j_conn_atom: + # the first descendent is the atom on the next residue to which + # this residue is connected + gen_depth[k_atom_ind] = max([gen_depth[l] for l in k_kids]) + 1 + else: + # first_descendant is already determined for this atom + gen_depth[k_atom_ind] = gen_depth_given_first_descendant() + else: + + if is_conn_atom[k_atom_ind]: + # in this case, "the" connection (there can possibly be more than one!) + # will be the first child and the other descendants will be second children + # we save the gen depth, but when calculating the gen depth of the + # fold-forest, if this residue is at the upstream end of an edge, then + # its depth will have to be calculated as the min gen-depth of the + # intra-residue bits and the gen-depth of the nodes downstream of it. + gen_depth[k_atom_ind] = max([gen_depth[l] for l in k_kids]) + 1 + else: + # most-common case: an atom not on the primary-exit path, and that isn't + # itself a conn atom. + # First we ask: are we on one or more exit paths? + # NOTE: this just chooses the first exit path atom it encounters + # as the first descendant and so I pause and think: if we have + # a block type with 4 inter-residue connections where the fold + # forest branches at this residue, then the algorithm for constructing + # the fewest-number-of-generations KinForest here is going + # will fail: we are treating all exit paths out of this residue + # as interchangable and we might say connection c should be + # ahead of connection c' in a case where c' has a greater gen_depth + # than c. + # + # The case I am designing for here is: there's a jump that has + # landed at a beta-amino acid's CA atom and there are exit paths + # through the N- and C-terminal ends of the residue and if the + # primary exit path is the C-term, then the N-term exit path should + # still have priority over the side-chain path. + # + # R + # | + # ... CB C + # \ / \ / \ + # N CA ... + # + # The path starting at CB should go towards N and not towards R. + # If we are only dealing with polymeric residues that have an + # up- and a down connection that that's it (e.g. nucleic acids), + # then this algorithm will still produce optimal KinForests. + # + # A case that this would fail to deliver the optimally-efficient + # (fewest number of generations) KinForest would be if this R group + # also contained an inter-residue connection and there were an + # edge in the FoldForest (a "chemical edge") leaving from that + # connection to some further chain, e.g., it could be a sugar + # group attached to a beta-ASN. Now if the path (CA->CB->N) takes + # precedence over the path (CA->CB->R), then everything down- + # stream of the R would have a generation-delay one greater than + # it would otherwise. + for kid in k_kids: + if is_on_exit_path[kid]: + first_descendant[k_atom_ind] = kid + is_on_exit_path[k_atom_ind] = True + + if not is_on_exit_path[k_atom_ind]: + # which should be the first descendant? the one with the greatest gen depth + first_descendant[k_atom_ind] = k_kids[ + numpy.argmax( + numpy.array([gen_depth[kid] for kid in k_kids]) + ) + ] + gen_depth[k_atom_ind] = gen_depth_given_first_descendant() + # print("gen_depth", bt.atom_name(k_atom_ind), "d:", gen_depth[k_atom_ind]) + # print("gen_depth", gen_depth) + + # OKAY! + # now we have paths rooted at each node up to the root + # we need to turn these paths into scan paths + processed_node_into_scan_path = is_on_primary_exit_path.copy() + gen_to_build_atom = numpy.full((bt.n_atoms,), -1, dtype=numpy.int64) + gen_to_build_atom[processed_node_into_scan_path] = 0 + # print("gen depth", gen_depth) + # print("starting bfs:", processed_node_into_scan_path) + for k in range(bt.n_atoms): + k_atom_ind = bfto_2_orig[k] + if processed_node_into_scan_path[k_atom_ind]: + continue + + # if we arrive here, that means k_atom_ind is the root of a + # new scan path + path = [] + # we have already processed the first scan path + # from the entrace-point atom to the first exit-point atom + assert k_atom_ind != i_conn_atom + # put the parent of this new root at the beginning of + # the scan path + path.append(preds[k_atom_ind]) + focused_atom = k_atom_ind + + gen_to_build_atom[focused_atom] = ( + gen_to_build_atom[preds[focused_atom]] + 1 + ) + # print( + # f"gen to build {bt.atom_name(focused_atom)} from {bt.atom_name(preds[focused_atom])}", + # f"with gen {gen_to_build_atom[focused_atom]}", + # ) + while focused_atom >= 0: + path.append(focused_atom) + processed_node_into_scan_path[focused_atom] = True + focused_atom = first_descendant[focused_atom] + if focused_atom >= 0: + gen_to_build_atom[focused_atom] = gen_to_build_atom[ + preds[focused_atom] + ] + if is_on_exit_path[k_atom_ind]: + gen_scan_paths[gen_to_build_atom[k_atom_ind]].insert(0, path) + else: + gen_scan_paths[gen_to_build_atom[k_atom_ind]].append(path) + # Now we need to assemble the scan paths in a compact way: + # print("gen scan paths", gen_scan_paths) + + ij_n_gens = gen_depth[i_conn_atom] + # print("ij_n_gens", i, j, ij_n_gens) + ij_n_scans = numpy.array( + [len(gen_scan_paths[k]) for k in range(ij_n_gens)], dtype=int + ) + # print("ij_n_scans", i, j, ij_n_scans) + ij_scan_starts = [ + numpy.zeros((ij_n_scans[k],), dtype=int) for k in range(ij_n_gens) + ] + ij_scan_lengths = [ + numpy.array( + [len(gen_scan_paths[k][l]) for l in range(len(gen_scan_paths[k]))], + dtype=int, + ) + for k in range(ij_n_gens) + ] + # print("ij_scan_lengths", i, j, ij_scan_lengths) + for k in range(ij_n_gens): + offset = 0 + for l in range(ij_n_scans[k]): + ij_scan_starts[k][l] = offset + offset += ij_scan_lengths[k][l] + # print("ij_scan_starts", i, j, ij_scan_starts) + # print("ij_scan_lengths cumsum?", numpy.cumsum(ij_scan_lengths)) + ij_scan_is_inter_block = [ + numpy.zeros((ij_n_scans[k],), dtype=bool) for k in range(ij_n_gens) + ] + + for k in range(ij_n_gens): + for l in range(ij_n_scans[k]): + l_first_at = gen_scan_paths[k][l][0 if k == 0 else 1] + ij_scan_is_inter_block[k][l] = is_on_exit_path[l_first_at] + + # print("ij_scan_is_inter_block", ij_scan_is_inter_block) + # ij_n_nodes_for_gen = + ij_n_nodes_for_gen = numpy.array( + [ + sum(len(path) for path in gen_scan_paths[k]) + for k in range(ij_n_gens) + ], + dtype=int, + ) + # print("ij_n_nodes_for_gen", ij_n_nodes_for_gen) + scan_path_data[(i, j)] = dict( + n_gens=ij_n_gens, + n_nodes_for_gen=ij_n_nodes_for_gen, + nodes_for_generation=gen_scan_paths, + n_scans=ij_n_scans, + scan_starts=ij_scan_starts, + scan_is_inter_block=is_on_exit_path, + scan_lengths=ij_scan_lengths, + ) + # end for j + # end for i + + # Now let's count out the maximum number of generations, scans, and nodes-per-gen + # so we can create the GenerationalSegScanPaths object + max_n_gens = max( + scan_path_data[(i, j)]["n_gens"] + for i in range(n_input_types) + for j in range(n_output_types) + if (i, j) in scan_path_data + ) + max_n_scans = max( + max( + scan_path_data[(i, j)]["n_scans"][k] + for k in range(scan_path_data[(i, j)]["n_gens"]) + ) + for i in range(n_input_types) + for j in range(n_output_types) + if (i, j) in scan_path_data + ) + max_n_nodes_per_gen = max( + max( + scan_path_data[(i, j)]["n_nodes_for_gen"][k] + for k in range(scan_path_data[(i, j)]["n_gens"]) + ) + for i in range(n_input_types) + for j in range(n_output_types) + if (i, j) in scan_path_data + ) + bt_gen_seg_scan_paths = GenerationalSegScanPaths.empty( + n_input_types, + n_output_types, + bt.n_atoms, + max_n_gens, + max_n_scans, + max_n_nodes_per_gen, + ) + bt_gen_seg_scan_paths.parents = parents + bt_gen_seg_scan_paths.input_conn_atom = input_conn_atom + # Finally, we populate the GenerationalSegScanPaths object + for i in range(n_input_types): + for j in range(n_output_types): + if (i, j) not in scan_path_data: + continue + ij_n_gens = scan_path_data[(i, j)]["n_gens"] + bt_gen_seg_scan_paths.n_gens[i, j] = ij_n_gens + for k in range(ij_n_gens): + bt_gen_seg_scan_paths.n_nodes_for_gen[i, j, k] = scan_path_data[(i, j)][ + "n_nodes_for_gen" + ][k] + bt_gen_seg_scan_paths.n_scans[i, j, k] = scan_path_data[(i, j)][ + "n_scans" + ][k] + bt_gen_seg_scan_paths.scan_is_real[ + i, j, k, : bt_gen_seg_scan_paths.n_scans[i, j, k] + ] = True + + ijk_n_scans = scan_path_data[(i, j)]["n_scans"][k] + bt_gen_seg_scan_paths.scan_starts[i, j, k, :ijk_n_scans] = ( + scan_path_data[(i, j)]["scan_starts"][k] + ) + bt_gen_seg_scan_paths.scan_is_inter_block[i, j, k, :ijk_n_scans] = ( + scan_path_data[(i, j)]["scan_is_inter_block"][k] + ) + bt_gen_seg_scan_paths.scan_lengths[i, j, k, :ijk_n_scans] = ( + scan_path_data[(i, j)]["scan_lengths"][k] + ) + # for l in range(scan_path_data[(i, j)]["n_scans"][k]): + # bt_gen_seg_scan_paths.scan_starts[i, j, k, l] = scan_path_data[(i, j)]["scan_starts"][k][l] + # bt_gen_seg_scan_paths.scan_is_inter_block[i, j, k, l] = scan_path_data[(i, j)]["scan_is_inter_block"][k][l] + # bt_gen_seg_scan_paths.scan_lengths[i, j, k, l] = scan_path_data[(i, j)]["scan_lengths"][k][l] + for l in range(ijk_n_scans): + m_offset = scan_path_data[(i, j)]["scan_starts"][k][l] + for m in range( + len(scan_path_data[(i, j)]["nodes_for_generation"][k][l]) + ): + bt_gen_seg_scan_paths.nodes_for_gen[i, j, k, m_offset + m] = ( + scan_path_data[(i, j)]["nodes_for_generation"][k][l][m] + ) + # print("nodes for gen", i, j, k, bt_gen_seg_scan_paths.nodes_for_gen[i, j, k, :]) + + setattr(bt, "gen_seg_scan_paths", bt_gen_seg_scan_paths) + + +def test_gen_seg_scan_paths_block_type_annotation_smoke(fresh_default_restype_set): torch_device = torch.device("cpu") - co = default_canonical_ordering() - pbt = default_packed_block_types(torch_device) - canonical_form = canonical_form_from_pdb(co, ubq_pdb, torch_device) - pose_stack = pose_stack_from_canonical_form(co, pbt, **canonical_form) + # co = default_canonical_ordering() + # pbt = default_packed_block_types(torch_device) + # canonical_form = canonical_form_from_pdb(co, ubq_pdb, torch_device) + # pose_stack = pose_stack_from_canonical_form(co, pbt, **canonical_form) # okay! # 1. let's create some annotations of the packed block types - bt_list = [bt for bt in pbt.active_block_types if bt.name == "LEU"] + bt_list = [bt for bt in fresh_default_restype_set.residue_types if bt.name == "LEU"] # for bt in pbt.active_block_types: for bt in bt_list: - n_conn = len(bt.connections) - - n_input_types = n_conn + 2 # n_conn + jump input + root "input" - n_output_types = n_conn + 1 # n_conn + jump output - - n_gens = numpy.zeros((n_input_types, n_output_types), dtype=numpy.int64) - nodes_for_generation = [ - [[] for _ in range(n_output_types)] for _2 in range(n_input_types) - ] - n_scans = [[[] for _ in range(n_output_types)] for _2 in range(n_input_types)] - scan_starts = [ - [[] for _ in range(n_output_types)] for _2 in range(n_input_types) - ] - scan_is_inter_block = [ - [[] for _ in range(n_output_types)] for _2 in range(n_input_types) - ] - scan_lengths = [ - [[] for _ in range(n_output_types)] for _2 in range(n_input_types) - ] - - def _bonds_to_csgraph( - bonds: NDArray[int][:, 2], edge_weight: float - ) -> sparse.csr_matrix: - weights_array = numpy.full((1,), edge_weight, dtype=numpy.float32) - weights = numpy.broadcast_to(weights_array, bonds[:, 0].shape) - - bonds_csr = sparse.csr_matrix( - (weights, (bonds[:, 0], bonds[:, 1])), - shape=(bt.n_atoms, bt.n_atoms), - ) - return bonds_csr - - # create a bond graph and then we will create the prioritized edges - # and all edges - potential_bonds = _bonds_to_csgraph(bt.bond_indices, -1) - print("potential bonds", potential_bonds) - tor_atoms = [ - (uaids[1][0], uaids[2][0]) - for tor, uaids in bt.torsion_to_uaids.items() - if uaids[1][0] >= 0 and uaids[2][0] >= 0 - ] - if len(tor_atoms) == 0: - tor_atoms = numpy.zeros((0, 2), dtype=numpy.int64) - else: - tor_atoms = numpy.array(tor_atoms) - print("tor atoms:", tor_atoms) - - prioritized_bonds = _bonds_to_csgraph(tor_atoms, -0.125) - print("prioritized bonds", prioritized_bonds) - bond_graph = potential_bonds + prioritized_bonds - bond_graph_spanning_tree = csgraph.minimum_spanning_tree(bond_graph.tocsr()) - - mid_bt_atom = jump_bt_atom(bt, bond_graph_spanning_tree) - - is_conn_atom = numpy.zeros((bt.n_atoms,), dtype=bool) - for i in range(n_conn): - is_conn_atom[bt.ordered_connection_atoms[i]] = True - - scan_path_data = {} - for i in range(n_input_types): - - i_conn_atom = bt.ordered_connection_atoms[i] if i < n_conn else mid_bt_atom - bfto_2_orig, preds = csgraph.breadth_first_order( - bond_graph_spanning_tree, - i_conn_atom, - directed=False, - return_predecessors=True, - ) - print(bt.name, i, bfto_2_orig, preds) - print([bt.atom_name(bfto_2_orig[bfs_ind]) for bfs_ind in range(bt.n_atoms)]) - for j in range(n_output_types): - if i == j and i < n_conn: - # we cannot enter from one inter-residue connection point and then - # leave by that same inter-residue connection point unless we are - # building a jump - continue + _annotate_block_type_with_gen_scan_paths(bt) - # now we start at the j_conn_atom and work backwards toward the root - # which marks the first scan path for this block type: the "primary exit path" - gen_scan_paths = defaultdict(list) - j_conn_atom = ( - bt.ordered_connection_atoms[j] if j < n_conn else mid_bt_atom - ) +def test_construct_scan_paths_n_to_c_twores(ubq_pdb): + torch_device = torch.device("cpu") - first_descendant = numpy.full((bt.n_atoms,), -9999, dtype=numpy.int64) - is_on_primary_exit_path = numpy.zeros((bt.n_atoms,), dtype=bool) - is_on_primary_exit_path[i_conn_atom] = True - - focused_atom = j_conn_atom - primary_exit_scan_path = [] - while focused_atom != i_conn_atom: - print("exit path:", bt.atom_name(focused_atom)) - is_on_primary_exit_path[focused_atom] = True - primary_exit_scan_path.append(focused_atom) - pred = preds[focused_atom] - first_descendant[pred] = focused_atom - focused_atom = pred - primary_exit_scan_path.append(i_conn_atom) - primary_exit_scan_path.reverse() - # we need to prioritize exit paths of all stripes - # in constructing the trees - is_on_exit_path = is_on_primary_exit_path.copy() - for k in range(n_conn): - if k == i or k == j: - continue # truly unnecessary; nothing changes if I remove these two lines - is_on_exit_path[bt.ordered_connection_atoms[k]] = True - - print("primary_exit_scan_path:", primary_exit_scan_path) - gen_scan_paths[0].append(primary_exit_scan_path) - - # Create a list of children for each atom. - n_kids = numpy.zeros((bt.n_atoms,), dtype=numpy.int64) - atom_kids = [[] for _ in range(bt.n_atoms)] - for k in range(bt.n_atoms): - if preds[k] < 0: - assert ( - k == i_conn_atom - ), f"bad predecesor for atom {k} in {bt.name}, {preds[k]}" - continue # the root - n_kids[preds[k]] += 1 - atom_kids[preds[k]].append(k) - - # now we label each node with its "generation depth" using a - # leaf-to-root traversal perscribed by the original DFS, taking - # into account the fact that priority must be given to - # exit paths - gen_depth = numpy.ones((bt.n_atoms,), dtype=numpy.int64) - on_path_from_conn_to_i_conn_atom = numpy.zeros( - (bt.n_atoms,), dtype=bool - ) - for k in range(bt.n_atoms - 1, -1, -1): - k_atom_ind = bfto_2_orig[k] - # print("recursing upwards", i, "i_conn atom", i_conn_atom, j, "j_conn_atom", j_conn_atom, k, k_atom_ind) - k_kids = atom_kids[k_atom_ind] - # print("kids:", k_kids) - if len(k_kids) == 0: - continue - # from here forward, we know that k_atom_ind has > 0 children - - def gen_depth_given_first_descendant(): - # first set the first_descendant for k_atom_ind - # then the logic is: we have to add one to the - # gen-depth of every child but the first descendant - # which we get "for free" - # print(f"atom {bt.atom_name(k_atom_ind)} with first descendant {bt.atom_name(first_descendant[k_atom_ind]) if first_descendant[k_atom_ind] >= 0 else 'None'} and depth {gen_depth[first_descendant[k_atom_ind]] if first_descendant[k_atom_ind] >= 0 else -9999}") - return max( - [ - ( - gen_depth[k_kid] + 1 - if k_kid != first_descendant[k_atom_ind] - else gen_depth[k_kid] - ) - for k_kid in k_kids - ] - ) + co = default_canonical_ordering() + pbt = default_packed_block_types(torch_device) + canonical_form = canonical_form_from_pdb( + co, ubq_pdb, torch_device, residue_start=1, residue_end=3 + ) + res_not_connected = torch.zeros((1, 2, 2), dtype=torch.bool, device=torch_device) + res_not_connected[0, 0, 0] = True # simplest test case: not N-term + res_not_connected[0, 1, 1] = True # simplest test case: not C-term + pose_stack = pose_stack_from_canonical_form( + co, pbt, **canonical_form, res_not_connected=res_not_connected + ) - if is_on_primary_exit_path[k_atom_ind]: - # in this case, the first_descendant for this atom - # has already been decided - # print("on exit path:", bt.atom_name(k_atom_ind), first_descendant[k_atom_ind], is_conn_atom[k_atom_ind]) - if k_atom_ind == j_conn_atom: - # the first descendent is the atom on the next residue to which - # this residue is connected - gen_depth[k_atom_ind] = ( - max([gen_depth[l] for l in k_kids]) + 1 - ) - else: - # first_descendant is already determined for this atom - gen_depth[k_atom_ind] = gen_depth_given_first_descendant() - else: + for bt in pbt.active_block_types: + _annotate_block_type_with_gen_scan_paths(bt) + + # now lets assume we have everything we need for the final step + # of kintree construction: + + # output will be: + # (the data members of kintree) + # id: Tensor[torch.int32][...] + # # roots: Tensor[torch.int32][...] # not used in current kinforest + # doftype: Tensor[torch.int32][...] + # parent: Tensor[torch.int32][...] + # frame_x: Tensor[torch.int32][...] + # frame_y: Tensor[torch.int32][...] + # frame_z: Tensor[torch.int32][...] + # (and the data members appended in get_scans) + # nodes + # scans + # gens + + # now we figure out: what data do we need to construct these things? + + bt0 = pbt.active_block_types[pose_stack.block_type_ind[0, 0]] + bt1 = pbt.active_block_types[pose_stack.block_type_ind[0, 1]] + print("bt0", bt0.name, bt0.n_atoms) + print("bt1", bt1.name, bt1.n_atoms) + bt0gssp = bt0.gen_seg_scan_paths + bt1gssp = bt1.gen_seg_scan_paths + + print("nodes") + print(bt0gssp.nodes_for_gen[3, 1]) + print(bt1gssp.nodes_for_gen[0, 1]) + + print("scans") + print(bt0gssp.scan_starts[3, 1]) + print(bt1gssp.scan_starts[0, 1]) + + # print("gens") + # print(bt0gssp. + + print("parents") + print(bt0gssp.parents[3]) + print(bt1gssp.parents[3]) + + ij0 = [3, 1] # 3 => root "input"; Q: is this different from jump input? + ij1 = [0, 1] + + nodes = numpy.zeros((bt0.n_atoms + bt1.n_atoms,), dtype=numpy.int32) + scans = numpy.zeros( + (max(bt0gssp.scan_starts.shape[2], bt1gssp.scan_starts.shape[2]),), + dtype=numpy.int32, + ) + # gens = numpy.zeros(()) - if is_conn_atom[k_atom_ind]: - # in this case, "the" connection (there can possibly be more than one!) - # will be the first child and the other descendants will be second children - # we save the gen depth, but when calculating the gen depth of the - # fold-forest, if this residue is at the upstream end of an edge, then - # its depth will have to be calculated as the min gen-depth of the - # intra-residue bits and the gen-depth of the nodes downstream of it. - gen_depth[k_atom_ind] = ( - max([gen_depth[l] for l in k_kids]) + 1 - ) - else: - # most-common case: an atom not on the primary-exit path, and that isn't - # itself a conn atom. - # First we ask: are we on one or more exit paths? - # NOTE: this just chooses the first exit path atom it encounters - # as the first descendant and so I pause and think: if we have - # a block type with 4 inter-residue connections where the fold - # forest branches at this residue, then the algorithm for constructing - # the fewest-number-of-generations KinForest here is going - # will fail: we are treating all exit paths out of this residue - # as interchangable and we might say connection c should be - # ahead of connection c' in a case where c' has a greater gen_depth - # than c. - # - # The case I am designing for here is: there's a jump that has - # landed at a beta-amino acid's CA atom and there are exit paths - # through the N- and C-terminal ends of the residue and if the - # primary exit path is the C-term, then the N-term exit path should - # still have priority over the side-chain path. - # - # R - # | - # ... CB C - # \ / \ / \ - # N CA ... - # - # The path starting at CB should go towards N and not towards R. - # If we are only dealing with polymeric residues that have an - # up- and a down connection that that's it (e.g. nucleic acids), - # then this algorithm will still produce optimal KinForests. - # - # A case that this would fail to deliver the optimally-efficient - # (fewest number of generations) KinForest would be if this R group - # also contained an inter-residue connection and there were an - # edge in the FoldForest (a "chemical edge") leaving from that - # connection to some further chain, e.g., it could be a sugar - # group attached to a beta-ASN. Now if the path (CA->CB->N) takes - # precedence over the path (CA->CB->R), then everything down- - # stream of the R would have a generation-delay one greater than - # it would otherwise. - for kid in k_kids: - if is_on_exit_path[kid]: - first_descendant[k_atom_ind] = kid - is_on_exit_path[k_atom_ind] = True - - if not is_on_exit_path[k_atom_ind]: - # which should be the first descendant? the one with the greatest gen depth - first_descendant[k_atom_ind] = k_kids[ - numpy.argmax( - numpy.array([gen_depth[kid] for kid in k_kids]) - ) - ] - gen_depth[k_atom_ind] = gen_depth_given_first_descendant() - # print("gen_depth", bt.atom_name(k_atom_ind), "d:", gen_depth[k_atom_ind]) - # print("gen_depth", gen_depth) - - # OKAY! - # now we have paths rooted at each node up to the root - # we need to turn these paths into scan paths - processed_node_into_scan_path = is_on_primary_exit_path.copy() - gen_to_build_atom = numpy.full((bt.n_atoms,), -1, dtype=numpy.int64) - gen_to_build_atom[processed_node_into_scan_path] = 0 - print("gen depth", gen_depth) - print("starting bfs:", processed_node_into_scan_path) - for k in range(bt.n_atoms): - k_atom_ind = bfto_2_orig[k] - if processed_node_into_scan_path[k_atom_ind]: - continue - - # if we arrive here, that means k_atom_ind is the root of a - # new scan path - path = [] - # we have already processed the first scan path - # from the entrace-point atom to the first exit-point atom - assert k_atom_ind != i_conn_atom - # put the parent of this new root at the beginning of - # the scan path - path.append(preds[k_atom_ind]) - focused_atom = k_atom_ind - - gen_to_build_atom[focused_atom] = ( - gen_to_build_atom[preds[focused_atom]] + 1 - ) - print( - f"gen to build {bt.atom_name(focused_atom)} from {bt.atom_name(preds[focused_atom])}", - f"with gen {gen_to_build_atom[focused_atom]}", - ) - while focused_atom >= 0: - path.append(focused_atom) - processed_node_into_scan_path[focused_atom] = True - focused_atom = first_descendant[focused_atom] - if focused_atom >= 0: - gen_to_build_atom[focused_atom] = gen_to_build_atom[ - preds[focused_atom] - ] - if is_on_exit_path[k_atom_ind]: - gen_scan_paths[gen_to_build_atom[k_atom_ind]].insert(0, path) - else: - gen_scan_paths[gen_to_build_atom[k_atom_ind]].append(path) - # Now we need to assemble the scan paths in a compact way: - print("gen scan paths", gen_scan_paths) - - ij_n_gens = gen_depth[i_conn_atom] - print("ij_n_gens", i, j, ij_n_gens) - ij_n_scans = [len(gen_scan_paths[k]) for k in range(ij_n_gens)] - print("ij_n_scans", i, j, ij_n_scans) - ij_scan_starts = [[0] * ij_n_scans[k] for k in range(ij_n_gens)] - print("ij_scan_starts", i, j, ij_scan_starts) - ij_scan_lengths = [ - [len(gen_scan_paths[k][l]) for l in range(len(gen_scan_paths[k]))] - for k in range(ij_n_gens) - ] - print("ij_scan_lengths", i, j, ij_scan_lengths) - for k in range(ij_n_gens): - offset = 0 - for l in range(ij_n_scans[k]): - ij_scan_starts[k][l] = offset - offset += ij_scan_lengths[k][l] - ij_scan_is_inter_block = [ - [False] * ij_n_scans[k] for k in range(ij_n_gens) - ] - for k in range(ij_n_gens): - for l in range(ij_n_scans[k]): - l_first_at = gen_scan_paths[k][l][0 if k == 0 else 1] - ij_scan_is_inter_block[k][l] = is_on_exit_path[l_first_at] - - print("ij_scan_is_inter_block", ij_scan_is_inter_block) - # ij_n_nodes_for_gen = - ij_n_nodes_for_gen = [ - sum(len(path) for path in gen_scan_paths[k]) - for k in range(ij_n_gens) - ] - print("ij_n_nodes_for_gen", ij_n_nodes_for_gen) - scan_path_data[(i, j)] = dict( - n_gens=ij_n_gens, - n_nodes_for_gen=ij_n_nodes_for_gen, - nodes_for_generation=gen_scan_paths, - n_scans=ij_n_scans, - scan_starts=ij_scan_starts, - scan_is_inter_block=is_on_exit_path, - scan_lengths=ij_scan_lengths, - ) - # end for j - # end for i - max_n_gens = max( - scan_path_data[(i, j)]["n_gens"] - for i in range(n_input_types) - for j in range(n_output_types) - if (i, j) in scan_path_data - ) - max_n_scans = max( - max( - scan_path_data[(i, j)]["n_scans"][k] - for k in range(scan_path_data[(i, j)]["n_gens"]) - ) - for i in range(n_input_types) - for j in range(n_output_types) - if (i, j) in scan_path_data + ids_gold = numpy.concatenate( + ( + numpy.full((1,), -1, dtype=numpy.int32), + numpy.arange(bt0.n_atoms + bt1.n_atoms, dtype=numpy.int32), ) - max_n_nodes_per_gen = max( - max( - scan_path_data[(i, j)]["n_nodes_for_gen"][k] - for k in range(scan_path_data[(i, j)]["n_gens"]) - ) - for i in range(n_input_types) - for j in range(n_output_types) - if (i, j) in scan_path_data - ) - bt_gen_seg_scan_paths = GenSegScanPaths.empty( - n_input_types, n_output_types, max_n_gens, max_n_scans, max_n_nodes_per_gen + ) + print("ids_gold", ids_gold.shape) + print("ids_gold", ids_gold) + + parents_gold = numpy.array( + [ + 0, + 2, + 0, + 2, + 3, + 2, + 5, + 6, + 7, + 7, + 1, + 2, + 5, + 5, + 6, + 6, + 9, + 9, + 19, + 3, + 19, + 20, + 19, + 22, + 22, + 23, + 18, + 19, + 22, + 23, + 23, + 24, + 24, + 24, + 25, + 25, + 25, + ], + dtype=numpy.int32, + ) + print("parents_gold", parents_gold.shape) + dof_type_gold = numpy.full(1 + bt0.n_atoms + bt1.n_atoms, 2, dtype=numpy.int32) + dof_type_gold[0] = NodeType.root.value + dof_type_gold[2] = NodeType.jump.value + frame_x_gold = numpy.arange(1 + bt0.n_atoms + bt1.n_atoms, dtype=numpy.int32) + frame_y_gold = parents_gold # we will correct the jump atom below + frame_z_gold = parents_gold[parents_gold] # grandparents + frame_x_gold[0] = 2 + frame_y_gold[0] = 0 + frame_z_gold[0] = 3 + frame_x_gold[2] = 2 + frame_y_gold[2] = 0 + frame_z_gold[2] = 3 + + nodes_gold = numpy.array( + [ + 0, + 2, + 3, + 18, + 19, + 20, # gen 1 + 2, + 1, + 2, + 5, + 6, + 7, + 9, + 16, + 2, + 11, + 3, + 4, + 18, + 26, + 19, + 22, + 23, + 25, + 34, + 19, + 27, + 20, + 21, # gen 2 + # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 + 5, + 12, + 5, + 13, + 1, + 10, + 6, + 14, + 6, + 15, + 7, + 8, + 9, + 17, + 22, + 24, + 31, + 22, + 28, + 23, + 29, + 23, + 30, + 25, + 35, + 25, + 36, # gen 3 + 24, + 32, + 24, + 33, # gen 4 + ], + dtype=numpy.int32, + ) + + scans_gold = numpy.array( + [ + 0, # gen 1 + 0, + 2, + 8, + 10, + 12, + 14, + 19, + 21, # gen 2 + 0, + 2, + 4, + 6, + 8, + 10, + 12, + 14, + 17, + 19, + 21, + 23, + 25, # gen 3; + 0, + 2, # gen 4 + ], + dtype=numpy.int32, + ) + + generations_gold = numpy.array( + [ + [0, 0], + [6, 1 + 0], + [23 + 6, 8 + 1 + 0], + [27 + 23 + 6, 13 + 8 + 1 + 0], + [4 + 27 + 23 + 6, 2 + 13 + 8 + 1 + 0], + ], + dtype=numpy.int32, + ) + + print("nodes_gold", nodes_gold.shape) + print("scans_gold", scans_gold.shape) + print("generations_gold", generations_gold.shape) + print("generations_gold", generations_gold) + + def _t(x): + return torch.tensor(x, dtype=torch.int32) + + ids_gold_t = _t(ids_gold) + parents_gold_t = _t(parents_gold) + frame_x_gold_t = _t(frame_x_gold) + frame_y_gold_t = _t(frame_y_gold) + frame_z_gold_t = _t(frame_z_gold) + dof_type_gold_t = _t(dof_type_gold) + nodes_gold_t = _t(nodes_gold) + scans_gold_t = _t(scans_gold) + generations_gold_t = _t(generations_gold) + + kincoords = torch.zeros((1 + bt0.n_atoms + bt1.n_atoms, 3), dtype=torch.float32) + kincoords[1:] = pose_stack.coords.view(-1, 3)[ids_gold[1:]] + + # okay, now what? + raw_dofs = inverse_kin( + kincoords, + _t(parents_gold), + _t(frame_x_gold), + _t(frame_y_gold), + _t(frame_z_gold), + _t(dof_type_gold), + ) + print("raw dofs", raw_dofs.shape) + print("raw dofs", raw_dofs[:10]) + + def _p(t): + return torch.nn.Parameter(t, requires_grad=False) + + def _tint(ts): + return tuple(map(lambda t: t.to(torch.int32), ts)) + + kinforest = _p( + torch.stack( + _tint( + [ + ids_gold_t, + dof_type_gold_t, + parents_gold_t, + frame_x_gold_t, + frame_y_gold_t, + frame_z_gold_t, + ] + ), + dim=1, ) - for i in range(n_input_types): - for j in range(n_output_types): - if (i, j) not in scan_path_data: - continue - ij_n_gens = scan_path_data[(i, j)]["n_gens"] - bt_gen_seg_scan_paths.n_gens[i, j] = ij_n_gens + ) + + new_coords = forward_kin_op( + raw_dofs, + nodes_gold_t, + scans_gold_t, + generations_gold_t, + nodes_gold_t, # note: backward version; incorrect to assume same as forward, temp! + scans_gold_t, + generations_gold_t, + kinforest, + ) + + print("starting coords", pose_stack.coords.view(-1, 3)[:10]) + print("kincoords", kincoords[:10]) + print("new coords", new_coords[:10]) def test_decide_scan_paths_for_foldforest(ubq_pdb): From 4d6c7ae6b6e1befadaeb37baffdf8335be44f43c Mon Sep 17 00:00:00 2001 From: Andrew Leaver-Fay Date: Mon, 12 Aug 2024 15:53:01 -0400 Subject: [PATCH 04/52] Fixed parent definition for residue 2 to go along w/ N-conn input --- ...st_create_scan_orering_from_block_types.py | 153 +++--------------- 1 file changed, 26 insertions(+), 127 deletions(-) diff --git a/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py b/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py index 75af9a331..a2d78ddce 100644 --- a/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py +++ b/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py @@ -588,7 +588,7 @@ def test_construct_scan_paths_n_to_c_twores(ubq_pdb): print("parents") print(bt0gssp.parents[3]) - print(bt1gssp.parents[3]) + print(bt1gssp.parents[0]) ij0 = [3, 1] # 3 => root "input"; Q: is this different from jump input? ij1 = [0, 1] @@ -609,48 +609,16 @@ def test_construct_scan_paths_n_to_c_twores(ubq_pdb): print("ids_gold", ids_gold.shape) print("ids_gold", ids_gold) + # fmt: off parents_gold = numpy.array( [ - 0, - 2, - 0, - 2, - 3, - 2, - 5, - 6, - 7, - 7, - 1, - 2, - 5, - 5, - 6, - 6, - 9, - 9, - 19, - 3, - 19, - 20, - 19, - 22, - 22, - 23, - 18, - 19, - 22, - 23, - 23, - 24, - 24, - 24, - 25, - 25, - 25, + 0, # virtual root "atom" + 2, 0, 2, 3, 2, 5, 6, 7, 7, 1, 2, 5, 5, 6, 6, 9, 9, # res 1 + 3, 18, 19, 20, 19, 22, 22, 23, 18, 19, 22, 23, 23, 24, 24, 24, 25, 25, 25, # res 2 ], dtype=numpy.int32, ) + # fmt: on print("parents_gold", parents_gold.shape) dof_type_gold = numpy.full(1 + bt0.n_atoms + bt1.n_atoms, 2, dtype=numpy.int32) dof_type_gold[0] = NodeType.root.value @@ -665,69 +633,13 @@ def test_construct_scan_paths_n_to_c_twores(ubq_pdb): frame_y_gold[2] = 0 frame_z_gold[2] = 3 + # fmt: off nodes_gold = numpy.array( [ - 0, - 2, - 3, - 18, - 19, - 20, # gen 1 - 2, - 1, - 2, - 5, - 6, - 7, - 9, - 16, - 2, - 11, - 3, - 4, - 18, - 26, - 19, - 22, - 23, - 25, - 34, - 19, - 27, - 20, - 21, # gen 2 - # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 - 5, - 12, - 5, - 13, - 1, - 10, - 6, - 14, - 6, - 15, - 7, - 8, - 9, - 17, - 22, - 24, - 31, - 22, - 28, - 23, - 29, - 23, - 30, - 25, - 35, - 25, - 36, # gen 3 - 24, - 32, - 24, - 33, # gen 4 + 0, 2, 3, 18, 19, 20, # gen 1 + 2, 1, 2, 5, 6, 7, 9, 16, 2, 11, 3, 4, 18, 26, 19, 22, 23, 25, 34, 19, 27, 20, 21, # gen 2 + 5, 12, 5, 13, 1, 10, 6, 14, 6, 15, 7, 8, 9, 17, 22, 24, 31, 22, 28, 23, 29, 23, 30, 25, 35, 25, 36, # gen 3 + 24, 32, 24, 33, # gen 4 ], dtype=numpy.int32, ) @@ -735,29 +647,9 @@ def test_construct_scan_paths_n_to_c_twores(ubq_pdb): scans_gold = numpy.array( [ 0, # gen 1 - 0, - 2, - 8, - 10, - 12, - 14, - 19, - 21, # gen 2 - 0, - 2, - 4, - 6, - 8, - 10, - 12, - 14, - 17, - 19, - 21, - 23, - 25, # gen 3; - 0, - 2, # gen 4 + 0, 2, 8, 10, 12, 14, 19, 21, # gen 2 + 0, 2, 4, 6, 8, 10, 12, 14, 17, 19, 21, 23, 25, # gen 3; + 0, 2, # gen 4 ], dtype=numpy.int32, ) @@ -772,6 +664,7 @@ def test_construct_scan_paths_n_to_c_twores(ubq_pdb): ], dtype=numpy.int32, ) + # fmt: on print("nodes_gold", nodes_gold.shape) print("scans_gold", scans_gold.shape) @@ -795,6 +688,9 @@ def _t(x): kincoords[1:] = pose_stack.coords.view(-1, 3)[ids_gold[1:]] # okay, now what? + # Let's test that the gold version of the kinforest will actually + # generate the input coordinates given the dofs extracted from + # the input coordinates raw_dofs = inverse_kin( kincoords, _t(parents_gold), @@ -803,8 +699,8 @@ def _t(x): _t(frame_z_gold), _t(dof_type_gold), ) - print("raw dofs", raw_dofs.shape) - print("raw dofs", raw_dofs[:10]) + # print("raw dofs", raw_dofs.shape) + # print("raw dofs", raw_dofs[:10]) def _p(t): return torch.nn.Parameter(t, requires_grad=False) @@ -839,9 +735,12 @@ def _tint(ts): kinforest, ) - print("starting coords", pose_stack.coords.view(-1, 3)[:10]) - print("kincoords", kincoords[:10]) - print("new coords", new_coords[:10]) + print("starting coords", pose_stack.coords.view(-1, 3)[14:19]) + + print("kincoords", kincoords[15:20]) + print("new coords", new_coords[15:20]) + + torch.testing.assert_close(kincoords, new_coords, rtol=1e-5, atol=1e-5) def test_decide_scan_paths_for_foldforest(ubq_pdb): From 4355f01f8990344a88614fa71b020e1cf6a1f670 Mon Sep 17 00:00:00 2001 From: Andrew Leaver-Fay Date: Thu, 15 Aug 2024 10:13:56 -0400 Subject: [PATCH 05/52] Move some code out of the unit test file into tmol/kinematics --- tmol/kinematics/datatypes.py | 113 +++ tmol/kinematics/scan_ordering.py | 519 ++++++++++++- ...st_create_scan_orering_from_block_types.py | 727 ++++++------------ 3 files changed, 874 insertions(+), 485 deletions(-) diff --git a/tmol/kinematics/datatypes.py b/tmol/kinematics/datatypes.py index 5fe9e8e52..cac56a64c 100644 --- a/tmol/kinematics/datatypes.py +++ b/tmol/kinematics/datatypes.py @@ -4,6 +4,7 @@ from tmol.types.torch import Tensor from tmol.types.tensor import TensorGroup +from tmol.types.array import NDArray from tmol.types.attrs import ConvertAttrs from tmol.types.functional import convert_args @@ -233,3 +234,115 @@ def RBbeta(self): @property def RBgamma(self): return self.raw[..., JumpDOFTypes.RBgamma] + + +@attrs.define +class BTGenerationalSegScanPaths: + jump_atom: int + parents: NDArray[numpy.int64][:, :] # n-input x n-atoms + input_conn_atom: NDArray[numpy.int64][:] # n-input + n_gens: NDArray[numpy.int64][:, :] # n-input x n-output + n_nodes_for_gen: NDArray[numpy.int64][:, :, :] + nodes_for_gen: NDArray[numpy.int64][ + :, :, :, : + ] # n-input x n-output x max-n-gen x max-n-nodes-per-gen + n_scans: NDArray[numpy.int64][:, :, :] + scan_starts: NDArray[numpy.int64][:, :, :, :] + scan_is_real: NDArray[bool][:, :, :, :] + scan_is_inter_block: NDArray[bool][:, :, :, :] + scan_lengths: NDArray[numpy.int64][:, :, :, :] + + @classmethod + def empty( + cls, + n_input_types, + n_output_types, + n_atoms, + max_n_gens, + max_n_scans, + max_n_nodes_per_gen, + ): + io = (n_input_types, n_output_types) + return cls( + jump_input_atom=-1, + parents=numpy.full( + (n_input_types, n_atoms), -1, dtype=int + ), # independent of primary output + input_conn_atom=numpy.full(n_input_types, -1, dtype=int), + n_gens=numpy.zeros(io, dtype=int), + n_nodes_for_gen=numpy.zeros(io + (max_n_gens,), dtype=int), + nodes_for_gen=numpy.full( + io + (max_n_gens, max_n_nodes_per_gen), -1, dtype=int + ), + n_scans=numpy.zeros(io + (max_n_gens,), dtype=int), + scan_starts=numpy.full(io + (max_n_gens, max_n_scans), -1, dtype=int), + scan_is_real=numpy.zeros(io + (max_n_gens, max_n_scans), dtype=bool), + scan_is_inter_block=numpy.zeros(io + (max_n_gens, max_n_scans), dtype=bool), + scan_lengths=numpy.zeros(io + (max_n_gens, max_n_scans), dtype=int), + ) + + +@attrs.define +class PBTGenerationalSegScanPaths: + jump_atom: NDArray[numpy.int64][:] # n-bt + parents: Tensor[torch.int32][:, :, :] # n-bt x n-input x n-atoms + input_conn_atom: Tensor[torch.int32][:, :] # n-bt x n-input + n_gens: Tensor[torch.int32][:, :, :] # n-bt x n-input x n-output + n_nodes_for_gen: Tensor[torch.int32][:, :, :, :] + nodes_for_gen: Tensor[torch.int32][ + :, :, :, :, : + ] # n-input x n-output x max-n-gen x max-n-nodes-per-gen + n_scans: Tensor[torch.int32][:, :, :, :] + scan_starts: Tensor[torch.int32][:, :, :, :, :] + scan_is_real: Tensor[bool][:, :, :, :, :] + scan_is_inter_block: Tensor[bool][:, :, :, :, :] + scan_lengths: Tensor[torch.int32][:, :, :, :, :] + + @classmethod + def empty( + cls, + device, + n_bt, + max_n_input_types, + max_n_output_types, + max_n_atoms, + max_n_gens, + max_n_scans, + max_n_nodes_per_gen, + ): + io = (n_bt, max_n_input_types, max_n_output_types) + return cls( + jump_input_atom=torch.full(n_bt, -1, dtype=torch.int32, device=device), + parents=torch.full( + (n_bt, max_n_input_types, max_n_atoms), + -1, + dtype=torch.int32, + device=device, + ), # independent of primary output + input_conn_atom=torch.full( + (n_bt, max_n_input_types), -1, dtype=torch.int32, device=device + ), + n_gens=torch.zeros(io, dtype=torch.int32, device=device), + n_nodes_for_gen=torch.zeros( + io + (max_n_gens,), dtype=torch.int32, device=device + ), + nodes_for_gen=torch.full( + io + (max_n_gens, max_n_nodes_per_gen), + -1, + dtype=torch.int32, + device=device, + ), + n_scans=torch.zeros(io + (max_n_gens,), dtype=torch.int32, device=device), + scan_starts=torch.full( + io + (max_n_gens, max_n_scans), -1, dtype=torch.int32, device=device + ), + scan_is_real=torch.zeros( + io + (max_n_gens, max_n_scans), dtype=torch.bool, device=device + ), + scan_is_inter_block=torch.zeros( + io + (max_n_gens, max_n_scans), dtype=bool, device=device + ), + scan_lengths=torch.zeros( + io + (max_n_gens, max_n_scans), dtype=torch.int32, device=device + ), + ) diff --git a/tmol/kinematics/scan_ordering.py b/tmol/kinematics/scan_ordering.py index 9873cb00c..7e4df3fd9 100644 --- a/tmol/kinematics/scan_ordering.py +++ b/tmol/kinematics/scan_ordering.py @@ -2,7 +2,11 @@ import numpy import torch -from .datatypes import KinForest +from .datatypes import ( + KinForest, + BTGenerationalSegScanPaths, + PBTGenerationalSegScanPaths, +) from numba import jit from tmol.types.torch import Tensor @@ -11,6 +15,26 @@ from tmol.types.functional import validate_args +from collections import defaultdict +from numba import jit + +import scipy.sparse as sparse +import scipy.sparse.csgraph as csgraph +from tmol.types.torch import Tensor + +from tmol.io.canonical_ordering import ( + default_canonical_ordering, + default_packed_block_types, + canonical_form_from_pdb, +) +from tmol.io.pose_stack_construction import pose_stack_from_canonical_form +from tmol.kinematics.datatypes import NodeType +from tmol.kinematics.fold_forest import EdgeType +from tmol.kinematics.scan_ordering import get_children +from tmol.kinematics.compiled import inverse_kin, forward_kin_op + +from tmol.utility.tensor.common_operations import exclusive_cumsum1d + @jit(nopython=True) def get_children(parents): @@ -323,3 +347,496 @@ def calculate_from_kinforest(cls, kinforest: KinForest): forward_scan_paths=forward_scan_paths, backward_scan_paths=backward_scan_paths, ) + + +def jump_atom_for_bt(bt): + """Return the index of the atom that will be jumped to or jumped from""" + # TEMP: CA if CA is present; ow, atom 0 + return bt.atom_to_idx("CA") if "CA" in bt.atom_names else 0 + + +def _annotate_block_type_with_gen_scan_paths(bt): + if hasattr(bt, "gen_seg_scan_paths"): + return + n_conn = len(bt.connections) + + n_input_types = n_conn + 2 # n_conn + jump input + root "input" + n_output_types = n_conn + 1 # n_conn + jump output + + n_gens = numpy.zeros((n_input_types, n_output_types), dtype=numpy.int64) + nodes_for_generation = [ + [[] for _ in range(n_output_types)] for _2 in range(n_input_types) + ] + n_scans = [[[] for _ in range(n_output_types)] for _2 in range(n_input_types)] + scan_starts = [[[] for _ in range(n_output_types)] for _2 in range(n_input_types)] + scan_is_inter_block = [ + [[] for _ in range(n_output_types)] for _2 in range(n_input_types) + ] + scan_lengths = [[[] for _ in range(n_output_types)] for _2 in range(n_input_types)] + + def _bonds_to_csgraph( + bonds: NDArray[int][:, 2], edge_weight: float + ) -> sparse.csr_matrix: + weights_array = numpy.full((1,), edge_weight, dtype=numpy.float32) + weights = numpy.broadcast_to(weights_array, bonds[:, 0].shape) + + bonds_csr = sparse.csr_matrix( + (weights, (bonds[:, 0], bonds[:, 1])), + shape=(bt.n_atoms, bt.n_atoms), + ) + return bonds_csr + + # create a bond graph and then we will create the prioritized edges + # and all edges + potential_bonds = _bonds_to_csgraph(bt.bond_indices, -1) + # print("potential bonds", potential_bonds) + tor_atoms = [ + (uaids[1][0], uaids[2][0]) + for tor, uaids in bt.torsion_to_uaids.items() + if uaids[1][0] >= 0 and uaids[2][0] >= 0 + ] + if len(tor_atoms) == 0: + tor_atoms = numpy.zeros((0, 2), dtype=numpy.int64) + else: + tor_atoms = numpy.array(tor_atoms) + # print("tor atoms:", tor_atoms) + + prioritized_bonds = _bonds_to_csgraph(tor_atoms, -0.125) + # print("prioritized bonds", prioritized_bonds) + bond_graph = potential_bonds + prioritized_bonds + bond_graph_spanning_tree = csgraph.minimum_spanning_tree(bond_graph.tocsr()) + + mid_bt_atom = jump_bt_atom(bt, bond_graph_spanning_tree) + + is_conn_atom = numpy.zeros((bt.n_atoms,), dtype=bool) + for i in range(n_conn): + is_conn_atom[bt.ordered_connection_atoms[i]] = True + + scan_path_data = {} + parents = numpy.full((n_input_types, bt.n_atoms), -1, dtype=numpy.int64) + input_conn_atom = numpy.zeros((n_input_types,), dtype=numpy.int64) + for i in range(n_input_types): + + i_conn_atom = bt.ordered_connection_atoms[i] if i < n_conn else mid_bt_atom + input_conn_atom[i] = i_conn_atom + bfto_2_orig, preds = csgraph.breadth_first_order( + bond_graph_spanning_tree, + i_conn_atom, + directed=False, + return_predecessors=True, + ) + parents[i, :] = preds + # Now, the parent of the i_conn_atom comes from the previous residue, so we will + # need to fix this atom when we are hooking the blocks together. For now, leave + # it as -9999 (which is what csgraph labels it as) so that we can tell if we have + # not corrected this parent index later on. + # print(bt.name, i, bfto_2_orig, preds) + # print([bt.atom_name(bfto_2_orig[bfs_ind]) for bfs_ind in range(bt.n_atoms)]) + for j in range(n_output_types): + if i == j and i < n_conn: + # we cannot enter from one inter-residue connection point and then + # leave by that same inter-residue connection point unless we are + # building a jump + continue + + # now we start at the j_conn_atom and work backwards toward the root + # which marks the first scan path for this block type: the "primary exit path" + gen_scan_paths = defaultdict(list) + + j_conn_atom = bt.ordered_connection_atoms[j] if j < n_conn else mid_bt_atom + + first_descendant = numpy.full((bt.n_atoms,), -9999, dtype=numpy.int64) + is_on_primary_exit_path = numpy.zeros((bt.n_atoms,), dtype=bool) + is_on_primary_exit_path[i_conn_atom] = True + + focused_atom = j_conn_atom + primary_exit_scan_path = [] + while focused_atom != i_conn_atom: + # print("exit path:", bt.atom_name(focused_atom)) + is_on_primary_exit_path[focused_atom] = True + primary_exit_scan_path.append(focused_atom) + pred = preds[focused_atom] + first_descendant[pred] = focused_atom + focused_atom = pred + primary_exit_scan_path.append(i_conn_atom) + primary_exit_scan_path.reverse() + # we need to prioritize exit paths of all stripes + # in constructing the trees + is_on_exit_path = is_on_primary_exit_path.copy() + for k in range(n_conn): + if k == i or k == j: + continue # truly unnecessary; nothing changes if I remove these two lines + is_on_exit_path[bt.ordered_connection_atoms[k]] = True + + # print("primary_exit_scan_path:", primary_exit_scan_path) + gen_scan_paths[0].append(primary_exit_scan_path) + + # Create a list of children for each atom. + n_kids = numpy.zeros((bt.n_atoms,), dtype=numpy.int64) + atom_kids = [[] for _ in range(bt.n_atoms)] + for k in range(bt.n_atoms): + if preds[k] < 0: + assert ( + k == i_conn_atom + ), f"bad predecesor for atom {k} in {bt.name}, {preds[k]}" + continue # the root + n_kids[preds[k]] += 1 + atom_kids[preds[k]].append(k) + + # now we label each node with its "generation depth" using a + # leaf-to-root traversal perscribed by the original DFS, taking + # into account the fact that priority must be given to + # exit paths + gen_depth = numpy.ones((bt.n_atoms,), dtype=numpy.int64) + on_path_from_conn_to_i_conn_atom = numpy.zeros((bt.n_atoms,), dtype=bool) + for k in range(bt.n_atoms - 1, -1, -1): + k_atom_ind = bfto_2_orig[k] + # print("recursing upwards", i, "i_conn atom", i_conn_atom, j, "j_conn_atom", j_conn_atom, k, k_atom_ind) + k_kids = atom_kids[k_atom_ind] + # print("kids:", k_kids) + if len(k_kids) == 0: + continue + # from here forward, we know that k_atom_ind has > 0 children + + def gen_depth_given_first_descendant(): + # first set the first_descendant for k_atom_ind + # then the logic is: we have to add one to the + # gen-depth of every child but the first descendant + # which we get "for free" + # print(f"atom {bt.atom_name(k_atom_ind)} with first descendant {bt.atom_name(first_descendant[k_atom_ind]) if first_descendant[k_atom_ind] >= 0 else 'None'} and depth {gen_depth[first_descendant[k_atom_ind]] if first_descendant[k_atom_ind] >= 0 else -9999}") + return max( + [ + ( + gen_depth[k_kid] + 1 + if k_kid != first_descendant[k_atom_ind] + else gen_depth[k_kid] + ) + for k_kid in k_kids + ] + ) + + if is_on_primary_exit_path[k_atom_ind]: + # in this case, the first_descendant for this atom + # has already been decided + # print("on exit path:", bt.atom_name(k_atom_ind), first_descendant[k_atom_ind], is_conn_atom[k_atom_ind]) + if k_atom_ind == j_conn_atom: + # the first descendent is the atom on the next residue to which + # this residue is connected + gen_depth[k_atom_ind] = max([gen_depth[l] for l in k_kids]) + 1 + else: + # first_descendant is already determined for this atom + gen_depth[k_atom_ind] = gen_depth_given_first_descendant() + else: + + if is_conn_atom[k_atom_ind]: + # in this case, "the" connection (there can possibly be more than one!) + # will be the first child and the other descendants will be second children + # we save the gen depth, but when calculating the gen depth of the + # fold-forest, if this residue is at the upstream end of an edge, then + # its depth will have to be calculated as the min gen-depth of the + # intra-residue bits and the gen-depth of the nodes downstream of it. + gen_depth[k_atom_ind] = max([gen_depth[l] for l in k_kids]) + 1 + else: + # most-common case: an atom not on the primary-exit path, and that isn't + # itself a conn atom. + # First we ask: are we on one or more exit paths? + # NOTE: this just chooses the first exit path atom it encounters + # as the first descendant and so I pause and think: if we have + # a block type with 4 inter-residue connections where the fold + # forest branches at this residue, then the algorithm for constructing + # the fewest-number-of-generations KinForest here is going + # will fail: we are treating all exit paths out of this residue + # as interchangable and we might say connection c should be + # ahead of connection c' in a case where c' has a greater gen_depth + # than c. + # + # The case I am designing for here is: there's a jump that has + # landed at a beta-amino acid's CA atom and there are exit paths + # through the N- and C-terminal ends of the residue and if the + # primary exit path is the C-term, then the N-term exit path should + # still have priority over the side-chain path. + # + # R + # | + # ... CB C + # \ / \ / \ + # N CA ... + # + # The path starting at CB should go towards N and not towards R. + # If we are only dealing with polymeric residues that have an + # up- and a down connection that that's it (e.g. nucleic acids), + # then this algorithm will still produce optimal KinForests. + # + # A case that this would fail to deliver the optimally-efficient + # (fewest number of generations) KinForest would be if this R group + # also contained an inter-residue connection and there were an + # edge in the FoldForest (a "chemical edge") leaving from that + # connection to some further chain, e.g., it could be a sugar + # group attached to a beta-ASN. Now if the path (CA->CB->N) takes + # precedence over the path (CA->CB->R), then everything down- + # stream of the R would have a generation-delay one greater than + # it would otherwise. + for kid in k_kids: + if is_on_exit_path[kid]: + first_descendant[k_atom_ind] = kid + is_on_exit_path[k_atom_ind] = True + + if not is_on_exit_path[k_atom_ind]: + # which should be the first descendant? the one with the greatest gen depth + first_descendant[k_atom_ind] = k_kids[ + numpy.argmax( + numpy.array([gen_depth[kid] for kid in k_kids]) + ) + ] + gen_depth[k_atom_ind] = gen_depth_given_first_descendant() + # print("gen_depth", bt.atom_name(k_atom_ind), "d:", gen_depth[k_atom_ind]) + # print("gen_depth", gen_depth) + + # OKAY! + # now we have paths rooted at each node up to the root + # we need to turn these paths into scan paths + processed_node_into_scan_path = is_on_primary_exit_path.copy() + gen_to_build_atom = numpy.full((bt.n_atoms,), -1, dtype=numpy.int64) + gen_to_build_atom[processed_node_into_scan_path] = 0 + # print("gen depth", gen_depth) + # print("starting bfs:", processed_node_into_scan_path) + for k in range(bt.n_atoms): + k_atom_ind = bfto_2_orig[k] + if processed_node_into_scan_path[k_atom_ind]: + continue + + # if we arrive here, that means k_atom_ind is the root of a + # new scan path + path = [] + # we have already processed the first scan path + # from the entrace-point atom to the first exit-point atom + assert k_atom_ind != i_conn_atom + # put the parent of this new root at the beginning of + # the scan path + path.append(preds[k_atom_ind]) + focused_atom = k_atom_ind + + gen_to_build_atom[focused_atom] = ( + gen_to_build_atom[preds[focused_atom]] + 1 + ) + # print( + # f"gen to build {bt.atom_name(focused_atom)} from {bt.atom_name(preds[focused_atom])}", + # f"with gen {gen_to_build_atom[focused_atom]}", + # ) + while focused_atom >= 0: + path.append(focused_atom) + processed_node_into_scan_path[focused_atom] = True + focused_atom = first_descendant[focused_atom] + if focused_atom >= 0: + gen_to_build_atom[focused_atom] = gen_to_build_atom[ + preds[focused_atom] + ] + if is_on_exit_path[k_atom_ind]: + gen_scan_paths[gen_to_build_atom[k_atom_ind]].insert(0, path) + else: + gen_scan_paths[gen_to_build_atom[k_atom_ind]].append(path) + # Now we need to assemble the scan paths in a compact way: + # print("gen scan paths", gen_scan_paths) + + ij_n_gens = gen_depth[i_conn_atom] + # print("ij_n_gens", i, j, ij_n_gens) + ij_n_scans = numpy.array( + [len(gen_scan_paths[k]) for k in range(ij_n_gens)], dtype=int + ) + # print("ij_n_scans", i, j, ij_n_scans) + ij_scan_starts = [ + numpy.zeros((ij_n_scans[k],), dtype=int) for k in range(ij_n_gens) + ] + ij_scan_lengths = [ + numpy.array( + [len(gen_scan_paths[k][l]) for l in range(len(gen_scan_paths[k]))], + dtype=int, + ) + for k in range(ij_n_gens) + ] + # print("ij_scan_lengths", i, j, ij_scan_lengths) + for k in range(ij_n_gens): + offset = 0 + for l in range(ij_n_scans[k]): + ij_scan_starts[k][l] = offset + offset += ij_scan_lengths[k][l] + # print("ij_scan_starts", i, j, ij_scan_starts) + # print("ij_scan_lengths cumsum?", numpy.cumsum(ij_scan_lengths)) + ij_scan_is_inter_block = [ + numpy.zeros((ij_n_scans[k],), dtype=bool) for k in range(ij_n_gens) + ] + + for k in range(ij_n_gens): + for l in range(ij_n_scans[k]): + l_first_at = gen_scan_paths[k][l][0 if k == 0 else 1] + ij_scan_is_inter_block[k][l] = is_on_exit_path[l_first_at] + + # print("ij_scan_is_inter_block", ij_scan_is_inter_block) + # ij_n_nodes_for_gen = + ij_n_nodes_for_gen = numpy.array( + [ + sum(len(path) for path in gen_scan_paths[k]) + for k in range(ij_n_gens) + ], + dtype=int, + ) + # print("ij_n_nodes_for_gen", ij_n_nodes_for_gen) + scan_path_data[(i, j)] = dict( + n_gens=ij_n_gens, + n_nodes_for_gen=ij_n_nodes_for_gen, + nodes_for_generation=gen_scan_paths, + n_scans=ij_n_scans, + scan_starts=ij_scan_starts, + scan_is_inter_block=is_on_exit_path, + scan_lengths=ij_scan_lengths, + ) + # end for j + # end for i + + # Now let's count out the maximum number of generations, scans, and nodes-per-gen + # so we can create the BTGenerationalSegScanPaths object + max_n_gens = max( + scan_path_data[(i, j)]["n_gens"] + for i in range(n_input_types) + for j in range(n_output_types) + if (i, j) in scan_path_data + ) + max_n_scans = max( + max( + scan_path_data[(i, j)]["n_scans"][k] + for k in range(scan_path_data[(i, j)]["n_gens"]) + ) + for i in range(n_input_types) + for j in range(n_output_types) + if (i, j) in scan_path_data + ) + max_n_nodes_per_gen = max( + max( + scan_path_data[(i, j)]["n_nodes_for_gen"][k] + for k in range(scan_path_data[(i, j)]["n_gens"]) + ) + for i in range(n_input_types) + for j in range(n_output_types) + if (i, j) in scan_path_data + ) + bt_gen_seg_scan_paths = BTGenerationalSegScanPaths.empty( + n_input_types, + n_output_types, + bt.n_atoms, + max_n_gens, + max_n_scans, + max_n_nodes_per_gen, + ) + bt_gen_seg_scan_paths.jump_atom = jump_atom_for_bt(bt) + bt_gen_seg_scan_paths.parents = parents + bt_gen_seg_scan_paths.input_conn_atom = input_conn_atom + # Finally, we populate the BTGenerationalSegScanPaths object + for i in range(n_input_types): + for j in range(n_output_types): + if (i, j) not in scan_path_data: + continue + ij_n_gens = scan_path_data[(i, j)]["n_gens"] + bt_gen_seg_scan_paths.n_gens[i, j] = ij_n_gens + for k in range(ij_n_gens): + bt_gen_seg_scan_paths.n_nodes_for_gen[i, j, k] = scan_path_data[(i, j)][ + "n_nodes_for_gen" + ][k] + bt_gen_seg_scan_paths.n_scans[i, j, k] = scan_path_data[(i, j)][ + "n_scans" + ][k] + bt_gen_seg_scan_paths.scan_is_real[ + i, j, k, : bt_gen_seg_scan_paths.n_scans[i, j, k] + ] = True + + ijk_n_scans = scan_path_data[(i, j)]["n_scans"][k] + bt_gen_seg_scan_paths.scan_starts[i, j, k, :ijk_n_scans] = ( + scan_path_data[(i, j)]["scan_starts"][k] + ) + bt_gen_seg_scan_paths.scan_is_inter_block[i, j, k, :ijk_n_scans] = ( + scan_path_data[(i, j)]["scan_is_inter_block"][k] + ) + bt_gen_seg_scan_paths.scan_lengths[i, j, k, :ijk_n_scans] = ( + scan_path_data[(i, j)]["scan_lengths"][k] + ) + # for l in range(scan_path_data[(i, j)]["n_scans"][k]): + # bt_gen_seg_scan_paths.scan_starts[i, j, k, l] = scan_path_data[(i, j)]["scan_starts"][k][l] + # bt_gen_seg_scan_paths.scan_is_inter_block[i, j, k, l] = scan_path_data[(i, j)]["scan_is_inter_block"][k][l] + # bt_gen_seg_scan_paths.scan_lengths[i, j, k, l] = scan_path_data[(i, j)]["scan_lengths"][k][l] + for l in range(ijk_n_scans): + m_offset = scan_path_data[(i, j)]["scan_starts"][k][l] + for m in range( + len(scan_path_data[(i, j)]["nodes_for_generation"][k][l]) + ): + bt_gen_seg_scan_paths.nodes_for_gen[i, j, k, m_offset + m] = ( + scan_path_data[(i, j)]["nodes_for_generation"][k][l][m] + ) + # print("nodes for gen", i, j, k, bt_gen_seg_scan_paths.nodes_for_gen[i, j, k, :]) + + setattr(bt, "gen_seg_scan_paths", bt_gen_seg_scan_paths) + + +def _annotate_packed_block_type_with_gen_scan_paths(pbt): + for bt in pbt.active_block_types: + _annotate_block_type_with_gen_scan_paths(bt) + max_n_input_types = max( + bt.gen_seg_scan_paths.n_gens.shape[0] for bt in pbt.active_block_types + ) + max_n_output_types = max( + bt.gen_seg_scan_paths.n_gens.shape[1] for bt in pbt.active_block_types + ) + # max_n_atoms : pbt already provides this! + max_n_gens = max( + bt.gen_seg_scan_paths.n_nodes_for_gen.shape[2] for bt in pbt.active_block_types + ) + max_n_scans = max( + bt.gen_seg_scan_paths.scan_starts.shape[3] for bt in pbt.active_block_types + ) + max_n_nodes_per_gen = max( + bt.gen_seg_scan_paths.nodes_for_gen.shape[3] for bt in pbt.active_block_types + ) + + gen_seg_scan_paths = PBTGenerationalSegScanPaths.empty( + pbt.device, + pbt.n_types, + max_n_input_types, + max_n_output_types, + pbt.max_n_atoms, + max_n_gens, + max_n_scans, + max_n_nodes_per_gen, + ) + varnames = [ + "parents", + "input_conn_atom", + "n_gens", + "n_nodes_for_gen", + "nodes_for_gen", + "n_scans", + "scan_starts", + "scan_is_real", + "scan_is_inter_block", + "scan_lengths", + ] + for i, bt in enumerate(pbt.active_block_types): + bt_gssp = bt.gen_seg_scan_paths + for vname in varnames: + dst = getattr(gen_seg_scan_paths, vname) + src = getattr(bt_gssp, vname) + src = torch.tensor( + src, + dtype=(torch.int32 if src.dtype == numpy.int64 else torch.bool), + device=pbt.device, + ) + if len(src.shape) == 1: + dst[i, : src.shape[0]] = src + elif len(src.shape) == 2: + dst[i, : src.shape[0], : src.shape[1]] = src + elif len(src.shape) == 3: + dst[i, : src.shape[0], : src.shape[1], : src.shape[2]] = src + elif len(src.shape) == 4: + dst[ + i, : src.shape[0], : src.shape[1], : src.shape[2], : src.shape[3] + ] = src + else: + raise ValueError("unhandled shape") + setattr(pbt, "gen_seg_scan_paths", gen_seg_scan_paths) diff --git a/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py b/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py index a2d78ddce..230be9bad 100644 --- a/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py +++ b/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py @@ -7,7 +7,7 @@ import scipy.sparse as sparse import scipy.sparse.csgraph as csgraph -from tmol.types.array import NDArray +from tmol.types.torch import Tensor from tmol.io.canonical_ordering import ( default_canonical_ordering, @@ -17,9 +17,15 @@ from tmol.io.pose_stack_construction import pose_stack_from_canonical_form from tmol.kinematics.datatypes import NodeType from tmol.kinematics.fold_forest import EdgeType -from tmol.kinematics.scan_ordering import get_children +from tmol.kinematics.scan_ordering import ( + get_children, + _annotate_block_type_with_gen_scan_paths, + _annotate_packed_block_type_with_gen_scan_paths, +) from tmol.kinematics.compiled import inverse_kin, forward_kin_op +from tmol.utility.tensor.common_operations import exclusive_cumsum1d + # @jit # def get_branch_depth(parents): # # modeled off get_children @@ -46,493 +52,17 @@ # # As we do this, -def jump_bt_atom(bt, spanning_tree): - # CA! TEMP!!! Replace with code that connects up conn atom to down conn atom - # in the spanning tree and chooses the midpoing along that path, but for now, - # CA is atom 1. - return 1 - - -@attrs.define -class GenerationalSegScanPaths: - parents: NDArray[numpy.int64][:, :] # n-input x n-atoms - input_conn_atom: NDArray[numpy.int64][:] # n-input - n_gens: NDArray[numpy.int64][:, :] # n-input x n-output - n_nodes_for_gen: NDArray[numpy.int64][:, :, :] - nodes_for_gen: NDArray[numpy.int64][ - :, :, :, : - ] # n-input x n-output x max-n-gen x max-n-nodes-per-gen - n_scans: NDArray[numpy.int64][:, :, :] - scan_starts: NDArray[numpy.int64][:, :, :, :] - scan_is_real: NDArray[bool][:, :, :, :] - scan_is_inter_block: NDArray[bool][:, :, :, :] - scan_lengths: NDArray[numpy.int64][:, :, :, :] - - @classmethod - def empty( - cls, - n_input_types, - n_output_types, - n_atoms, - max_n_gens, - max_n_scans, - max_n_nodes_per_gen, - ): - io = (n_input_types, n_output_types) - return cls( - parents=numpy.full( - (n_input_types, n_atoms), -1, dtype=int - ), # independent of primary output - input_conn_atom=numpy.full(n_input_types, -1, dtype=int), - n_gens=numpy.zeros(io, dtype=int), - n_nodes_for_gen=numpy.zeros(io + (max_n_gens,), dtype=int), - nodes_for_gen=numpy.full( - io + (max_n_gens, max_n_nodes_per_gen), -1, dtype=int - ), - n_scans=numpy.zeros(io + (max_n_gens,), dtype=int), - scan_starts=numpy.full(io + (max_n_gens, max_n_scans), -1, dtype=int), - scan_is_real=numpy.zeros(io + (max_n_gens, max_n_scans), dtype=bool), - scan_is_inter_block=numpy.zeros(io + (max_n_gens, max_n_scans), dtype=bool), - scan_lengths=numpy.zeros(io + (max_n_gens, max_n_scans), dtype=int), - ) - - -def _annotate_block_type_with_gen_scan_paths(bt): - n_conn = len(bt.connections) - - n_input_types = n_conn + 2 # n_conn + jump input + root "input" - n_output_types = n_conn + 1 # n_conn + jump output - - n_gens = numpy.zeros((n_input_types, n_output_types), dtype=numpy.int64) - nodes_for_generation = [ - [[] for _ in range(n_output_types)] for _2 in range(n_input_types) - ] - n_scans = [[[] for _ in range(n_output_types)] for _2 in range(n_input_types)] - scan_starts = [[[] for _ in range(n_output_types)] for _2 in range(n_input_types)] - scan_is_inter_block = [ - [[] for _ in range(n_output_types)] for _2 in range(n_input_types) - ] - scan_lengths = [[[] for _ in range(n_output_types)] for _2 in range(n_input_types)] - - def _bonds_to_csgraph( - bonds: NDArray[int][:, 2], edge_weight: float - ) -> sparse.csr_matrix: - weights_array = numpy.full((1,), edge_weight, dtype=numpy.float32) - weights = numpy.broadcast_to(weights_array, bonds[:, 0].shape) - - bonds_csr = sparse.csr_matrix( - (weights, (bonds[:, 0], bonds[:, 1])), - shape=(bt.n_atoms, bt.n_atoms), - ) - return bonds_csr - - # create a bond graph and then we will create the prioritized edges - # and all edges - potential_bonds = _bonds_to_csgraph(bt.bond_indices, -1) - # print("potential bonds", potential_bonds) - tor_atoms = [ - (uaids[1][0], uaids[2][0]) - for tor, uaids in bt.torsion_to_uaids.items() - if uaids[1][0] >= 0 and uaids[2][0] >= 0 - ] - if len(tor_atoms) == 0: - tor_atoms = numpy.zeros((0, 2), dtype=numpy.int64) - else: - tor_atoms = numpy.array(tor_atoms) - # print("tor atoms:", tor_atoms) - - prioritized_bonds = _bonds_to_csgraph(tor_atoms, -0.125) - # print("prioritized bonds", prioritized_bonds) - bond_graph = potential_bonds + prioritized_bonds - bond_graph_spanning_tree = csgraph.minimum_spanning_tree(bond_graph.tocsr()) - - mid_bt_atom = jump_bt_atom(bt, bond_graph_spanning_tree) - - is_conn_atom = numpy.zeros((bt.n_atoms,), dtype=bool) - for i in range(n_conn): - is_conn_atom[bt.ordered_connection_atoms[i]] = True - - scan_path_data = {} - parents = numpy.full((n_input_types, bt.n_atoms), -1, dtype=numpy.int64) - input_conn_atom = numpy.zeros((n_input_types,), dtype=numpy.int64) - for i in range(n_input_types): - - i_conn_atom = bt.ordered_connection_atoms[i] if i < n_conn else mid_bt_atom - input_conn_atom[i] = i_conn_atom - bfto_2_orig, preds = csgraph.breadth_first_order( - bond_graph_spanning_tree, - i_conn_atom, - directed=False, - return_predecessors=True, - ) - parents[i, :] = preds - # Now, the parent of the i_conn_atom comes from the previous residue, so we will - # need to fix this atom when we are hooking the blocks together. For now, leave - # it as -9999 (which is what csgraph labels it as) so that we can tell if we have - # not corrected this parent index later on. - # print(bt.name, i, bfto_2_orig, preds) - # print([bt.atom_name(bfto_2_orig[bfs_ind]) for bfs_ind in range(bt.n_atoms)]) - for j in range(n_output_types): - if i == j and i < n_conn: - # we cannot enter from one inter-residue connection point and then - # leave by that same inter-residue connection point unless we are - # building a jump - continue - - # now we start at the j_conn_atom and work backwards toward the root - # which marks the first scan path for this block type: the "primary exit path" - gen_scan_paths = defaultdict(list) - - j_conn_atom = bt.ordered_connection_atoms[j] if j < n_conn else mid_bt_atom - - first_descendant = numpy.full((bt.n_atoms,), -9999, dtype=numpy.int64) - is_on_primary_exit_path = numpy.zeros((bt.n_atoms,), dtype=bool) - is_on_primary_exit_path[i_conn_atom] = True - - focused_atom = j_conn_atom - primary_exit_scan_path = [] - while focused_atom != i_conn_atom: - # print("exit path:", bt.atom_name(focused_atom)) - is_on_primary_exit_path[focused_atom] = True - primary_exit_scan_path.append(focused_atom) - pred = preds[focused_atom] - first_descendant[pred] = focused_atom - focused_atom = pred - primary_exit_scan_path.append(i_conn_atom) - primary_exit_scan_path.reverse() - # we need to prioritize exit paths of all stripes - # in constructing the trees - is_on_exit_path = is_on_primary_exit_path.copy() - for k in range(n_conn): - if k == i or k == j: - continue # truly unnecessary; nothing changes if I remove these two lines - is_on_exit_path[bt.ordered_connection_atoms[k]] = True - - # print("primary_exit_scan_path:", primary_exit_scan_path) - gen_scan_paths[0].append(primary_exit_scan_path) - - # Create a list of children for each atom. - n_kids = numpy.zeros((bt.n_atoms,), dtype=numpy.int64) - atom_kids = [[] for _ in range(bt.n_atoms)] - for k in range(bt.n_atoms): - if preds[k] < 0: - assert ( - k == i_conn_atom - ), f"bad predecesor for atom {k} in {bt.name}, {preds[k]}" - continue # the root - n_kids[preds[k]] += 1 - atom_kids[preds[k]].append(k) - - # now we label each node with its "generation depth" using a - # leaf-to-root traversal perscribed by the original DFS, taking - # into account the fact that priority must be given to - # exit paths - gen_depth = numpy.ones((bt.n_atoms,), dtype=numpy.int64) - on_path_from_conn_to_i_conn_atom = numpy.zeros((bt.n_atoms,), dtype=bool) - for k in range(bt.n_atoms - 1, -1, -1): - k_atom_ind = bfto_2_orig[k] - # print("recursing upwards", i, "i_conn atom", i_conn_atom, j, "j_conn_atom", j_conn_atom, k, k_atom_ind) - k_kids = atom_kids[k_atom_ind] - # print("kids:", k_kids) - if len(k_kids) == 0: - continue - # from here forward, we know that k_atom_ind has > 0 children - - def gen_depth_given_first_descendant(): - # first set the first_descendant for k_atom_ind - # then the logic is: we have to add one to the - # gen-depth of every child but the first descendant - # which we get "for free" - # print(f"atom {bt.atom_name(k_atom_ind)} with first descendant {bt.atom_name(first_descendant[k_atom_ind]) if first_descendant[k_atom_ind] >= 0 else 'None'} and depth {gen_depth[first_descendant[k_atom_ind]] if first_descendant[k_atom_ind] >= 0 else -9999}") - return max( - [ - ( - gen_depth[k_kid] + 1 - if k_kid != first_descendant[k_atom_ind] - else gen_depth[k_kid] - ) - for k_kid in k_kids - ] - ) - - if is_on_primary_exit_path[k_atom_ind]: - # in this case, the first_descendant for this atom - # has already been decided - # print("on exit path:", bt.atom_name(k_atom_ind), first_descendant[k_atom_ind], is_conn_atom[k_atom_ind]) - if k_atom_ind == j_conn_atom: - # the first descendent is the atom on the next residue to which - # this residue is connected - gen_depth[k_atom_ind] = max([gen_depth[l] for l in k_kids]) + 1 - else: - # first_descendant is already determined for this atom - gen_depth[k_atom_ind] = gen_depth_given_first_descendant() - else: - - if is_conn_atom[k_atom_ind]: - # in this case, "the" connection (there can possibly be more than one!) - # will be the first child and the other descendants will be second children - # we save the gen depth, but when calculating the gen depth of the - # fold-forest, if this residue is at the upstream end of an edge, then - # its depth will have to be calculated as the min gen-depth of the - # intra-residue bits and the gen-depth of the nodes downstream of it. - gen_depth[k_atom_ind] = max([gen_depth[l] for l in k_kids]) + 1 - else: - # most-common case: an atom not on the primary-exit path, and that isn't - # itself a conn atom. - # First we ask: are we on one or more exit paths? - # NOTE: this just chooses the first exit path atom it encounters - # as the first descendant and so I pause and think: if we have - # a block type with 4 inter-residue connections where the fold - # forest branches at this residue, then the algorithm for constructing - # the fewest-number-of-generations KinForest here is going - # will fail: we are treating all exit paths out of this residue - # as interchangable and we might say connection c should be - # ahead of connection c' in a case where c' has a greater gen_depth - # than c. - # - # The case I am designing for here is: there's a jump that has - # landed at a beta-amino acid's CA atom and there are exit paths - # through the N- and C-terminal ends of the residue and if the - # primary exit path is the C-term, then the N-term exit path should - # still have priority over the side-chain path. - # - # R - # | - # ... CB C - # \ / \ / \ - # N CA ... - # - # The path starting at CB should go towards N and not towards R. - # If we are only dealing with polymeric residues that have an - # up- and a down connection that that's it (e.g. nucleic acids), - # then this algorithm will still produce optimal KinForests. - # - # A case that this would fail to deliver the optimally-efficient - # (fewest number of generations) KinForest would be if this R group - # also contained an inter-residue connection and there were an - # edge in the FoldForest (a "chemical edge") leaving from that - # connection to some further chain, e.g., it could be a sugar - # group attached to a beta-ASN. Now if the path (CA->CB->N) takes - # precedence over the path (CA->CB->R), then everything down- - # stream of the R would have a generation-delay one greater than - # it would otherwise. - for kid in k_kids: - if is_on_exit_path[kid]: - first_descendant[k_atom_ind] = kid - is_on_exit_path[k_atom_ind] = True - - if not is_on_exit_path[k_atom_ind]: - # which should be the first descendant? the one with the greatest gen depth - first_descendant[k_atom_ind] = k_kids[ - numpy.argmax( - numpy.array([gen_depth[kid] for kid in k_kids]) - ) - ] - gen_depth[k_atom_ind] = gen_depth_given_first_descendant() - # print("gen_depth", bt.atom_name(k_atom_ind), "d:", gen_depth[k_atom_ind]) - # print("gen_depth", gen_depth) - - # OKAY! - # now we have paths rooted at each node up to the root - # we need to turn these paths into scan paths - processed_node_into_scan_path = is_on_primary_exit_path.copy() - gen_to_build_atom = numpy.full((bt.n_atoms,), -1, dtype=numpy.int64) - gen_to_build_atom[processed_node_into_scan_path] = 0 - # print("gen depth", gen_depth) - # print("starting bfs:", processed_node_into_scan_path) - for k in range(bt.n_atoms): - k_atom_ind = bfto_2_orig[k] - if processed_node_into_scan_path[k_atom_ind]: - continue - - # if we arrive here, that means k_atom_ind is the root of a - # new scan path - path = [] - # we have already processed the first scan path - # from the entrace-point atom to the first exit-point atom - assert k_atom_ind != i_conn_atom - # put the parent of this new root at the beginning of - # the scan path - path.append(preds[k_atom_ind]) - focused_atom = k_atom_ind - - gen_to_build_atom[focused_atom] = ( - gen_to_build_atom[preds[focused_atom]] + 1 - ) - # print( - # f"gen to build {bt.atom_name(focused_atom)} from {bt.atom_name(preds[focused_atom])}", - # f"with gen {gen_to_build_atom[focused_atom]}", - # ) - while focused_atom >= 0: - path.append(focused_atom) - processed_node_into_scan_path[focused_atom] = True - focused_atom = first_descendant[focused_atom] - if focused_atom >= 0: - gen_to_build_atom[focused_atom] = gen_to_build_atom[ - preds[focused_atom] - ] - if is_on_exit_path[k_atom_ind]: - gen_scan_paths[gen_to_build_atom[k_atom_ind]].insert(0, path) - else: - gen_scan_paths[gen_to_build_atom[k_atom_ind]].append(path) - # Now we need to assemble the scan paths in a compact way: - # print("gen scan paths", gen_scan_paths) - - ij_n_gens = gen_depth[i_conn_atom] - # print("ij_n_gens", i, j, ij_n_gens) - ij_n_scans = numpy.array( - [len(gen_scan_paths[k]) for k in range(ij_n_gens)], dtype=int - ) - # print("ij_n_scans", i, j, ij_n_scans) - ij_scan_starts = [ - numpy.zeros((ij_n_scans[k],), dtype=int) for k in range(ij_n_gens) - ] - ij_scan_lengths = [ - numpy.array( - [len(gen_scan_paths[k][l]) for l in range(len(gen_scan_paths[k]))], - dtype=int, - ) - for k in range(ij_n_gens) - ] - # print("ij_scan_lengths", i, j, ij_scan_lengths) - for k in range(ij_n_gens): - offset = 0 - for l in range(ij_n_scans[k]): - ij_scan_starts[k][l] = offset - offset += ij_scan_lengths[k][l] - # print("ij_scan_starts", i, j, ij_scan_starts) - # print("ij_scan_lengths cumsum?", numpy.cumsum(ij_scan_lengths)) - ij_scan_is_inter_block = [ - numpy.zeros((ij_n_scans[k],), dtype=bool) for k in range(ij_n_gens) - ] - - for k in range(ij_n_gens): - for l in range(ij_n_scans[k]): - l_first_at = gen_scan_paths[k][l][0 if k == 0 else 1] - ij_scan_is_inter_block[k][l] = is_on_exit_path[l_first_at] - - # print("ij_scan_is_inter_block", ij_scan_is_inter_block) - # ij_n_nodes_for_gen = - ij_n_nodes_for_gen = numpy.array( - [ - sum(len(path) for path in gen_scan_paths[k]) - for k in range(ij_n_gens) - ], - dtype=int, - ) - # print("ij_n_nodes_for_gen", ij_n_nodes_for_gen) - scan_path_data[(i, j)] = dict( - n_gens=ij_n_gens, - n_nodes_for_gen=ij_n_nodes_for_gen, - nodes_for_generation=gen_scan_paths, - n_scans=ij_n_scans, - scan_starts=ij_scan_starts, - scan_is_inter_block=is_on_exit_path, - scan_lengths=ij_scan_lengths, - ) - # end for j - # end for i - - # Now let's count out the maximum number of generations, scans, and nodes-per-gen - # so we can create the GenerationalSegScanPaths object - max_n_gens = max( - scan_path_data[(i, j)]["n_gens"] - for i in range(n_input_types) - for j in range(n_output_types) - if (i, j) in scan_path_data - ) - max_n_scans = max( - max( - scan_path_data[(i, j)]["n_scans"][k] - for k in range(scan_path_data[(i, j)]["n_gens"]) - ) - for i in range(n_input_types) - for j in range(n_output_types) - if (i, j) in scan_path_data - ) - max_n_nodes_per_gen = max( - max( - scan_path_data[(i, j)]["n_nodes_for_gen"][k] - for k in range(scan_path_data[(i, j)]["n_gens"]) - ) - for i in range(n_input_types) - for j in range(n_output_types) - if (i, j) in scan_path_data - ) - bt_gen_seg_scan_paths = GenerationalSegScanPaths.empty( - n_input_types, - n_output_types, - bt.n_atoms, - max_n_gens, - max_n_scans, - max_n_nodes_per_gen, - ) - bt_gen_seg_scan_paths.parents = parents - bt_gen_seg_scan_paths.input_conn_atom = input_conn_atom - # Finally, we populate the GenerationalSegScanPaths object - for i in range(n_input_types): - for j in range(n_output_types): - if (i, j) not in scan_path_data: - continue - ij_n_gens = scan_path_data[(i, j)]["n_gens"] - bt_gen_seg_scan_paths.n_gens[i, j] = ij_n_gens - for k in range(ij_n_gens): - bt_gen_seg_scan_paths.n_nodes_for_gen[i, j, k] = scan_path_data[(i, j)][ - "n_nodes_for_gen" - ][k] - bt_gen_seg_scan_paths.n_scans[i, j, k] = scan_path_data[(i, j)][ - "n_scans" - ][k] - bt_gen_seg_scan_paths.scan_is_real[ - i, j, k, : bt_gen_seg_scan_paths.n_scans[i, j, k] - ] = True - - ijk_n_scans = scan_path_data[(i, j)]["n_scans"][k] - bt_gen_seg_scan_paths.scan_starts[i, j, k, :ijk_n_scans] = ( - scan_path_data[(i, j)]["scan_starts"][k] - ) - bt_gen_seg_scan_paths.scan_is_inter_block[i, j, k, :ijk_n_scans] = ( - scan_path_data[(i, j)]["scan_is_inter_block"][k] - ) - bt_gen_seg_scan_paths.scan_lengths[i, j, k, :ijk_n_scans] = ( - scan_path_data[(i, j)]["scan_lengths"][k] - ) - # for l in range(scan_path_data[(i, j)]["n_scans"][k]): - # bt_gen_seg_scan_paths.scan_starts[i, j, k, l] = scan_path_data[(i, j)]["scan_starts"][k][l] - # bt_gen_seg_scan_paths.scan_is_inter_block[i, j, k, l] = scan_path_data[(i, j)]["scan_is_inter_block"][k][l] - # bt_gen_seg_scan_paths.scan_lengths[i, j, k, l] = scan_path_data[(i, j)]["scan_lengths"][k][l] - for l in range(ijk_n_scans): - m_offset = scan_path_data[(i, j)]["scan_starts"][k][l] - for m in range( - len(scan_path_data[(i, j)]["nodes_for_generation"][k][l]) - ): - bt_gen_seg_scan_paths.nodes_for_gen[i, j, k, m_offset + m] = ( - scan_path_data[(i, j)]["nodes_for_generation"][k][l][m] - ) - # print("nodes for gen", i, j, k, bt_gen_seg_scan_paths.nodes_for_gen[i, j, k, :]) - - setattr(bt, "gen_seg_scan_paths", bt_gen_seg_scan_paths) - - def test_gen_seg_scan_paths_block_type_annotation_smoke(fresh_default_restype_set): torch_device = torch.device("cpu") - # co = default_canonical_ordering() - # pbt = default_packed_block_types(torch_device) - # canonical_form = canonical_form_from_pdb(co, ubq_pdb, torch_device) - # pose_stack = pose_stack_from_canonical_form(co, pbt, **canonical_form) - - # okay! - # 1. let's create some annotations of the packed block types bt_list = [bt for bt in fresh_default_restype_set.residue_types if bt.name == "LEU"] - - # for bt in pbt.active_block_types: for bt in bt_list: _annotate_block_type_with_gen_scan_paths(bt) def test_construct_scan_paths_n_to_c_twores(ubq_pdb): torch_device = torch.device("cpu") + device = torch_device co = default_canonical_ordering() pbt = default_packed_block_types(torch_device) @@ -545,9 +75,10 @@ def test_construct_scan_paths_n_to_c_twores(ubq_pdb): pose_stack = pose_stack_from_canonical_form( co, pbt, **canonical_form, res_not_connected=res_not_connected ) + _annotate_packed_block_type_with_gen_scan_paths(pbt) - for bt in pbt.active_block_types: - _annotate_block_type_with_gen_scan_paths(bt) + # for bt in pbt.active_block_types: + # _annotate_block_type_with_gen_scan_paths(bt) # now lets assume we have everything we need for the final step # of kintree construction: @@ -589,6 +120,14 @@ def test_construct_scan_paths_n_to_c_twores(ubq_pdb): print("parents") print(bt0gssp.parents[3]) print(bt1gssp.parents[0]) + print( + "parents in pbt, res1", + pbt.gen_seg_scan_paths.parents[pose_stack.block_type_ind[0, 0], 3], + ) + print( + "parents in pbt, res2", + pbt.gen_seg_scan_paths.parents[pose_stack.block_type_ind[0, 1], 0], + ) ij0 = [3, 1] # 3 => root "input"; Q: is this different from jump input? ij1 = [0, 1] @@ -735,13 +274,233 @@ def _tint(ts): kinforest, ) - print("starting coords", pose_stack.coords.view(-1, 3)[14:19]) + # print("starting coords", pose_stack.coords.view(-1, 3)[14:19]) - print("kincoords", kincoords[15:20]) - print("new coords", new_coords[15:20]) + # print("kincoords", kincoords[15:20]) + # print("new coords", new_coords[15:20]) torch.testing.assert_close(kincoords, new_coords, rtol=1e-5, atol=1e-5) + # okay: let's construct the components of the kinforest from + # the block types + + # 1. id: Tensor[torch.int32][...] + + is_bt_real = pose_stack.block_type_ind != -1 + nz_is_bt_real = torch.nonzero(is_bt_real, as_tuple=True) + n_atoms = torch.zeros_like(pose_stack.block_type_ind64) + n_atoms[is_bt_real] = pbt.n_atoms[pose_stack.block_type_ind64[is_bt_real]].to( + torch.int64 + ) + n_atoms_real_bt = n_atoms[is_bt_real] + n_atoms_total = n_atoms.sum() + + # let's imagine a variable that says for each residue + # whether it is connected to its parent by a jump, + # an N->C connection, or a C->N connection + ff_conn_to_parent = torch.full( + (pose_stack.n_poses, pose_stack.max_n_blocks), + -1, + dtype=torch.int32, + device=device, + ) + ff_conn_to_parent[0, 0] = 2 # jump + ff_conn_to_parent[0, 1] = 0 # N->C + + block_in_out = torch.full( + (pose_stack.n_poses, pose_stack.max_n_blocks, 2), + -1, + dtype=torch.int64, + device=device, + ) + block_in_out[0, 0, 0] = 3 # input from root + block_in_out[0, 0, 1] = 1 # output through upper connection + block_in_out[0, 1, 0] = 0 # input from lower connection + block_in_out[0, 1, 1] = 1 # output through upper connection + + fold_forest_parent = torch.full( + (pose_stack.n_poses, pose_stack.max_n_blocks), + -1, + dtype=torch.int32, + device=device, + ) + fold_forest_parent[0, 1] = 0 + + id = torch.concatenate( # cat? + ( + torch.full((1,), -1, dtype=torch.int32, device=device), + torch.arange(n_atoms_total, dtype=torch.int32, device=device), + ) + ) + torch.testing.assert_close(id, ids_gold_t) + + # doftype: Tensor[torch.int32][...] + doftype = torch.full_like(id, NodeType.bond.value) + + # 2. parent: Tensor[torch.int32][...] + + parent = torch.full_like(id, -1, dtype=torch.int32, device=device) + + # masked-out residues and residues connected directly to the root + # don't need their parent atoms calculated + ffparent_is_real_block = fold_forest_parent != -1 + real_ffparent = fold_forest_parent[ffparent_is_real_block] + nz_block_w_real_ffparent = torch.nonzero(ffparent_is_real_block, as_tuple=True) + + per_block_type_parent = torch.full( + (pose_stack.n_poses, pose_stack.max_n_blocks, pbt.max_n_atoms), + -1, + dtype=torch.int32, + ) + per_block_type_parent[is_bt_real, :] = pbt.gen_seg_scan_paths.parents[ + pose_stack.block_type_ind64[is_bt_real], + block_in_out[is_bt_real][:, 0], + ] + print("per block type parent", per_block_type_parent) + + # atom_pose_ind = torch.arange( + # pose_stack.n_poses, dtype=torch.int32, device=device + # ).unsqueeze(-1).unsqueeze(-1).expand( + # (pose_stack.n_poses, pose_stack.max_n_blocks, pose_stack.max_n_atoms) + # ) + is_atom_real = torch.zeros( + (pose_stack.n_poses, pose_stack.max_n_blocks, pose_stack.max_n_atoms), + dtype=torch.bool, + ) + is_atom_real[is_bt_real] = pbt.atom_is_real[pose_stack.block_type_ind64[is_bt_real]] + + # atom_block_coord_offset = pose_stack.block_coord_offset.unsqueeze(-1).expand( + # (pose_stack.n_poses, pose_stack.max_n_blocks, pose_stack.max_n_atoms) + # ) + + kfo_block_offset = n_atoms.clone().flatten() + kfo_block_offset[0] += 1 # add in the virtual root + kfo_block_offset = exclusive_cumsum1d(kfo_block_offset) + kfo_block_offset[0] = 1 # adjust for the virtual root + kfo_block_offset = kfo_block_offset.view( + (pose_stack.n_poses, pose_stack.max_n_blocks) + ) + + kfo_block_offset_for_atom = kfo_block_offset.unsqueeze(-1).expand( + (pose_stack.n_poses, pose_stack.max_n_blocks, pose_stack.max_n_atoms) + ) + real_bt_ind_for_bt = torch.full_like( + pose_stack.block_type_ind, -1, dtype=torch.int32 + ) + real_bt_ind_for_bt[is_bt_real] = torch.arange( + is_bt_real.to(torch.int32).sum(), dtype=torch.int32, device=device + ) + + # which atom on the parent are we connected to? + # if we are connected by bond, then we can check the pose_stack's + # inter_residue_connections tensor; if we are connected by jump, + # then the parent atom is the jump atom of the parent block type + real_ffparent_block_type = pose_stack.block_type_ind64[ + nz_block_w_real_ffparent[0], real_ffparent + ] + # not so fast, tiger + # real_ffparent_conn_ind = pose_stack.inter_residue_connections[ + # nz_block_w_real_ffparent[0], nz_block_w_real_ffparent[1], block_in_out[] + # ] + is_connected_to_ffparent_w_non_jump = torch.logical_and( + ff_conn_to_parent != -1, ff_conn_to_parent != 2 + ) + nz_conn_to_ffparent_w_non_jump = torch.nonzero( + is_connected_to_ffparent_w_non_jump, as_tuple=True + ) + is_connected_to_root = ff_conn_to_parent == 2 + + is_connected_to_ffparent_w_lower_conn = torch.logical_and( + ff_conn_to_parent != -1, ff_conn_to_parent == 0 + ) + is_connected_to_ffparent_w_upper_conn = torch.logical_and( + ff_conn_to_parent != -1, ff_conn_to_parent == 1 + ) + print( + "is connected to ffparent w lower conn", is_connected_to_ffparent_w_lower_conn + ) + print( + "is connected to ffparent w upper conn", is_connected_to_ffparent_w_upper_conn + ) + + real_nonjump_ffparent = fold_forest_parent[is_connected_to_ffparent_w_non_jump] + real_nonjump_ffparent_p_block_type = pose_stack.block_type_ind64[ + nz_conn_to_ffparent_w_non_jump[0], real_nonjump_ffparent + ] + real_nonjump_ffparent_block_type = pose_stack.block_type_ind64[ + nz_block_w_real_ffparent[0], nz_block_w_real_ffparent[1] + ] + + conn_ind = torch.full_like(ff_conn_to_parent, -1, dtype=torch.int32) + conn_ind[is_connected_to_ffparent_w_lower_conn] = pbt.down_conn_inds[ + pose_stack.block_type_ind64[is_connected_to_ffparent_w_lower_conn] + ] + conn_ind[is_connected_to_ffparent_w_upper_conn] = pbt.up_conn_inds[ + pose_stack.block_type_ind64[is_connected_to_ffparent_w_upper_conn] + ] + print("conn ind", conn_ind) + real_nonjump_ffparent_p_conn_ind = pose_stack.inter_residue_connections[ + nz_conn_to_ffparent_w_non_jump[0], + nz_conn_to_ffparent_w_non_jump[1], + conn_ind[is_connected_to_ffparent_w_non_jump], + 1, + ] + real_nonjump_ffparent_p_conn_atom = ( + pbt.conn_atom[ + real_nonjump_ffparent_p_block_type, real_nonjump_ffparent_p_conn_ind + ] + + kfo_block_offset[nz_conn_to_ffparent_w_non_jump[0], real_nonjump_ffparent] + ) + print("real_nonjump_ffparent_p_conn_atom", real_nonjump_ffparent_p_conn_atom) + real_nonjump_ffparent_conn_atom = pbt.conn_atom[ + real_nonjump_ffparent_block_type, conn_ind[is_connected_to_ffparent_w_non_jump] + ] + atoms_connected_by_nonjump = ( + real_nonjump_ffparent_conn_atom + + kfo_block_offset[ + nz_conn_to_ffparent_w_non_jump[0], nz_conn_to_ffparent_w_non_jump[1] + ] + ) + print("atoms connected by nonjump", atoms_connected_by_nonjump) + + real_conn_to_root_conn_atom = pbt.conn_atom[ + pose_stack.block_type_ind64[is_connected_to_root], 0 + ] + + atoms_connected_to_the_root = 2 # TEMP! FIX ME!!!! + print("atoms connected to the root") + + # TO DO: + # Lookup jump conn atom when connected by jump + + parent[1:] = ( + per_block_type_parent[is_atom_real] + kfo_block_offset_for_atom[is_atom_real] + ) + + parent[atoms_connected_by_nonjump] = real_nonjump_ffparent_p_conn_atom.to( + torch.int32 + ) + + # correct the roots + parent[0] = 0 + parent[atoms_connected_to_the_root] = 0 + + # okay, but we have to adjust the parent atoms for the connection + # atoms (with negative parent values) + print("parent", parent) + print("parents_gold_t", parents_gold_t) + + torch.testing.assert_close(parent, parents_gold_t) + + # # roots: Tensor[torch.int32][...] # not used in current kinforest + # frame_x: Tensor[torch.int32][...] + # frame_y: Tensor[torch.int32][...] + # frame_z: Tensor[torch.int32][...] + # (and the data members appended in get_scans) + # nodes + # scans + # gens + def test_decide_scan_paths_for_foldforest(ubq_pdb): torch_device = torch.device("cpu") From 167a555e09e59480992e01c59ad65099cdfb7c7a Mon Sep 17 00:00:00 2001 From: Andrew Leaver-Fay Date: Thu, 15 Aug 2024 11:18:15 -0400 Subject: [PATCH 06/52] Fix unit tests following code shuffle Automated construction of both "id" and "parent" tensors now working properly. --- tmol/kinematics/datatypes.py | 15 +++++++------ tmol/kinematics/scan_ordering.py | 13 ++++++++--- ...st_create_scan_orering_from_block_types.py | 22 +++++++++++++------ 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/tmol/kinematics/datatypes.py b/tmol/kinematics/datatypes.py index cac56a64c..edb6c7c37 100644 --- a/tmol/kinematics/datatypes.py +++ b/tmol/kinematics/datatypes.py @@ -1,6 +1,7 @@ import enum +import numpy import torch -import attr +import attrs from tmol.types.torch import Tensor from tmol.types.tensor import TensorGroup @@ -18,7 +19,7 @@ class NodeType(enum.IntEnum): bond = enum.auto() -@attr.s(auto_attribs=True, frozen=True) +@attrs.define(auto_attribs=True, frozen=True) class KinForest(TensorGroup, ConvertAttrs): """A collection of atom-level kinematic trees, each of which can be processed in parallel. @@ -122,7 +123,7 @@ def root_node(cls): ) -@attr.s(auto_attribs=True, slots=True, frozen=True) +@attrs.define(auto_attribs=True, slots=True, frozen=True) class KinDOF(TensorGroup, ConvertAttrs): """Internal coordinate data. @@ -170,7 +171,7 @@ class JumpDOFTypes(enum.IntEnum): RBgamma = enum.auto() -@attr.s(auto_attribs=True, slots=True, frozen=True) +@attrs.define(auto_attribs=True, slots=True, frozen=True) class BondDOF(TensorGroup, ConvertAttrs): """A bond dof view of KinDOF.""" @@ -193,7 +194,7 @@ def phi_c(self): return self.raw[..., BondDOFTypes.phi_c] -@attr.s(auto_attribs=True, slots=True, frozen=True) +@attrs.define(auto_attribs=True, slots=True, frozen=True) class JumpDOF(TensorGroup, ConvertAttrs): """A jump dof view of KinDOF.""" @@ -264,7 +265,7 @@ def empty( ): io = (n_input_types, n_output_types) return cls( - jump_input_atom=-1, + jump_atom=-1, parents=numpy.full( (n_input_types, n_atoms), -1, dtype=int ), # independent of primary output @@ -312,7 +313,7 @@ def empty( ): io = (n_bt, max_n_input_types, max_n_output_types) return cls( - jump_input_atom=torch.full(n_bt, -1, dtype=torch.int32, device=device), + jump_atom=torch.full((n_bt,), -1, dtype=torch.int32, device=device), parents=torch.full( (n_bt, max_n_input_types, max_n_atoms), -1, diff --git a/tmol/kinematics/scan_ordering.py b/tmol/kinematics/scan_ordering.py index 7e4df3fd9..22b574b9a 100644 --- a/tmol/kinematics/scan_ordering.py +++ b/tmol/kinematics/scan_ordering.py @@ -9,6 +9,7 @@ ) from numba import jit +from tmol.types.array import NDArray from tmol.types.torch import Tensor from tmol.types.tensor import TensorGroup from tmol.types.attrs import ConvertAttrs, ValidateAttrs @@ -30,7 +31,8 @@ from tmol.io.pose_stack_construction import pose_stack_from_canonical_form from tmol.kinematics.datatypes import NodeType from tmol.kinematics.fold_forest import EdgeType -from tmol.kinematics.scan_ordering import get_children + +# from tmol.kinematics.scan_ordering import get_children from tmol.kinematics.compiled import inverse_kin, forward_kin_op from tmol.utility.tensor.common_operations import exclusive_cumsum1d @@ -352,7 +354,7 @@ def calculate_from_kinforest(cls, kinforest: KinForest): def jump_atom_for_bt(bt): """Return the index of the atom that will be jumped to or jumped from""" # TEMP: CA if CA is present; ow, atom 0 - return bt.atom_to_idx("CA") if "CA" in bt.atom_names else 0 + return bt.atom_to_idx["CA"] if "CA" in bt.atom_names_set else 0 def _annotate_block_type_with_gen_scan_paths(bt): @@ -406,7 +408,7 @@ def _bonds_to_csgraph( bond_graph = potential_bonds + prioritized_bonds bond_graph_spanning_tree = csgraph.minimum_spanning_tree(bond_graph.tocsr()) - mid_bt_atom = jump_bt_atom(bt, bond_graph_spanning_tree) + mid_bt_atom = jump_atom_for_bt(bt) is_conn_atom = numpy.zeros((bt.n_atoms,), dtype=bool) for i in range(n_conn): @@ -805,6 +807,11 @@ def _annotate_packed_block_type_with_gen_scan_paths(pbt): max_n_scans, max_n_nodes_per_gen, ) + gen_seg_scan_paths.jump_atom[:] = torch.tensor( + [bt.gen_seg_scan_paths.jump_atom for bt in pbt.active_block_types], + dtype=torch.int32, + device=pbt.device, + ) varnames = [ "parents", "input_conn_atom", diff --git a/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py b/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py index 230be9bad..e75fbad67 100644 --- a/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py +++ b/tmol/tests/kinematics/test_create_scan_orering_from_block_types.py @@ -58,6 +58,7 @@ def test_gen_seg_scan_paths_block_type_annotation_smoke(fresh_default_restype_se bt_list = [bt for bt in fresh_default_restype_set.residue_types if bt.name == "LEU"] for bt in bt_list: _annotate_block_type_with_gen_scan_paths(bt) + assert hasattr(bt, "gen_seg_scan_paths") def test_construct_scan_paths_n_to_c_twores(ubq_pdb): @@ -77,6 +78,8 @@ def test_construct_scan_paths_n_to_c_twores(ubq_pdb): ) _annotate_packed_block_type_with_gen_scan_paths(pbt) + pbt_gssp = pbt.gen_seg_scan_paths + # for bt in pbt.active_block_types: # _annotate_block_type_with_gen_scan_paths(bt) @@ -122,11 +125,11 @@ def test_construct_scan_paths_n_to_c_twores(ubq_pdb): print(bt1gssp.parents[0]) print( "parents in pbt, res1", - pbt.gen_seg_scan_paths.parents[pose_stack.block_type_ind[0, 0], 3], + pbt_gssp.parents[pose_stack.block_type_ind[0, 0], 3], ) print( "parents in pbt, res2", - pbt.gen_seg_scan_paths.parents[pose_stack.block_type_ind[0, 1], 0], + pbt_gssp.parents[pose_stack.block_type_ind[0, 1], 0], ) ij0 = [3, 1] # 3 => root "input"; Q: is this different from jump input? @@ -352,7 +355,7 @@ def _tint(ts): -1, dtype=torch.int32, ) - per_block_type_parent[is_bt_real, :] = pbt.gen_seg_scan_paths.parents[ + per_block_type_parent[is_bt_real, :] = pbt_gssp.parents[ pose_stack.block_type_ind64[is_bt_real], block_in_out[is_bt_real][:, 0], ] @@ -463,11 +466,16 @@ def _tint(ts): ) print("atoms connected by nonjump", atoms_connected_by_nonjump) - real_conn_to_root_conn_atom = pbt.conn_atom[ - pose_stack.block_type_ind64[is_connected_to_root], 0 - ] + # real_conn_to_root_conn_atom = pbt.conn_atom[ + # pose_stack.block_type_ind64[is_connected_to_root], 0 + # ] + real_conn_to_root_bt = pose_stack.block_type_ind64[is_connected_to_root] + real_conn_to_root_atoms = pbt_gssp.jump_atom[real_conn_to_root_bt] + atoms_connected_to_the_root = ( + real_conn_to_root_atoms + kfo_block_offset[is_connected_to_root] + ) - atoms_connected_to_the_root = 2 # TEMP! FIX ME!!!! + # atoms_connected_to_the_root = 2 # TEMP! FIX ME!!!! print("atoms connected to the root") # TO DO: From 54bded258cc6ec624070749138eb18aa12c91a03 Mon Sep 17 00:00:00 2001 From: Andrew Leaver-Fay Date: Tue, 10 Sep 2024 08:05:39 -0400 Subject: [PATCH 07/52] Move scan_types to its own, CUDA-independent declaration --- tmol/extern/moderngpu/cta_scan.hxx | 46 +++++++++++------------- tmol/extern/moderngpu/cta_segscan.hxx | 6 ++-- tmol/extern/moderngpu/kernel_scan.hxx | 52 +++++++++++++-------------- tmol/extern/moderngpu/scan_types.hxx | 15 ++++++++ 4 files changed, 65 insertions(+), 54 deletions(-) create mode 100644 tmol/extern/moderngpu/scan_types.hxx diff --git a/tmol/extern/moderngpu/cta_scan.hxx b/tmol/extern/moderngpu/cta_scan.hxx index f690157e7..856253d9e 100644 --- a/tmol/extern/moderngpu/cta_scan.hxx +++ b/tmol/extern/moderngpu/cta_scan.hxx @@ -2,14 +2,10 @@ #pragma once #include "loadstore.hxx" #include "intrinsics.hxx" +#include "scan_types.hxx" BEGIN_MGPU_NAMESPACE -enum scan_type_t { - scan_type_exc, - scan_type_inc -}; - template 0)> struct scan_result_t { type_t scan; @@ -32,7 +28,7 @@ struct cta_scan_t { struct { type_t threads[nt], warps[num_warps]; }; }; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 ////////////////////////////////////////////////////////////////////////////// // Optimized CTA scan code that uses warp shfl intrinsics. @@ -41,7 +37,7 @@ struct cta_scan_t { template > MGPU_DEVICE scan_result_t - scan(int tid, type_t x, storage_t& storage, int count = nt, op_t op = op_t(), + scan(int tid, type_t x, storage_t& storage, int count = nt, op_t op = op_t(), type_t init = type_t(), scan_type_t type = scan_type_exc) const { int warp = tid / warp_size; @@ -61,7 +57,7 @@ struct cta_scan_t { __syncthreads(); // Scan the warp reductions. - if(tid < num_warps) { + if(tid < num_warps) { type_t cta_scan = storage.warps[tid]; iterate([&](int pass) { cta_scan = shfl_up_op(cta_scan, 1<< pass, op, num_warps); @@ -78,10 +74,10 @@ struct cta_scan_t { if(warp > 0) scan = op(scan, storage.warps[warp - 1]); type_t reduction = storage.warps[div_up(count, warp_size) - 1]; - - scan_result_t result { - tid < count ? scan : reduction, - reduction + + scan_result_t result { + tid < count ? scan : reduction, + reduction }; __syncthreads(); @@ -91,11 +87,11 @@ struct cta_scan_t { #else ////////////////////////////////////////////////////////////////////////////// - // Standard CTA scan code that does not use shfl intrinsics. + // Standard CTA scan code that does not use shfl intrinsics. template > - MGPU_DEVICE scan_result_t - scan(int tid, type_t x, storage_t& storage, int count = nt, op_t op = op_t(), + MGPU_DEVICE scan_result_t + scan(int tid, type_t x, storage_t& storage, int count = nt, op_t op = op_t(), type_t init = type_t(), scan_type_t type = scan_type_exc) const { int first = 0; @@ -113,7 +109,7 @@ struct cta_scan_t { scan_result_t result; result.reduction = storage.data[first + count - 1]; - result.scan = (tid < count) ? + result.scan = (tid < count) ? (scan_type_inc == type ? x : (tid ? storage.data[first + tid - 1] : init)) : result.reduction; @@ -122,16 +118,16 @@ struct cta_scan_t { return result; } -#endif +#endif ////////////////////////////////////////////////////////////////////////////// - // CTA vectorized scan. Accepts multiple values per thread and adds in + // CTA vectorized scan. Accepts multiple values per thread and adds in // optional global carry-in. template > MGPU_DEVICE scan_result_t - scan(int tid, array_t x, storage_t& storage, - type_t carry_in = type_t(), bool use_carry_in = false, + scan(int tid, array_t x, storage_t& storage, + type_t carry_in = type_t(), bool use_carry_in = false, int count = nt, op_t op = op_t(), type_t init = type_t(), scan_type_t type = scan_type_exc) const { @@ -143,14 +139,14 @@ struct cta_scan_t { } else { iterate([&](int i) { int index = vt * tid + i; - x[i] = i ? + x[i] = i ? ((index < count) ? op(x[i], x[i - 1]) : x[i - 1]) : (x[i] = (index < count) ? x[i] : init); }); } // Scan the thread-local reductions for a carry-in for each thread. - scan_result_t result = scan(tid, x[vt - 1], storage, + scan_result_t result = scan(tid, x[vt - 1], storage, div_up(count, vt), op, init, scan_type_exc); // Perform the scan downsweep and add both the global carry-in and the @@ -185,7 +181,7 @@ struct cta_scan_t { int warps[num_warps]; }; - MGPU_DEVICE scan_result_t scan(int tid, bool x, + MGPU_DEVICE scan_result_t scan(int tid, bool x, storage_t& storage) const { // Store the bit totals for each warp. @@ -207,7 +203,7 @@ struct cta_scan_t { } __syncthreads(); #else - + if(0 == tid) { // Inclusive scan of partial reductions.. int scan = 0; @@ -217,7 +213,7 @@ struct cta_scan_t { } __syncthreads(); -#endif +#endif int scan = ((warp > 0) ? storage.warps[warp - 1] : 0) + popc(bfe(bits, 0, lane)); diff --git a/tmol/extern/moderngpu/cta_segscan.hxx b/tmol/extern/moderngpu/cta_segscan.hxx index f27c26545..dd960afbe 100644 --- a/tmol/extern/moderngpu/cta_segscan.hxx +++ b/tmol/extern/moderngpu/cta_segscan.hxx @@ -18,11 +18,11 @@ struct cta_segscan_t { enum { num_warps = nt / warp_size }; union storage_t { - int delta[num_warps + nt]; + int delta[num_warps + nt]; struct { type_t values[2 * nt]; int packed[nt]; }; }; - MGPU_DEVICE int find_left_lane(int tid, bool has_head_flag, + MGPU_DEVICE int find_left_lane(int tid, bool has_head_flag, storage_t& storage) const { int warp = tid / warp_size; @@ -93,7 +93,7 @@ struct cta_segscan_t { // the carry-out value as the total. bool has_carry_in = tid ? (0 != (1 & storage.packed[tid - 1])) : false; - segscan_result_t result { + segscan_result_t result { (has_carry_in && tid) ? storage.values[first + tid - 1] : init, storage.values[first + nt - 1], has_carry_in, diff --git a/tmol/extern/moderngpu/kernel_scan.hxx b/tmol/extern/moderngpu/kernel_scan.hxx index b5f308599..988e9bdab 100644 --- a/tmol/extern/moderngpu/kernel_scan.hxx +++ b/tmol/extern/moderngpu/kernel_scan.hxx @@ -8,13 +8,13 @@ BEGIN_MGPU_NAMESPACE -template -void scan_event(input_it input, int count, output_it output, op_t op, +void scan_event(input_it input, int count, output_it output, op_t op, reduction_it reduction, context_t& context, cudaEvent_t event) { - typedef typename conditional_typedef_t, arch_35_cta<128, 7>, @@ -54,7 +54,7 @@ void scan_event(input_it input, int count, output_it output, op_t op, }, tid, tile.count()); // Reduce across all threads. - type_t all_reduce = reduce_t().reduce(tid, scalar, shared.reduce, + type_t all_reduce = reduce_t().reduce(tid, scalar, shared.reduce, tile.count(), op); // Store the final reduction to the partials. @@ -69,7 +69,7 @@ void scan_event(input_it input, int count, output_it output, op_t op, scan_event(partials_data, num_ctas, partials_data, op, reduction, context, event); - // Record the event. This lets the caller wait on just the reduction + // Record the event. This lets the caller wait on just the reduction // part of the operation. It's useful when writing the reduction to // host-side paged-locked memory; the caller can read out the value more // quickly to allocate memory and launch the next kernel. @@ -77,7 +77,7 @@ void scan_event(input_it input, int count, output_it output, op_t op, cudaEventRecord(event, context.stream()); //////////////////////////////////////////////////////////////////////////// - // Downsweep phase. Perform an intra-tile scan and add the scan of the + // Downsweep phase. Perform an intra-tile scan and add the scan of the // partials as carry-in. auto downsweep_k = [=] MGPU_DEVICE(int tid, int cta) { @@ -92,20 +92,20 @@ void scan_event(input_it input, int count, output_it output, op_t op, // Load a tile to register in thread order. range_t tile = get_tile(cta, nv, count); - array_t x = mem_to_reg_thread(input + tile.begin, + array_t x = mem_to_reg_thread(input + tile.begin, tid, tile.count(), shared.values); // Scan the array with carry-in from the partials. - array_t y = scan_t().scan(tid, x, shared.scan, - partials_data[cta], cta > 0, tile.count(), op, type_t(), + array_t y = scan_t().scan(tid, x, shared.scan, + partials_data[cta], cta > 0, tile.count(), op, type_t(), scan_type).scan; // Store the scanned values to the output. - reg_to_mem_thread(y, tid, tile.count(), output + tile.begin, - shared.values); + reg_to_mem_thread(y, tid, tile.count(), output + tile.begin, + shared.values); }; cta_transform(downsweep_k, count, context); - + } else { //////////////////////////////////////////////////////////////////////////// @@ -113,7 +113,7 @@ void scan_event(input_it input, int count, output_it output, op_t op, typedef launch_params_t<512, 3> spine_params_t; auto spine_k = [=] MGPU_DEVICE(int tid, int cta) { - + enum { nt = spine_params_t::nt, vt = spine_params_t::vt, nv = nt * vt }; typedef cta_scan_t scan_t; @@ -126,16 +126,16 @@ void scan_event(input_it input, int count, output_it output, op_t op, for(int cur = 0; cur < count; cur += nv) { // Cooperatively load values into register. int count2 = min(count - cur, nv); - array_t x = mem_to_reg_thread(input + cur, + array_t x = mem_to_reg_thread(input + cur, tid, count2, shared.values); scan_result_t result = scan_t().scan(tid, x, shared.scan, carry_in, cur > 0, count2, op, type_t(), scan_type); // Store the scanned values back to global memory. - reg_to_mem_thread(result.scan, tid, count2, + reg_to_mem_thread(result.scan, tid, count2, output + cur, shared.values); - + // Roll the reduction into carry_in. carry_in = result.reduction; } @@ -147,7 +147,7 @@ void scan_event(input_it input, int count, output_it output, op_t op, }; cta_launch(spine_k, 1, context); - // Record the event. This lets the caller wait on just the reduction + // Record the event. This lets the caller wait on just the reduction // part of the operation. It's useful when writing the reduction to // host-side paged-locked memory; the caller can read out the value more // quickly to allocate memory and launch the next kernel. @@ -156,17 +156,17 @@ void scan_event(input_it input, int count, output_it output, op_t op, } } -template -void scan(input_it input, int count, output_it output, op_t op, +void scan(input_it input, int count, output_it output, op_t op, reduction_it reduction, context_t& context) { - return scan_event(input, count, output, op, + return scan_event(input, count, output, op, reduction, context, 0); } -template void scan(input_it input, int count, output_it output, context_t& context) { @@ -175,7 +175,7 @@ void scan(input_it input, int count, output_it output, context_t& context) { discard_iterator_t(), context); } -template void transform_scan_event(func_t f, int count, output_it output, op_t op, @@ -185,7 +185,7 @@ void transform_scan_event(func_t f, int count, output_it output, op_t op, count, output, op, reduction, context, event); } -template void transform_scan(func_t f, int count, output_it output, op_t op, diff --git a/tmol/extern/moderngpu/scan_types.hxx b/tmol/extern/moderngpu/scan_types.hxx new file mode 100644 index 000000000..85fc31a25 --- /dev/null +++ b/tmol/extern/moderngpu/scan_types.hxx @@ -0,0 +1,15 @@ +#pragma once + +// For mgpu namespace macros +#include "meta.hxx" + +BEGIN_MGPU_NAMESPACE + +// Types for scan operations that are CPU-compatible. + +enum scan_type_t { + scan_type_exc, + scan_type_inc +}; + +END_MGPU_NAMESPACE From 43ec5e4ee55f2ffd0b2783d6d8ae993e0f42c188 Mon Sep 17 00:00:00 2001 From: Andrew Leaver-Fay Date: Tue, 10 Sep 2024 08:06:57 -0400 Subject: [PATCH 08/52] Add C++ implementation of fix-jump-nodes --- tmol/kinematics/compiled/common.hh | 173 ++++++++++++++++++ tmol/kinematics/compiled/common_dispatch.hh | 13 ++ tmol/kinematics/compiled/compiled.cpu.cpp | 5 + tmol/kinematics/compiled/compiled_ops.cpp | 25 +++ tmol/kinematics/compiled/compiled_ops.py | 1 + tmol/score/common/accumulate.hh | 38 ++-- .../common/device_operations.cpu.impl.hh | 14 ++ .../common/device_operations.cuda.impl.cuh | 10 + tmol/score/common/device_operations.hh | 6 + ...st_create_scan_orering_from_block_types.py | 30 ++- tmol/tests/kinematics/test_gpu_operations.py | 88 +++++++++ tmol/tests/kinematics/test_script_modules.py | 7 + 12 files changed, 386 insertions(+), 24 deletions(-) diff --git a/tmol/kinematics/compiled/common.hh b/tmol/kinematics/compiled/common.hh index 32459ba9e..704c9f2bc 100644 --- a/tmol/kinematics/compiled/common.hh +++ b/tmol/kinematics/compiled/common.hh @@ -347,6 +347,179 @@ struct common { } }; +// @numba.jit(nopython=True) +// def get_c1_and_c2_atoms( +// jump_atom: int, +// atom_is_jump: NDArray[int][:], +// child_list_span: NDArray[int][:], +// child_list: NDArray[int][:], +// parents: NDArray[int][:], +// ) -> tuple: +// """Preferably a jump should steal DOFs from its first (nonjump) child +// and its first (nonjump) grandchild, but if the first child does not +// have any children, then it can steal a DOF from its second (nonjump) +// child. If a jump does not have a sufficient number of descendants, then +// we must recurse to its parent. +// """ + +// first_nonjump_child = -1 +// second_nonjump_child = -1 +// for child_ind in range( +// child_list_span[jump_atom, 0], child_list_span[jump_atom, 1] +// ): +// child_atom = child_list[child_ind] +// if atom_is_jump[child_atom]: +// continue +// if first_nonjump_child == -1: +// first_nonjump_child = child_atom +// else: +// second_nonjump_child = child_atom +// break + +// if first_nonjump_child == -1: +// jump_parent = parents[jump_atom] +// assert jump_parent != jump_atom +// return get_c1_and_c2_atoms( +// jump_parent, atom_is_jump, child_list_span, child_list, parents +// ) + +// for grandchild_ind in range( +// child_list_span[first_nonjump_child, 0], +// child_list_span[first_nonjump_child, 1] +// ): +// grandchild_atom = child_list[grandchild_ind] +// if not atom_is_jump[grandchild_atom]: +// return first_nonjump_child, grandchild_atom + +// if second_nonjump_child == -1: +// jump_parent = parents[jump_atom] +// assert jump_parent != jump_atom +// return get_c1_and_c2_atoms( +// jump_parent, atom_is_jump, child_list_span, child_list, parents +// ) + +// return first_nonjump_child, second_nonjump_child + +// @numba.jit(nopython=True) +// def fix_jump_nodes( +// parents: NDArray[int][:], +// frame_x: NDArray[int][:], +// frame_y: NDArray[int][:], +// frame_z: NDArray[int][:], +// roots: NDArray[int][:], +// jumps: NDArray[int][:], +// ): +// # nelts = parents.shape[0] +// n_children, child_list_span, child_list = get_children(parents) + +// atom_is_jump = numpy.full(parents.shape, 0, dtype=numpy.int32) +// atom_is_jump[roots] = 1 +// atom_is_jump[jumps] = 1 + +// for root in roots: +// assert stub_defined_for_jump_atom( +// root, atom_is_jump, child_list_span, child_list +// ) + +// root_c1, second_descendent = get_c1_and_c2_atoms( +// root, atom_is_jump, child_list_span, child_list, parents +// ) + +// # set the frame_x, _y, and _z to the same values for both the root +// # and the root's first child + +// frame_x[root] = root_c1 +// frame_y[root] = root +// frame_z[root] = second_descendent + +// frame_x[root_c1] = root_c1 +// frame_y[root_c1] = root +// frame_z[root_c1] = second_descendent + +// # all the other children of the root need an updated kinematic +// description for child_ind in range(child_list_span[root, 0] + 1, +// child_list_span[root, 1]): +// child = child_list[child_ind] +// if atom_is_jump[child]: +// continue +// if child == root_c1: +// continue +// frame_x[child] = child +// frame_y[child] = root +// frame_z[child] = root_c1 + +// for jump in jumps: +// if stub_defined_for_jump_atom(jump, atom_is_jump, child_list_span, +// child_list): +// jump_c1, jump_c2 = get_c1_and_c2_atoms( +// jump, atom_is_jump, child_list_span, child_list, parents +// ) + +// # set the frame_x, _y, and _z to the same values for both the +// jump # and the jump's first child + +// frame_x[jump] = jump_c1 +// frame_y[jump] = jump +// frame_z[jump] = jump_c2 + +// frame_x[jump_c1] = jump_c1 +// frame_y[jump_c1] = jump +// frame_z[jump_c1] = jump_c2 + +// # all the other children of the jump need an updated kinematic +// description for child_ind in range( +// child_list_span[jump, 0] + 1, child_list_span[jump, 1] +// ): +// child = child_list[child_ind] +// if atom_is_jump[child]: +// continue +// if child == jump_c1: +// continue +// frame_x[child] = child +// frame_y[child] = jump +// frame_z[child] = jump_c1 +// else: +// # ok, so... I don't understand the atom tree well enough to +// understand this # situation. If the jump has no non-jump +// children, then certainly none # of them need their frame +// definitions updated c1, c2 = get_c1_and_c2_atoms( +// parents[jump], atom_is_jump, child_list_span, child_list, +// parents +// ) + +// frame_x[jump] = c1 +// frame_y[jump] = jump +// frame_z[jump] = c2 + +// # the jump may have one child; it's not entirely clear to me +// # what frame the child should have! +// # TO DO: figure this out +// for child_ind in range( +// child_list_span[jump, 0] + 1, child_list_span[jump, 1] +// ): +// child = child_list[child_ind] +// if atom_is_jump[child]: +// continue +// frame_x[child] = c1 +// frame_y[child] = jump +// frame_z[child] = c2 + +template +void get_c1_and_c2_atoms( + int jump_atom, + TView atom_is_jump, + TView child_list_span, + TView child_list, + TView parents) { + // Preferably a jump should steal DOFs from its first (nonjump) child + // and its first (nonjump) grandchild, but if the first child does not + // have any children, then it can steal a DOF from its second (nonjump) + // child. If a jump does not have a sufficient number of descendants, then + // we must recurse to its parent. + + // TO DO! +} + #undef Dofs #undef HomogeneousTransform #undef QuatTranslation diff --git a/tmol/kinematics/compiled/common_dispatch.hh b/tmol/kinematics/compiled/common_dispatch.hh index 607871c4e..da0a569d9 100644 --- a/tmol/kinematics/compiled/common_dispatch.hh +++ b/tmol/kinematics/compiled/common_dispatch.hh @@ -64,6 +64,19 @@ struct KinDerivDispatch { TView, 1, D> kintree) -> TPack; }; +// +// +template