From c93a35cbcdde52f40506eb2cb3f323f8cb34c674 Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Tue, 25 Jul 2023 12:57:13 +0200 Subject: [PATCH] Add more docs --- src/tad_dftd3/util/misc.py | 39 +++++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/src/tad_dftd3/util/misc.py b/src/tad_dftd3/util/misc.py index ffb76c9..336c629 100644 --- a/src/tad_dftd3/util/misc.py +++ b/src/tad_dftd3/util/misc.py @@ -27,10 +27,42 @@ def real_atoms(numbers: Tensor) -> Tensor: + """ + Create a mask for atoms, discerning padding and actual atoms. + Padding value is zero. + + Parameters + ---------- + numbers : Tensor + Atomic numbers for all atoms. + + Returns + ------- + Tensor + Mask for atoms that discerns padding and real atoms. + """ return numbers != 0 def real_pairs(numbers: Tensor, diagonal: bool = False) -> Tensor: + """ + Create a mask for pairs of atoms from atomic numbers, discerning padding + and actual atoms. Padding value is zero. + + Parameters + ---------- + numbers : Tensor + Atomic numbers for all atoms. + diagonal : bool, optional + Flag for also writing `False` to the diagonal, i.e., to all pairs + with the same indices. Defaults to `False`, i.e., writing False + to the diagonal. + + Returns + ------- + Tensor + Mask for atom pairs that discerns padding and real atoms. + """ real = real_atoms(numbers) mask = real.unsqueeze(-2) * real.unsqueeze(-1) if diagonal is False: @@ -42,7 +74,7 @@ def real_triples( numbers: torch.Tensor, diagonal: bool = False, self: bool = True ) -> Tensor: """ - Create a mask for triples from atomic numbers. + Create a mask for triples from atomic numbers. Padding value is zero. Parameters ---------- @@ -53,12 +85,13 @@ def real_triples( triples with the same indices. Defaults to `False`, i.e., writing False to the diagonal. self : bool, optional - Flag for also writing `False` to all triples where at least two indices are identical. Defaults to `True`, i.e., not writing `False`. + Flag for also writing `False` to all triples where at least two indices + are identical. Defaults to `True`, i.e., not writing `False`. Returns ------- Tensor - Mask. + Mask for triples. """ real = real_pairs(numbers, diagonal=True) mask = real.unsqueeze(-3) * real.unsqueeze(-2) * real.unsqueeze(-1)