1use crate::strategy::TestResult;
2use crate::utils::contingency_table::{
3 build_global_category_map, contingency_table, contingency_table_from_indices,
4};
5use crate::utils::contingency_test::contingency_test;
6use crate::utils::partition_indices::partition_indices;
7use anyhow::ensure;
8use ndarray::{Array1, Array2};
9use statrs::distribution::{ChiSquared, ContinuousCDF};
10
11pub fn power_divergence(
17 x_values: &Array1<f64>,
18 y_values: &Array1<f64>,
19 z: &Array2<f64>,
20 boolean: bool,
21 significance_level: f64,
22 lambda: f64,
23) -> anyhow::Result<TestResult> {
24 ensure!(
25 x_values.len() == y_values.len(),
26 "x and y must have the same length, got {} and {}",
27 x_values.len(),
28 y_values.len(),
29 );
30 ensure!(
31 z.ncols() == 0 || z.nrows() == x_values.len(),
32 "z must have the same number of rows as x and y ({}), got {}",
33 x_values.len(),
34 z.nrows(),
35 );
36 if z.ncols() == 0 {
37 let table = contingency_table(x_values, y_values);
38 let (statistic, p_value, degrees_of_freedom) = contingency_test(&table, lambda)?;
39 return Ok(wrap_result(
40 boolean,
41 p_value,
42 statistic,
43 degrees_of_freedom,
44 significance_level,
45 ));
46 }
47
48 let x_categories = build_global_category_map(x_values);
49 let y_categories = build_global_category_map(y_values);
50
51 let mut statistic = 0.0;
52 let mut degrees_of_freedom = 0;
53
54 for indices in partition_indices(z) {
55 let table = contingency_table_from_indices(
56 &indices,
57 x_values,
58 y_values,
59 &x_categories,
60 &y_categories,
61 );
62 let Ok((stat, _p, dof)) = contingency_test(&table, lambda) else {
63 continue;
64 };
65
66 if dof == 0 {
67 continue;
68 }
69 statistic += stat;
70 degrees_of_freedom += dof;
71 }
72 let p_value = if degrees_of_freedom == 0 {
73 1.0
74 } else {
75 #[allow(clippy::cast_precision_loss)]
76 ChiSquared::new(degrees_of_freedom as f64)?.sf(statistic)
77 };
78 Ok(wrap_result(
79 boolean,
80 p_value,
81 statistic,
82 degrees_of_freedom,
83 significance_level,
84 ))
85}
86
87fn wrap_result(
88 boolean: bool,
89 p_value: f64,
90 coefficient: f64,
91 degrees_of_freedom: usize,
92 significance_level: f64,
93) -> TestResult {
94 if boolean {
95 return TestResult::Boolean(p_value >= significance_level);
96 }
97 TestResult::Statistic(p_value, coefficient, degrees_of_freedom)
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use crate::utils::EPS;
104 use ndarray::{array, Array2};
105
106 const LAMBDA: f64 = 1.0;
107 const SIGNIFICANCE_LEVEL: f64 = 0.05;
108
109 fn empty_z() -> Array2<f64> {
110 Array2::zeros((0, 0))
111 }
112
113 fn unwrap_statistic(r: &TestResult) -> (f64, f64, usize) {
114 match r {
115 TestResult::Statistic(a, b, c) => (*a, *b, *c),
116 _ => panic!("expected Statistic"),
117 }
118 }
119
120 #[test]
121 fn unconditional_mismatched_lengths_returns_error() {
122 let x = array![1., 2., 3.];
123 let y = array![1., 2.];
124 let result = power_divergence(&x, &y, &empty_z(), false, SIGNIFICANCE_LEVEL, LAMBDA);
125 assert!(result.is_err());
126 }
127
128 #[test]
130 fn unconditional_single_element_has_zero_dof() {
131 let x = array![1.];
132 let y = array![1.];
133 let result =
134 power_divergence(&x, &y, &empty_z(), false, SIGNIFICANCE_LEVEL, LAMBDA).unwrap();
135 let (p, stat, dof) = unwrap_statistic(&result);
136 assert_eq!(dof, 0);
137 assert!((p - 1.0).abs() < EPS);
138 assert!(stat.abs() < EPS);
139 }
140
141 #[test]
143 fn unconditional_single_category_in_x_has_zero_dof() {
144 let x = array![1., 1., 1., 1.];
145 let y = array![1., 2., 1., 2.];
146 let result =
147 power_divergence(&x, &y, &empty_z(), false, SIGNIFICANCE_LEVEL, LAMBDA).unwrap();
148 let (p, _stat, dof) = unwrap_statistic(&result);
149 assert_eq!(dof, 0);
150 assert!((p - 1.0).abs() < EPS);
151 }
152
153 #[test]
155 fn unconditional_single_category_in_y_has_zero_dof() {
156 let x = array![1., 2., 1., 2.];
157 let y = array![1., 1., 1., 1.];
158 let result =
159 power_divergence(&x, &y, &empty_z(), false, SIGNIFICANCE_LEVEL, LAMBDA).unwrap();
160 let (p, _stat, dof) = unwrap_statistic(&result);
161 assert_eq!(dof, 0);
162 assert!((p - 1.0).abs() < EPS);
163 }
164
165 #[test]
168 fn unconditional_nan_in_x_does_not_panic() {
169 let x = array![1., f64::NAN, 2., 1.];
170 let y = array![1., 2., 1., 2.];
171 let result = power_divergence(&x, &y, &empty_z(), false, SIGNIFICANCE_LEVEL, LAMBDA);
172 assert!(result.is_ok());
173 }
174
175 #[test]
180 fn conditional_nan_in_z_skips_bad_groups() {
181 let x = array![1., 1., 2., 2.];
182 let y = array![1., 2., 1., 2.];
183 let z = Array2::from_shape_vec((4, 1), vec![0., f64::NAN, 0., 1.]).unwrap();
184 let result = power_divergence(&x, &y, &z, false, SIGNIFICANCE_LEVEL, LAMBDA).unwrap();
185 let (p, _stat, dof) = unwrap_statistic(&result);
186 assert_eq!(dof, 0);
187 assert!((p - 1.0).abs() < EPS);
188 }
189
190 #[test]
194 fn conditional_all_groups_single_category_skips_all() {
195 let x = array![1., 1., 2., 2.];
196 let y = array![1., 2., 1., 2.];
197 let z = Array2::from_shape_vec((4, 1), vec![0., 0., 1., 1.]).unwrap();
198 let result = power_divergence(&x, &y, &z, false, SIGNIFICANCE_LEVEL, LAMBDA).unwrap();
199 let (p, _stat, dof) = unwrap_statistic(&result);
200 assert_eq!(dof, 0);
201 assert!((p - 1.0).abs() < EPS);
202 }
203}