1use crate::strategy::{CITest, CITestDataType, TestResult};
2use crate::utils::power_divergence::power_divergence;
3
4use ndarray::{Array1, Array2};
5
6const CHI_SQUARED_LAMBDA: f64 = 1.0;
7
8#[derive(Debug, Clone, PartialEq)]
13pub struct ChiSquared {
14 pub boolean: bool,
15 pub significance_level: f64,
16}
17
18impl ChiSquared {
19 #[must_use]
20 pub fn new(boolean: bool, significance_level: f64) -> Self {
21 Self {
22 boolean,
23 significance_level,
24 }
25 }
26}
27
28impl CITest for ChiSquared {
29 fn run_test(
30 &self,
31 x_values: Array1<f64>,
32 y_values: Array1<f64>,
33 z: Array2<f64>,
34 ) -> anyhow::Result<TestResult> {
35 power_divergence(
36 &x_values,
37 &y_values,
38 &z,
39 self.boolean,
40 self.significance_level,
41 CHI_SQUARED_LAMBDA,
42 )
43 }
44
45 fn data_types(&self) -> &'static [CITestDataType] {
46 &[CITestDataType::Discrete]
47 }
48}
49
50#[cfg(test)]
51#[allow(clippy::many_single_char_names)]
52mod tests {
53 use super::*;
54 use crate::utils::EPS;
55 use ndarray::{array, Array2};
56
57 fn unwrap_correlated(r: &TestResult) -> (f64, f64, usize) {
58 match r {
59 TestResult::Statistic(a, b, c) => (*a, *b, *c),
60 _ => panic!("expected Correlated2"),
61 }
62 }
63
64 #[test]
65 fn uncond_independent_data_accepted() {
66 let t = ChiSquared {
67 boolean: false,
68 significance_level: 0.05,
69 };
70 let x = array![1., 1., 2., 2., 1., 1., 2., 2.];
71 let y = array![1., 2., 1., 2., 1., 2., 1., 2.];
72 let empty = Array2::<f64>::zeros((0, 0));
73
74 let (p, stat, dof) = unwrap_correlated(&t.run_test(x, y, empty).unwrap());
75 assert!(stat.abs() < EPS, "stat should be ~0, got {stat}");
76 assert!(p > 0.99);
77 assert_eq!(dof, 1);
78 }
79
80 #[test]
81 fn cond_independent_data_accepted() {
82 let t = ChiSquared {
83 boolean: false,
84 significance_level: 0.05,
85 };
86 let x = array![1., 1., 2., 2., 1., 1., 2., 2.];
87 let y = array![1., 2., 1., 2., 1., 2., 1., 2.];
88 let z = array![[1.], [1.], [1.], [1.], [2.], [2.], [2.], [2.]];
89
90 let (p, stat, dof) = unwrap_correlated(&t.run_test(x, y, z).unwrap());
91 assert!(stat.abs() < EPS, "stat should be ~0, got {stat}");
92 assert!(p > 0.99);
93 assert_eq!(dof, 2);
94 }
95
96 #[test]
97 fn uncond_dependent_data_rejected() {
98 let t = ChiSquared {
99 boolean: false,
100 significance_level: 0.05,
101 };
102 let x = array![1., 1., 1., 1., 2., 2., 2., 2.];
103 let y = array![1., 1., 1., 1., 2., 2., 2., 2.];
104 let empty = Array2::<f64>::zeros((0, 0));
105
106 let (p, stat, dof) = unwrap_correlated(&t.run_test(x, y, empty).unwrap());
107 assert!((stat - 8.0).abs() < EPS, "got {stat}");
108 assert!((p - 0.004_677_734_981_047_276).abs() < EPS, "got {p}");
109 assert_eq!(dof, 1);
110 }
111
112 #[test]
113 fn cond_dependent_data_rejected() {
114 let t = ChiSquared {
115 boolean: false,
116 significance_level: 0.05,
117 };
118 let x = array![1., 1., 2., 2., 1., 1., 2., 2.];
119 let y = array![1., 1., 2., 2., 1., 1., 2., 2.];
120 let z = array![[1.], [1.], [1.], [1.], [2.], [2.], [2.], [2.]];
121
122 let (p, stat, dof) = unwrap_correlated(&t.run_test(x, y, z).unwrap());
123 assert!((stat - 8.0).abs() < EPS, "stat {stat} should be larger");
124 assert!(
125 (p - 0.018_315_638_888_734_193).abs() < EPS,
126 "rejected p value {p}"
127 );
128 assert_eq!(dof, 2);
129 }
130
131 #[test]
132 fn uncond_boolean_mode() {
133 let t = ChiSquared {
134 boolean: true,
135 significance_level: 0.05,
136 };
137 let empty = Array2::<f64>::zeros((0, 0));
138 let x = array![1., 1., 2., 2., 1., 1., 2., 2.];
139 let y = array![1., 2., 1., 2., 1., 2., 1., 2.];
140 let r = t.run_test(x, y, empty.clone()).unwrap();
141 assert!(matches!(r, TestResult::Boolean(true)));
142
143 let x = array![1., 1., 1., 1., 2., 2., 2., 2.];
144 let y = array![1., 1., 1., 1., 2., 2., 2., 2.];
145 let r = t.run_test(x, y, empty).unwrap();
146 assert!(matches!(r, TestResult::Boolean(false)));
147 }
148
149 #[test]
150 fn cond_boolean_mode() {
151 let t = ChiSquared {
152 boolean: true,
153 significance_level: 0.05,
154 };
155 let z = array![[1.], [1.], [1.], [1.], [2.], [2.], [2.], [2.]];
156 let x = array![1., 1., 2., 2., 1., 1., 2., 2.];
157 let y = array![1., 2., 1., 2., 1., 2., 1., 2.];
158 let r = t.run_test(x, y, z).unwrap();
159 assert!(matches!(r, TestResult::Boolean(true)));
160
161 let x = array![1., 1., 2., 2., 1., 1., 2., 2.];
162 let y = array![1., 1., 2., 2., 1., 1., 2., 2.];
163 let z = array![[1.], [1.], [1.], [1.], [2.], [2.], [2.], [2.]];
164 let r = t.run_test(x, y, z).unwrap();
165 assert!(matches!(r, TestResult::Boolean(false)));
166 }
167}