Skip to content

Commit

Permalink
boots: move register allocation code into new file
Browse files Browse the repository at this point in the history
  • Loading branch information
dinfuehr committed Aug 29, 2023
1 parent 57355cb commit 37ed74d
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 315 deletions.
318 changes: 5 additions & 313 deletions dora-boots/boots.dora
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ use std::collections::HashMap;

use package::assembler::{RegMap, RegSet, Register};
use package::codegen::CodeGen;
use package::codegen::x64::CodeGenX64;
use package::codegen::arm64::CodeGenArm64;
use package::codegen::{emitInst, CodeDescriptor, setLocationConstraints};
use package::codegen::{createCodeGen, emitInst, CodeDescriptor, generateCode, setLocationConstraints};
use package::dominator::buildDominatorTree;
use package::graph::{createGotoInst, createMove, Block, Edge, Graph, Inst, Location, LocationData, Op, opName};
use package::graph::dump::dumpGraph;
use package::graph_builder::createGraph;
use package::interface::{Architecture, CompilationInfo};
use package::interface::{compile, config, getSystemConfig};
use package::regalloc::performRegisterAllocation;
use package::liveness::computeLiveness;

mod assembler;
Expand All @@ -25,6 +24,7 @@ mod graph_builder_tests;
mod graph_builder;
mod interface;
mod liveness;
mod regalloc;
mod serializer;

pub fn compileFunction(ci: CompilationInfo): CodeDescriptor {
Expand All @@ -51,16 +51,14 @@ pub fn compileFunction(ci: CompilationInfo): CodeDescriptor {
println(dumpGraph(graph, config.architecture));
}

ensureLocationData(graph, codegen);
allocateRegisters(graph, codegen);
ssaDestruction(graph, codegen);
performRegisterAllocation(graph, codegen);

if (ci.emitGraph) {
println("after register allocation:");
println(dumpGraph(graph, config.architecture));
}

generate(graph, codegen)
generateCode(graph, codegen)
}

fn splitCriticalEdges(graph: Graph) {
Expand Down Expand Up @@ -113,309 +111,3 @@ fn splitCriticalEdge(graph: Graph, edge: Edge) {
assert(intermediate.successors.isEmpty());
intermediate.successors.push(intermediateToSuccessorEdge);
}

fn ensureLocationData(graph: Graph, codegen: CodeGen) {
for block in graph.insertionOrderIterator() {
for inst in block.phisIterator() {
setLocationConstraints(codegen, inst);
}

for inst in block.instructionsIterator() {
setLocationConstraints(codegen, inst);
}
}
}

fn allocateRegisters(graph: Graph, codegen: CodeGen) {
for block in graph.reversePostOrderIterator() {
let dominator = block.getDominator();

let inRegs = if dominator.isSome() {
let dominator = dominator.getOrPanic();
dominator.getInRegs().clone()
} else {
assert(block === graph.getEntryBlock());
RegMap[Inst]::new(codegen.allocatableRegisters())
};

let dropRegisters: Vec[Register] = Vec[Register]::new();

for (reg, inst) in inRegs {
if !block.getLiveIn().contains(inst.id().toInt64()) {
dropRegisters.push(reg);
}
}

for reg in dropRegisters {
inRegs.free(reg);
}

for inst in block.phisIterator() {
let loc = inst.getLocationData();

/*for input in inst.getInputs() {
let value = input.getValue();

if value.getLastUse() === input {
let output = value.getLocationData().getOutput();
if !output.hasRegister() {
println("no register for inst ${value.id()}");
}
inRegs.free(output.getRegister());
}
}*/

// pick a register
let register = inRegs.acquire(inst).getOrPanic();
loc.getOutput().setRegister(register);

}

for inst in block.instructionsIterator() {
let loc = inst.getLocationData();
let mut idx = 0;

for input in inst.getInputs() {
let value = input.getValue();
let output = value.getLocationData().getOutput();
loc.getInput(idx).setRegister(output.getRegister());
idx = idx + 1;
}

if loc.hasOutput() {
let output = loc.getOutput();

if output.hasRegister() {
// fixed register was already assigned
inRegs.pick(output.getRegister(), inst);
} else {
// pick a register
let register = inRegs.acquire(inst).getOrPanic();
output.setRegister(register);
inRegs.acquire(inst).getOrPanic();
}
}

for input in inst.getInputs() {
let value = input.getValue();

if value.getLastUse() === input {
let output = value.getLocationData().getOutput();
inRegs.free(output.getRegister());
}
}

}

block.setInRegs(inRegs);
}
}

fn ssaDestruction(graph: Graph, codegen: CodeGen) {
for block in graph.insertionOrderIterator() {
for predecessorEdge in block.predecessors {
let predecessor = predecessorEdge.source;
let moves = ParallelMoveResolver::new();

for phi in block.phisIterator() {
let phiDestRegister = phi.getLocationData().getOutput().getRegister();
let value = phi.getInput(predecessorEdge.targetIdx).getValue();

// Critical edges should already be split.
assert(predecessor.successors.size() == 1);

let valueRegister = value.getLocationData().getOutput().getRegister();

moves.add(phiDestRegister, valueRegister);
}

moves.resolve(codegen.getScratchRegister());
moves.emitMoves(predecessor);
}
}
}

class ParallelMoveResolver {
moves: Vec[(Register, Register)],
orderedMoves: Vec[(Register, Register)],
// All written registers.
directpreds: HashMap[Register, Register],
// All used variables and their current location.
locations: HashMap[Register, Register],
todo: RegSet,
}

impl ParallelMoveResolver {
static fn new(): ParallelMoveResolver {
ParallelMoveResolver(
Vec[(Register, Register)]::new(),
Vec[(Register, Register)]::new(),
HashMap[Register, Register]::new(),
HashMap[Register, Register]::new(),
RegSet::new(),
)
}

fn add(dest: Register, src: Register) {
if dest == src { return; }
self.moves.push((dest, src));
}

fn resolve(scratch: Register) {
for (dest, src) in self.moves {
assert(!self.directpreds.contains(dest));
self.directpreds.insert(dest, src);
self.locations.insert(src, src);
assert(!self.todo.contains(dest));
self.todo.add(dest);
}

for (dest, src) in self.moves {
if !self.locations.contains(dest) {
assert(self.todo.contains(dest));
self.emitMoveChainStartingAt(dest);
}
}

while !self.todo.isEmpty() {
let dest = self.todo.getFirst().getOrPanic();
assert(self.inCycle(dest));
self.emitOrderedMove(scratch, dest);
self.locations.insert(dest, scratch);
self.emitMoveChainStartingAt(dest);
}
}

fn inCycle(reg: Register): Bool {
let start = reg;
let mut reg = reg;
assert(self.todo.contains(reg));

while self.directpreds.contains(reg) {
reg = self.directpreds.get(reg).getOrPanic();
assert(self.todo.contains(reg));

if reg == start {
return true;
}
}

false
}

fn emitMoveChainStartingAt(dest: Register) {
let mut dest = dest;
assert(self.todo.contains(dest));

while self.todo.contains(dest) {
let original_src = self.directpreds.get(dest).getOrPanic();
let actual_src = self.locations.get(original_src).getOrPanic();
self.emitOrderedMove(dest, actual_src);
self.locations.insert(original_src, dest);
assert(self.todo.remove(dest));

if !self.directpreds.contains(original_src) {
return;
}

dest = original_src;
}
}

fn emitOrderedMove(dest: Register, src: Register) {
self.orderedMoves.push((dest, src));
}

fn emitMoves(block: Block) {
for (dest, src) in self.orderedMoves {
let inst = createMove(dest, src);
let terminator = block.lastInst();
assert(terminator.isTerminator());
block.insertBefore(inst, terminator);
}
}
}

@Test
fn testMovesWithoutDependencies() {
let moves = ParallelMoveResolver::new();
moves.add(Register(1u8), Register(0u8));
moves.add(Register(3u8), Register(2u8));
moves.resolve(Register(10u8));

let mut idx = 0;
for move in moves.orderedMoves {
assert(move.0 == moves.moves(idx).0);
assert(move.1 == moves.moves(idx).1);
idx = idx + 1;
}
}

@Test
fn testMovesWithDependency() {
let moves = ParallelMoveResolver::new();
moves.add(Register(1u8), Register(0u8));
moves.add(Register(2u8), Register(1u8));
moves.resolve(Register(10u8));

let expected = Array[(Register, Register)]::new(
(Register(2u8), Register(1u8)),
(Register(1u8), Register(0u8)),
);

let mut idx = 0;
for move in moves.orderedMoves {
assert(move.0 == expected(idx).0);
assert(move.1 == expected(idx).1);
idx = idx + 1;
}
}

@Test
fn testMovesWithCycle() {
let moves = ParallelMoveResolver::new();
moves.add(Register(1u8), Register(0u8));
moves.add(Register(2u8), Register(1u8));
moves.add(Register(0u8), Register(2u8));
moves.resolve(Register(3u8));

let expected = Array[(Register, Register)]::new(
(Register(3u8), Register(0u8)),
(Register(0u8), Register(2u8)),
(Register(2u8), Register(1u8)),
(Register(1u8), Register(3u8)),
);

let mut idx = 0;
for move in moves.orderedMoves {
assert(move.0 == expected(idx).0);
assert(move.1 == expected(idx).1);
idx = idx + 1;
}
}

fn generate(graph: Graph, codegen: CodeGen): CodeDescriptor {
codegen.prolog();

for block in graph.insertionOrderIterator() {
block.setLabel(codegen.createLabel());
}

for block in graph.reversePostOrderIterator() {
codegen.emitComment("Block ${block.id()}");
codegen.bindLabel(block.getLabel());

for inst in block.instructionsIterator() {
emitInst(codegen, inst);
}
}

codegen.finalize()
}

fn createCodeGen(ci: CompilationInfo): CodeGen {
match config.architecture {
Architecture::X64 => CodeGenX64::new(ci) as CodeGen,
Architecture::Arm64 => CodeGenArm64::new(ci) as CodeGen,
}
}
Loading

0 comments on commit 37ed74d

Please sign in to comment.