Skip to content

Commit

Permalink
Merge pull request #192 from f3ndot/fix/Issue-125
Browse files Browse the repository at this point in the history
Fixes #125

Increased thread-safety by using Lock context manager syntax
  • Loading branch information
tayler6000 authored Jan 3, 2024
2 parents 1b59e70 + fbe8024 commit c5f8a55
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 118 deletions.
10 changes: 5 additions & 5 deletions pyVoIP/RTP.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,10 @@ def read(self, length: int = 160) -> bytes:
# This acts functionally as a lock while the buffer is being rebuilt.
while self.rebuilding:
time.sleep(0.01)
self.bufferLock.acquire()
packet = self.buffer.read(length)
if len(packet) < length:
packet = packet + (b"\x80" * (length - len(packet)))
self.bufferLock.release()
with self.bufferLock:
packet = self.buffer.read(length)
if len(packet) < length:
packet = packet + (b"\x80" * (length - len(packet)))
return packet

def rebuild(self, reset: bool, offset: int = 0, data: bytes = b"") -> None:
Expand All @@ -192,6 +191,7 @@ def rebuild(self, reset: bool, offset: int = 0, data: bytes = b"") -> None:
self.rebuilding = False

def write(self, offset: int, data: bytes) -> None:
# TODO: Can this safely be changed to use context manager syntax?
self.bufferLock.acquire()
self.log[offset] = data
bufferloc = self.buffer.tell()
Expand Down
197 changes: 99 additions & 98 deletions pyVoIP/SIP.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum, IntEnum
from threading import Timer, Lock
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
from pyVoIP.util import acquired_lock_and_unblocked_socket
from pyVoIP.VoIP.status import PhoneStatus
import pyVoIP
import hashlib
Expand Down Expand Up @@ -40,6 +41,10 @@ class SIPParseError(Exception):
pass


class RetryRequiredError(Exception):
pass


class Counter:
def __init__(self, start: int = 1):
self.x = start
Expand Down Expand Up @@ -842,40 +847,41 @@ def __init__(
self.registerFailures = 0
self.recvLock = Lock()

def recv(self) -> None:
def recv_loop(self) -> None:
while self.NSD:
self.recvLock.acquire()
self.s.setblocking(False)
try:
raw = self.s.recv(8192)
if raw != b"\x00\x00\x00\x00":
try:
message = SIPMessage(raw)
debug(message.summary())
self.parseMessage(message)
except Exception as ex:
debug(f"Error on header parsing: {ex}")
with acquired_lock_and_unblocked_socket(self.recvLock, self.s):
self.recv()
except BlockingIOError:
self.s.setblocking(True)
self.recvLock.release()
time.sleep(0.01)
continue
except SIPParseError as e:
if "SIP Version" in str(e):
request = self.genSIPVersionNotSupported(message)
self.out.sendto(
request.encode("utf8"), (self.server, self.port)
)
else:
debug(f"SIPParseError in SIP.recv: {type(e)}, {e}")
except Exception as e:
debug(f"SIP.recv error: {type(e)}, {e}\n\n{str(raw, 'utf8')}")
if pyVoIP.DEBUG:
self.s.setblocking(True)
self.recvLock.release()
raise
self.s.setblocking(True)
self.recvLock.release()

def recv(self) -> None:
try:
raw = self.s.recv(8192)
if raw != b"\x00\x00\x00\x00":
try:
message = SIPMessage(raw)
debug(message.summary())
self.parseMessage(message)
except Exception as ex:
debug(f"Error on header parsing: {ex}")
except SIPParseError as e:
if "SIP Version" in str(e):
request = self.genSIPVersionNotSupported(message)
self.out.sendto(
request.encode("utf8"), (self.server, self.port)
)
else:
debug(f"SIPParseError in SIP.recv: {type(e)}, {e}")
except BlockingIOError:
# Re-raise BlockingIOError so recv_loop() can release locks and
# continue
raise
except Exception as e:
debug(f"SIP.recv error: {type(e)}, {e}\n\n{str(raw, 'utf8')}")
if pyVoIP.DEBUG:
raise

def parseMessage(self, message: SIPMessage) -> None:
warnings.warn(
Expand Down Expand Up @@ -955,7 +961,7 @@ def start(self) -> None:
self.s.bind((self.myIP, self.myPort))
self.out = self.s
self.register()
t = Timer(1, self.recv)
t = Timer(1, self.recv_loop)
t.name = "SIP Recieve"
t.start()

Expand Down Expand Up @@ -1596,52 +1602,49 @@ def invite(
invite = self.genInvite(
number, str(sess_id), ms, sendtype, branch, call_id
)
self.recvLock.acquire()
self.out.sendto(invite.encode("utf8"), (self.server, self.port))
debug("Invited")
response = SIPMessage(self.s.recv(8192))

while (
response.status != SIPStatus(401)
and response.status != SIPStatus(100)
and response.status != SIPStatus(180)
) or response.headers["Call-ID"] != call_id:
if not self.NSD:
break
self.parseMessage(response)
with self.recvLock:
self.out.sendto(invite.encode("utf8"), (self.server, self.port))
debug("Invited")
response = SIPMessage(self.s.recv(8192))

if response.status == SIPStatus(100) or response.status == SIPStatus(
180
):
self.recvLock.release()
return SIPMessage(invite.encode("utf8")), call_id, sess_id
debug(f"Received Response: {response.summary()}")
ack = self.genAck(response)
self.out.sendto(ack.encode("utf8"), (self.server, self.port))
debug("Acknowledged")
authhash = self.genAuthorization(response)
nonce = response.authentication["nonce"]
realm = response.authentication["realm"]
auth = (
f'Authorization: Digest username="{self.username}",realm='
+ f'"{realm}",nonce="{nonce}",uri="sip:{self.server};'
+ f'transport=UDP",response="{str(authhash, "utf8")}",'
+ "algorithm=MD5\r\n"
)

invite = self.genInvite(
number, str(sess_id), ms, sendtype, branch, call_id
)
invite = invite.replace(
"\r\nContent-Length", f"\r\n{auth}Content-Length"
)
while (
response.status != SIPStatus(401)
and response.status != SIPStatus(100)
and response.status != SIPStatus(180)
) or response.headers["Call-ID"] != call_id:
if not self.NSD:
break
self.parseMessage(response)
response = SIPMessage(self.s.recv(8192))

if response.status == SIPStatus(
100
) or response.status == SIPStatus(180):
return SIPMessage(invite.encode("utf8")), call_id, sess_id
debug(f"Received Response: {response.summary()}")
ack = self.genAck(response)
self.out.sendto(ack.encode("utf8"), (self.server, self.port))
debug("Acknowledged")
authhash = self.genAuthorization(response)
nonce = response.authentication["nonce"]
realm = response.authentication["realm"]
auth = (
f'Authorization: Digest username="{self.username}",realm='
+ f'"{realm}",nonce="{nonce}",uri="sip:{self.server};'
+ f'transport=UDP",response="{str(authhash, "utf8")}",'
+ "algorithm=MD5\r\n"
)

self.out.sendto(invite.encode("utf8"), (self.server, self.port))
invite = self.genInvite(
number, str(sess_id), ms, sendtype, branch, call_id
)
invite = invite.replace(
"\r\nContent-Length", f"\r\n{auth}Content-Length"
)

self.recvLock.release()
self.out.sendto(invite.encode("utf8"), (self.server, self.port))

return SIPMessage(invite.encode("utf8")), call_id, sess_id
return SIPMessage(invite.encode("utf8")), call_id, sess_id

def bye(self, request: SIPMessage) -> None:
message = self.genBye(request)
Expand All @@ -1650,7 +1653,8 @@ def bye(self, request: SIPMessage) -> None:

def deregister(self) -> bool:
try:
deregistered = self.__deregister()
with self.recvLock:
deregistered = self.__deregister()
if not deregistered:
debug("DEREGISTERATION FAILED")
return False
Expand All @@ -1660,12 +1664,16 @@ def deregister(self) -> bool:
return deregistered
except BaseException as e:
debug(f"DEREGISTERATION ERROR: {e}")
# TODO: a maximum tries check should be implemented otherwise a
# RecursionError will throw
if isinstance(e, RetryRequiredError):
time.sleep(5)
return self.deregister()
if type(e) is OSError:
raise
return False

def __deregister(self) -> bool:
self.recvLock.acquire()
self.phone._status = PhoneStatus.DEREGISTERING
firstRequest = self.genFirstRequest(deregister=True)
self.out.sendto(firstRequest.encode("utf8"), (self.server, self.port))
Expand All @@ -1676,7 +1684,6 @@ def __deregister(self) -> bool:
if ready[0]:
resp = self.s.recv(8192)
else:
self.recvLock.release()
raise TimeoutError("Deregistering on SIP Server timed out")

response = SIPMessage(resp)
Expand All @@ -1696,7 +1703,6 @@ def __deregister(self) -> bool:
# At this point, it's reasonable to assume that
# this is caused by invalid credentials.
debug("Unauthorized")
self.recvLock.release()
raise InvalidAccountInfoError(
"Invalid Username or "
+ "Password for SIP server "
Expand All @@ -1710,23 +1716,20 @@ def __deregister(self) -> bool:
# with new urn:uuid or reply with expire 0
self._handle_bad_request()
else:
self.recvLock.release()
raise TimeoutError("Deregistering on SIP Server timed out")

if response.status == SIPStatus(500):
self.recvLock.release()
time.sleep(5)
return self.deregister()
# We raise so the calling function can sleep and try again
raise RetryRequiredError("Response SIP status of 500")

if response.status == SIPStatus.OK:
self.recvLock.release()
return True
self.recvLock.release()
return False

def register(self) -> bool:
try:
registered = self.__register()
with self.recvLock:
registered = self.__register()
if not registered:
debug("REGISTERATION FAILED")
self.registerFailures += 1
Expand All @@ -1749,6 +1752,9 @@ def register(self) -> bool:
self.stop()
self.fatalCallback()
return False
if isinstance(e, RetryRequiredError):
time.sleep(5)
return self.register()
self.__start_register_timer(delay=0)

def __start_register_timer(self, delay: Optional[int] = None):
Expand All @@ -1764,7 +1770,6 @@ def __start_register_timer(self, delay: Optional[int] = None):
self.registerThread.start()

def __register(self) -> bool:
self.recvLock.acquire()
self.phone._status = PhoneStatus.REGISTERING
firstRequest = self.genFirstRequest()
self.out.sendto(firstRequest.encode("utf8"), (self.server, self.port))
Expand All @@ -1775,7 +1780,6 @@ def __register(self) -> bool:
if ready[0]:
resp = self.s.recv(8192)
else:
self.recvLock.release()
raise TimeoutError("Registering on SIP Server timed out")

response = SIPMessage(resp)
Expand Down Expand Up @@ -1814,7 +1818,6 @@ def __register(self) -> bool:
debug("\nRECEIVED")
debug(response.summary())
debug("=" * 50)
self.recvLock.release()
raise InvalidAccountInfoError(
"Invalid Username or "
+ "Password for SIP server "
Expand All @@ -1828,7 +1831,6 @@ def __register(self) -> bool:
# with new urn:uuid or reply with expire 0
self._handle_bad_request()
else:
self.recvLock.release()
raise TimeoutError("Registering on SIP Server timed out")

if response.status == SIPStatus(407):
Expand All @@ -1844,17 +1846,15 @@ def __register(self) -> bool:
]:
# Unauthorized
if response.status == SIPStatus(500):
self.recvLock.release()
time.sleep(5)
return self.register()
# We raise so the calling function can sleep and try again
raise RetryRequiredError("Response SIP status of 500")
else:
# TODO: determine if needed here
self.parseMessage(response)

debug(response.summary())
debug(response.raw)

self.recvLock.release()
if response.status == SIPStatus.OK:
return True
else:
Expand All @@ -1873,16 +1873,17 @@ def _handle_bad_request(self) -> None:

def subscribe(self, lastresponse: SIPMessage) -> None:
# TODO: check if needed and maybe implement fully
self.recvLock.acquire()

subRequest = self.genSubscribe(lastresponse)
self.out.sendto(subRequest.encode("utf8"), (self.server, self.port))

response = SIPMessage(self.s.recv(8192))
with self.recvLock:
subRequest = self.genSubscribe(lastresponse)
self.out.sendto(
subRequest.encode("utf8"), (self.server, self.port)
)

debug(f'Got response to subscribe: {str(response.heading, "utf8")}')
response = SIPMessage(self.s.recv(8192))

self.recvLock.release()
debug(
f'Got response to subscribe: {str(response.heading, "utf8")}'
)

def trying_timeout_check(self, response: SIPMessage) -> SIPMessage:
"""
Expand Down
Loading

0 comments on commit c5f8a55

Please sign in to comment.