Source code for bandhic._test_utils
# -*- coding: utf-8 -*-
# _test_utils.py
"""
_test_utils: Utility functions for testing the BandHiC package.
Author: Weibing Wang
Date: 2025-06-11
Email: wangweibing@xidian.edu.cn
This module provides functions to compare two band_hic_matrix objects and assert their equality.
"""
from .bandhic import band_hic_matrix
import numpy as np
__all__ = [
"matrix_equal",
"assert_band_matrix_equal",
]
[docs]
def matrix_equal(a: band_hic_matrix, b: band_hic_matrix) -> bool:
"""
Check if two band_hic_matrix objects are equal.
Parameters
----------
a : band_hic_matrix
First band_hic_matrix object.
b : band_hic_matrix
Second band_hic_matrix object.
Returns
-------
bool
True if the matrices are equal, False otherwise.
"""
if isinstance(a, band_hic_matrix) and isinstance(b, band_hic_matrix):
if a.shape != b.shape:
return False
elif a.diag_num != b.diag_num:
return False
elif a.dtype != b.dtype:
return False
elif a.default_value != b.default_value:
return False
elif a.mask is not None and b.mask is not None:
if not np.array_equal(a.mask, b.mask):
return False
elif a.mask is not None or b.mask is not None:
return False
elif a.mask_row_col is not None and b.mask_row_col is not None:
if not np.array_equal(a.mask_row_col, b.mask_row_col):
return False
elif a.mask_row_col is not None or b.mask_row_col is not None:
return False
elif not np.array_equal(a.data[a.mask], b.data[b.mask]):
return False
else:
return True
else:
return False
[docs]
def assert_band_matrix_equal(a: band_hic_matrix, b: band_hic_matrix) -> bool:
"""
Assert that two band_hic_matrix objects are equal.
Parameters
----------
a : band_hic_matrix
First band_hic_matrix object.
b : band_hic_matrix
Second band_hic_matrix object.
Raises
------
AssertionError
If the two matrices are not equal.
"""
if isinstance(a, band_hic_matrix) and isinstance(b, band_hic_matrix):
if a.shape != b.shape:
raise AssertionError(
"Shapes do not match: {} vs {}".format(a.shape, b.shape)
)
elif a.diag_num != b.diag_num:
raise AssertionError(
"Diagonal numbers do not match: {} vs {}".format(
a.diag_num, b.diag_num
)
)
elif a.dtype != b.dtype:
raise AssertionError(
"Data types do not match: {} vs {}".format(a.dtype, b.dtype)
)
elif a.default_value != b.default_value:
raise AssertionError(
"Default values do not match: {} vs {}".format(
a.default_value, b.default_value
)
)
elif a.mask is not None and b.mask is not None:
if not np.array_equal(a.mask, b.mask):
raise AssertionError(
"Masks do not match: {} vs {}".format(a.mask, b.mask)
)
elif a.mask is not None or b.mask is not None:
raise AssertionError(
"One of the matrices has a mask while the other does not."
)
elif a.mask_row_col is not None and b.mask_row_col is not None:
if not np.array_equal(a.mask_row_col, b.mask_row_col):
raise AssertionError(
"Row/column masks do not match: {} vs {}".format(
a.mask_row_col, b.mask_row_col
)
)
elif a.mask_row_col is not None or b.mask_row_col is not None:
raise AssertionError(
"One of the matrices has a row/column mask while the other does not."
)
elif not np.array_equal(a.data, b.data):
raise AssertionError(
"Data in band_hic_matrix objects do not match, matrices: {} vs {}".format(
a, b
)
)
else:
return True
else:
raise AssertionError(
"Both inputs must be band_hic_matrix objects, got types: {} vs {}".format(
type(a), type(b)
)
)