Skip to content

Commit

Permalink
Add more docs
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Jul 25, 2023
1 parent 5c61565 commit c93a35c
Showing 1 changed file with 36 additions and 3 deletions.
39 changes: 36 additions & 3 deletions src/tad_dftd3/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,42 @@


def real_atoms(numbers: Tensor) -> Tensor:
"""
Create a mask for atoms, discerning padding and actual atoms.
Padding value is zero.
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms.
Returns
-------
Tensor
Mask for atoms that discerns padding and real atoms.
"""
return numbers != 0


def real_pairs(numbers: Tensor, diagonal: bool = False) -> Tensor:
"""
Create a mask for pairs of atoms from atomic numbers, discerning padding
and actual atoms. Padding value is zero.
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms.
diagonal : bool, optional
Flag for also writing `False` to the diagonal, i.e., to all pairs
with the same indices. Defaults to `False`, i.e., writing False
to the diagonal.
Returns
-------
Tensor
Mask for atom pairs that discerns padding and real atoms.
"""
real = real_atoms(numbers)
mask = real.unsqueeze(-2) * real.unsqueeze(-1)
if diagonal is False:
Expand All @@ -42,7 +74,7 @@ def real_triples(
numbers: torch.Tensor, diagonal: bool = False, self: bool = True
) -> Tensor:
"""
Create a mask for triples from atomic numbers.
Create a mask for triples from atomic numbers. Padding value is zero.
Parameters
----------
Expand All @@ -53,12 +85,13 @@ def real_triples(
triples with the same indices. Defaults to `False`, i.e., writing False
to the diagonal.
self : bool, optional
Flag for also writing `False` to all triples where at least two indices are identical. Defaults to `True`, i.e., not writing `False`.
Flag for also writing `False` to all triples where at least two indices
are identical. Defaults to `True`, i.e., not writing `False`.
Returns
-------
Tensor
Mask.
Mask for triples.
"""
real = real_pairs(numbers, diagonal=True)
mask = real.unsqueeze(-3) * real.unsqueeze(-2) * real.unsqueeze(-1)
Expand Down

0 comments on commit c93a35c

Please sign in to comment.