diff --git a/lake/modules/mu2f_io_core.py b/lake/modules/mu2f_io_core.py index c9eba91a..25071366 100644 --- a/lake/modules/mu2f_io_core.py +++ b/lake/modules/mu2f_io_core.py @@ -10,13 +10,13 @@ from lake.attributes.control_signal_attr import ControlSignalAttr from _kratos import create_wrapper_flatten from lake.modules.reg_fifo import RegFIFO +from lake.modules.mux import Mux class IOCore_mu2f(Generator): def __init__(self, matrix_unit_data_width=16, tile_array_data_width=17, - num_ios=2, fifo_depth=2, allow_bypass=False, use_almost_full=False, @@ -34,6 +34,7 @@ def __init__(self, self.use_almost_full = use_almost_full self.total_sets = 0 + num_tracks = 5 # inputs self._clk = self.clock("clk") @@ -58,87 +59,217 @@ def __init__(self, mu_data_width = matrix_unit_data_width + mu_tile_array_datawidth_difference = tile_array_data_width - mu_data_width + assert mu_tile_array_datawidth_difference >=0, "Error: Matrix unit bus cannot drive CGRA bus because MU datawidth > CGRA datawidth" + + ######################################## + # FIFO ZERO + ######################################## + # Valid in from matrix unit + mu2io_v_0 = self.input(f"mu2io_{mu_data_width}_0_valid", 1) + mu2io_v_0.add_attribute(ControlSignalAttr(is_control=True, full_bus=False)) + + # Ready out to matrix unit + mu2io_r_0 = self.output(f"mu2io_{mu_data_width}_0_ready", 1) + mu2io_r_0.add_attribute(ControlSignalAttr(is_control=False, full_bus=False)) + + # R-V interface with fabric + io2f_r_0 = self.var(f"io2f_{tile_array_data_width}_0_ready", 1) + io2f_r_0.add_attribute(ControlSignalAttr(is_control=True, full_bus=False)) + io2f_v_0 = self.var(f"io2f_{tile_array_data_width}_0_valid", 1) + io2f_v_0.add_attribute(ControlSignalAttr(is_control=False, full_bus=False)) + + mu2io_0 = self.input(f"mu2io_{mu_data_width}_0", mu_data_width, packed=True) + mu2io_0.add_attribute(ControlSignalAttr(is_control=False, full_bus=True)) + + io2f_0 = self.var(f"io2f_{tile_array_data_width}_0", tile_array_data_width, packed=True) + io2f_0.add_attribute(ControlSignalAttr(is_control=False, full_bus=True)) + + # mu2io -> io2f fifo + mu2io_2_io2f_fifo_0 = RegFIFO(data_width=tile_array_data_width, + width_mult=1, + depth=self.fifo_depth, + mod_name_suffix=self.fifo_name_suffix, + almost_full_diff=1) + + self.add_child(f"mu2io_2_io2f_{tile_array_data_width}_0", + mu2io_2_io2f_fifo_0, + clk=self._gclk, + rst_n=self._rst_n, + clk_en=self._clk_en, + push=mu2io_v_0, + pop=io2f_r_0) + + # Append 0s at MSBs if CGRA bitwidth exceeds Matrix unit bitwidth + if mu_tile_array_datawidth_difference > 0: + self.wire(mu2io_2_io2f_fifo_0.ports.data_in, kts.concat(kts.const(0, mu_tile_array_datawidth_difference), mu2io_0)) + else: + self.wire(mu2io_2_io2f_fifo_0.ports.data_in, mu2io_0) + + ######################################## + # END FIFO ZERO + ######################################## + + + ######################################## + # FIFO ONE + ######################################## + # Valid in from matrix unit + mu2io_v_1 = self.input(f"mu2io_{mu_data_width}_1_valid", 1) + mu2io_v_1.add_attribute(ControlSignalAttr(is_control=True, full_bus=False)) + + # Ready out to matrix unit + mu2io_r_1 = self.output(f"mu2io_{mu_data_width}_1_ready", 1) + mu2io_r_1.add_attribute(ControlSignalAttr(is_control=False, full_bus=False)) + + # R-V interface with fabric + io2f_r_1 = self.var(f"io2f_{tile_array_data_width}_1_ready", 1) + io2f_r_1.add_attribute(ControlSignalAttr(is_control=True, full_bus=False)) + io2f_v_1 = self.var(f"io2f_{tile_array_data_width}_1_valid", 1) + io2f_v_1.add_attribute(ControlSignalAttr(is_control=False, full_bus=False)) + + mu2io_1 = self.input(f"mu2io_{mu_data_width}_1", mu_data_width, packed=True) + mu2io_1.add_attribute(ControlSignalAttr(is_control=False, full_bus=True)) + + io2f_1 = self.var(f"io2f_{tile_array_data_width}_1", tile_array_data_width, packed=True) + io2f_1.add_attribute(ControlSignalAttr(is_control=False, full_bus=True)) + + # mu2io -> io2f fifo + mu2io_2_io2f_fifo_1 = RegFIFO(data_width=tile_array_data_width, + width_mult=1, + depth=self.fifo_depth, + mod_name_suffix=self.fifo_name_suffix, + almost_full_diff=1) + + self.add_child(f"mu2io_2_io2f_{tile_array_data_width}_1", + mu2io_2_io2f_fifo_1, + clk=self._gclk, + rst_n=self._rst_n, + clk_en=self._clk_en, + push=mu2io_v_1, + pop=io2f_r_1) + + # Append 0s at MSBs if CGRA bitwidth exceeds Matrix unit bitwidth + if mu_tile_array_datawidth_difference > 0: + self.wire(mu2io_2_io2f_fifo_1.ports.data_in, kts.concat(kts.const(0, mu_tile_array_datawidth_difference), mu2io_1)) + else: + self.wire(mu2io_2_io2f_fifo_1.ports.data_in, mu2io_1) - # # Valid in from matrix unit - # tmp_mu2io_v = self.input(f"mu2io_{mu_data_width}_valid", 1) - # tmp_mu2io_v.add_attribute(ControlSignalAttr(is_control=True, full_bus=False)) + ######################################## + # END FIFO ONE + ######################################## - - for io_num in range(num_ios): - # Valid in from matrix unit - tmp_mu2io_v = self.input(f"mu2io_{mu_data_width}_{io_num}_valid", 1) - tmp_mu2io_v.add_attribute(ControlSignalAttr(is_control=True, full_bus=False)) + #TODO: This could potentially be replaced with AND gates + ######################################## + # READY SELECT + ######################################## - # Ready out to matrix unit - tmp_mu2io_r = self.output(f"mu2io_{mu_data_width}_{io_num}_ready", 1) - tmp_mu2io_r.add_attribute(ControlSignalAttr(is_control=False, full_bus=False)) + # Create ready select muxes + ready_mux_0 = Mux(height=num_tracks, width=1) + self.add_child(f"ready_0_mux", ready_mux_0) - # R-V interface with fabric - tmp_io2f_r = self.input(f"io2f_{tile_array_data_width}_{io_num}_ready", 1) - tmp_io2f_r.add_attribute(ControlSignalAttr(is_control=True, full_bus=False)) - tmp_io2f_v = self.output(f"io2f_{tile_array_data_width}_{io_num}_valid", 1) - tmp_io2f_v.add_attribute(ControlSignalAttr(is_control=False, full_bus=False)) + ready_mux_1 = Mux(height=num_tracks, width=1) + self.add_child(f"ready_1_mux", ready_mux_1) - tmp_mu2io = self.input(f"mu2io_{mu_data_width}_{io_num}", mu_data_width, packed=True) - tmp_mu2io.add_attribute(ControlSignalAttr(is_control=False, full_bus=True)) - - tmp_io2f = self.output(f"io2f_{tile_array_data_width}_{io_num}", tile_array_data_width, packed=True) - tmp_io2f.add_attribute(ControlSignalAttr(is_control=False, full_bus=True)) - - # mu2io -> io2f fifo - mu2io_2_io2f_fifo = RegFIFO(data_width=tile_array_data_width, - width_mult=1, - depth=self.fifo_depth, - mod_name_suffix=self.fifo_name_suffix, - almost_full_diff=1) - - self.add_child(f"mu2io_2_io2f_{tile_array_data_width}_{io_num}", - mu2io_2_io2f_fifo, - clk=self._gclk, - rst_n=self._rst_n, - clk_en=self._clk_en, - #clk_en=kts.const(1, 1), - push=tmp_mu2io_v, - pop=tmp_io2f_r) + # Ready select config regs (the actual select signal) + self._ready_select_0 = self.input(f"ready_select_0", 3) + self._ready_select_0.add_attribute(ConfigRegAttr("Track select config register. Selects driver for that track.")) + self._ready_select_1 = self.input(f"ready_select_1", 3) + self._ready_select_1.add_attribute(ConfigRegAttr("Track select config register. Selects driver for that track.")) + + # Wire up the select signals + self.wire(ready_mux_0.ports.S, self._ready_select_0) + self.wire(ready_mux_1.ports.S, self._ready_select_1) + + # Create readys and wire them to mux inputs + for track_num in range(num_tracks): + tmp_track_out_r = self.input(f"io2f_{tile_array_data_width}_T{track_num}_ready", 1) + tmp_track_out_r.add_attribute(ControlSignalAttr(is_control=False, full_bus=True)) + self.wire(ready_mux_0.ports.I[track_num], tmp_track_out_r) + self.wire(ready_mux_1.ports.I[track_num], tmp_track_out_r) - - - mu_tile_array_datawidth_difference = tile_array_data_width - mu_data_width - assert mu_tile_array_datawidth_difference >=0, "Error: Matrix unit bus cannot drive CGRA bus because MU datawidth > CGRA datawidth" - # Append 0s at MSBs if CGRA bitwidth exceeds Matrix unit bitwidth - if mu_tile_array_datawidth_difference > 0: - self.wire(mu2io_2_io2f_fifo.ports.data_in, kts.concat(kts.const(0, mu_tile_array_datawidth_difference), tmp_mu2io)) + self.wire(io2f_r_0, ready_mux_0.ports.O) + self.wire(io2f_r_1, ready_mux_1.ports.O) + + # If MU_inactive, set ready_out = 0 + # If dense bypass, send data straight through, bypassing FIFOs + if self.allow_bypass: + if self.use_almost_full: + self.wire(mu2io_r_0, kts.ternary(self._tile_en, kts.ternary(self._dense_bypass, + io2f_r_0, + ~mu2io_2_io2f_fifo_0.ports.almost_full)), kts.const(0, 1)) + self.wire(mu2io_r_1, kts.ternary(self._tile_en, kts.ternary(self._dense_bypass, + io2f_r_1, + ~mu2io_2_io2f_fifo_1.ports.almost_full)), kts.const(0, 1)) else: - self.wire(mu2io_2_io2f_fifo.ports.data_in, tmp_mu2io) + self.wire(mu2io_r_0, kts.ternary(self._tile_en, kts.ternary(self._dense_bypass, + io2f_r_0, + ~mu2io_2_io2f_fifo_0.ports.full)), kts.const(0, 1)) + self.wire(mu2io_r_1, kts.ternary(self._tile_en, kts.ternary(self._dense_bypass, + io2f_r_1, + ~mu2io_2_io2f_fifo_1.ports.full)), kts.const(0, 1)) + else: + if self.use_almost_full: + self.wire(mu2io_r_0, kts.ternary(self._tile_en, ~mu2io_2_io2f_fifo_0.ports.almost_full, kts.const(0, 1))) + self.wire(mu2io_r_1, kts.ternary(self._tile_en, ~mu2io_2_io2f_fifo_1.ports.almost_full, kts.const(0, 1))) + else: + self.wire(mu2io_r_0, kts.ternary(self._tile_en, ~mu2io_2_io2f_fifo_0.ports.full, kts.const(0, 1))) + self.wire(mu2io_r_1, kts.ternary(self._tile_en, ~mu2io_2_io2f_fifo_1.ports.full, kts.const(0, 1))) + + + + ######################################## + # TRACK SELECT + ######################################## + for track_num in range(num_tracks): + # Create track select config reg + self._tmp_track_select = self.input(f"track_select_T{track_num}", 2) + self._tmp_track_select.add_attribute(ConfigRegAttr("Track select config register. Selects driver for that track.")) + + # Create track output and its valid interface + tmp_track_out = self.output(f"io2f_{tile_array_data_width}_T{track_num}", tile_array_data_width) + tmp_track_out.add_attribute(ControlSignalAttr(is_control=False, full_bus=True)) + + tmp_track_out_v = self.output(f"io2f_{tile_array_data_width}_T{track_num}_valid", 1) + tmp_track_out_v.add_attribute(ControlSignalAttr(is_control=False, full_bus=False)) + + + # Create 3-to-1 mux (track_select mux) + track_mux = Mux(height=3, width=tile_array_data_width) + self.add_child(f"T_{track_num}_mux", track_mux) + + # Create 3-to-1 mux (track_select valid mux) + track_valid_mux = Mux(height=3, width=1) + self.add_child(f"T_{track_num}_valid_mux", track_valid_mux) - # If tile_en (matrix unit inactive) false, send 0 to fabric. - # If dense bypass, send data straight through, bypassing FIFOs + # Wire track select signal + self.wire(track_mux.ports.S, self._tmp_track_select) + self.wire(track_valid_mux.ports.S, self._tmp_track_select) + + self.wire(track_mux.ports.I[0], kts.const(0, tile_array_data_width)) + self.wire(track_valid_mux.ports.I[0], kts.const(0, 1)) + + # FIFO -> 3-to-1 mux connnections if self.allow_bypass: - self.wire(tmp_io2f, kts.ternary(self._tile_en, kts.ternary(self._dense_bypass, - tmp_mu2io, - mu2io_2_io2f_fifo.ports.data_out)), kts.const(0, tile_array_data_width)) - - if self.use_almost_full: - self.wire(tmp_mu2io_r, kts.ternary(self._tile_en, kts.ternary(self._dense_bypass, - tmp_io2f_r, - ~mu2io_2_io2f_fifo.ports.almost_full)), kts.const(0, 1)) - else: - self.wire(tmp_mu2io_r, kts.ternary(self._tile_en, kts.ternary(self._dense_bypass, - tmp_io2f_r, - ~mu2io_2_io2f_fifo.ports.full)), kts.const(0, 1)) - - self.wire(tmp_io2f_v, kts.ternary(self._tile_en, kts.ternary(self._dense_bypass, - tmp_mu2io_v, - ~mu2io_2_io2f_fifo.ports.empty)), kts.const(0, 1)) + self.wire(track_mux.ports.I[1], kts.ternary(self._dense_bypass, mu2io_0, mu2io_2_io2f_fifo_0.ports.data_out)) + self.wire(track_mux.ports.I[2], kts.ternary(self._dense_bypass, mu2io_1, mu2io_2_io2f_fifo_1.ports.data_out)) + + self.wire(track_valid_mux.ports.I[1], kts.ternary(self._dense_bypass, mu2io_v_0, ~mu2io_2_io2f_fifo_0.ports.empty)) + self.wire(track_valid_mux.ports.I[2], kts.ternary(self._dense_bypass, mu2io_v_1, ~mu2io_2_io2f_fifo_1.ports.empty)) else: - self.wire(tmp_io2f, kts.ternary(self._tile_en, mu2io_2_io2f_fifo.ports.data_out, kts.const(0, tile_array_data_width))) - if self.use_almost_full: - self.wire(tmp_mu2io_r, kts.ternary(self._tile_en, ~mu2io_2_io2f_fifo.ports.almost_full, kts.const(0, 1))) - else: - self.wire(tmp_mu2io_r, kts.ternary(self._tile_en, ~mu2io_2_io2f_fifo.ports.full, kts.const(0, 1))) - self.wire(tmp_io2f_v, kts.ternary(self._tile_en, ~mu2io_2_io2f_fifo.ports.empty, kts.const(0, 1))) + self.wire(track_mux.ports.I[1], mu2io_2_io2f_fifo_0.ports.data_out) + self.wire(track_mux.ports.I[2], mu2io_2_io2f_fifo_1.ports.data_out) + + self.wire(track_valid_mux.ports.I[1], ~mu2io_2_io2f_fifo_0.ports.empty) + self.wire(track_valid_mux.ports.I[2], ~mu2io_2_io2f_fifo_1.ports.empty) + + + # MU active mux (wire output) + self.wire(tmp_track_out, kts.ternary(self._tile_en, track_mux.ports.O, kts.const(0, 1))) + self.wire(tmp_track_out_v, kts.ternary(self._tile_en, track_valid_mux.ports.O, kts.const(0, 1))) + if self.add_clk_enable: kts.passes.auto_insert_clock_enable(self.internal_generator) @@ -156,6 +287,44 @@ def get_bitstream(self, config_dict): # Store all configurations here config = [("tile_en", 1)] + track_select_T0_val = 0 + track_select_T1_val = 0 + track_select_T2_val = 0 + track_select_T3_val = 0 + track_select_T4_val = 0 + + if 'track_select_T0' in config_dict: + track_select_T0_val = config_dict['track_select_T0'] + + if 'track_select_T1' in config_dict: + track_select_T1_val = config_dict['track_select_T1'] + + if 'track_select_T2' in config_dict: + track_select_T2_val = config_dict['track_select_T2'] + + if 'track_select_T3' in config_dict: + track_select_T3_val = config_dict['track_select_T3'] + + if 'track_select_T4' in config_dict: + track_select_T4_val = config_dict['track_select_T4'] + + + config += [("track_select_T0", track_select_T0_val), ("track_select_T1", track_select_T1_val), + ("track_select_T2", track_select_T2_val), ("track_select_T3", track_select_T3_val), + ("track_select_T4", track_select_T4_val)] + + + ready_select_0_val = 0 + ready_select_1_val = 0 + + if 'ready_select_0' in config_dict: + ready_select_0_val = config_dict['ready_select_0'] + + if 'ready_select_1' in config_dict: + ready_select_1_val = config_dict['ready_select_1'] + + config += [("ready_select_0", ready_select_0_val), ("ready_select_1", ready_select_1_val)] + if self.allow_bypass: dense_bypass_val = 0 diff --git a/lake/modules/mux.py b/lake/modules/mux.py new file mode 100644 index 00000000..41fe20c4 --- /dev/null +++ b/lake/modules/mux.py @@ -0,0 +1,36 @@ +import kratos as kts +from kratos import * + +class Mux(Generator): + def __init__(self, height: int, width: int): + name = "Mux_{0}_{1}".format(width, height) + super().__init__(name) + + # pass through wires + if height == 1: + self.in_ = self.input("I", width) + self.out_ = self.output("O", width) + self.wire(self.out_, self.in_) + return + + self.sel_size = clog2(height) + input_ = self.input("I", width, size=height) + self.out_ = self.output("O", width) + self._sel = self.input("S", self.sel_size) + + # add a combinational block + comb = self.combinational() + + # add a case statement + switch_ = comb.switch_(self.ports.S) + for i in range(height): + switch_.case_(i, self.out_(input_[i])) + # add default + switch_.case_(None, self.out_(0)) + +if __name__ == "__main__": + + mux_dut = Mux(height=4, width=16) + + verilog(mux_dut, filename="mux.sv", + optimize_if=False)