Skip to main content

ci_core/utils/
contingency_test.rs

1use crate::utils::EPS;
2use anyhow::{bail, Result};
3
4const MODIFIED_LIKELIHOOD_LAMBDA: f64 = -1.0;
5const FREEMAN_TUKEY_LAMBDA: f64 = -0.5;
6use ndarray::{Array1, Array2, Axis};
7use statrs::distribution::{ChiSquared, ContinuousCDF};
8
9/// Compute a Cressie-Read power-divergence statistic for a contingency table.
10///
11/// # Errors
12/// Returns an error when the table is empty, contains negatives, sums to zero,
13/// or has a zero expected frequency.
14pub fn contingency_test(observed: &Array2<f64>, lambda: f64) -> Result<(f64, f64, usize)> {
15    // Check whether contingency test is applicable
16    if observed.is_empty() {
17        bail!("No data; `observed` has size 0.");
18    }
19    if observed.iter().any(|&x| x < 0.0) {
20        bail!("All values in `observed` must be nonnegative.");
21    }
22
23    let (nrows, ncols) = observed.dim();
24    let row_sums = observed.sum_axis(Axis(1));
25    let col_sums = observed.sum_axis(Axis(0));
26    let total: f64 = row_sums.sum();
27    if total == 0.0 {
28        bail!("Total sum of observed frequencies must be > 0.");
29    }
30    let inverse_total = 1.0 / total;
31
32    let col_times_total = &col_sums * inverse_total;
33    let ln_total = total.ln();
34    let ln_row_sums = row_sums.mapv(f64::ln);
35    let ln_col_sums = col_sums.mapv(f64::ln);
36
37    let statistic: f64 = if lambda.abs() < EPS {
38        g_test(
39            observed,
40            &row_sums,
41            &col_times_total,
42            ln_total,
43            &ln_row_sums,
44            &ln_col_sums,
45        )?
46    } else if (lambda - MODIFIED_LIKELIHOOD_LAMBDA).abs() < EPS {
47        modified_log_likelihood_ratio_test(
48            observed,
49            &row_sums,
50            &col_times_total,
51            ln_total,
52            &ln_row_sums,
53            &ln_col_sums,
54        )?
55    } else if (lambda - FREEMAN_TUKEY_LAMBDA).abs() < EPS {
56        freeman_tukey(lambda, observed, &row_sums, &col_times_total)?
57    } else {
58        cressie_read(lambda, observed, &row_sums, &col_times_total)?
59    };
60
61    let degrees_of_freedom = if nrows < 2 || ncols < 2 {
62        0
63    } else {
64        (nrows - 1) * (ncols - 1)
65    };
66
67    let p_value = if degrees_of_freedom == 0 {
68        1.0
69    } else {
70        #[allow(clippy::cast_precision_loss)]
71        ChiSquared::new(degrees_of_freedom as f64)?.sf(statistic)
72    };
73
74    Ok((statistic, p_value, degrees_of_freedom))
75}
76
77/// # Errors
78///
79/// Returns an error if an expected frequency evaluates to zero, or if the
80/// calculated statistic results in a mathematically impossible value.
81pub fn g_test(
82    observed: &Array2<f64>,
83    row_sums: &Array1<f64>,
84    col_times_total: &Array1<f64>,
85    ln_total: f64,
86    ln_row_sums: &Array1<f64>,
87    ln_col_sums: &Array1<f64>,
88) -> anyhow::Result<f64> {
89    // G-test: 2 * sum(O * ln(O / E))
90    let (nrows, ncols) = observed.dim();
91    let mut temp_stat: f64 = 0.0;
92    for i in 0..nrows {
93        for j in 0..ncols {
94            let temp_expected: f64 = row_sums[i] * col_times_total[j]; // division is worse than multiplication in rust
95            let temp_observed = observed[[i, j]];
96            if temp_expected == 0. {
97                bail!("Expected frequency is zero at position [{i}, {j}]");
98            }
99            if temp_observed == 0. {
100                continue;
101            }
102            temp_stat +=
103                temp_observed * (temp_observed.ln() + ln_total - ln_row_sums[i] - ln_col_sums[j]);
104            //used logarithmic rules to get rid of division and mulitplication
105        }
106    }
107    let final_stat = 2.0 * temp_stat;
108    if final_stat < -EPS {
109        bail!("Statistic evaluated to {final_stat}, which should be impossible.");
110        //make sure it bails when the negative number is not negligibly small
111    }
112    Ok(final_stat.max(0.0))
113}
114
115/// # Errors
116///
117/// Returns an error if an expected frequency evaluates to zero, or if the
118/// calculated statistic results in a mathematically impossible value.
119pub fn modified_log_likelihood_ratio_test(
120    observed: &Array2<f64>,
121    row_sums: &Array1<f64>,
122    col_times_total: &Array1<f64>,
123    ln_total: f64,
124    ln_row_sums: &Array1<f64>,
125    ln_col_sums: &Array1<f64>,
126) -> anyhow::Result<f64> {
127    let (nrows, ncols) = observed.dim();
128    let mut temp_stat: f64 = 0.0;
129    for i in 0..nrows {
130        for j in 0..ncols {
131            let temp_expected: f64 = row_sums[i] * col_times_total[j]; //multiplication instead of division
132            let temp_observed = observed[[i, j]];
133            if temp_expected == 0. {
134                continue;
135            }
136            if temp_observed == 0. {
137                // log of temp_observed will be -infinity, causing temp_stat to become infinity.
138                // However the math still works, so this is not an error.
139                return Ok(f64::INFINITY);
140            }
141            temp_stat +=
142                temp_expected * (ln_row_sums[i] + ln_col_sums[j] - temp_observed.ln() - ln_total);
143        }
144    }
145    let final_stat = 2.0 * temp_stat;
146    if final_stat < -EPS {
147        bail!("Statistic evaluated to {final_stat}, which should be impossible.");
148    }
149    Ok(final_stat.max(0.0))
150}
151
152/// # Errors
153///
154/// Returns an error if an expected frequency evaluates to zero, or if the
155/// calculated statistic results in a mathematically impossible value.
156pub fn freeman_tukey(
157    lambda: f64,
158    observed: &Array2<f64>,
159    row_sums: &Array1<f64>,
160    col_times_total: &Array1<f64>,
161) -> anyhow::Result<f64> {
162    let (nrows, ncols) = observed.dim();
163    let mut temp_stat: f64 = 0.0;
164    for i in 0..nrows {
165        for j in 0..ncols {
166            let temp_expected: f64 = row_sums[i] * col_times_total[j]; //again the multiplication instead of division
167            let temp_observed = observed[[i, j]];
168            if temp_expected == 0.0 {
169                bail!("Expected frequency is zero at position [{i}, {j}]");
170            }
171            temp_stat += temp_observed.sqrt() * temp_expected.sqrt() - temp_observed;
172        }
173    }
174    let final_stat = (2.0 * temp_stat) / (lambda * (lambda + 1.0));
175    if final_stat < -EPS {
176        bail!("Statistic evaluated to {final_stat}, which should be impossible.");
177    }
178    Ok(final_stat.max(0.0))
179}
180
181/// # Errors
182///
183/// Returns an error if an expected frequency evaluates to zero, or if the
184/// calculated statistic results in a mathematically impossible value.
185pub fn cressie_read(
186    lambda: f64,
187    observed: &Array2<f64>,
188    row_sums: &Array1<f64>,
189    col_times_total: &Array1<f64>,
190) -> anyhow::Result<f64> {
191    let (nrows, ncols) = observed.dim();
192    let mut temp_stat: f64 = 0.0;
193    for i in 0..nrows {
194        for j in 0..ncols {
195            let temp_expected: f64 = row_sums[i] * col_times_total[j]; //again the multiplication instead of division
196            let temp_observed = observed[[i, j]];
197            if temp_expected == 0.0 {
198                bail!("Expected frequency is zero at position [{i}, {j}]");
199            }
200            temp_stat += temp_observed * ((temp_observed / temp_expected).powf(lambda) - 1.0);
201        }
202    }
203    let final_stat = (2.0 * temp_stat) / (lambda * (lambda + 1.0));
204    if final_stat < -EPS {
205        bail!("Statistic evaluated to {final_stat}, which should be impossible.");
206    }
207    Ok(final_stat.max(0.0))
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use ndarray::array;
214
215    #[test]
216    fn test_empty_table_error() {
217        let observed = Array2::<f64>::zeros((0, 0));
218        let result = contingency_test(&observed, 0.0);
219        assert!(result.is_err());
220        assert!(result.unwrap_err().to_string().contains("size 0"));
221    }
222
223    #[test]
224    fn test_negative_values_error() {
225        let observed = array![[1.0, -1.0], [2.0, 3.0]];
226        let result = contingency_test(&observed, 0.0);
227        assert!(result.is_err());
228        assert!(result.unwrap_err().to_string().contains("nonnegative"));
229    }
230
231    #[test]
232    fn test_zero_total_error() {
233        let observed = array![[0.0, 0.0], [0.0, 0.0]];
234        let result = contingency_test(&observed, 0.0);
235        assert!(result.is_err());
236        assert!(result.unwrap_err().to_string().contains("> 0"));
237    }
238
239    #[test]
240    fn test_zero_expected_frequency_error() {
241        let observed = array![[10.0, 0.0], [0.0, 0.0]];
242        let result = contingency_test(&observed, 1.0);
243        assert!(result.is_err());
244        assert!(result
245            .unwrap_err()
246            .to_string()
247            .contains("Expected frequency is zero"));
248    }
249
250    #[test]
251    fn test_g_test_zero_expected_frequency() {
252        let observed = array![[5.0, 0.0], [0.0, 0.0]];
253        let result = contingency_test(&observed, 0.0);
254        assert!(result.is_err());
255        assert!(result
256            .unwrap_err()
257            .to_string()
258            .contains("Expected frequency is zero"));
259    }
260
261    #[test]
262    fn test_modified_log_likelihood_zero_observed() {
263        let observed = array![[5.0, 1.0], [2.0, 0.0]]; // contains zero observed
264        let (statistic, p_value, _dof) = contingency_test(&observed, -1.0).unwrap();
265
266        assert!(statistic.is_infinite() && p_value < EPS);
267    }
268
269    #[test]
270    fn test_g_test_valid() {
271        let observed = array![[10.0, 20.0], [20.0, 40.0]];
272        let result = contingency_test(&observed, 0.0).unwrap();
273
274        let (stat, p, dof) = result;
275        assert!(stat >= 0.0);
276        assert!((0.0..=1.0).contains(&p));
277        assert_eq!(dof, 1);
278    }
279
280    #[test]
281    fn test_modified_log_likelihood_valid() {
282        let observed = array![[10.0, 20.0], [20.0, 40.0]];
283        let result = contingency_test(&observed, -1.0).unwrap();
284
285        let (stat, p, dof) = result;
286        assert!(stat >= 0.0);
287        assert!((0.0..=1.0).contains(&p));
288        assert_eq!(dof, 1);
289    }
290
291    #[test]
292    fn test_freeman_tukey_valid() {
293        let observed = array![[5.0, 1.0], [1.0, 5.0]];
294        let result = contingency_test(&observed, -0.5).unwrap();
295
296        let (stat, p, dof) = result;
297        assert!(stat >= 0.0);
298        assert!((stat - 6.319_453_539_579_289).abs() < EPS); // Validated against scipy
299        assert!((0.0..=1.0).contains(&p));
300        assert_eq!(dof, 1);
301    }
302
303    #[test]
304    fn test_freeman_tukey_zero_observed() {
305        let observed = array![[2.0, 1.0], [0.0, 3.0]]; // contains zero observed
306        let result = contingency_test(&observed, -0.5);
307
308        assert!(result.is_ok());
309        let (stat, p, dof) = result.unwrap();
310
311        assert!(stat >= 0.0);
312        assert!((0.0..=1.0).contains(&p));
313        assert_eq!(dof, 1);
314
315        // Manual calculation
316        let expected_stat = 4.0
317            * ((2.0f64.sqrt() - 1.0).powi(2) +
318            (1.0 - 2.0f64.sqrt()).powi(2) +
319            1.0 + // (0.0 - sqrt(1.0))^2
320            (3.0f64.sqrt() - 2.0f64.sqrt()).powi(2));
321        assert!((stat - expected_stat).abs() < EPS);
322    }
323
324    #[test]
325    fn test_cressie_read_valid() {
326        let observed = array![[10.0, 20.0], [20.0, 40.0]];
327        let result = contingency_test(&observed, 0.5).unwrap();
328
329        let (stat, p, dof) = result;
330        assert!(stat >= 0.0);
331        assert!((0.0..=1.0).contains(&p));
332        assert_eq!(dof, 1);
333    }
334
335    #[test]
336    fn test_degrees_of_freedom_2x2() {
337        let observed = array![[1.0, 2.0], [3.0, 4.0]];
338        let (_, _, dof) = contingency_test(&observed, 0.0).unwrap();
339        assert_eq!(dof, 1); // (2-1)*(2-1)
340    }
341
342    #[test]
343    fn test_degrees_of_freedom_3x4() {
344        let observed = array![
345            [1.0, 2.0, 3.0, 4.0],
346            [2.0, 3.0, 4.0, 5.0],
347            [3.0, 4.0, 5.0, 6.0]
348        ];
349        let (_, _, dof) = contingency_test(&observed, 0.0).unwrap();
350        assert_eq!(dof, (3 - 1) * (4 - 1)); // 2 * 3 = 6
351    }
352
353    #[test]
354    fn test_degrees_of_freedom_degenerate() {
355        let observed = array![[1.0, 2.0, 3.0]]; // 1x3
356        let (_, p, dof) = contingency_test(&observed, 0.0).unwrap();
357
358        assert_eq!(dof, 0);
359        assert!((p - 1.0).abs() < 1e12); // By definition in your implementation
360    }
361}