Skip to content

Commit

Permalink
Removed explicit dependency on device.
Browse files Browse the repository at this point in the history
  • Loading branch information
ErnstRoell committed Jul 2, 2024
1 parent b776beb commit d58569b
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions dect/directions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch


def generate_uniform_directions(num_thetas: int = 64, d: int = 3, device: str = "cpu"):
def generate_uniform_directions(num_thetas: int = 64, d: int = 3):
"""
Generate randomly sampled directions from a sphere in d dimensions.
Expand All @@ -22,15 +22,13 @@ def generate_uniform_directions(num_thetas: int = 64, d: int = 3, device: str =
The number of directions to generate.
d: int
The dimension of the unit sphere. Default is 3 (hence R^3)
device: str
The device to put the tensor on.
"""
v = torch.randn(size=(d, num_thetas), device=device)
v = torch.randn(size=(d, num_thetas))
v /= v.pow(2).sum(axis=0).sqrt().unsqueeze(1)
return v


def generate_uniform_2d_directions(num_thetas: int = 64, device: str = "cpu"):
def generate_uniform_2d_directions(num_thetas: int = 64):
"""
Generate uniformly sampled directions on the unit circle in two dimensions.
Expand All @@ -44,13 +42,11 @@ def generate_uniform_2d_directions(num_thetas: int = 64, device: str = "cpu"):
The number of directions to generate.
d: int
The dimension of the unit sphere. Default is 3 (hence R^3)
device: str
The device to put the tensor on.
"""
v = torch.vstack(
[
torch.sin(torch.linspace(0, 2 * torch.pi, num_thetas, device=device)),
torch.cos(torch.linspace(0, 2 * torch.pi, num_thetas, device=device)),
torch.sin(torch.linspace(0, 2 * torch.pi, num_thetas)),
torch.cos(torch.linspace(0, 2 * torch.pi, num_thetas)),
]
)

Expand Down

0 comments on commit d58569b

Please sign in to comment.