Skip to main content

ci_core/utils/
power_divergence.rs

1use crate::strategy::TestResult;
2use crate::utils::contingency_table::{
3    build_global_category_map, contingency_table, contingency_table_from_indices,
4};
5use crate::utils::contingency_test::contingency_test;
6use crate::utils::partition_indices::partition_indices;
7use anyhow::ensure;
8use ndarray::{Array1, Array2};
9use statrs::distribution::{ChiSquared, ContinuousCDF};
10
11/// Run a power-divergence based conditional independence test.
12///
13/// # Errors
14/// Returns an error if the inputs are invalid or if the underlying contingency
15/// test or chi-squared distribution construction fails.
16pub fn power_divergence(
17    x_values: &Array1<f64>,
18    y_values: &Array1<f64>,
19    z: &Array2<f64>,
20    boolean: bool,
21    significance_level: f64,
22    lambda: f64,
23) -> anyhow::Result<TestResult> {
24    ensure!(
25        x_values.len() == y_values.len(),
26        "x and y must have the same length, got {} and {}",
27        x_values.len(),
28        y_values.len(),
29    );
30    ensure!(
31        z.ncols() == 0 || z.nrows() == x_values.len(),
32        "z must have the same number of rows as x and y ({}), got {}",
33        x_values.len(),
34        z.nrows(),
35    );
36    if z.ncols() == 0 {
37        let table = contingency_table(x_values, y_values);
38        let (statistic, p_value, degrees_of_freedom) = contingency_test(&table, lambda)?;
39        return Ok(wrap_result(
40            boolean,
41            p_value,
42            statistic,
43            degrees_of_freedom,
44            significance_level,
45        ));
46    }
47
48    let x_categories = build_global_category_map(x_values);
49    let y_categories = build_global_category_map(y_values);
50
51    let mut statistic = 0.0;
52    let mut degrees_of_freedom = 0;
53
54    for indices in partition_indices(z) {
55        let table = contingency_table_from_indices(
56            &indices,
57            x_values,
58            y_values,
59            &x_categories,
60            &y_categories,
61        );
62        let Ok((stat, _p, dof)) = contingency_test(&table, lambda) else {
63            continue;
64        };
65
66        if dof == 0 {
67            continue;
68        }
69        statistic += stat;
70        degrees_of_freedom += dof;
71    }
72    let p_value = if degrees_of_freedom == 0 {
73        1.0
74    } else {
75        #[allow(clippy::cast_precision_loss)]
76        ChiSquared::new(degrees_of_freedom as f64)?.sf(statistic)
77    };
78    Ok(wrap_result(
79        boolean,
80        p_value,
81        statistic,
82        degrees_of_freedom,
83        significance_level,
84    ))
85}
86
87fn wrap_result(
88    boolean: bool,
89    p_value: f64,
90    coefficient: f64,
91    degrees_of_freedom: usize,
92    significance_level: f64,
93) -> TestResult {
94    if boolean {
95        return TestResult::Boolean(p_value >= significance_level);
96    }
97    TestResult::Statistic(p_value, coefficient, degrees_of_freedom)
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use crate::utils::EPS;
104    use ndarray::{array, Array2};
105
106    const LAMBDA: f64 = 1.0;
107    const SIGNIFICANCE_LEVEL: f64 = 0.05;
108
109    fn empty_z() -> Array2<f64> {
110        Array2::zeros((0, 0))
111    }
112
113    fn unwrap_statistic(r: &TestResult) -> (f64, f64, usize) {
114        match r {
115            TestResult::Statistic(a, b, c) => (*a, *b, *c),
116            _ => panic!("expected Statistic"),
117        }
118    }
119
120    #[test]
121    fn unconditional_mismatched_lengths_returns_error() {
122        let x = array![1., 2., 3.];
123        let y = array![1., 2.];
124        let result = power_divergence(&x, &y, &empty_z(), false, SIGNIFICANCE_LEVEL, LAMBDA);
125        assert!(result.is_err());
126    }
127
128    // A single observation produces a 1x1 table with dof=0.
129    #[test]
130    fn unconditional_single_element_has_zero_dof() {
131        let x = array![1.];
132        let y = array![1.];
133        let result =
134            power_divergence(&x, &y, &empty_z(), false, SIGNIFICANCE_LEVEL, LAMBDA).unwrap();
135        let (p, stat, dof) = unwrap_statistic(&result);
136        assert_eq!(dof, 0);
137        assert!((p - 1.0).abs() < EPS);
138        assert!(stat.abs() < EPS);
139    }
140
141    // When X has only one distinct value the table has one row → dof=0.
142    #[test]
143    fn unconditional_single_category_in_x_has_zero_dof() {
144        let x = array![1., 1., 1., 1.];
145        let y = array![1., 2., 1., 2.];
146        let result =
147            power_divergence(&x, &y, &empty_z(), false, SIGNIFICANCE_LEVEL, LAMBDA).unwrap();
148        let (p, _stat, dof) = unwrap_statistic(&result);
149        assert_eq!(dof, 0);
150        assert!((p - 1.0).abs() < EPS);
151    }
152
153    // When Y has only one distinct value the table has one column → dof=0.
154    #[test]
155    fn unconditional_single_category_in_y_has_zero_dof() {
156        let x = array![1., 2., 1., 2.];
157        let y = array![1., 1., 1., 1.];
158        let result =
159            power_divergence(&x, &y, &empty_z(), false, SIGNIFICANCE_LEVEL, LAMBDA).unwrap();
160        let (p, _stat, dof) = unwrap_statistic(&result);
161        assert_eq!(dof, 0);
162        assert!((p - 1.0).abs() < EPS);
163    }
164
165    // NaN in x_values becomes its own category via OrderedFloat.
166    // The result is a valid (though likely meaningless) statistic, not a panic.
167    #[test]
168    fn unconditional_nan_in_x_does_not_panic() {
169        let x = array![1., f64::NAN, 2., 1.];
170        let y = array![1., 2., 1., 2.];
171        let result = power_divergence(&x, &y, &empty_z(), false, SIGNIFICANCE_LEVEL, LAMBDA);
172        assert!(result.is_ok());
173    }
174
175    // NaN in the conditioning set creates its own partition group.
176    // Global category maps include both X categories, so a singleton
177    // group gets a table with a zero-expected cell. That group is
178    // skipped, so the overall result is valid (p=1 when all groups are skipped).
179    #[test]
180    fn conditional_nan_in_z_skips_bad_groups() {
181        let x = array![1., 1., 2., 2.];
182        let y = array![1., 2., 1., 2.];
183        let z = Array2::from_shape_vec((4, 1), vec![0., f64::NAN, 0., 1.]).unwrap();
184        let result = power_divergence(&x, &y, &z, false, SIGNIFICANCE_LEVEL, LAMBDA).unwrap();
185        let (p, _stat, dof) = unwrap_statistic(&result);
186        assert_eq!(dof, 0);
187        assert!((p - 1.0).abs() < EPS);
188    }
189
190    // Each Z-group has only one X value. Global category maps force a 2x2
191    // table per group with a zero row → zero expected frequency. All
192    // groups are skipped, yielding dof=0 and p=1.
193    #[test]
194    fn conditional_all_groups_single_category_skips_all() {
195        let x = array![1., 1., 2., 2.];
196        let y = array![1., 2., 1., 2.];
197        let z = Array2::from_shape_vec((4, 1), vec![0., 0., 1., 1.]).unwrap();
198        let result = power_divergence(&x, &y, &z, false, SIGNIFICANCE_LEVEL, LAMBDA).unwrap();
199        let (p, _stat, dof) = unwrap_statistic(&result);
200        assert_eq!(dof, 0);
201        assert!((p - 1.0).abs() < EPS);
202    }
203}