Skip to content

Commit

Permalink
make centre work with nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
hanbin973 committed Oct 11, 2024
1 parent 45ac61e commit 2cdb9dd
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -8666,26 +8666,13 @@ def _rand_svd(
U = Q @ U_hat
return U[:,:rank], D[:rank], V[:rank]

def _genetic_relatedness_vector(
def _genetic_relatedness_vector_individual(
arr: np.ndarray,
rows: np.ndarray,
cols: np.ndarray,
centre: bool = True,
windows = None,
) -> np.ndarray:
"""
Wrapper around `tskit.TreeSequence.genetic_relatedness_vector` to support centering in respect to individuals.
Multiplies an array to the genetic relatedness matrix of :class:`tskit.TreeSequence`.
:param numpy.ndarray arr: The array to multiply. Either a vector or a matrix.
:param numpy.ndarray rows: Index of rows of the genetic relatedness matrix to be selected.
:param numpy.ndarray cols: Index of cols of the genetic relatedness matrix to be selected. The size should match the row length of `arr`.
:param bool centre: Centre the genetic relatedness matrix. Centering happens respect to the `rows` and `cols`.
:param windows: An increasing list of breakpoints between the windows to compute the genetic relatedness matrix in.
:return: An array that is the matrix-array product of the genetic relatedness matrix and the array.
:rtype: `np.ndarray`
"""

assert cols.size == arr.shape[0], "Dimension mismatch"
ij = np.vstack([[n,k] for k, i in enumerate(individuals) for n in self.individual(i).nodes])
samples, sample_individuals = ij[:,0], ij[:,1] # sample node index, individual of those nodes
Expand All @@ -8697,6 +8684,20 @@ def _genetic_relatedness_vector(

return x

def _genetic_relatedness_vector_node(
arr: np.ndarray,
rows: np.ndarray,
cols: np.ndarray,
centre: bool = True,
windows = None,
) -> np.ndarray:
assert cols.size == arr.shape[0], "Dimension mismatch"
x = arr - arr.mean(axis=0) if centre else arr
x = self.genetic_relatedness_vector(W=x, windows=windows, mode="branch", centre=False, nodes=cols)[0]
x = x - x.mean(axis=0) if centre else x

return x

random_state = np.random.default_rng(random_seed)
if samples is None and individuals is None: samples = self.samples()

Expand All @@ -8717,10 +8718,10 @@ def _genetic_relatedness_vector(
D = np.empty((len(windows)-1, n_components))
for i in range(len(windows)-1):
if mode == 'node':
_G = lambda x: self.genetic_relatedness_vector(
x, windows=windows[i:i+2], mode="branch", centre=centre, nodes=samples)[0]
_G = lambda x: _genetic_relatedness_vector_node(
x, samples, samples, centre=centre, windows=windows[i:i+2])
elif mode == 'individual':
_G = lambda x: _genetic_relatedness_vector(
_G = lambda x: _genetic_relatedness_vector_individual(
x, individuals, individuals, centre=centre, windows=windows[i:i+2])
U[i], D[i], _ = _rand_svd(
operator=_G,
Expand Down

0 comments on commit 2cdb9dd

Please sign in to comment.