Skip to main content

ci_core/utils/
partition_indices.rs

1use ndarray::{Array2, Axis};
2use ordered_float::OrderedFloat;
3use std::collections::HashMap;
4
5/// Partition the rows of `data` (indexed as `data[row][col]`) into groups
6/// that share the same combination of column values.
7///
8/// Returns `indices[partition][i]`, where each inner `Vec` holds the row
9/// indices that belong to that partition. The order of partitions and the
10/// order of indices within a partition are both unspecified.
11#[must_use]
12pub fn partition_indices(data: &Array2<f64>) -> Vec<Vec<usize>> {
13    let mut groups: HashMap<Vec<OrderedFloat<f64>>, Vec<usize>> = HashMap::new();
14    for (i, row) in data.axis_iter(Axis(0)).enumerate() {
15        let key: Vec<OrderedFloat<f64>> = row.iter().map(|&v| OrderedFloat(v)).collect();
16        groups.entry(key).or_default().push(i);
17    }
18    groups.into_values().collect()
19}
20
21#[cfg(test)]
22mod tests {
23    use super::*;
24    use ndarray::array;
25    use std::collections::HashSet;
26
27    #[test]
28    fn simple_grouping() {
29        let data = array![[1.0, 10.0], [1.0, 10.0], [2.0, 30.0],];
30
31        let result = partition_indices(&data);
32
33        assert_eq!(result.len(), 2);
34    }
35
36    #[test]
37    fn test_empty_array() {
38        let data: Array2<f64> = Array2::zeros((0, 2));
39        let result = partition_indices(&data);
40
41        assert!(result.is_empty());
42    }
43
44    #[test]
45    fn test_singleton_groups() {
46        let data = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0],];
47
48        let result = partition_indices(&data);
49
50        assert_eq!(result.len(), 3);
51        assert!(result.iter().all(|g| g.len() == 1));
52    }
53
54    #[test]
55    fn test_single_partition() {
56        let data = array![[1.0, 10.0], [1.0, 10.0], [1.0, 10.0],];
57
58        let result = partition_indices(&data);
59
60        assert_eq!(result.len(), 1);
61        let mut group = result[0].clone();
62        group.sort_unstable();
63        assert_eq!(group, vec![0, 1, 2]);
64    }
65
66    #[test]
67    fn float_rounding() {
68        let data = array![[1.0, 10.0], [1.000_000_000_1, 10.0],];
69
70        let result = partition_indices(&data);
71
72        //To test whether it rounds it up
73        assert_eq!(result.len(), 2);
74    }
75
76    #[test]
77    fn multiple_columns() {
78        let data = array![
79            [1.0, 2.0, 10.0],
80            [1.0, 2.0, 10.0],
81            [1.0, 3.0, 30.0],
82            [2.0, 2.0, 40.0],
83        ];
84
85        let result = partition_indices(&data);
86        assert_eq!(result.len(), 3);
87
88        // Find the group that has 2 rows (the [1.0, 2.0, 10.0] group)
89        assert!(result.iter().any(|g| g.len() == 2));
90    }
91
92    #[test]
93    fn order_dependence() {
94        let data = array![[1.0, 2.0], [2.0, 1.0], [2.0, 1.0], [1.0, 2.0]];
95        let expected: HashSet<Vec<usize>> = [vec![0, 3], vec![1, 2]].into_iter().collect();
96        let result: HashSet<Vec<usize>> = partition_indices(&data).into_iter().collect();
97
98        assert_eq!(expected, result);
99    }
100
101    #[test]
102    fn zero_column_rows() {
103        let data: Array2<f64> = Array2::zeros((3, 0));
104        let result = partition_indices(&data);
105        assert_eq!(result.len(), 1);
106        let mut group = result[0].clone();
107        group.sort_unstable();
108        assert_eq!(group, vec![0, 1, 2]);
109    }
110
111    #[test]
112    fn single_column_rows() {
113        let data = array![[1.0], [2.0], [1.0]];
114        let result = partition_indices(&data);
115        assert_eq!(result.len(), 2);
116        let result_set: HashSet<Vec<usize>> = result.into_iter().collect();
117        let expected: HashSet<Vec<usize>> = [vec![0, 2], vec![1]].into_iter().collect();
118        assert_eq!(result_set, expected);
119    }
120
121    #[test]
122    fn negative_values() {
123        let data = array![
124            [-1.0, 2.0],
125            [0.0, -0.0],
126            [0.0, 0.0],
127            [-1.0, 2.0],
128            [-1.0, -2.0],
129        ];
130
131        let expected: HashSet<Vec<usize>> = [vec![0, 3], vec![1, 2], vec![4]].into_iter().collect();
132        let result: HashSet<Vec<usize>> = partition_indices(&data).into_iter().collect();
133
134        assert_eq!(expected, result);
135    }
136}