1use crate::{
2 strategy::{CITest, CITestDataType, TestResult},
3 utils::EPS,
4};
5use anyhow::ensure;
6use nalgebra::{DMatrix, DVector};
7use ndarray::{Array1, Array2, ArrayView1};
8use statrs::distribution::{ContinuousCDF, StudentsT};
9
10const SVD_TOLERANCE: f64 = 1e-10;
11const MIN_SAMPLE_SIZE: usize = 3;
12
13#[derive(Debug, Clone, PartialEq)]
24pub struct PearsonCorrelation {
25 pub boolean: bool,
26 pub significance_level: f64,
27}
28
29impl PearsonCorrelation {
30 #[must_use]
31 pub fn new(boolean: bool, significance_level: f64) -> Self {
32 Self {
33 boolean,
34 significance_level,
35 }
36 }
37}
38
39impl CITest for PearsonCorrelation {
40 fn run_test(
41 &self,
42 x_values: Array1<f64>,
43 y_values: Array1<f64>,
44 z: Array2<f64>,
45 ) -> anyhow::Result<TestResult> {
46 if z.is_empty() {
47 let (coefficient, p_value) = pearsonr(&x_values.view(), &y_values.view())?;
48 Ok(wrap_result(
49 self.boolean,
50 p_value,
51 coefficient,
52 self.significance_level,
53 ))
54 } else {
55 let z_na = DMatrix::from_row_iterator(z.nrows(), z.ncols(), z.iter().copied());
56 let x_na = DVector::from_iterator(x_values.len(), x_values.iter().copied());
57 let y_na = DVector::from_iterator(y_values.len(), y_values.iter().copied());
58
59 let svd = z_na.svd(true, true);
60 let x_coefficient = svd
61 .solve(&x_na, SVD_TOLERANCE)
62 .map_err(|e| anyhow::anyhow!("least squares failed for x: {e}"))?;
63 let y_coefficient = svd
64 .solve(&y_na, SVD_TOLERANCE)
65 .map_err(|e| anyhow::anyhow!("least squares failed for y: {e}"))?;
66
67 let x_coef_nd = Array1::from_vec(x_coefficient.iter().copied().collect());
68 let y_coef_nd = Array1::from_vec(y_coefficient.iter().copied().collect());
69
70 let residual_x = x_values - z.dot(&x_coef_nd);
71 let residual_y = y_values - z.dot(&y_coef_nd);
72
73 let (coefficient, p_value) = pearsonr(&residual_x.view(), &residual_y.view())?;
74 Ok(wrap_result(
75 self.boolean,
76 p_value,
77 coefficient,
78 self.significance_level,
79 ))
80 }
81 }
82
83 fn data_types(&self) -> &'static [CITestDataType] {
84 &[CITestDataType::Continuous]
85 }
86}
87
88#[must_use]
89pub fn wrap_result(
90 boolean: bool,
91 p_value: f64,
92 coefficient: f64,
93 significance_level: f64,
94) -> TestResult {
95 if boolean {
96 return TestResult::Boolean(p_value >= significance_level);
97 }
98 TestResult::PValue(p_value, coefficient)
99}
100
101fn pearsonr(x_values: &ArrayView1<f64>, y_values: &ArrayView1<f64>) -> anyhow::Result<(f64, f64)> {
110 let n = x_values.len();
111 ensure!(
112 x_values.len() == y_values.len() && x_values.len() >= MIN_SAMPLE_SIZE,
113 "pearsonr requires equal-length inputs with n >= 3"
114 );
115
116 #[allow(
117 clippy::cast_precision_loss,
118 reason = "array length most likely won't exceed 2^53"
119 )]
120 let number_of_elements = n as f64;
121
122 let x_mean = x_values.sum() / number_of_elements;
123 let y_mean = y_values.sum() / number_of_elements;
124
125 let mut sum_sq_x = 0.0;
126 let mut sum_sq_y = 0.0;
127 let mut sum_coproduct = 0.0;
128
129 for (&x, &y) in x_values.iter().zip(y_values.iter()) {
130 let dx = x - x_mean;
131 let dy = y - y_mean;
132
133 sum_sq_x += dx * dx;
134 sum_sq_y += dy * dy;
135 sum_coproduct += dx * dy;
136 }
137
138 if sum_sq_x == 0.0 || sum_sq_y == 0.0 {
140 let array_name = if sum_sq_x == 0.0 { "x" } else { "y" };
141 panic!("Array {array_name} is constant, so the pearson coëfficient is undefined.");
142 }
143
144 let mut coefficient = sum_coproduct / (sum_sq_x * sum_sq_y).sqrt();
146
147 coefficient = coefficient.clamp(-1.0 + EPS, 1.0 - EPS);
150
151 let t_statistic =
152 coefficient * (number_of_elements - 2.0).sqrt() / (1.0 - coefficient.powi(2)).sqrt();
153
154 let t_distribution = StudentsT::new(0.0, 1.0, number_of_elements - 2.0)?;
155 let p_value = 2.0 * t_distribution.sf(t_statistic.abs());
156
157 Ok((coefficient, p_value))
158}
159
160#[cfg(test)]
161#[allow(clippy::many_single_char_names)]
162mod tests {
163 use super::*;
164 use crate::utils::EPS;
165 use ndarray::{array, Array1, Array2};
166
167 const SIGNIFICANCE_LEVEL: f64 = 0.05;
168
169 fn unwrap_correlated(r: &TestResult) -> (f64, f64) {
170 match r {
171 TestResult::PValue(p, coef) => (*p, *coef),
172 _ => panic!("expected TestResult::PValue"),
173 }
174 }
175
176 #[test]
177 fn uncond_independent_data_accepted() {
178 let t = PearsonCorrelation {
179 boolean: false,
180 significance_level: SIGNIFICANCE_LEVEL,
181 };
182 let x = array![2., 4., 1., 5., 3., 8., 6., 7., 9., 10.];
183 let y = array![5., 3., 7., 2., 8., 1., 9., 4., 6., 10.];
184
185 let (p, coef) = unwrap_correlated(&t.run_test(x, y, Array2::zeros((0, 0))).unwrap());
186 assert!(p > SIGNIFICANCE_LEVEL, "got {p}");
187 assert!(
188 coef.abs() < 0.1,
189 "coef={coef} should be near 0 for uncorrelated data"
190 );
191 }
192
193 #[test]
194 fn uncond_boolean_mode() {
195 let t = PearsonCorrelation {
196 boolean: true,
197 significance_level: SIGNIFICANCE_LEVEL,
198 };
199 let x = array![2., 4., 1., 5., 3., 8., 6., 7., 9., 10.];
201 let y = array![5., 3., 7., 2., 8., 1., 9., 4., 6., 10.];
202 let r = t.run_test(x, y, Array2::zeros((0, 0))).unwrap();
203 assert!(matches!(r, TestResult::Boolean(true)));
204
205 let x = array![1., 2., 3., 4., 5.];
207 let y = array![2., 4., 6., 8., 10.];
208 let r = t.run_test(x, y, Array2::zeros((0, 0))).unwrap();
209 assert!(matches!(r, TestResult::Boolean(false)));
210 }
211
212 #[test]
213 fn uncond_dependent_data_rejected() {
214 let t = PearsonCorrelation {
215 boolean: false,
216 significance_level: SIGNIFICANCE_LEVEL,
217 };
218 let x = array![1., 2., 3., 4., 5.];
219 let y = array![2., 4., 6., 8., 10.];
220
221 let (p, coef) = unwrap_correlated(&t.run_test(x, y, Array2::zeros((0, 0))).unwrap());
222 assert!(p < SIGNIFICANCE_LEVEL, "got {p}");
223 assert!(
224 coef.abs() > 0.9,
225 "coef={coef} should be high for perfectly correlated data"
226 );
227 }
228
229 #[test]
231 fn cond_independent_data_accepted() {
232 let t = PearsonCorrelation {
233 boolean: false,
234 significance_level: SIGNIFICANCE_LEVEL,
235 };
236 let x = array![
237 2.019_608, 1.039_216, 4.058_824, 3.078_431, 6.098_039, 5.117_647, 8.137_255, 7.156_863
238 ];
239 let y = array![
240 2.059_406, 3.059_406, 2.138_614, 3.138_614, 6.217_822, 7.217_822, 6.297_030, 7.297_030
241 ];
242 let z = array![[1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.]];
243 let (p, coef) = unwrap_correlated(&t.run_test(x, y, z).unwrap());
244 assert!(p > SIGNIFICANCE_LEVEL, "got {p}");
245 assert!(
246 coef.abs() < 0.1,
247 "coef={coef} should be near 0 after conditioning"
248 );
249 }
250
251 #[test]
252 fn cond_boolean_mode() {
253 let t = PearsonCorrelation {
255 boolean: true,
256 significance_level: SIGNIFICANCE_LEVEL,
257 };
258 let x = array![
259 2.019_608, 1.039_216, 4.058_824, 3.078_431, 6.098_039, 5.117_647, 8.137_255, 7.156_863
260 ];
261 let y = array![
262 2.059_406, 3.059_406, 2.138_614, 3.138_614, 6.217_822, 7.217_822, 6.297_030, 7.297_030
263 ];
264 let z = array![[1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.]];
265 let r = t.run_test(x, y, z).unwrap();
266 assert!(matches!(r, TestResult::Boolean(true)));
267
268 let x = array![1., 2., 3., 4., 5., 6., 7., 8.];
270 let y = array![8., 7., 6., 5., 4., 3., 2., 1.];
271 let z = array![[4.5], [4.5], [4.5], [4.5], [4.5], [4.5], [4.5], [4.5]];
272 let r = t.run_test(x, y, z).unwrap();
273 assert!(matches!(r, TestResult::Boolean(false)));
274 }
275
276 #[test]
278 fn cond_dependent_data_rejected() {
279 let t = PearsonCorrelation {
280 boolean: false,
281 significance_level: SIGNIFICANCE_LEVEL,
282 };
283 let x = array![1., 2., 3., 4., 5., 6., 7., 8.];
284 let y = array![8., 7., 6., 5., 4., 3., 2., 1.];
285 let z = array![[9.], [9.], [9.], [9.], [9.], [9.], [9.], [9.]];
286
287 let (p, coef) = unwrap_correlated(&t.run_test(x, y, z).unwrap());
288 assert!(p < SIGNIFICANCE_LEVEL, "got {p}");
289 assert!(
290 coef.abs() > 0.9,
291 "coef={coef} should be high for collider structure"
292 );
293 }
294
295 #[test]
296 fn cond_bool_rejects_dependent() {
297 let t = PearsonCorrelation {
298 boolean: true,
299 significance_level: SIGNIFICANCE_LEVEL,
300 };
301 let x = array![1., 2., 3., 4., 5., 6., 7., 8.];
302 let y = array![8., 7., 6., 5., 4., 3., 2., 1.];
303 let z = array![[9.], [9.], [9.], [9.], [9.], [9.], [9.], [9.]];
304
305 let r = t.run_test(x, y, z).unwrap();
306 assert!(matches!(r, TestResult::Boolean(false)));
307 }
308 #[test]
309 fn cond_multiple_vars_independent_accepted() {
310 let t = PearsonCorrelation {
311 boolean: false,
312 significance_level: SIGNIFICANCE_LEVEL,
313 };
314 let x = array![2.5, 2.5, 2.5, 4.0, 4.0, 4.0, 4.0, 5.5];
315 let y = array![2.4, 2.4, 0.8, 2.8, 5.2, 5.2, 3.6, 5.6];
316 let z = array![
317 [1., 0., 1.],
318 [1., 1., 0.],
319 [2., 0., 0.],
320 [2., 1., 1.],
321 [3., 0., 1.],
322 [3., 1., 0.],
323 [4., 0., 0.],
324 [4., 1., 1.],
325 ];
326
327 let (p, coef) = unwrap_correlated(&t.run_test(x, y, z).unwrap());
328 assert!(p > SIGNIFICANCE_LEVEL, "got {p}");
329 assert!(
330 coef.abs() < 0.1,
331 "coef={coef} should be near 0 after conditioning on all confounders"
332 );
333 }
334
335 #[test]
336 fn pearsonr_errors_on_empty_input() {
337 let x: Array1<f64> = Array1::zeros(0);
338 let y: Array1<f64> = Array1::zeros(0);
339 assert!(pearsonr(&x.view(), &y.view()).is_err());
340 }
341
342 #[test]
343 fn pearsonr_errors_on_too_few_elements() {
344 let x = Array1::from_vec(vec![1.0, 2.0]);
345 let y = Array1::from_vec(vec![3.0, 4.0]);
346 assert!(pearsonr(&x.view(), &y.view()).is_err());
347 }
348
349 #[test]
350 fn pearsonr_errors_on_mismatched_lengths() {
351 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
352 let y = Array1::from_vec(vec![1.0, 2.0]);
353 assert!(pearsonr(&x.view(), &y.view()).is_err());
354 }
355
356 #[test]
357 fn pearsonr_succeeds_with_minimum_input() {
358 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
359 let y = Array1::from_vec(vec![1.0, 2.0, 3.0]);
360 let (coef, p_value) = pearsonr(&x.view(), &y.view()).unwrap();
361 assert!((coef - 1.0).abs() < EPS, "perfect positive correlation");
362 assert!(p_value < SIGNIFICANCE_LEVEL, "got {p_value}");
363 }
364}