Skip to content

Commit

Permalink
function rearrangement
Browse files Browse the repository at this point in the history
  • Loading branch information
hanbin973 committed Oct 22, 2024
1 parent af16340 commit e81db15
Showing 1 changed file with 109 additions and 108 deletions.
217 changes: 109 additions & 108 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -8592,6 +8592,60 @@ def genetic_relatedness_vector(
)
return out

def _genetic_relatedness_vector_node(
self,
arr: np.ndarray,
indices: np.ndarray,
mode: str,
centre: bool = True,
windows = None,
) -> np.ndarray:
x = arr - arr.mean(axis=0) if centre else arr
x = self.genetic_relatedness_vector(
W=x, windows=windows, mode=mode, centre=False, nodes=indices,
)[0]
x = x - x.mean(axis=0) if centre else x

return x

def _genetic_relatedness_vector_individual(
self,
arr: np.ndarray,
indices: np.ndarray,
mode: str,
centre: bool = True,
windows = None,
) -> np.ndarray:
ij = np.vstack(
[
[n, k]
for k, i in enumerate(indices)
for n in self.individual(i).nodes
]
)
samples, sample_individuals = (
ij[:, 0],
ij[:, 1],
) # sample node index, individual of those nodes
x = (
arr - arr.mean(axis=0) if centre else arr
) # centering within index in rows
x = self.genetic_relatedness_vector(
W=x[sample_individuals],
windows=windows,
mode=mode,
centre=False,
nodes=samples,
)[0]

def bincount_fn(w):
return np.bincount(sample_individuals, w)

x = np.apply_along_axis(bincount_fn, axis=0, arr=x)
x = x - x.mean(axis=0) if centre else x # centering within index in cols

return x

def pca(
self,
num_components: int,
Expand Down Expand Up @@ -8669,6 +8723,58 @@ def pca(
"the number of samples (or individuals, if specified)."
)

def _rand_pow_range_finder(
operator,
operator_dim: int,
rank: int,
depth: int,
num_vectors: int,
rng: np.random.Generator,
range_sketch: np.ndarray = None,
) -> np.ndarray:
"""
Algorithm 9 in https://arxiv.org/pdf/2002.01387
"""
assert num_vectors >= rank > 0, "num_vectors should be larger than rank"
if range_sketch is None:
test_vectors = rng.normal(size=(operator_dim, num_vectors))
Q = test_vectors
else:
Q = range_sketch
for _ in range(depth):
Q = np.linalg.qr(Q).Q
Q = operator(Q)
Q = np.linalg.qr(Q).Q
return Q[:, :rank]

def _rand_svd(
operator,
operator_dim: int,
rank: int,
depth: int,
num_vectors: int,
rng: np.random.Generator,
range_sketch: np.ndarray = None,
) -> (np.ndarray, np.ndarray, np.ndarray, float):
"""
Algorithm 8 in https://arxiv.org/pdf/2002.01387
"""
assert num_vectors >= rank > 0
Q = _rand_pow_range_finder(
operator, operator_dim, num_vectors, depth, num_vectors, rng, range_sketch
)
C = operator(Q).T
U_hat, D, V = np.linalg.svd(C, full_matrices=False)
U = Q @ U_hat

error_factor = np.power(
1 + 4 * np.sqrt(2 * operator_dim / (rank - 1)),
1 / (2 * depth + 1)
)
error_bound = D[-1] * (1 + error_factor)
return U[:, :rank], D[:rank], V[:rank], Q, error_bound


random_state = np.random.default_rng(random_seed)
drop_windows = windows is None
windows = self.parse_windows(windows)
Expand All @@ -8683,17 +8789,17 @@ def pca(
for i in range(num_windows):
this_window = windows[i : i + 2]
_f = (
_genetic_relatedness_vector_node
self._genetic_relatedness_vector_node
if output_type == "node"
else _genetic_relatedness_vector_individual
else self._genetic_relatedness_vector_individual
)
indices = (
samples
if output_type == "node"
else individuals
)
def _G(x):
return _f(tree_sequence=self, arr=x, indices=indices, mode=mode, centre=centre, windows=this_window) # NOQA: B023
return _f(arr=x, indices=indices, mode=mode, centre=centre, windows=this_window) # NOQA: B023

U[i], D[i], _, Q[i], E[i] = _rand_svd(
operator=_G,
Expand Down Expand Up @@ -10292,111 +10398,6 @@ def write_ms(
else:
print(file=output)

def _rand_pow_range_finder(
operator,
operator_dim: int,
rank: int,
depth: int,
num_vectors: int,
rng: np.random.Generator,
range_sketch: np.ndarray = None,
) -> np.ndarray:
"""
Algorithm 9 in https://arxiv.org/pdf/2002.01387
"""
assert num_vectors >= rank > 0, "num_vectors should be larger than rank"
if range_sketch is None:
test_vectors = rng.normal(size=(operator_dim, num_vectors))
Q = test_vectors
else:
Q = range_sketch
for _ in range(depth):
Q = np.linalg.qr(Q).Q
Q = operator(Q)
Q = np.linalg.qr(Q).Q
return Q[:, :rank]

def _rand_svd(
operator,
operator_dim: int,
rank: int,
depth: int,
num_vectors: int,
rng: np.random.Generator,
range_sketch: np.ndarray = None,
) -> (np.ndarray, np.ndarray, np.ndarray, float):
"""
Algorithm 8 in https://arxiv.org/pdf/2002.01387
"""
assert num_vectors >= rank > 0
Q = _rand_pow_range_finder(
operator, operator_dim, num_vectors, depth, num_vectors, rng, range_sketch
)
C = operator(Q).T
U_hat, D, V = np.linalg.svd(C, full_matrices=False)
U = Q @ U_hat

error_factor = np.power(
1 + 4 * np.sqrt(2 * operator_dim / (rank - 1)),
1 / (2 * depth + 1)
)
error_bound = D[-1] * (1 + error_factor)
return U[:, :rank], D[:rank], V[:rank], Q, error_bound

def _genetic_relatedness_vector_individual(
tree_sequence: tskit.TreeSequence,
arr: np.ndarray,
indices: np.ndarray,
mode: str,
centre: bool = True,
windows = None,
) -> np.ndarray:
ij = np.vstack(
[
[n, k]
for k, i in enumerate(indices)
for n in tree_sequence.individual(i).nodes
]
)
samples, sample_individuals = (
ij[:, 0],
ij[:, 1],
) # sample node index, individual of those nodes
x = (
arr - arr.mean(axis=0) if centre else arr
) # centering within index in rows
x = tree_sequence.genetic_relatedness_vector(
W=x[sample_individuals],
windows=windows,
mode=mode,
centre=False,
nodes=samples,
)[0]

def bincount_fn(w):
return np.bincount(sample_individuals, w)

x = np.apply_along_axis(bincount_fn, axis=0, arr=x)
x = x - x.mean(axis=0) if centre else x # centering within index in cols

return x

def _genetic_relatedness_vector_node(
tree_sequence: tskit.TreeSequence,
arr: np.ndarray,
indices: np.ndarray,
mode: str,
centre: bool = True,
windows = None,
) -> np.ndarray:
x = arr - arr.mean(axis=0) if centre else arr
x = tree_sequence.genetic_relatedness_vector(
W=x, windows=windows, mode=mode, centre=False, nodes=indices,
)[0]
x = x - x.mean(axis=0) if centre else x

return x

@dataclass
class PCAResult:
"""
Expand Down

0 comments on commit e81db15

Please sign in to comment.