Skip to main content

ci_core/ci_tests/
pearson_correlation.rs

1use crate::{
2    strategy::{CITest, CITestDataType, TestResult},
3    utils::EPS,
4};
5use anyhow::ensure;
6use nalgebra::{DMatrix, DVector};
7use ndarray::{Array1, Array2, ArrayView1};
8use statrs::distribution::{ContinuousCDF, StudentsT};
9
10const SVD_TOLERANCE: f64 = 1e-10;
11const MIN_SAMPLE_SIZE: usize = 3;
12
13/// Pearson correlation conditional independence test.
14///
15/// Should be used only on continuous data. When the conditioning set is non-empty,
16/// uses linear regression to compute residuals and tests the Pearson correlation
17/// on those residuals (partial correlation).
18///
19/// # References
20///
21/// - [Pearson correlation coefficient](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)
22/// - [Partial correlation using linear regression](https://en.wikipedia.org/wiki/Partial_correlation#Using_linear_regression)
23#[derive(Debug, Clone, PartialEq)]
24pub struct PearsonCorrelation {
25    pub boolean: bool,
26    pub significance_level: f64,
27}
28
29impl PearsonCorrelation {
30    #[must_use]
31    pub fn new(boolean: bool, significance_level: f64) -> Self {
32        Self {
33            boolean,
34            significance_level,
35        }
36    }
37}
38
39impl CITest for PearsonCorrelation {
40    fn run_test(
41        &self,
42        x_values: Array1<f64>,
43        y_values: Array1<f64>,
44        z: Array2<f64>,
45    ) -> anyhow::Result<TestResult> {
46        if z.is_empty() {
47            let (coefficient, p_value) = pearsonr(&x_values.view(), &y_values.view())?;
48            Ok(wrap_result(
49                self.boolean,
50                p_value,
51                coefficient,
52                self.significance_level,
53            ))
54        } else {
55            let z_na = DMatrix::from_row_iterator(z.nrows(), z.ncols(), z.iter().copied());
56            let x_na = DVector::from_iterator(x_values.len(), x_values.iter().copied());
57            let y_na = DVector::from_iterator(y_values.len(), y_values.iter().copied());
58
59            let svd = z_na.svd(true, true);
60            let x_coefficient = svd
61                .solve(&x_na, SVD_TOLERANCE)
62                .map_err(|e| anyhow::anyhow!("least squares failed for x: {e}"))?;
63            let y_coefficient = svd
64                .solve(&y_na, SVD_TOLERANCE)
65                .map_err(|e| anyhow::anyhow!("least squares failed for y: {e}"))?;
66
67            let x_coef_nd = Array1::from_vec(x_coefficient.iter().copied().collect());
68            let y_coef_nd = Array1::from_vec(y_coefficient.iter().copied().collect());
69
70            let residual_x = x_values - z.dot(&x_coef_nd);
71            let residual_y = y_values - z.dot(&y_coef_nd);
72
73            let (coefficient, p_value) = pearsonr(&residual_x.view(), &residual_y.view())?;
74            Ok(wrap_result(
75                self.boolean,
76                p_value,
77                coefficient,
78                self.significance_level,
79            ))
80        }
81    }
82
83    fn data_types(&self) -> &'static [CITestDataType] {
84        &[CITestDataType::Continuous]
85    }
86}
87
88#[must_use]
89pub fn wrap_result(
90    boolean: bool,
91    p_value: f64,
92    coefficient: f64,
93    significance_level: f64,
94) -> TestResult {
95    if boolean {
96        return TestResult::Boolean(p_value >= significance_level);
97    }
98    TestResult::PValue(p_value, coefficient)
99}
100
101/// Compute the Pearson correlation coefficient and its two-tailed p-value.
102///
103/// Tests H₀: ρ = 0 using the t-distribution with n − 2 degrees of freedom.
104/// Returns `(coefficient, p_value)`.
105///
106/// # Errors
107///
108/// Returns an error if the input has fewer than 3 elements (degrees of freedom < 1).
109fn pearsonr(x_values: &ArrayView1<f64>, y_values: &ArrayView1<f64>) -> anyhow::Result<(f64, f64)> {
110    let n = x_values.len();
111    ensure!(
112        x_values.len() == y_values.len() && x_values.len() >= MIN_SAMPLE_SIZE,
113        "pearsonr requires equal-length inputs with n >= 3"
114    );
115
116    #[allow(
117        clippy::cast_precision_loss,
118        reason = "array length most likely won't exceed 2^53"
119    )]
120    let number_of_elements = n as f64;
121
122    let x_mean = x_values.sum() / number_of_elements;
123    let y_mean = y_values.sum() / number_of_elements;
124
125    let mut sum_sq_x = 0.0;
126    let mut sum_sq_y = 0.0;
127    let mut sum_coproduct = 0.0;
128
129    for (&x, &y) in x_values.iter().zip(y_values.iter()) {
130        let dx = x - x_mean;
131        let dy = y - y_mean;
132
133        sum_sq_x += dx * dx;
134        sum_sq_y += dy * dy;
135        sum_coproduct += dx * dy;
136    }
137
138    // If one of the datasets is constant, pearson coefficient is undefined.
139    if sum_sq_x == 0.0 || sum_sq_y == 0.0 {
140        let array_name = if sum_sq_x == 0.0 { "x" } else { "y" };
141        panic!("Array {array_name} is constant, so the pearson coëfficient is undefined.");
142    }
143
144    // Calculate correlation directly
145    let mut coefficient = sum_coproduct / (sum_sq_x * sum_sq_y).sqrt();
146
147    // Floating-point math can sometimes drift slightly outside (-1.0, 1.0), so then we clamp.
148    // By adding/subtracting EPS we prevent divide by 0 errors.
149    coefficient = coefficient.clamp(-1.0 + EPS, 1.0 - EPS);
150
151    let t_statistic =
152        coefficient * (number_of_elements - 2.0).sqrt() / (1.0 - coefficient.powi(2)).sqrt();
153
154    let t_distribution = StudentsT::new(0.0, 1.0, number_of_elements - 2.0)?;
155    let p_value = 2.0 * t_distribution.sf(t_statistic.abs());
156
157    Ok((coefficient, p_value))
158}
159
160#[cfg(test)]
161#[allow(clippy::many_single_char_names)]
162mod tests {
163    use super::*;
164    use crate::utils::EPS;
165    use ndarray::{array, Array1, Array2};
166
167    const SIGNIFICANCE_LEVEL: f64 = 0.05;
168
169    fn unwrap_correlated(r: &TestResult) -> (f64, f64) {
170        match r {
171            TestResult::PValue(p, coef) => (*p, *coef),
172            _ => panic!("expected TestResult::PValue"),
173        }
174    }
175
176    #[test]
177    fn uncond_independent_data_accepted() {
178        let t = PearsonCorrelation {
179            boolean: false,
180            significance_level: SIGNIFICANCE_LEVEL,
181        };
182        let x = array![2., 4., 1., 5., 3., 8., 6., 7., 9., 10.];
183        let y = array![5., 3., 7., 2., 8., 1., 9., 4., 6., 10.];
184
185        let (p, coef) = unwrap_correlated(&t.run_test(x, y, Array2::zeros((0, 0))).unwrap());
186        assert!(p > SIGNIFICANCE_LEVEL, "got {p}");
187        assert!(
188            coef.abs() < 0.1,
189            "coef={coef} should be near 0 for uncorrelated data"
190        );
191    }
192
193    #[test]
194    fn uncond_boolean_mode() {
195        let t = PearsonCorrelation {
196            boolean: true,
197            significance_level: SIGNIFICANCE_LEVEL,
198        };
199        // independent -> true
200        let x = array![2., 4., 1., 5., 3., 8., 6., 7., 9., 10.];
201        let y = array![5., 3., 7., 2., 8., 1., 9., 4., 6., 10.];
202        let r = t.run_test(x, y, Array2::zeros((0, 0))).unwrap();
203        assert!(matches!(r, TestResult::Boolean(true)));
204
205        // dependent -> false
206        let x = array![1., 2., 3., 4., 5.];
207        let y = array![2., 4., 6., 8., 10.];
208        let r = t.run_test(x, y, Array2::zeros((0, 0))).unwrap();
209        assert!(matches!(r, TestResult::Boolean(false)));
210    }
211
212    #[test]
213    fn uncond_dependent_data_rejected() {
214        let t = PearsonCorrelation {
215            boolean: false,
216            significance_level: SIGNIFICANCE_LEVEL,
217        };
218        let x = array![1., 2., 3., 4., 5.];
219        let y = array![2., 4., 6., 8., 10.];
220
221        let (p, coef) = unwrap_correlated(&t.run_test(x, y, Array2::zeros((0, 0))).unwrap());
222        assert!(p < SIGNIFICANCE_LEVEL, "got {p}");
223        assert!(
224            coef.abs() > 0.9,
225            "coef={coef} should be high for perfectly correlated data"
226        );
227    }
228
229    // Z is a confounder: X = 3*Z + noise, Y = 2*Z + noise. After conditioning, residuals are independent.
230    #[test]
231    fn cond_independent_data_accepted() {
232        let t = PearsonCorrelation {
233            boolean: false,
234            significance_level: SIGNIFICANCE_LEVEL,
235        };
236        let x = array![
237            2.019_608, 1.039_216, 4.058_824, 3.078_431, 6.098_039, 5.117_647, 8.137_255, 7.156_863
238        ];
239        let y = array![
240            2.059_406, 3.059_406, 2.138_614, 3.138_614, 6.217_822, 7.217_822, 6.297_030, 7.297_030
241        ];
242        let z = array![[1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.]];
243        let (p, coef) = unwrap_correlated(&t.run_test(x, y, z).unwrap());
244        assert!(p > SIGNIFICANCE_LEVEL, "got {p}");
245        assert!(
246            coef.abs() < 0.1,
247            "coef={coef} should be near 0 after conditioning"
248        );
249    }
250
251    #[test]
252    fn cond_boolean_mode() {
253        // accepted
254        let t = PearsonCorrelation {
255            boolean: true,
256            significance_level: SIGNIFICANCE_LEVEL,
257        };
258        let x = array![
259            2.019_608, 1.039_216, 4.058_824, 3.078_431, 6.098_039, 5.117_647, 8.137_255, 7.156_863
260        ];
261        let y = array![
262            2.059_406, 3.059_406, 2.138_614, 3.138_614, 6.217_822, 7.217_822, 6.297_030, 7.297_030
263        ];
264        let z = array![[1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.]];
265        let r = t.run_test(x, y, z).unwrap();
266        assert!(matches!(r, TestResult::Boolean(true)));
267
268        // rejected
269        let x = array![1., 2., 3., 4., 5., 6., 7., 8.];
270        let y = array![8., 7., 6., 5., 4., 3., 2., 1.];
271        let z = array![[4.5], [4.5], [4.5], [4.5], [4.5], [4.5], [4.5], [4.5]];
272        let r = t.run_test(x, y, z).unwrap();
273        assert!(matches!(r, TestResult::Boolean(false)));
274    }
275
276    // Z = 2*X + 2*Y + noise is a collider; conditioning on it induces dependence between X and Y.
277    #[test]
278    fn cond_dependent_data_rejected() {
279        let t = PearsonCorrelation {
280            boolean: false,
281            significance_level: SIGNIFICANCE_LEVEL,
282        };
283        let x = array![1., 2., 3., 4., 5., 6., 7., 8.];
284        let y = array![8., 7., 6., 5., 4., 3., 2., 1.];
285        let z = array![[9.], [9.], [9.], [9.], [9.], [9.], [9.], [9.]];
286
287        let (p, coef) = unwrap_correlated(&t.run_test(x, y, z).unwrap());
288        assert!(p < SIGNIFICANCE_LEVEL, "got {p}");
289        assert!(
290            coef.abs() > 0.9,
291            "coef={coef} should be high for collider structure"
292        );
293    }
294
295    #[test]
296    fn cond_bool_rejects_dependent() {
297        let t = PearsonCorrelation {
298            boolean: true,
299            significance_level: SIGNIFICANCE_LEVEL,
300        };
301        let x = array![1., 2., 3., 4., 5., 6., 7., 8.];
302        let y = array![8., 7., 6., 5., 4., 3., 2., 1.];
303        let z = array![[9.], [9.], [9.], [9.], [9.], [9.], [9.], [9.]];
304
305        let r = t.run_test(x, y, z).unwrap();
306        assert!(matches!(r, TestResult::Boolean(false)));
307    }
308    #[test]
309    fn cond_multiple_vars_independent_accepted() {
310        let t = PearsonCorrelation {
311            boolean: false,
312            significance_level: SIGNIFICANCE_LEVEL,
313        };
314        let x = array![2.5, 2.5, 2.5, 4.0, 4.0, 4.0, 4.0, 5.5];
315        let y = array![2.4, 2.4, 0.8, 2.8, 5.2, 5.2, 3.6, 5.6];
316        let z = array![
317            [1., 0., 1.],
318            [1., 1., 0.],
319            [2., 0., 0.],
320            [2., 1., 1.],
321            [3., 0., 1.],
322            [3., 1., 0.],
323            [4., 0., 0.],
324            [4., 1., 1.],
325        ];
326
327        let (p, coef) = unwrap_correlated(&t.run_test(x, y, z).unwrap());
328        assert!(p > SIGNIFICANCE_LEVEL, "got {p}");
329        assert!(
330            coef.abs() < 0.1,
331            "coef={coef} should be near 0 after conditioning on all confounders"
332        );
333    }
334
335    #[test]
336    fn pearsonr_errors_on_empty_input() {
337        let x: Array1<f64> = Array1::zeros(0);
338        let y: Array1<f64> = Array1::zeros(0);
339        assert!(pearsonr(&x.view(), &y.view()).is_err());
340    }
341
342    #[test]
343    fn pearsonr_errors_on_too_few_elements() {
344        let x = Array1::from_vec(vec![1.0, 2.0]);
345        let y = Array1::from_vec(vec![3.0, 4.0]);
346        assert!(pearsonr(&x.view(), &y.view()).is_err());
347    }
348
349    #[test]
350    fn pearsonr_errors_on_mismatched_lengths() {
351        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
352        let y = Array1::from_vec(vec![1.0, 2.0]);
353        assert!(pearsonr(&x.view(), &y.view()).is_err());
354    }
355
356    #[test]
357    fn pearsonr_succeeds_with_minimum_input() {
358        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
359        let y = Array1::from_vec(vec![1.0, 2.0, 3.0]);
360        let (coef, p_value) = pearsonr(&x.view(), &y.view()).unwrap();
361        assert!((coef - 1.0).abs() < EPS, "perfect positive correlation");
362        assert!(p_value < SIGNIFICANCE_LEVEL, "got {p_value}");
363    }
364}