Skip to content

Commit

Permalink
feat(edges): update_edge_lengths and update_all_edge_lengths added (T…
Browse files Browse the repository at this point in the history
…hanks @FinnOD!)
  • Loading branch information
aaronmussig committed Jul 16, 2024
2 parents aaea7e5 + 2925321 commit bed49be
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 7 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,7 @@ wheelhouse/

# Node modules
node_modules/

# venv
bin/*
pyvenv.cfg
6 changes: 6 additions & 0 deletions python/phylodm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ def add_edge(self, parent_id: int, child_id: int, length: float):
"""
return self._rs.add_edge(parent_id=parent_id, child_id=child_id, length=length)

def update_edge_lengths(self, child_nodes: np.ndarray, new_edge_lengths: np.ndarray):
return self._rs.update_edge_lengths(child_nodes=child_nodes, lengths=new_edge_lengths)

def update_all_edge_lengths(self, length: float):
return self._rs.update_all_edge_lengths(length=length)

def get_nodes(self) -> List[int]:
"""Return all node indexes in the tree."""
return self._rs.get_nodes()
Expand Down
69 changes: 64 additions & 5 deletions src/pdm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ impl PDM {
}
Ok(out)
}

/// Return all node IDs in the tree.
#[must_use]
pub fn node_ids(&self) -> Vec<NodeId> {
let mut out = Vec::with_capacity(self.nodes.len());
for node in &self.nodes {
out.push(node.id);
}
out
}

/// Return the sum of all branches in the tree.
#[must_use]
Expand Down Expand Up @@ -160,8 +170,51 @@ impl PDM {
self.get_node_mut(child).set_parent(parent, length);
}

/// Set the depth of each node in the tree.
pub fn assign_node_depth(&mut self) -> Result<(), PhyloErr> {
/// Update edge lengths of a tree
///
/// # Arguments
///
/// * `child_nodes`: - Slice of `NodeId`s
/// * `lengths`: - Slice of `Edge`s
///
pub fn update_edge_lengths(&mut self, child_nodes: &[NodeId], lengths: &[Edge]) -> Result<(), PhyloErr> {

// Find the root node and skip it
let root_node_id = self.root_node()?;

// Check the vectors are the same length
if child_nodes.len() != lengths.len() {
return Err(PhyloErr("Lengths vector does not match the number of nodes!".to_string()));
}

// Update the values
for (child_node_id, length) in child_nodes.iter().zip(lengths.iter()) {
if child_node_id == &root_node_id {
return Err(PhyloErr("Root node cannot have an edge length!".to_string()));
}
self.get_node_mut(*child_node_id).set_parent_distance(*length);
}

self.compute_row_vec()?; // For distance matrix calculation
Ok(())
}

/// Update all edge lengths of a tree to the same value.
pub fn update_all_edge_lengths(&mut self, length: Edge) -> Result<(), PhyloErr> {
let root_node_id = self.root_node()?;
let node_ids = self.node_ids();

for node_id in node_ids {
if node_id != root_node_id {
self.get_node_mut(node_id).set_parent_distance(length);
}
}
self.compute_row_vec()?; // For distance matrix calculation
Ok(())
}

/// Return the root node of the tree.
pub fn root_node(&self) -> Result<NodeId, PhyloErr> {
// Iterate over each node to make sure there is only one root node.
let mut root = None;
for node in &self.nodes {
Expand All @@ -172,13 +225,19 @@ impl PDM {
root = Some(node.id);
}
}

if root.is_none() {
return Err(PhyloErr("No root node detected!".to_string()));
}
Ok(root.unwrap())
}

/// Set the depth of each node in the tree.
pub fn assign_node_depth(&mut self) -> Result<(), PhyloErr> {
// Iterate over each node to make sure there is only one root node.
let root = self.root_node()?;

// Set the depth of all nodes
self.set_node_depth_dfs(root.unwrap())?;
self.set_node_depth_dfs(root)?;
Ok(())
}

Expand Down Expand Up @@ -237,7 +296,7 @@ impl PDM {

/// Wrapper method to calculate the pairwise distances at a given depth.
pub fn calculate_distances_at_depth(&mut self, depth: NodeDepth, row_vec: &mut [f64]) -> Result<(), PhyloErr> {
// Iterate over all nodes a this depth
// Iterate over all nodes at this depth
for &node_id in &self.get_node_idxs_at_depth(depth)? {
let node = self.get_node(node_id);

Expand Down
27 changes: 26 additions & 1 deletion src/python.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use numpy::{PyArray2, ToPyArray};
use numpy::{PyArray1, PyArray2, PyArrayMethods, ToPyArray};
use pyo3::{Py, pyclass, pymethods, pymodule, PyResult, Python, types::PyModule, Bound};
use pyo3::exceptions::PyValueError;

Expand Down Expand Up @@ -46,6 +46,31 @@ impl PhyloDM {
);
}

pub fn update_edge_lengths(&mut self, child_nodes: &Bound<'_, PyArray1<usize>>, lengths: &Bound<'_, PyArray1<f64>>) -> PyResult<()> {

let binding = lengths.to_vec().unwrap();
let new_lengths_vec: Vec<Edge> = binding.iter().map(|x| Edge(*x)).collect();

let child_nodes_binding = child_nodes.to_vec().unwrap();
let child_nodes_vec: Vec<NodeId> = child_nodes_binding.iter().map(|x| NodeId(*x)).collect();

let result = self.tree.update_edge_lengths(&child_nodes_vec, &new_lengths_vec);

if result.is_err() {
return Err(PyValueError::new_err("Unable to update edge lengths."));
}

Ok(())
}

pub fn update_all_edge_lengths(&mut self, length: f64) -> PyResult<()> {
let result = self.tree.update_all_edge_lengths(Edge(length));
if result.is_err() {
return Err(PyValueError::new_err("Unable to update all edge lengths."));
}
Ok(())
}

pub fn get_nodes(&self) -> Vec<usize> {
let mut out = Vec::new();
for node in &self.tree.nodes {
Expand Down
6 changes: 6 additions & 0 deletions src/tree/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ impl Node {
self.parent_distance = Some(length);
}

// Set the distance to parent.
// Can the root have a distance?
pub fn set_parent_distance(&mut self, length: Edge) {
self.parent_distance = Some(length);
}

/// Set the depth of this node.
pub fn set_depth(&mut self, depth: NodeDepth) {
self.depth = Some(depth);
Expand Down
26 changes: 26 additions & 0 deletions test/test_phylodm.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,32 @@ def test_load_from_dendropy_with_trifurication(self):
self.assertTrue(test_tree['taxa'] == tuple(pdm.taxa()))
return

def test_tree_set_branch_lengths(self):
test_tree = get_test_tree(10, trifurication=True)
pdm = PhyloDM.load_from_dendropy(test_tree['tree'])

print('\n\n')
print('default')
dm = pdm.dm(norm=False)
print(dm)
print('pdm length ', pdm.length())
print('test length', test_tree['length'])

print('\nset 1')
num_nodes = len(pdm.get_nodes())
pdm.update_edge_lengths(np.ones((num_nodes)))
dm2 = pdm.dm(norm=False)
print('pdm length ', pdm.length())
print(dm2)

print('\nset 0')
num_nodes = len(pdm.get_nodes())
pdm.update_edge_lengths(np.zeros((num_nodes)))
dm3 = pdm.dm(norm=False)
print('pdm length ', pdm.length())
print(dm3)


# def test_tree_with_bootstraps_from_newick(self):
# with tempfile.TemporaryDirectory() as tmp_dir:
# tmp_dir = Path(tmp_dir)
Expand Down
1 change: 1 addition & 0 deletions tests/test_bl_7.tree
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
((T8:7,T2:7):7,((T9:7,(T10:7,T4:7):7):7,((T6:7,((T3:7,T1:7):7,T7:7):7):7,T5:7):7):7):7;
81 changes: 80 additions & 1 deletion tests/tree.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[cfg(test)]
mod tests {
use phylodm::PDM;
use phylodm::tree::Taxon;
use phylodm::tree::{Taxon, Edge};

#[test]
fn test_tree_dm_twice() {
Expand Down Expand Up @@ -169,4 +169,83 @@ mod tests {
]
);
}

#[test]
fn test_set_lengths() {

// Load in the default tree from test.tree
// Then calculate the distance matrix as arr_normal
let mut tree_normal = PDM::default();
let _ = tree_normal.load_from_newick_path("tests/test.tree");
let (_taxon, arr_normal) = tree_normal.matrix(false).unwrap();
let tree_normal_length = tree_normal.length();

// Load in the modified tree with all branch lengths = 7
// Then calculate the distance matrix as arr_bl7
let mut tree_bl7 = PDM::default();
let _ = tree_bl7.load_from_newick_path("tests/test_bl_7.tree");
let (_taxon, arr_bl7) = tree_bl7.matrix(false).unwrap();
let tree_bl7_length = tree_bl7.length();

// Use the default tree and modify all branch lengths to be 7
// Then calculate the distance matrix as arr_modified_7
let _ = tree_normal.update_all_edge_lengths(Edge(7.0));
let (_taxon, arr_modified_7) = tree_normal.matrix(false).unwrap();
let tree_modified_7_length: Edge = tree_normal.length();

println!("\n\nNormal");
println!("{:?}", arr_normal);
println!("{:?}", tree_normal_length);

println!("Branch lengths 7 (from file)");
println!("{:?}", arr_bl7);
println!("{:?}", tree_bl7_length);

println!("Branch lengths 7 (modified)");
println!("{:?}", arr_modified_7);
println!("{:?}", tree_modified_7_length);

// Somehow the modified tree is 7 units longer that the one read from file.
assert_eq!(arr_bl7, arr_modified_7);
assert_eq!(tree_bl7_length.0, tree_modified_7_length.0);

}


#[test]
fn test_update_edge_lengths() {
let mut tree = PDM::default();

let taxon_b = Taxon("B".to_string());
let taxon_c = Taxon("C".to_string());
let taxon_d = Taxon("D".to_string());

let root_node = tree.add_node(None).unwrap();
let node_a = tree.add_node(None).unwrap();
let node_b = tree.add_node(Some(&taxon_b)).unwrap();
let node_c = tree.add_node(Some(&taxon_c)).unwrap();
let node_d = tree.add_node(Some(&taxon_d)).unwrap();

tree.add_edge(root_node, node_a, Edge(1.0));
tree.add_edge(root_node, node_b, Edge(2.0));
tree.add_edge(node_a, node_c, Edge(3.0));
tree.add_edge(node_a, node_d, Edge(7.0));

tree.compute_row_vec().unwrap();

let dist_b_to_c_before = tree.distance(&taxon_b, &taxon_c, false);

let node_ids = vec![node_a, node_c];
let lengths = vec![Edge(11.0), Edge(12.0)];

let _ = tree.update_edge_lengths(&node_ids, &lengths);

let dist_b_to_c_after = tree.distance(&taxon_b, &taxon_c, false);

assert_eq!(dist_b_to_c_before, 6.0);
assert_eq!(dist_b_to_c_after, 25.0);

}


}

0 comments on commit bed49be

Please sign in to comment.