From 3565e3c9d71f93432d75f785ae2f814c9ce12c6d Mon Sep 17 00:00:00 2001 From: muhd-umer Date: Sun, 14 Apr 2024 23:15:53 +0500 Subject: [PATCH] Fix Rician arguments --- comyx/network/links.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/comyx/network/links.py b/comyx/network/links.py index 1cf8974..662bdea 100644 --- a/comyx/network/links.py +++ b/comyx/network/links.py @@ -8,14 +8,16 @@ from ..fading import get_rvs from ..propagation import get_pathloss from ..utils import db2pow, ensure_list, get_distance +from .ris import RIS, STAR_RIS if TYPE_CHECKING: - from .ris import RIS from .transceiver import Transceiver NDArrayFloat = npt.NDArray[np.floating[Any]] NDArrayComplex = npt.NDArray[np.complexfloating[Any, Any]] +EPSILON = np.finfo(float).eps + class Link: r"""Represents a link in the modelled environment. @@ -129,7 +131,7 @@ def generate_rvs( Not private to allow for the generation of new channel gains for more flexible simulations. """ - if custom_rvs is None: + if custom_rvs is None and self._rician_args is None: self.rvs = get_rvs(self.shape, **self._fading_args, seed=seed) elif self._rician_args is not None: self.rvs = self.rician_fading(**self._rician_args) @@ -207,12 +209,12 @@ def rician_fading( los = [] if order == "post": assert isinstance( - self.rx, RIS + self.rx, (RIS, STAR_RIS) ), "The receiver must be an RIS for the post-order Rician fading." n_elements = self.rx.n_elements elif order == "pre": assert isinstance( - self.tx, RIS + self.tx, (RIS, STAR_RIS) ), "The transmitter must be an RIS for the pre-order Rician fading." n_elements = self.tx.n_elements else: @@ -226,15 +228,15 @@ def rician_fading( * np.pi * (self.rx.position[1] - self.tx.position[1]) / ( - np.sqrt( - (self.rx.position[0] - self.tx.position[0]) ** 2 - + (self.rx.position[1] - self.tx.position[1]) ** 2 + np.sqrt( # EPSILON is a small value to avoid division by zero + (self.rx.position[0] - self.tx.position[0] + EPSILON) ** 2 + + (self.rx.position[1] - self.tx.position[1] + EPSILON) ** 2 ) ) ) ) - los = np.array(los).reshape(self.shape) + los = np.array(np.repeat(los, self.shape[-1])).reshape(self.shape) nlos = get_rvs(self.shape, **self._fading_args, seed=self.seed) rvs = los * (np.sqrt(K / (K + 1))) + nlos * (1 / (np.sqrt(K + 1)))