Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for SPIR-V code generation #652

Merged
merged 7 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 55 additions & 26 deletions src/tcompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,34 +762,36 @@ struct CCallingConv {
lua_State *L;
terra_CompilerState *C;
Types *Ty;
bool pass_struct_as_exploded_values;
bool return_empty_struct_as_void;
bool wasm_cconv;
bool aarch64_cconv;
bool amdgpu_cconv;
bool ppc64_cconv;
int ppc64_float_limit;
int ppc64_int_limit;
bool ppc64_count_used;
bool spirv_cconv;
bool wasm_cconv;

CCallingConv(TerraCompilationUnit *CU_, Types *Ty_)
: CU(CU_),
T(CU_->T),
L(CU_->T->L),
C(CU_->T->C),
Ty(Ty_),
pass_struct_as_exploded_values(false),
return_empty_struct_as_void(false),
wasm_cconv(false),
aarch64_cconv(false),
amdgpu_cconv(false),
ppc64_cconv(false),
ppc64_float_limit(0),
ppc64_int_limit(0),
ppc64_count_used(false) {
ppc64_count_used(false),
spirv_cconv(false),
wasm_cconv(false) {
auto Triple = CU->TT->tm->getTargetTriple();
switch (Triple.getArch()) {
case Triple::ArchType::amdgcn: {
return_empty_struct_as_void = true;
pass_struct_as_exploded_values = true;
amdgpu_cconv = true;
} break;
case Triple::ArchType::aarch64:
case Triple::ArchType::aarch64_be: {
Expand All @@ -806,16 +808,27 @@ struct CCallingConv {
ppc64_int_limit = 8;
ppc64_count_used = true;
} break;
#if LLVM_VERSION >= 150
case Triple::ArchType::spirv32:
case Triple::ArchType::spirv64: {
return_empty_struct_as_void = true;
spirv_cconv = true;
} break;
#endif
case Triple::ArchType::wasm32:
case Triple::ArchType::wasm64: {
wasm_cconv = true;
} break;
default:
break;
}

switch (Triple.getOS()) {
case Triple::OSType::Win32: {
return_empty_struct_as_void = true;
} break;
default:
break;
}
}

Expand Down Expand Up @@ -1088,11 +1101,11 @@ struct CCallingConv {
return Argument(C_PRIMITIVE, t, usei1 ? Type::getInt1Ty(*CU->TT->ctx) : NULL);
}

if (wasm_cconv && !WasmIsSingletonOrEmpty(t->type)) {
if ((wasm_cconv && !WasmIsSingletonOrEmpty(t->type)) || spirv_cconv) {
return Argument(C_AGGREGATE_MEM, t);
}

if (pass_struct_as_exploded_values) {
if (amdgpu_cconv) {
return Argument(C_AGGREGATE_REG, t, t->type);
}

Expand Down Expand Up @@ -1128,7 +1141,7 @@ struct CCallingConv {

return Argument(C_AGGREGATE_REG, t, StructType::get(*CU->TT->ctx, elements));
}
void Classify(Obj *ftype, Obj *params, Classification *info) {
void Classify(Obj *ftype, CallingConv::ID cconv, Obj *params, Classification *info) {
Obj fparams, returntype;
ftype->obj("parameters", &fparams);
ftype->obj("returntype", &returntype);
Expand Down Expand Up @@ -1161,13 +1174,13 @@ struct CCallingConv {
CreateFunctionType(info, fparams.size(), ftype->boolean("isvararg"));
}

Classification *ClassifyFunction(Obj *fntyp) {
Classification *ClassifyFunction(Obj *fntyp, CallingConv::ID cconv) {
Classification *info = (Classification *)CU->symbols->getud(fntyp);
if (!info) {
info = new Classification(); // TODO: fix leak
Obj params;
fntyp->obj("parameters", &params);
Classify(fntyp, &params, info);
Classify(fntyp, cconv, &params, info);
CU->symbols->setud(fntyp, info);
}
return info;
Expand Down Expand Up @@ -1279,8 +1292,9 @@ struct CCallingConv {
}
}

Function *CreateFunction(Module *M, Obj *ftype, const Twine &name) {
Classification *info = ClassifyFunction(ftype);
Function *CreateFunction(Module *M, Obj *ftype, CallingConv::ID cconv,
const Twine &name) {
Classification *info = ClassifyFunction(ftype, cconv);
Function *fn = Function::Create(info->fntype, Function::InternalLinkage, name, M);
AttributeFnOrCall(fn, info);
return fn;
Expand Down Expand Up @@ -1311,7 +1325,7 @@ struct CCallingConv {
}
void EmitEntry(IRBuilder<> *B, Obj *ftype, Function *func,
std::vector<Value *> *variables) {
Classification *info = ClassifyFunction(ftype);
Classification *info = ClassifyFunction(ftype, func->getCallingConv());
assert(info->paramtypes.size() == variables->size());
Function::arg_iterator ai = func->arg_begin();
if (info->returntype.kind == C_AGGREGATE_MEM)
Expand Down Expand Up @@ -1359,7 +1373,7 @@ struct CCallingConv {
}
}
void EmitReturn(IRBuilder<> *B, Obj *ftype, Function *function, Value *result) {
Classification *info = ClassifyFunction(ftype);
Classification *info = ClassifyFunction(ftype, function->getCallingConv());
ArgumentKind kind = info->returntype.kind;

if (C_AGGREGATE_REG == kind &&
Expand Down Expand Up @@ -1420,10 +1434,10 @@ struct CCallingConv {
}
}

Value *EmitCall(IRBuilder<> *B, Obj *ftype, Obj *paramtypes, Value *callee,
std::vector<Value *> *actuals) {
Value *EmitCall(IRBuilder<> *B, Obj *ftype, CallingConv::ID cconv, Obj *paramtypes,
Value *callee, std::vector<Value *> *actuals) {
Classification info;
Classify(ftype, paramtypes, &info);
Classify(ftype, cconv, paramtypes, &info);

std::vector<Value *> arguments;

Expand Down Expand Up @@ -1867,16 +1881,25 @@ struct FunctionEmitter {
if (fstate->func) return fstate;
}

CallingConv::ID callingconv = CallingConv::MaxID;
if (funcobj->hasfield("callingconv")) {
callingconv = ParseCallingConv(funcobj->string("callingconv"));
}

Obj ftype;
funcobj->obj("type", &ftype);
// function name is $+name so that it can't conflict with any symbols imported
// from the C namespace
fstate->func = CC->CreateFunction(
M, &ftype, Twine(StringRef((isextern) ? "" : "$"), name));
fstate->func =
CC->CreateFunction(M, &ftype, callingconv,
Twine(StringRef((isextern) ? "" : "$"), name));
if (isextern) {
// Set external linkage for extern functions.
fstate->func->setLinkage(GlobalValue::ExternalLinkage);
}
if (callingconv != CallingConv::MaxID) {
fstate->func->setCallingConv(callingconv);
}

if (funcobj->hasfield("alwaysinline")) {
if (funcobj->boolean("alwaysinline")) {
Expand All @@ -1891,10 +1914,6 @@ struct FunctionEmitter {
fstate->func->addFnAttr(Attribute::NoInline);
}
}
if (funcobj->hasfield("callingconv")) {
const char *callingconv = funcobj->string("callingconv");
fstate->func->setCallingConv(ParseCallingConv(callingconv));
}
if (funcobj->hasfield("noreturn")) {
if (funcobj->boolean("noreturn")) {
fstate->func->addFnAttr(Attribute::NoReturn);
Expand Down Expand Up @@ -2810,7 +2829,12 @@ struct FunctionEmitter {
#if LLVM_VERSION < 170
return B->CreateBitCast(v, toT->type);
#else
return v;
if (fromT->type->getPointerAddressSpace() !=
toT->type->getPointerAddressSpace()) {
return B->CreateAddrSpaceCast(v, toT->type);
} else {
return v;
}
#endif
} else {
assert(toT->type->isIntegerTy());
Expand Down Expand Up @@ -3205,6 +3229,11 @@ struct FunctionEmitter {

call->obj("value", &func);

CallingConv::ID callingconv = CallingConv::MaxID;
if (func.hasfield("callingconv")) {
callingconv = ParseCallingConv(func.string("callingconv"));
}

Value *fn = emitExp(&func);

Obj fnptrtyp;
Expand All @@ -3220,7 +3249,7 @@ struct FunctionEmitter {
setInsertBlock(bb);
deferred.push_back(bb);
}
Value *r = CC->EmitCall(B, &fntyp, &paramtypes, fn, &actuals);
Value *r = CC->EmitCall(B, &fntyp, callingconv, &paramtypes, fn, &actuals);
setInsertBlock(cur); // defer may have changed it
return r;
}
Expand Down
13 changes: 9 additions & 4 deletions src/terralib.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,12 @@ do
if status then return r end
end
return self.name
elseif self:ispointer() then return "&"..tostring(self.type)
elseif self:ispointer() then
if not self.addressspace or self.addressspace == 0 then
return "&"..tostring(self.type)
else
return "pointer("..tostring(self.type)..","..tostring(self.addressspace)..")"
end
elseif self:isvector() then return "vector("..tostring(self.type)..","..tostring(self.N)..")"
elseif self:isfunction() then return mkstring(self.parameters,"{",",",self.isvararg and " ...}" or "}").." -> "..tostring(self.returntype)
elseif self:isarray() then
Expand Down Expand Up @@ -2455,15 +2460,15 @@ function typecheck(topexp,luaenv,simultaneousdefinitions)
end
local function ascompletepointer(exp) --convert pointer like things into pointers to _complete_ types
exp.type.type:tcomplete(exp)
return (insertcast(exp,terra.types.pointer(exp.type.type))) --parens are to truncate to 1 argument
return (insertcast(exp,terra.types.pointer(exp.type.type, exp.type.addressspace))) --parens are to truncate to 1 argument
end
-- subtracting 2 pointers
if pointerlike(l.type) and pointerlike(r.type) and l.type.type == r.type.type and e.operator == tokens["-"] then
return e:copy { operands = List {ascompletepointer(l),ascompletepointer(r)} }:withtype(terra.types.ptrdiff)
elseif pointerlike(l.type) and r.type:isintegral() then -- adding or subtracting a int to a pointer
return e:copy {operands = List {ascompletepointer(l),r} }:withtype(terra.types.pointer(l.type.type))
return e:copy {operands = List {ascompletepointer(l),r} }:withtype(terra.types.pointer(l.type.type, l.type.addressspace))
elseif l.type:isintegral() and pointerlike(r.type) then
return e:copy {operands = List {ascompletepointer(r),l} }:withtype(terra.types.pointer(r.type.type))
return e:copy {operands = List {ascompletepointer(r),l} }:withtype(terra.types.pointer(r.type.type, r.type.addressspace))
else
return meetbinary(e,"isarithmeticorvector",l,r)
end
Expand Down
26 changes: 26 additions & 0 deletions tests/addressspace.t
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
-- Tests of pointers with address spaces.

-- The exact meaning of this depends on the target, but at least basic
-- code compilation should work.

local function ptr1(ty)
-- A pointer in address space 1.
return terralib.types.pointer(ty, 1)
end

terra test(x : &int, y : ptr1(int))
-- Should be able to do math on pointers with non-zero address spaces:
var a = [ptr1(int8)](y)
var b = a + 8
var c = [ptr1(int)](b)
var d = c - y
y = c

-- Casts should work:
y = [ptr1(int)](x)
x = [&int](y)

return d
end
test:compile()
print(test)
Loading