Skip to content

Commit

Permalink
Fix capturing for bools/arrays/matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
nyorain committed Jan 5, 2025
1 parent 5cc8bf1 commit 7082171
Show file tree
Hide file tree
Showing 8 changed files with 755 additions and 109 deletions.
361 changes: 361 additions & 0 deletions docs/own/stash.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,364 @@

/*
struct CaptureType {
Type* type {};
bool converted {};
};
CaptureType buildCaptureType(ShaderPatch& patch, const CaptureProcess& cpt) {
}
struct CaptureProcess {
u32 typeID;
u32 loadedID;
const spc::Meta::Decoration* memberDeco;
};
Type* processCapture(ShaderPatch& patch, const CaptureProcess& cpt) {
auto [typeID, loadedID, memberDeco] = cpt;
auto& alloc = patch.alloc;
auto& compiler = patch.compiler;
auto stype = &compiler.get_type(typeID);
if(stype->pointer) {
dlg_assert(stype->parent_type);
typeID = stype->parent_type;
stype = &compiler.get_type(typeID);
// TODO load
}
auto& dst = alloc.construct<Type>();
dst.deco.typeID = typeID;
auto* meta = compiler.get_ir().find_meta(typeID);
if(meta) {
dst.deco.name = copy(alloc, meta->decoration.alias);
}
if(memberDeco) {
if(memberDeco->decoration_flags.get(spv::DecorationRowMajor)) {
dst.deco.flags |= Decoration::Bits::rowMajor;
}
if(memberDeco->decoration_flags.get(spv::DecorationColMajor)) {
dst.deco.flags |= Decoration::Bits::colMajor;
}
if(memberDeco->decoration_flags.get(spv::DecorationMatrixStride)) {
dst.deco.matrixStride = memberDeco->matrix_stride;
}
}
// handle array
if(!stype->array.empty()) {
if(meta && meta->decoration.decoration_flags.get(spv::DecorationArrayStride)) {
dst.deco.arrayStride = meta->decoration.array_stride;
}
dlg_assert(stype->array.size() == stype->array_size_literal.size());
dst.array = alloc.alloc<u32>(stype->array.size());
for(auto d = 0u; d < stype->array.size(); ++d) {
if(stype->array_size_literal[d] == true) {
dst.array[d] = stype->array[d];
} else {
dst.array[d] = compiler.evaluate_constant_u32(stype->array[d]);
}
}
dst.deco.arrayTypeID = typeID;
dlg_assert(stype->parent_type);
typeID = stype->parent_type;
stype = &compiler.get_type(typeID);
meta = compiler.get_ir().find_meta(typeID);
dst.deco.typeID = typeID;
}
if(stype->basetype == spc::SPIRType::Struct) {
// handle struct
dst.members = alloc.alloc<Type::Member>(stype->member_types.size());
for(auto i = 0u; i < stype->member_types.size(); ++i) {
auto memTypeID = stype->member_types[i];
const spc::Meta::Decoration* deco {};
if(meta && meta->members.size() > i) {
deco = &meta->members[i];
}
// TODO PERF: remove allocation via dlg format here,
// use linearAllocator instead if needed
auto name = dlg::format("?{}", i);
if(deco && !deco->alias.empty()) {
// TODO PERF: we copy here with new, terrible
name = deco->alias;
}
auto& mdst = dst.members[i];
mdst.type = processCapture(patch, memTypeID, alloc, deco);
mdst.name = copy(alloc, name);
mdst.offset = deco ? deco->offset : 0u;
if(!mdst.type) {
return nullptr;
}
}
dst.type = Type::typeStruct;
return &dst;
}
// handle atom
auto getBaseType = [](spc::SPIRType::BaseType t) -> std::optional<Type::BaseType> {
switch(t) {
case spc::SPIRType::Double:
case spc::SPIRType::Float:
case spc::SPIRType::Half:
return Type::typeFloat;
case spc::SPIRType::Int:
case spc::SPIRType::Short:
case spc::SPIRType::Int64:
case spc::SPIRType::SByte:
return Type::typeInt;
case spc::SPIRType::UInt:
case spc::SPIRType::UShort:
case spc::SPIRType::UInt64:
case spc::SPIRType::UByte:
return Type::typeUint;
case spc::SPIRType::Boolean:
return Type::typeBool;
default:
return std::nullopt;
}
};
auto bt = getBaseType(stype->basetype);
if(!bt) {
dlg_error("Unsupported shader type: {}", u32(stype->basetype));
return nullptr;
}
dst.type = *bt;
dst.width = stype->width;
dst.vecsize = stype->vecsize;
dst.columns = stype->columns;
return &dst;
}
*/

/*
ProcessedCapture processCaptureNonArray(ShaderPatch& patch, LinAllocScope& tms,
Type& type, span<const u32> loadedIDs) {
u32 copiedTypeID = type.deco.typeID;
span<const u32> retIDs = loadedIDs;
if(!type.members.empty()) {
dlg_assert(type.type == Type::typeStruct);
auto copied = tms.alloc<u32>(loadedIDs.size());
span<u32> typeIDs = tms.alloc<u32>(loadedIDs.size());
span<span<const u32>> memberIDs =
tms.alloc<span<const u32>>(loadedIDs.size());
for(auto [i, member] : enumerate(type.members)) {
span<u32> loadedMembers = tms.alloc<u32>(loadedIDs.size());
for(auto [j, id] : enumerate(loadedIDs)) {
loadedMembers[j] = patch.genOp(spv::OpCompositeExtract,
member.type->array.empty() ? member.type->deco.typeID : member.type->deco.arrayTypeID,
id, i);
}
auto capture = processCapture(patch, tms, *member.type, loadedMembers);
memberIDs[i] = capture.ids;
typeIDs[i] = capture.typeID;
}
copiedTypeID = ++patch.freeID;
patch.decl<spv::OpTypeStruct>()
.push(copiedTypeID)
.push(typeIDs);
// TODO offset deco
// TODO copy other member decos!
for(auto [i, ids] : enumerate(memberIDs)) {
copied[i] = patch.genOp(spv::OpCompositeConstruct, copiedTypeID, ids);
}
retIDs = copied;
} else if(type.type == Type::typeBool) {
type.type = Type::typeUint;
type.width = 32u;
type.deco.typeID = patch.typeUint;
copiedTypeID = patch.typeUint;
auto copied = tms.alloc<u32>(loadedIDs.size());
for(auto [i, src] : enumerate(loadedIDs)) {
copied[i] = patch.genOp(spv::OpSelect, patch.typeUint,
src, patch.const1, patch.const0);
}
retIDs = copied;
}
ProcessedCapture ret;
ret.typeID = copiedTypeID;
ret.ids = retIDs;
return ret;
}
ProcessedCapture processCapture(ShaderPatch& patch, LinAllocScope& tms,
Type& type, span<const u32> loadedIDs) {
if(type.array.empty()) {
return processCaptureNonArray(patch, tms, type, loadedIDs);
}
auto totalCount = 1u;
for(auto dimSize : type.array) {
totalCount *= dimSize;
}
span<u32> atomIDs = tms.alloc<u32>(loadedIDs.size() * totalCount);
u32 typeID = type.deco.arrayTypeID;
auto* spcType = &patch.compiler.get_type(typeID);
for(auto [i, id] : enumerate(loadedIDs)) {
atomIDs[i * totalCount] = id;
}
auto lastCount = loadedIDs.size();
auto stride = totalCount;
for(auto dimSize : reversed(type.array)) {
dlg_assert(dimSize <= stride);
for(auto srcOff = 0u; srcOff < lastCount; ++srcOff) {
auto srcID = srcOff * stride;
for(auto dstOff = 0u; dstOff < dimSize; ++dstOff) {
auto dstID = srcID + dstOff;
atomIDs[dstID] = patch.genOp(spv::OpCompositeExtract,
typeID, atomIDs[srcID], dstOff);
}
}
dlg_assert(spcType->parent_type);
u32 typeID = spcType->parent_type;
spcType = &patch.compiler.get_type(typeID);
lastCount *= dimSize;
stride /= dimSize;
}
dlg_assert(stride == 1u);
dlg_assert(lastCount == totalCount * loadedIDs.size());
auto baseCapture = processCaptureNonArray(patch, tms, type, atomIDs);
auto copiedTypeID = baseCapture.typeID;
std::copy(baseCapture.ids.begin(), baseCapture.ids.end(), atomIDs.begin());
for(auto dimSize : type.array) {
auto id = ++patch.freeID;
patch.decl<spv::OpTypeArray>()
.push(id)
.push(copiedTypeID)
.push(u32(dimSize));
// TODO stride deco. member?
copiedTypeID = id;
dlg_assert(lastCount % dimSize == 0u);
auto dstCount = lastCount / dimSize;
for(auto dstOff = 0u; dstOff < dstCount; ++dstOff) {
auto dstID = ++patch.freeID;
auto builder = patch.instr(spv::OpCompositeConstruct);
builder.push(copiedTypeID);
builder.push(dstID);
for(auto srcOff = 0u; srcOff < dimSize; ++srcOff) {
auto srcID = dstOff * dimSize + srcOff;
builder.push(atomIDs[srcID]);
}
atomIDs[dstOff] = dstID;
}
lastCount = dstCount;
}
dlg_assert(lastCount == loadedIDs.size());
ProcessedCapture ret;
ret.typeID = copiedTypeID;
ret.ids = atomIDs.first(lastCount);
return ret;
}
void fixDecorateCaptureType(ShaderPatch& patch, Type& type) {
const auto& ir = patch.compiler.get_ir();
if(!type.members.empty()) {
dlg_assert(type.type == Type::typeStruct);
auto* meta = ir.find_meta(type.deco.typeID);
dlg_assert(meta && meta->members.size() == type.members.size());
auto needsOffsetDeco = !meta->members[0].decoration_flags.get(spv::DecorationOffset);
auto offset = 0u;
for(auto [i, member] : enumerate(type.members)) {
fixDecorateCaptureType(patch, *const_cast<Type*>(member.type));
if(needsOffsetDeco) {
dlg_assert(!meta->members[0].decoration_flags.get(spv::DecorationOffset));
offset = vil::alignPOT(offset, align(type, patch.bufLayout));
member.offset = offset;
patch.decl<spv::OpMemberDecorate>()
.push(type.deco.typeID)
.push(u32(i))
.push(spv::DecorationOffset)
.push(offset);
auto dstSize = size(*member.type, patch.bufLayout);
offset += dstSize;
}
}
}
if(!type.array.empty()) {
dlg_assert(type.deco.arrayTypeID != 0u);
auto* meta = ir.find_meta(type.deco.arrayTypeID);
if(!meta || !meta->decoration.decoration_flags.get(spv::DecorationArrayStride)) {
dlg_assert(type.deco.arrayStride == 0u);
auto tarray = type.array;
type.array = {};
type.deco.arrayStride = align(
size(type, patch.bufLayout),
align(type, patch.bufLayout));
type.array = tarray;
patch.decl<spv::OpDecorate>()
.push(type.deco.arrayTypeID)
.push(spv::DecorationArrayStride)
.push(type.deco.arrayStride);
} else {
dlg_assert(type.deco.arrayStride);
}
}
// TODO: matrixStride
if(type.columns > 1u) {
dlg_error("TODO: add matrixstride deco");
}
}
*/



////
///
#include <fwd.hpp>
#include <commandDesc.hpp>
#include <cb.hpp>
Expand Down
Loading

0 comments on commit 7082171

Please sign in to comment.