Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fifo communication for large testing projects #24690

Merged
merged 3 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def install_python_libs(session: nox.Session):
)

session.install("packaging")
session.install("debugpy")

# Download get-pip script
session.run(
Expand Down
56 changes: 17 additions & 39 deletions python_files/testing_tools/socket_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,24 @@ def __exit__(self, *_):
self.close()

def connect(self):
if sys.platform == "win32":
self._writer = open(self.name, "w", encoding="utf-8") # noqa: SIM115, PTH123
# reader created in read method
else:
self._socket = _SOCKET(socket.AF_UNIX, socket.SOCK_STREAM)
self._socket.connect(self.name)
self._writer = open(self.name, "w", encoding="utf-8") # noqa: SIM115, PTH123
# reader created in read method
return self

def close(self):
if sys.platform == "win32":
self._writer.close()
else:
# add exception catch
self._socket.close()
self._writer.close()
if hasattr(self, "_reader"):
self._reader.close()

def write(self, data: str):
if sys.platform == "win32":
try:
# for windows, is should only use \n\n
request = (
f"""content-length: {len(data)}\ncontent-type: application/json\n\n{data}"""
)
self._writer.write(request)
self._writer.flush()
except Exception as e:
print("error attempting to write to pipe", e)
raise (e)
else:
# must include the carriage-return defined (as \r\n) for unix systems
request = (
f"""content-length: {len(data)}\r\ncontent-type: application/json\r\n\r\n{data}"""
)
self._socket.send(request.encode("utf-8"))
try:
# for windows, is should only use \n\n
request = f"""content-length: {len(data)}\ncontent-type: application/json\n\n{data}"""
self._writer.write(request)
self._writer.flush()
except Exception as e:
print("error attempting to write to pipe", e)
raise (e)

def read(self, bufsize=1024) -> str:
"""Read data from the socket.
Expand All @@ -63,17 +48,10 @@ def read(self, bufsize=1024) -> str:
Returns:
data (str): Data received from the socket.
"""
if sys.platform == "win32":
# returns a string automatically from read
if not hasattr(self, "_reader"):
self._reader = open(self.name, encoding="utf-8") # noqa: SIM115, PTH123
return self._reader.read(bufsize)
else:
# receive bytes and convert to string
while True:
part: bytes = self._socket.recv(bufsize)
data: str = part.decode("utf-8")
return data
# returns a string automatically from read
if not hasattr(self, "_reader"):
self._reader = open(self.name, encoding="utf-8") # noqa: SIM115, PTH123
return self._reader.read(bufsize)


class SocketManager:
Expand Down
33 changes: 27 additions & 6 deletions python_files/tests/pytestadapter/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,22 @@ def parse_rpc_message(data: str) -> Tuple[Dict[str, str], str]:
print("json decode error")


def _listen_on_fifo(pipe_name: str, result: List[str], completed: threading.Event):
# Open the FIFO for reading
fifo_path = pathlib.Path(pipe_name)
with fifo_path.open() as fifo:
print("Waiting for data...")
while True:
if completed.is_set():
break # Exit loop if completed event is set
data = fifo.read() # This will block until data is available
if len(data) == 0:
# If data is empty, assume EOF
break
print(f"Received: {data}")
result.append(data)


def _listen_on_pipe_new(listener, result: List[str], completed: threading.Event):
"""Listen on the named pipe or Unix domain socket for JSON data from the server.
Expand Down Expand Up @@ -307,14 +323,19 @@ def runner_with_cwd_env(
# if additional environment variables are passed, add them to the environment
if env_add:
env.update(env_add)
server = UnixPipeServer(pipe_name)
server.start()
# server = UnixPipeServer(pipe_name)
# server.start()
#################
# Create the FIFO (named pipe) if it doesn't exist
# if not pathlib.Path.exists(pipe_name):
os.mkfifo(pipe_name)
#################

completed = threading.Event()

result = [] # result is a string array to store the data during threading
t1: threading.Thread = threading.Thread(
target=_listen_on_pipe_new, args=(server, result, completed)
target=_listen_on_fifo, args=(pipe_name, result, completed)
)
t1.start()

Expand Down Expand Up @@ -364,14 +385,14 @@ def generate_random_pipe_name(prefix=""):

# For Windows, named pipes have a specific naming convention.
if sys.platform == "win32":
return f"\\\\.\\pipe\\{prefix}-{random_suffix}-sock"
return f"\\\\.\\pipe\\{prefix}-{random_suffix}"

# For Unix-like systems, use either the XDG_RUNTIME_DIR or a temporary directory.
xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR")
if xdg_runtime_dir:
return os.path.join(xdg_runtime_dir, f"{prefix}-{random_suffix}.sock") # noqa: PTH118
return os.path.join(xdg_runtime_dir, f"{prefix}-{random_suffix}") # noqa: PTH118
else:
return os.path.join(tempfile.gettempdir(), f"{prefix}-{random_suffix}.sock") # noqa: PTH118
return os.path.join(tempfile.gettempdir(), f"{prefix}-{random_suffix}") # noqa: PTH118


class UnixPipeServer:
Expand Down
19 changes: 13 additions & 6 deletions python_files/unittestadapter/pvsc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from typing_extensions import NotRequired # noqa: E402

from testing_tools import socket_manager # noqa: E402

# Types


Expand Down Expand Up @@ -331,10 +329,10 @@ def send_post_request(

if __writer is None:
try:
__writer = socket_manager.PipeManager(test_run_pipe)
__writer.connect()
__writer = open(test_run_pipe, "wb") # noqa: SIM115, PTH123
except Exception as error:
error_msg = f"Error attempting to connect to extension named pipe {test_run_pipe}[vscode-unittest]: {error}"
print(error_msg, file=sys.stderr)
__writer = None
raise VSCodeUnittestError(error_msg) from error

Expand All @@ -343,10 +341,19 @@ def send_post_request(
"params": payload,
}
data = json.dumps(rpc)

try:
if __writer:
__writer.write(data)
request = (
f"""content-length: {len(data)}\r\ncontent-type: application/json\r\n\r\n{data}"""
)
size = 4096
encoded = request.encode("utf-8")
bytes_written = 0
while bytes_written < len(encoded):
print("writing more bytes!")
segment = encoded[bytes_written : bytes_written + size]
bytes_written += __writer.write(segment)
__writer.flush()
else:
print(
f"Connection error[vscode-unittest], writer is None \n[vscode-unittest] data: \n{data} \n",
Expand Down
50 changes: 27 additions & 23 deletions python_files/vscode_pytest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@

import pytest

script_dir = pathlib.Path(__file__).parent.parent
sys.path.append(os.fspath(script_dir))
sys.path.append(os.fspath(script_dir / "lib" / "python"))
from testing_tools import socket_manager # noqa: E402

if TYPE_CHECKING:
from pluggy import Result

Expand Down Expand Up @@ -171,7 +166,7 @@ def pytest_exception_interact(node, call, report):
collected_test = TestRunResultDict()
collected_test[node_id] = item_result
cwd = pathlib.Path.cwd()
execution_post(
send_execution_message(
os.fsdecode(cwd),
"success",
collected_test if collected_test else None,
Expand Down Expand Up @@ -295,7 +290,7 @@ def pytest_report_teststatus(report, config): # noqa: ARG001
)
collected_test = TestRunResultDict()
collected_test[absolute_node_id] = item_result
execution_post(
send_execution_message(
os.fsdecode(cwd),
"success",
collected_test if collected_test else None,
Expand Down Expand Up @@ -329,7 +324,7 @@ def pytest_runtest_protocol(item, nextitem): # noqa: ARG001
)
collected_test = TestRunResultDict()
collected_test[absolute_node_id] = item_result
execution_post(
send_execution_message(
os.fsdecode(cwd),
"success",
collected_test if collected_test else None,
Expand Down Expand Up @@ -405,15 +400,15 @@ def pytest_sessionfinish(session, exitstatus):
"children": [],
"id_": "",
}
post_response(os.fsdecode(cwd), error_node)
send_discovery_message(os.fsdecode(cwd), error_node)
try:
session_node: TestNode | None = build_test_tree(session)
if not session_node:
raise VSCodePytestError(
"Something went wrong following pytest finish, \
no session node was created"
)
post_response(os.fsdecode(cwd), session_node)
send_discovery_message(os.fsdecode(cwd), session_node)
except Exception as e:
ERRORS.append(
f"Error Occurred, traceback: {(traceback.format_exc() if e.__traceback__ else '')}"
Expand All @@ -425,7 +420,7 @@ def pytest_sessionfinish(session, exitstatus):
"children": [],
"id_": "",
}
post_response(os.fsdecode(cwd), error_node)
send_discovery_message(os.fsdecode(cwd), error_node)
else:
if exitstatus == 0 or exitstatus == 1:
exitstatus_bool = "success"
Expand All @@ -435,7 +430,7 @@ def pytest_sessionfinish(session, exitstatus):
)
exitstatus_bool = "error"

execution_post(
send_execution_message(
os.fsdecode(cwd),
exitstatus_bool,
None,
Expand Down Expand Up @@ -489,7 +484,7 @@ def pytest_sessionfinish(session, exitstatus):
result=file_coverage_map,
error=None,
)
send_post_request(payload)
send_message(payload)


def build_test_tree(session: pytest.Session) -> TestNode:
Expand Down Expand Up @@ -857,8 +852,10 @@ def get_node_path(node: Any) -> pathlib.Path:
atexit.register(lambda: __writer.close() if __writer else None)


def execution_post(cwd: str, status: Literal["success", "error"], tests: TestRunResultDict | None):
"""Sends a POST request with execution payload details.
def send_execution_message(
cwd: str, status: Literal["success", "error"], tests: TestRunResultDict | None
):
"""Sends message execution payload details.
Args:
cwd (str): Current working directory.
Expand All @@ -870,10 +867,10 @@ def execution_post(cwd: str, status: Literal["success", "error"], tests: TestRun
)
if ERRORS:
payload["error"] = ERRORS
send_post_request(payload)
send_message(payload)


def post_response(cwd: str, session_node: TestNode) -> None:
def send_discovery_message(cwd: str, session_node: TestNode) -> None:
"""
Sends a POST request with test session details in payload.
Expand All @@ -889,7 +886,7 @@ def post_response(cwd: str, session_node: TestNode) -> None:
}
if ERRORS is not None:
payload["error"] = ERRORS
send_post_request(payload, cls_encoder=PathEncoder)
send_message(payload, cls_encoder=PathEncoder)


class PathEncoder(json.JSONEncoder):
Expand All @@ -901,7 +898,7 @@ def default(self, o):
return super().default(o)


def send_post_request(
def send_message(
payload: ExecutionPayloadDict | DiscoveryPayloadDict | CoveragePayloadDict,
cls_encoder=None,
):
Expand All @@ -926,8 +923,7 @@ def send_post_request(

if __writer is None:
try:
__writer = socket_manager.PipeManager(TEST_RUN_PIPE)
__writer.connect()
__writer = open(TEST_RUN_PIPE, "wb") # noqa: SIM115, PTH123
except Exception as error:
error_msg = f"Error attempting to connect to extension named pipe {TEST_RUN_PIPE}[vscode-pytest]: {error}"
print(error_msg, file=sys.stderr)
Expand All @@ -945,10 +941,18 @@ def send_post_request(
"params": payload,
}
data = json.dumps(rpc, cls=cls_encoder)

try:
if __writer:
__writer.write(data)
request = (
f"""content-length: {len(data)}\r\ncontent-type: application/json\r\n\r\n{data}"""
)
size = 4096
encoded = request.encode("utf-8")
bytes_written = 0
while bytes_written < len(encoded):
segment = encoded[bytes_written : bytes_written + size]
bytes_written += __writer.write(segment)
__writer.flush()
else:
print(
f"Plugin error connection error[vscode-pytest], writer is None \n[vscode-pytest] data: \n{data} \n",
Expand Down
2 changes: 2 additions & 0 deletions python_files/vscode_pytest/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# def send_post_request():
# return
Loading
Loading