diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2024-10-29 14:49:26 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-29 14:49:26 +0800 |
| commit | f65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch) | |
| tree | ea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ir-lower-generics.cpp | |
| parent | a729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff) | |
format
* format
* Minor test fixes
* enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ir-lower-generics.cpp')
| -rw-r--r-- | source/slang/slang-ir-lower-generics.cpp | 487 |
1 files changed, 250 insertions, 237 deletions
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index dc267aaeb..4fd162e53 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -1,298 +1,311 @@ // slang-ir-lower-generics.cpp #include "slang-ir-lower-generics.h" -#include "slang-ir-any-value-marshalling.h" +#include "../core/slang-func-ptr.h" +#include "../core/slang-performance-profiler.h" #include "slang-ir-any-value-inference.h" +#include "slang-ir-any-value-marshalling.h" #include "slang-ir-augment-make-existential.h" #include "slang-ir-generics-lowering-context.h" +#include "slang-ir-inst-pass-base.h" +#include "slang-ir-layout.h" #include "slang-ir-lower-existential.h" -#include "slang-ir-lower-tuple-types.h" -#include "slang-ir-lower-generic-function.h" #include "slang-ir-lower-generic-call.h" +#include "slang-ir-lower-generic-function.h" #include "slang-ir-lower-generic-type.h" -#include "slang-ir-inst-pass-base.h" +#include "slang-ir-lower-tuple-types.h" #include "slang-ir-specialize-dispatch.h" #include "slang-ir-specialize-dynamic-associatedtype-lookup.h" -#include "slang-ir-witness-table-wrapper.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-util.h" -#include "slang-ir-layout.h" - -#include "../core/slang-performance-profiler.h" -#include "../core/slang-func-ptr.h" +#include "slang-ir-witness-table-wrapper.h" namespace Slang { - // Replace all uses of RTTI objects with its sequential ID. - // Currently we don't use RTTI objects at all, so all of them - // are 0. - void specializeRTTIObjectReferences(SharedGenericsLoweringContext* sharedContext) +// Replace all uses of RTTI objects with its sequential ID. +// Currently we don't use RTTI objects at all, so all of them +// are 0. +void specializeRTTIObjectReferences(SharedGenericsLoweringContext* sharedContext) +{ + uint32_t id = 0; + for (auto rtti : sharedContext->mapTypeToRTTIObject) { - uint32_t id = 0; - for (auto rtti : sharedContext->mapTypeToRTTIObject) + IRBuilder builder(sharedContext->module); + builder.setInsertBefore(rtti.value); + IRUse* nextUse = nullptr; + auto uint2Type = builder.getVectorType( + builder.getUIntType(), + builder.getIntValue(builder.getIntType(), 2)); + IRInst* uint2Args[] = { + builder.getIntValue(builder.getUIntType(), id), + builder.getIntValue(builder.getUIntType(), 0)}; + auto idOperand = builder.emitMakeVector(uint2Type, 2, uint2Args); + for (auto use = rtti.value->firstUse; use; use = nextUse) { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(rtti.value); - IRUse* nextUse = nullptr; - auto uint2Type = builder.getVectorType( - builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2)); - IRInst* uint2Args[] = { - builder.getIntValue(builder.getUIntType(), id), - builder.getIntValue(builder.getUIntType(), 0)}; - auto idOperand = builder.emitMakeVector(uint2Type, 2, uint2Args); - for (auto use = rtti.value->firstUse; use; use = nextUse) + nextUse = use->nextUse; + if (use->getUser()->getOp() == kIROp_GetAddr) { - nextUse = use->nextUse; - if (use->getUser()->getOp() == kIROp_GetAddr) - { - use->getUser()->replaceUsesWith(idOperand); - } + use->getUser()->replaceUsesWith(idOperand); } } } +} - // Replace all WitnessTableID type or RTTIHandleType with `uint2`. - void cleanUpRTTIHandleTypes(SharedGenericsLoweringContext* sharedContext) +// Replace all WitnessTableID type or RTTIHandleType with `uint2`. +void cleanUpRTTIHandleTypes(SharedGenericsLoweringContext* sharedContext) +{ + List<IRInst*> instsToRemove; + for (auto inst : sharedContext->module->getGlobalInsts()) { - List<IRInst*> instsToRemove; - for (auto inst : sharedContext->module->getGlobalInsts()) + switch (inst->getOp()) { - switch (inst->getOp()) + case kIROp_WitnessTableIDType: + if (isComInterfaceType((IRType*)inst->getOperand(0))) + continue; + // fall through + case kIROp_RTTIHandleType: { - case kIROp_WitnessTableIDType: - if (isComInterfaceType((IRType*)inst->getOperand(0))) - continue; - // fall through - case kIROp_RTTIHandleType: - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - auto uint2Type = builder.getVectorType( - builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2)); - inst->replaceUsesWith(uint2Type); - instsToRemove.add(inst); - } - break; + IRBuilder builder(sharedContext->module); + builder.setInsertBefore(inst); + auto uint2Type = builder.getVectorType( + builder.getUIntType(), + builder.getIntValue(builder.getIntType(), 2)); + inst->replaceUsesWith(uint2Type); + instsToRemove.add(inst); } + break; } - for (auto inst : instsToRemove) - inst->removeAndDeallocate(); } + for (auto inst : instsToRemove) + inst->removeAndDeallocate(); +} - // Remove all interface types from module. - void cleanUpInterfaceTypes(SharedGenericsLoweringContext* sharedContext) +// Remove all interface types from module. +void cleanUpInterfaceTypes(SharedGenericsLoweringContext* sharedContext) +{ + IRBuilder builder(sharedContext->module); + builder.setInsertInto(sharedContext->module->getModuleInst()); + auto dummyInterfaceObj = builder.getIntValue(builder.getIntType(), 0); + List<IRInst*> interfaceInsts; + for (auto inst : sharedContext->module->getGlobalInsts()) { - IRBuilder builder(sharedContext->module); - builder.setInsertInto(sharedContext->module->getModuleInst()); - auto dummyInterfaceObj = builder.getIntValue(builder.getIntType(), 0); - List<IRInst*> interfaceInsts; - for (auto inst : sharedContext->module->getGlobalInsts()) + if (inst->getOp() == kIROp_InterfaceType) { - if (inst->getOp() == kIROp_InterfaceType) - { - if (inst->findDecoration<IRComInterfaceDecoration>()) - continue; + if (inst->findDecoration<IRComInterfaceDecoration>()) + continue; - interfaceInsts.add(inst); - } - } - for (auto inst : interfaceInsts) - { - inst->replaceUsesWith(dummyInterfaceObj); - inst->removeAndDeallocate(); + interfaceInsts.add(inst); } } - - void lowerIsTypeInsts(SharedGenericsLoweringContext* sharedContext) + for (auto inst : interfaceInsts) { - InstPassBase pass(sharedContext->module); - pass.processInstsOfType<IRIsType>(kIROp_IsType, [&](IRIsType* inst) - { - auto witnessTableType = as<IRWitnessTableTypeBase>(inst->getValueWitness()->getDataType()); - if (witnessTableType && isComInterfaceType((IRType*)witnessTableType->getConformanceType())) - return; - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - auto eqlInst = builder.emitEql(builder.emitGetSequentialIDInst(inst->getValueWitness()), - builder.emitGetSequentialIDInst(inst->getTargetWitness())); - inst->replaceUsesWith(eqlInst); - inst->removeAndDeallocate(); - }); + inst->replaceUsesWith(dummyInterfaceObj); + inst->removeAndDeallocate(); } +} - // Turn all references of witness table or RTTI objects into integer IDs, generate - // specialized `switch` based dispatch functions based on witness table IDs, and remove - // all original witness table, RTTI object and interface definitions from IR module. - // With these transformations, the resulting code is compatible with D3D/Vulkan where - // no pointers are involved in RTTI / dynamic dispatch logic. - void specializeRTTIObjects(SharedGenericsLoweringContext* sharedContext, DiagnosticSink* sink) - { - specializeDispatchFunctions(sharedContext); - if (sink->getErrorCount() != 0) - return; +void lowerIsTypeInsts(SharedGenericsLoweringContext* sharedContext) +{ + InstPassBase pass(sharedContext->module); + pass.processInstsOfType<IRIsType>( + kIROp_IsType, + [&](IRIsType* inst) + { + auto witnessTableType = + as<IRWitnessTableTypeBase>(inst->getValueWitness()->getDataType()); + if (witnessTableType && + isComInterfaceType((IRType*)witnessTableType->getConformanceType())) + return; + IRBuilder builder(sharedContext->module); + builder.setInsertBefore(inst); + auto eqlInst = builder.emitEql( + builder.emitGetSequentialIDInst(inst->getValueWitness()), + builder.emitGetSequentialIDInst(inst->getTargetWitness())); + inst->replaceUsesWith(eqlInst); + inst->removeAndDeallocate(); + }); +} + +// Turn all references of witness table or RTTI objects into integer IDs, generate +// specialized `switch` based dispatch functions based on witness table IDs, and remove +// all original witness table, RTTI object and interface definitions from IR module. +// With these transformations, the resulting code is compatible with D3D/Vulkan where +// no pointers are involved in RTTI / dynamic dispatch logic. +void specializeRTTIObjects(SharedGenericsLoweringContext* sharedContext, DiagnosticSink* sink) +{ + specializeDispatchFunctions(sharedContext); + if (sink->getErrorCount() != 0) + return; - lowerIsTypeInsts(sharedContext); + lowerIsTypeInsts(sharedContext); - specializeDynamicAssociatedTypeLookup(sharedContext); - if (sink->getErrorCount() != 0) - return; + specializeDynamicAssociatedTypeLookup(sharedContext); + if (sink->getErrorCount() != 0) + return; - sharedContext->mapInterfaceRequirementKeyValue.clear(); + sharedContext->mapInterfaceRequirementKeyValue.clear(); - specializeRTTIObjectReferences(sharedContext); + specializeRTTIObjectReferences(sharedContext); - cleanUpRTTIHandleTypes(sharedContext); + cleanUpRTTIHandleTypes(sharedContext); - cleanUpInterfaceTypes(sharedContext); - } + cleanUpInterfaceTypes(sharedContext); +} - void checkTypeConformanceExists(SharedGenericsLoweringContext* context) - { - HashSet<IRInst*> implementedInterfaces; +void checkTypeConformanceExists(SharedGenericsLoweringContext* context) +{ + HashSet<IRInst*> implementedInterfaces; - // Add all interface type that are implemented by at least one type to a set. - for (auto inst : context->module->getGlobalInsts()) + // Add all interface type that are implemented by at least one type to a set. + for (auto inst : context->module->getGlobalInsts()) + { + if (inst->getOp() == kIROp_WitnessTable) { - if (inst->getOp() == kIROp_WitnessTable) - { - auto interfaceType = cast<IRWitnessTableType>(inst->getDataType())->getConformanceType(); - implementedInterfaces.add(interfaceType); - } + auto interfaceType = + cast<IRWitnessTableType>(inst->getDataType())->getConformanceType(); + implementedInterfaces.add(interfaceType); } - // Check if an interface type has any implementations. - workOnModule(context, [&](IRInst* inst) + } + // Check if an interface type has any implementations. + workOnModule( + context, + [&](IRInst* inst) + { + if (auto lookupWitnessMethod = as<IRLookupWitnessMethod>(inst)) { - if (auto lookupWitnessMethod = as<IRLookupWitnessMethod>(inst)) + auto witnessTableType = lookupWitnessMethod->getWitnessTable()->getDataType(); + if (!witnessTableType) + return; + auto interfaceType = + cast<IRWitnessTableType>(witnessTableType)->getConformanceType(); + if (isComInterfaceType((IRType*)interfaceType)) + return; + if (!implementedInterfaces.contains(interfaceType)) { - auto witnessTableType = lookupWitnessMethod->getWitnessTable()->getDataType(); - if (!witnessTableType) - return; - auto interfaceType = cast<IRWitnessTableType>(witnessTableType)->getConformanceType(); - if (isComInterfaceType((IRType*)interfaceType)) - return; - if (!implementedInterfaces.contains(interfaceType)) - { - context->sink->diagnose(interfaceType->sourceLoc, Diagnostics::noTypeConformancesFoundForInterface, interfaceType); - // Add to set to prevent duplicate diagnostic messages. - implementedInterfaces.add(interfaceType); - } + context->sink->diagnose( + interfaceType->sourceLoc, + Diagnostics::noTypeConformancesFoundForInterface, + interfaceType); + // Add to set to prevent duplicate diagnostic messages. + implementedInterfaces.add(interfaceType); } - }); - } + } + }); +} - void stripWrapExistential(IRModule* module) - { - InstWorkList workList(module); +void stripWrapExistential(IRModule* module) +{ + InstWorkList workList(module); - workList.add(module->getModuleInst()); - for (Index i = 0; i < workList.getCount(); i++) + workList.add(module->getModuleInst()); + for (Index i = 0; i < workList.getCount(); i++) + { + auto inst = workList[i]; + switch (inst->getOp()) { - auto inst = workList[i]; - switch (inst->getOp()) + case kIROp_WrapExistential: { - case kIROp_WrapExistential: - { - auto operand = inst->getOperand(0); - inst->replaceUsesWith(operand); - inst->removeAndDeallocate(); - } - break; - default: - for (auto child : inst->getChildren()) - workList.add(child); - break; + auto operand = inst->getOperand(0); + inst->replaceUsesWith(operand); + inst->removeAndDeallocate(); } + break; + default: + for (auto child : inst->getChildren()) + workList.add(child); + break; } } +} - void lowerGenerics( - TargetProgram* targetProgram, - IRModule* module, - DiagnosticSink* sink) - { - SLANG_PROFILE; - - SharedGenericsLoweringContext sharedContext(module); - sharedContext.targetProgram = targetProgram; - sharedContext.sink = sink; - - checkTypeConformanceExists(&sharedContext); - - inferAnyValueSizeWhereNecessary(targetProgram, module); - - // Replace all `makeExistential` insts with `makeExistentialWithRTTI` - // before making any other changes. This is necessary because a parameter of - // generic type will be lowered into `AnyValueType`, and after that we can no longer - // access the original generic type parameter from the lowered parameter value. - // This steps ensures that the generic type parameter is available via an - // explicit operand in `makeExistentialWithRTTI`, so that type parameter - // can be translated into an RTTI object during `lower-generic-type`, - // and used to create a tuple representing the existential value. - augmentMakeExistentialInsts(module); - - lowerGenericFunctions(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - lowerGenericType(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - lowerExistentials(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - lowerGenericCalls(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - generateWitnessTableWrapperFunctions(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - // This optional step replaces all uses of witness tables and RTTI objects with - // sequential IDs. Without this step, we will emit code that uses function pointers and - // real RTTI objects and witness tables. - specializeRTTIObjects(&sharedContext, sink); - - simplifyIR(sharedContext.targetProgram, module, IRSimplificationOptions::getFast(sharedContext.targetProgram)); - - lowerTuples(module, sink); - if (sink->getErrorCount() != 0) - return; - - generateAnyValueMarshallingFunctions(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - // At this point, we should no longer need to care any `WrapExistential` insts, - // although they could still exist in the IR in order to call generic core module functions, - // e.g. RWStucturedBuffer.Load(WrapExistential(sbuffer, type), index). - // We should remove them now. - stripWrapExistential(module); - } - - void cleanupGenerics(TargetProgram* program, IRModule* module, DiagnosticSink* sink) - { - SharedGenericsLoweringContext sharedContext(module); - sharedContext.targetProgram = program; - sharedContext.sink = sink; - - specializeRTTIObjects(&sharedContext, sink); - - lowerTuples(module, sink); - if (sink->getErrorCount() != 0) - return; - - generateAnyValueMarshallingFunctions(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - // At this point, we should no longer need to care any `WrapExistential` insts, - // although they could still exist in the IR in order to call generic core module functions, - // e.g. RWStucturedBuffer.Load(WrapExistential(sbuffer, type), index). - // We should remove them now. - stripWrapExistential(module); - } +void lowerGenerics(TargetProgram* targetProgram, IRModule* module, DiagnosticSink* sink) +{ + SLANG_PROFILE; + + SharedGenericsLoweringContext sharedContext(module); + sharedContext.targetProgram = targetProgram; + sharedContext.sink = sink; + + checkTypeConformanceExists(&sharedContext); + + inferAnyValueSizeWhereNecessary(targetProgram, module); + + // Replace all `makeExistential` insts with `makeExistentialWithRTTI` + // before making any other changes. This is necessary because a parameter of + // generic type will be lowered into `AnyValueType`, and after that we can no longer + // access the original generic type parameter from the lowered parameter value. + // This steps ensures that the generic type parameter is available via an + // explicit operand in `makeExistentialWithRTTI`, so that type parameter + // can be translated into an RTTI object during `lower-generic-type`, + // and used to create a tuple representing the existential value. + augmentMakeExistentialInsts(module); + + lowerGenericFunctions(&sharedContext); + if (sink->getErrorCount() != 0) + return; + + lowerGenericType(&sharedContext); + if (sink->getErrorCount() != 0) + return; + + lowerExistentials(&sharedContext); + if (sink->getErrorCount() != 0) + return; + + lowerGenericCalls(&sharedContext); + if (sink->getErrorCount() != 0) + return; + + generateWitnessTableWrapperFunctions(&sharedContext); + if (sink->getErrorCount() != 0) + return; + + // This optional step replaces all uses of witness tables and RTTI objects with + // sequential IDs. Without this step, we will emit code that uses function pointers and + // real RTTI objects and witness tables. + specializeRTTIObjects(&sharedContext, sink); + + simplifyIR( + sharedContext.targetProgram, + module, + IRSimplificationOptions::getFast(sharedContext.targetProgram)); + + lowerTuples(module, sink); + if (sink->getErrorCount() != 0) + return; + + generateAnyValueMarshallingFunctions(&sharedContext); + if (sink->getErrorCount() != 0) + return; + + // At this point, we should no longer need to care any `WrapExistential` insts, + // although they could still exist in the IR in order to call generic core module functions, + // e.g. RWStucturedBuffer.Load(WrapExistential(sbuffer, type), index). + // We should remove them now. + stripWrapExistential(module); +} + +void cleanupGenerics(TargetProgram* program, IRModule* module, DiagnosticSink* sink) +{ + SharedGenericsLoweringContext sharedContext(module); + sharedContext.targetProgram = program; + sharedContext.sink = sink; + + specializeRTTIObjects(&sharedContext, sink); + + lowerTuples(module, sink); + if (sink->getErrorCount() != 0) + return; + + generateAnyValueMarshallingFunctions(&sharedContext); + if (sink->getErrorCount() != 0) + return; + + // At this point, we should no longer need to care any `WrapExistential` insts, + // although they could still exist in the IR in order to call generic core module functions, + // e.g. RWStucturedBuffer.Load(WrapExistential(sbuffer, type), index). + // We should remove them now. + stripWrapExistential(module); +} } // namespace Slang |
