Skip to content

Commit

Permalink
Merge pull request #1103 from SpiNNakerManchester/connection
Browse files Browse the repository at this point in the history
always use connection to send_sdp
  • Loading branch information
rowleya authored Oct 10, 2023
2 parents 9b83424 + 27c8ce1 commit 9f90f97
Showing 1 changed file with 57 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,6 @@ class DataSpeedUpPacketGatherMachineVertex(
"_coord_word",
# transaction id
"_transaction_id",
# socket
"_connection",
# path for the data in report
"_in_report_path",
# ipaddress
Expand Down Expand Up @@ -290,13 +288,11 @@ def __init__(self, x, y, ip_address):

self._missing_seq_nums_data_in = list()

# Create a connection to be used
self._x = x
self._y = y
self._coord_word = None
self._ip_address = ip_address
self._remote_tag = None
self._connection = None

# local provenance storage
self._run = 0
Expand All @@ -311,14 +307,16 @@ def __init__(self, x, y, ip_address):
# Stored reinjection status for resetting timeouts
self._last_status = None

def __throttled_send(self, message):
def __throttled_send(self, message, connection):
"""
Slows down transmissions to allow SpiNNaker to keep up.
:param ~.SDPMessage message: message to send
:type connection:
~spinnman.connections.udp_packet_connections.SCAMPConnection
"""
# send first message
self._connection.send_sdp_message(message)
connection.send_sdp_message(message)
time.sleep(self._TRANSMISSION_THROTTLE_TIME)

@property
Expand Down Expand Up @@ -431,37 +429,6 @@ def _reserve_memory_regions(self, spec):
def get_binary_file_name(self):
return "data_speed_up_packet_gatherer.aplx"

@staticmethod
def locate_correct_write_data_function_for_chip_location(
uses_advanced_monitors, x, y, transceiver,
extra_monitor_cores_to_ethernet_connection_map):
"""
Supports other components figuring out which gatherer and function
to call for writing data onto SpiNNaker.
:param bool uses_advanced_monitors:
Whether the system is using advanced monitors
:param int x: the chip x coordinate to write data to
:param int y: the chip y coordinate to write data to
:param ~spinnman.transceiver.Transceiver transceiver:
the SpiNNMan instance
:param extra_monitor_cores_to_ethernet_connection_map:
mapping between cores and connections
:type extra_monitor_cores_to_ethernet_connection_map:
dict(tuple(int,int), DataSpeedUpPacketGatherMachineVertex)
:return: a write function of either a LPG or the spinnMan
:rtype: callable
"""
if not uses_advanced_monitors:
return transceiver.write_memory

chip = FecDataView.get_chip_at(x, y)
ethernet_connected_chip = FecDataView.get_chip_at(
chip.nearest_ethernet_x, chip.nearest_ethernet_y)
gatherer = extra_monitor_cores_to_ethernet_connection_map[
ethernet_connected_chip.x, ethernet_connected_chip.y]
return gatherer.send_data_into_spinnaker

def _generate_data_in_report(
self, time_diff, data_size, x, y,
address_written_to, missing_seq_nums):
Expand Down Expand Up @@ -609,7 +576,7 @@ def _send_data_via_extra_monitors(
:param int start_address: the base SDRAM address
"""
# Set up the connection
with self.__open_connection() as self._connection:
with self.__open_connection() as connection:
# how many packets after first one we need to send
self._max_seq_num = ceildiv(
len(data_to_write), BYTES_IN_FULL_PACKET_WITH_KEY)
Expand All @@ -634,7 +601,8 @@ def _send_data_via_extra_monitors(
while not received_confirmation:

# send initial attempt at sending all the data
self._send_all_data_based_packets(data_to_write, start_address)
self._send_all_data_based_packets(
data_to_write, start_address, connection)

# Don't create a missing buffer until at least one packet has
# come back.
Expand All @@ -644,7 +612,7 @@ def _send_data_via_extra_monitors(
try:
# try to receive a confirmation of some sort from
# spinnaker
data = self._connection.receive(
data = connection.receive(
timeout=self._TIMEOUT_PER_RECEIVE_IN_SECONDS)
time_out_count = 0

Expand Down Expand Up @@ -679,7 +647,7 @@ def _send_data_via_extra_monitors(
# to retransmit.
if seen_all or seen_last:
self._outgoing_retransmit_missing_seq_nums(
data_to_write, missing)
data_to_write, missing, connection)
missing.clear()

except SpinnmanTimeoutException as e:
Expand All @@ -699,7 +667,7 @@ def _send_data_via_extra_monitors(
break

self._outgoing_retransmit_missing_seq_nums(
data_to_write, missing)
data_to_write, missing, connection)
missing.clear()

def _read_in_missing_seq_nums(self, data, position, seq_nums):
Expand Down Expand Up @@ -736,13 +704,16 @@ def _read_in_missing_seq_nums(self, data, position, seq_nums):
return seen_last, seen_all

def _outgoing_retransmit_missing_seq_nums(
self, data_to_write, missing):
self, data_to_write, missing, connection):
"""
Transmits back into SpiNNaker the missing data based off missing
sequence numbers.
:param bytearray data_to_write: the data to write.
:param set(int) missing: a set of missing sequence numbers
:type connection:
~spinnman.connections.udp_packet_connections.SCAMPConnection
"""

missing_seqs_as_list = list(missing)
Expand All @@ -753,10 +724,10 @@ def _outgoing_retransmit_missing_seq_nums(
message, _length = self._calculate_data_in_data_from_seq_number(
data_to_write, missing_seq_num,
DATA_IN_COMMANDS.SEND_SEQ_DATA.value, None)
self.__throttled_send(message)
self.__throttled_send(message, connection)

# request an update on what is missing
self._send_tell_flag()
self._send_tell_flag(connection)

@staticmethod
def _calculate_position_from_seq_number(seq_num):
Expand Down Expand Up @@ -824,13 +795,16 @@ def _calculate_data_in_data_from_seq_number(
# return message for sending, and the length in data sent
return message, packet_data_length

def _send_location(self, start_address):
def _send_location(self, start_address, connection):
"""
Send location as separate message.
:param int start_address: SDRAM location
:type connection:
~spinnman.connections.udp_packet_connections.SCAMPConnection
"""
self._connection.send_sdp_message(self.__make_sdp_message(
connection.send_sdp_message(self.__make_sdp_message(
self._placement, SDP_PORTS.EXTRA_MONITOR_CORE_DATA_IN_SPEED_UP,
_FIVE_WORDS.pack(
DATA_IN_COMMANDS.SEND_DATA_TO_LOCATION.value,
Expand All @@ -840,24 +814,31 @@ def _send_location(self, start_address):
"start address for transaction {} is {}",
self._transaction_id, start_address)

def _send_tell_flag(self):
def _send_tell_flag(self, connection):
"""
Send tell flag as separate message.
:type connection:
~spinnman.connections.udp_packet_connections.SCAMPConnection
"""
self._connection.send_sdp_message(self.__make_sdp_message(
connection.send_sdp_message(self.__make_sdp_message(
self._placement, SDP_PORTS.EXTRA_MONITOR_CORE_DATA_IN_SPEED_UP,
_TWO_WORDS.pack(
DATA_IN_COMMANDS.SEND_TELL.value, self._transaction_id)))

def _send_all_data_based_packets(self, data_to_write, start_address):
def _send_all_data_based_packets(
self, data_to_write, start_address, connection):
"""
Send all the data as one block.
:param bytearray data_to_write: the data to send
:param int start_address:
:type connection:
~spinnman.connections.udp_packet_connections.SCAMPConnection
"""
# Send the location
self._send_location(start_address)
self._send_location(start_address, connection)

# where in the data we are currently up to
position_in_data = 0
Expand All @@ -872,11 +853,11 @@ def _send_all_data_based_packets(self, data_to_write, start_address):
position_in_data += length_to_send

# send the message
self.__throttled_send(message)
self.__throttled_send(message, connection)
log.debug("sent seq {} of {} bytes", seq_num, length_to_send)

# check for end flag
self._send_tell_flag()
self._send_tell_flag(connection)
log.debug("sent end flag")

def set_cores_for_data_streaming(self):
Expand Down Expand Up @@ -1041,8 +1022,6 @@ def get_data(
self._run, "No Extraction time", end - start)
return data

transceiver = FecDataView.get_transceiver()

# Update the IP Tag to work through a NAT firewall
with self.__open_connection() as connection:
# update transaction id for extra monitor
Expand All @@ -1061,7 +1040,7 @@ def get_data(
self._view = memoryview(self._output)
self._max_seq_num = self.calculate_max_seq_num()
lost_seq_nums = self._receive_data(
transceiver, placement, connection, transaction_id)
placement, connection, transaction_id)

# Stop anything else getting through (and reduce traffic)
connection.send_sdp_message(self.__make_sdp_message(
Expand Down Expand Up @@ -1089,12 +1068,11 @@ def get_data(

return self._output

def _receive_data(
self, transceiver, placement, connection, transaction_id):
def _receive_data(self, placement, connection, transaction_id):
"""
:param ~.Transceiver transceiver:
:param ~.Placement placement:
:param ~.UDPConnection connection:
:type connection:
~spinnman.connections.udp_packet_connections.SCAMPConnection
:param int transaction_id:
:rtype: list(int)
"""
Expand All @@ -1110,8 +1088,8 @@ def _receive_data(
if transaction_id == response_transaction_id:
timeoutcount = 0
seq_nums, finished = self._process_data(
data, seq_nums, finished, placement, transceiver,
lost_seq_nums, transaction_id)
data, seq_nums, finished, placement, lost_seq_nums,
transaction_id, connection)
else:
log.info(
"ignoring packet as transaction id should be {}"
Expand All @@ -1127,8 +1105,8 @@ def _receive_data(
# self.__reset_connection()
if not finished:
finished = self._determine_and_retransmit_missing_seq_nums(
seq_nums, transceiver, placement, lost_seq_nums,
transaction_id)
seq_nums, placement, lost_seq_nums, transaction_id,
connection)
return lost_seq_nums

@staticmethod
Expand Down Expand Up @@ -1183,18 +1161,20 @@ def _calculate_missing_seq_nums(self, seq_nums):
return [sn for sn in range(self._max_seq_num) if sn not in seq_nums]

def _determine_and_retransmit_missing_seq_nums(
self, seq_nums, transceiver, placement, lost_seq_nums,
transaction_id):
self, seq_nums, placement, lost_seq_nums, transaction_id,
connection):
"""
Determine if there are any missing sequence numbers, and if so
retransmits the missing sequence numbers back to the core for
retransmission.
:param set(int) seq_nums: the sequence numbers already received
:param ~.Transceiver transceiver: spinnman instance
:param ~.Placement placement: placement instance
:param list(int) lost_seq_nums:
:param int transaction_id: transaction_id
:type connection:
~spinnman.connections.udp_packet_connections.SCAMPConnection
:return: whether all packets are transmitted
:rtype: bool
"""
Expand Down Expand Up @@ -1276,7 +1256,7 @@ def _determine_and_retransmit_missing_seq_nums(
seq_num_offset += length_left_in_packet

# build SDP message and send it to the core
transceiver.send_sdp_message(self.__make_sdp_message(
connection.send_sdp_message(self.__make_sdp_message(
placement, SDP_PORTS.EXTRA_MONITOR_CORE_DATA_SPEED_UP, data))

# sleep for ensuring core doesn't lose packets
Expand All @@ -1285,8 +1265,8 @@ def _determine_and_retransmit_missing_seq_nums(
return False

def _process_data(
self, data, seq_nums, finished, placement, transceiver,
lost_seq_nums, transaction_id):
self, data, seq_nums, finished, placement, lost_seq_nums,
transaction_id, connection):
"""
Take a packet and process it see if we're finished yet.
Expand All @@ -1295,10 +1275,12 @@ def _process_data(
:param bool finished: bool which states if finished or not
:param ~.Placement placement:
placement object for location on machine
:param ~.Transceiver transceiver: spinnman instance
:param int transaction_id: the transaction ID for this stream
:param list(int) lost_seq_nums:
the list of n sequence numbers lost per iteration
:param int transaction_id: the transaction ID for this stream
:type connection:
~spinnman.connections.udp_packet_connections.SCAMPConnection
:return: set of data items, if its the first packet, the list of
sequence numbers, the sequence number received and if its finished
:rtype: tuple(set(int), bool)
Expand Down Expand Up @@ -1343,9 +1325,9 @@ def _process_data(
if is_end_of_stream:
if not self._check(seq_nums):
finished = self._determine_and_retransmit_missing_seq_nums(
placement=placement, transceiver=transceiver,
seq_nums=seq_nums, lost_seq_nums=lost_seq_nums,
transaction_id=transaction_id)
placement=placement, seq_nums=seq_nums,
lost_seq_nums=lost_seq_nums, transaction_id=transaction_id,
connection=connection)
else:
finished = True
return seq_nums, finished
Expand Down

0 comments on commit 9f90f97

Please sign in to comment.