1use crate::ci_tests::PearsonCorrelation;
2use crate::strategy::{CITest, CITestDataType, TestResult};
3use crate::utils::EPS;
4
5const FISHER_Z_DOF_OFFSET: usize = 3;
6use anyhow::bail;
7use ndarray::{Array1, Array2, Axis};
8use statrs::distribution::{ContinuousCDF, Normal};
9
10#[derive(Debug, Clone, PartialEq)]
20pub struct PearsonEquivalence {
21 pub boolean: bool,
22 pub significance_level: f64,
23 pub delta_threshold: f64,
24}
25
26impl PearsonEquivalence {
27 #[must_use]
28 pub fn new(boolean: bool, significance_level: f64, delta_threshold: f64) -> Self {
29 Self {
30 boolean,
31 significance_level,
32 delta_threshold,
33 }
34 }
35}
36
37impl CITest for PearsonEquivalence {
38 fn run_test(
39 &self,
40 x_values: Array1<f64>,
41 y_values: Array1<f64>,
42 z: Array2<f64>,
43 ) -> anyhow::Result<TestResult> {
44 let n = x_values.len();
45 let s = z.axis_iter(Axis(1)).len();
46
47 let pearsonr = PearsonCorrelation {
48 boolean: false,
49 significance_level: self.significance_level,
50 }
51 .run_test(x_values, y_values, z);
52 let statistic = match pearsonr {
53 Ok(TestResult::PValue(_, statistic)) => statistic,
54 Ok(_) => 0.0,
55 Err(e) => return Err(e),
56 };
57 let rho = if statistic <= -1.0 {
58 -1.0 + EPS
59 } else if statistic >= 1.0 {
60 1.0 - EPS
61 } else {
62 statistic
63 };
64
65 let coefficient = rho.atanh();
66 let z_delta = self.delta_threshold.atanh();
67
68 #[allow(
69 clippy::cast_precision_loss,
70 reason = "array length and number of variables most likely won't exceed 2^53"
71 )]
72 let argument = (n - s - FISHER_Z_DOF_OFFSET) as f64;
73 let std_error_factor = if argument >= 0.0 {
74 argument.sqrt()
75 } else {
76 bail!("The length of the data should be at least 3 greater than the number of conditional variables");
77 };
78
79 let normal = Normal::new(0.0, 1.0).unwrap();
80
81 let z_score_lower = std_error_factor * (coefficient + z_delta);
82 let z_score_upper = std_error_factor * (coefficient - z_delta);
83
84 let p_value_lower = 1.0 - normal.cdf(z_score_lower);
85 let p_value_upper = normal.cdf(z_score_upper);
86
87 let p_value = if p_value_lower > p_value_upper {
88 p_value_lower
89 } else {
90 p_value_upper
91 };
92
93 Ok(wrap_result(
94 self.boolean,
95 p_value,
96 coefficient,
97 self.significance_level,
98 ))
99 }
100
101 fn data_types(&self) -> &'static [CITestDataType] {
102 &[CITestDataType::Continuous]
103 }
104}
105
106#[must_use]
107pub fn wrap_result(
108 boolean: bool,
109 p_value: f64,
110 coefficient: f64,
111 significance_level: f64,
112) -> TestResult {
113 if boolean {
114 return TestResult::Boolean(p_value < significance_level);
115 }
116 TestResult::PValue(p_value, coefficient)
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use ndarray::{array, stack, Array1, Array2, Axis};
123 use rand::rngs::SmallRng;
124 use rand::SeedableRng;
125 use rand_distr::{Distribution, Normal};
126
127 const SIGNIFICANCE_LEVEL: f64 = 0.05;
128 const DELTA_THRESHOLD: f64 = 0.1;
129 const PGMPY_EPS: f64 = 1e-8;
131
132 const N: usize = 1000;
133
134 #[test]
135 fn basic_test() {
136 let x_vals = array![1.0, 2.0, 3.0, 4.0];
137 let y_vals = array![1.0, 1.0, 2.0, 2.0];
138 let empty_z = array![[]];
139
140 let test = PearsonEquivalence {
141 boolean: false,
142 significance_level: 0.05,
143 delta_threshold: DELTA_THRESHOLD,
144 };
145 let result = test.run_test(x_vals, y_vals, empty_z);
146
147 let (p_value, statistic) = match result {
148 Ok(TestResult::PValue(a, b)) => (a, b),
149 _ => (0.0, 0.0),
150 };
151
152 assert!((p_value - 0.910_412_594_569_001_1).abs() < PGMPY_EPS);
154 assert!((statistic - 1.443_635_475_178_810_7).abs() < PGMPY_EPS);
155 }
156
157 fn seeded_rng() -> SmallRng {
158 SmallRng::seed_from_u64(40)
159 }
160
161 fn gen_normal(n: usize, mean: f64, std_dev: f64, rng: &mut SmallRng) -> Array1<f64> {
162 let dist = Normal::new(mean, std_dev).unwrap();
163 Array1::from_vec((0..n).map(|_| dist.sample(rng)).collect())
164 }
165
166 fn empty_array() -> Array2<f64> {
167 Array2::zeros((0, 0))
168 }
169
170 fn pearson() -> PearsonEquivalence {
171 PearsonEquivalence {
172 boolean: false,
173 significance_level: 0.05,
174 delta_threshold: DELTA_THRESHOLD,
175 }
176 }
177
178 fn pearson_boolean() -> PearsonEquivalence {
179 PearsonEquivalence {
180 boolean: true,
181 significance_level: 0.05,
182 delta_threshold: DELTA_THRESHOLD,
183 }
184 }
185
186 #[test]
187 fn unconditional_independent_data_is_not_rejected() {
188 let mut rng = seeded_rng();
189 let x = gen_normal(N, 0.0, 1.0, &mut rng);
190 let y = gen_normal(N, 0.0, 1.0, &mut rng);
191
192 let result = pearson().run_test(x, y, empty_array()).unwrap();
193 match result {
194 TestResult::PValue(p_value, coefficient) => {
195 assert!(
196 p_value <= SIGNIFICANCE_LEVEL,
197 "p_value {p_value} should be <= 0.05 for independent data"
198 );
199 assert!(
200 coefficient.abs() < DELTA_THRESHOLD,
201 "coefficient {coefficient} should be near 0 for independent data"
202 );
203 }
204 _ => panic!("Expected TestResult::PValue"),
205 }
206 }
207
208 #[test]
209 fn unconditional_boolean_accepts_independent() {
210 let mut rng = seeded_rng();
211 let x = gen_normal(N, 0.0, 1.0, &mut rng);
212 let y = gen_normal(N, 0.0, 1.0, &mut rng);
213
214 let result = pearson_boolean().run_test(x, y, empty_array()).unwrap();
215 match result {
216 TestResult::Boolean(independent) => {
217 assert!(independent, "Independent data should return true");
218 }
219 _ => panic!("Expected TestResult::Boolean"),
220 }
221 }
222
223 #[test]
224 fn unconditional_dependent_data_is_rejected() {
225 let mut rng = seeded_rng();
226 let x = gen_normal(N, 0.0, 1.0, &mut rng);
227 let noise = gen_normal(N, 0.0, 0.1, &mut rng);
228 let y = &x * 3.0 + &noise;
229
230 let result = pearson().run_test(x, y, empty_array()).unwrap();
231 match result {
232 TestResult::PValue(p_value, coefficient) => {
233 assert!(
234 p_value >= SIGNIFICANCE_LEVEL,
235 "p_value {p_value} should be >= 0.05 for correlated data"
236 );
237 assert!(
238 coefficient.abs() > 0.9,
239 "coefficient {coefficient} should be high for correlated data"
240 );
241 }
242 _ => panic!("Expected TestResult::PValue"),
243 }
244 }
245
246 #[test]
247 fn unconditional_boolean_rejects_dependent() {
248 let mut rng = seeded_rng();
249 let x = gen_normal(N, 0.0, 1.0, &mut rng);
250 let noise = gen_normal(N, 0.0, 0.1, &mut rng);
251 let y = &x * 3.0 + &noise;
252
253 let result = pearson_boolean().run_test(x, y, empty_array()).unwrap();
254 match result {
255 TestResult::Boolean(independent) => {
256 assert!(!independent, "Correlated data should return false");
257 }
258 _ => panic!("Expected TestResult::Boolean"),
259 }
260 }
261
262 #[test]
264 fn conditional_independent_data_is_not_rejected() {
265 let mut rng = seeded_rng();
266 let z = gen_normal(N, 0.0, 1.0, &mut rng);
267 let noise_x = gen_normal(N, 0.0, 0.1, &mut rng);
268 let noise_y = gen_normal(N, 0.0, 0.1, &mut rng);
269 let x = &z * 3.0 + &noise_x;
270 let y = &z * 2.0 + &noise_y;
271 let array = z.insert_axis(Axis(1));
272
273 let result = pearson().run_test(x, y, array).unwrap();
274 match result {
275 TestResult::PValue(p_value, coefficient) => {
276 assert!(
277 p_value <= SIGNIFICANCE_LEVEL,
278 "p_value {p_value} should be <= 0.05 after conditioning"
279 );
280 assert!(
281 coefficient.abs() < DELTA_THRESHOLD,
282 "coefficient {coefficient} should be near 0 after conditioning"
283 );
284 }
285 _ => panic!("Expected TestResult::PValue"),
286 }
287 }
288
289 #[test]
290 fn conditional_boolean_accepts_independent() {
291 let mut rng = seeded_rng();
292 let z = gen_normal(N, 0.0, 1.0, &mut rng);
293 let noise_x = gen_normal(N, 0.0, 0.1, &mut rng);
294 let noise_y = gen_normal(N, 0.0, 0.1, &mut rng);
295 let x = &z * 3.0 + &noise_x;
296 let y = &z * 2.0 + &noise_y;
297 let array = z.insert_axis(Axis(1));
298
299 let result = pearson_boolean().run_test(x, y, array).unwrap();
300 match result {
301 TestResult::Boolean(independent) => {
302 assert!(
303 independent,
304 "Conditionally independent data should return true"
305 );
306 }
307 _ => panic!("Expected TestResult::Boolean"),
308 }
309 }
310
311 #[test]
313 fn conditional_dependent_data_is_rejected() {
314 let mut rng = seeded_rng();
315 let x = gen_normal(N, 0.0, 1.0, &mut rng);
316 let y = gen_normal(N, 0.0, 1.0, &mut rng);
317 let noise = gen_normal(N, 0.0, 0.1, &mut rng);
318 let z = &x * 2.0 + &y * 2.0 + &noise;
319 let array = z.insert_axis(Axis(1));
320
321 let result = pearson().run_test(x, y, array).unwrap();
322 match result {
323 TestResult::PValue(p_value, coefficient) => {
324 assert!(
325 p_value >= SIGNIFICANCE_LEVEL,
326 "p_value {p_value} should be >= 0.05 for v-structure"
327 );
328 assert!(
329 coefficient.abs() > 0.9,
330 "coefficient {coefficient} should be high for v-structure"
331 );
332 }
333 _ => panic!("Expected TestResult::PValue"),
334 }
335 }
336
337 #[test]
338 fn conditional_boolean_rejects_dependent() {
339 let mut rng = seeded_rng();
340 let x = gen_normal(N, 0.0, 1.0, &mut rng);
341 let y = gen_normal(N, 0.0, 1.0, &mut rng);
342 let noise = gen_normal(N, 0.0, 0.1, &mut rng);
343 let z = &x * 2.0 + &y * 2.0 + &noise;
344 let array = z.insert_axis(Axis(1));
345
346 let result = pearson_boolean().run_test(x, y, array).unwrap();
347 match result {
348 TestResult::Boolean(independent) => {
349 assert!(
350 !independent,
351 "V-structure conditioned on collider should return false"
352 );
353 }
354 _ => panic!("Expected TestResult::Boolean"),
355 }
356 }
357
358 #[test]
359 fn conditional_multiple_vars_independent_is_not_rejected() {
360 let mut rng = seeded_rng();
361 let z_1 = gen_normal(N, 0.0, 1.0, &mut rng);
362 let z_2 = gen_normal(N, 0.0, 1.0, &mut rng);
363 let z_3 = gen_normal(N, 0.0, 1.0, &mut rng);
364 let noise_x = gen_normal(N, 0.0, 0.1, &mut rng);
365 let noise_y = gen_normal(N, 0.0, 0.1, &mut rng);
366 let x = 0.5 * &z_1 + 0.5 * &z_2 + 0.5 * &z_3 + &noise_x;
367 let y = 0.5 * &z_1 + 0.5 * &z_2 + 0.5 * &z_3 + &noise_y;
368
369 let array = stack(Axis(1), &[z_1.view(), z_2.view(), z_3.view()]).unwrap();
370
371 let result = pearson().run_test(x, y, array).unwrap();
372 match result {
373 TestResult::PValue(p_value, coefficient) => {
374 assert!(
375 p_value < SIGNIFICANCE_LEVEL,
376 "p_value {p_value} should be < 0.05 after conditioning on all confounders"
377 );
378 assert!(
379 coefficient.abs() <= DELTA_THRESHOLD,
380 "coefficient {coefficient} should be near 0 after conditioning on all confounders"
381 );
382 }
383 _ => panic!("Expected TestResult::PValue"),
384 }
385 }
386}