diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-11 15:33:28 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-11 15:33:28 -0800 |
| commit | a3ac6e71cbc922b7c941c45f23ee18a9fc274d1f (patch) | |
| tree | acf8c18601f124e9290494f8b379d2420369fc35 /source/slang/slang-ir-remove-unused-generic-param.cpp | |
| parent | 20262684bcbb707d16669b2670039df870b65ca8 (diff) | |
Make backward differentiation work with generics. (#2586)
* Make backward differentiation work with generics.
* Fix.
* Another fix.
* More fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-remove-unused-generic-param.cpp')
| -rw-r--r-- | source/slang/slang-ir-remove-unused-generic-param.cpp | 134 |
1 files changed, 134 insertions, 0 deletions
diff --git a/source/slang/slang-ir-remove-unused-generic-param.cpp b/source/slang/slang-ir-remove-unused-generic-param.cpp new file mode 100644 index 000000000..9337a00bb --- /dev/null +++ b/source/slang/slang-ir-remove-unused-generic-param.cpp @@ -0,0 +1,134 @@ +#include "slang-ir-remove-unused-generic-param.h" +#include "slang-ir-inst-pass-base.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ +struct RemoveUnusedGenericParamContext : InstPassBase +{ + RemoveUnusedGenericParamContext(IRModule* inModule) + : InstPassBase(inModule) + {} + + bool processModule() + { + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->init(module); + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + IRBuilder builder(sharedBuilder); + bool changed = false; + for (auto inst : module->getModuleInst()->getChildren()) + { + if (auto genInst = as<IRGeneric>(inst)) + { + auto returnVal = findGenericReturnVal(genInst); + switch (returnVal->getOp()) + { + case kIROp_StructType: + case kIROp_ClassType: + break; + case kIROp_Func: + case kIROp_FuncType: + default: + // Don't simplify functions since this can break signature compatiblity with the + // interface. For example, if we have + // interface IFoo { void genFunc<T>(int x); } + // We can't simplify this by removing `T` even when the function type here does not depend on T. + continue; + } + if (returnVal->findDecoration<IRTargetIntrinsicDecoration>()) + continue; + + List<UInt> paramToPreserve; + UInt id = 0; + List<IRInst*> paramsToRemove; + for (auto param : genInst->getParams()) + { + if (param->hasUses()) + { + paramToPreserve.add(id); + } + else + { + paramsToRemove.add(param); + } + id++; + } + if (paramsToRemove.getCount() == 0) + continue; + changed = true; + if (paramToPreserve.getCount() == 0) + { + // Special case: the generic return value is not dependent on the generic param, + // we can hoist to global scope safely. + for (auto child = genInst->getFirstBlock()->getFirstOrdinaryInst(); child; ) + { + auto next = child->getNextInst(); + if (child->getOp() == kIROp_Return) + { + break; + } + child->insertBefore(genInst); + child = next; + } + SLANG_ASSERT(returnVal); + List<IRUse*> uses; + for (auto use = genInst->firstUse; use; use = use->nextUse) + uses.add(use); + for (auto use : uses) + { + if (use->getUser()->getOp() == kIROp_Specialize && + use == use->getUser()->getOperands()) + { + use->getUser()->replaceUsesWith(returnVal); + } + } + genInst->replaceUsesWith(returnVal); + genInst->removeAndDeallocate(); + } + else + { + // General case: remove unnecessary specialization arguments. + // Disabled this optimization for now since we still need to take care + // of the type of the generic, or change other passes to not + // use type info on a generic at all. + List<IRUse*> uses; + for (auto use = genInst->firstUse; use; use = use->nextUse) + uses.add(use); + for (auto use : uses) + { + if (use->getUser()->getOp() == kIROp_Specialize && + use == use->getUser()->getOperands()) + { + auto specialize = as<IRSpecialize>(use->getUser()); + builder.setInsertBefore(specialize); + List<IRInst*> newArgs; + for (auto i : paramToPreserve) + newArgs.add(specialize->getArg(i)); + auto newSpecialize = builder.emitSpecializeInst( + specialize->getFullType(), + specialize->getBase(), + newArgs.getCount(), + newArgs.getBuffer()); + specialize->transferDecorationsTo(newSpecialize); + specialize->replaceUsesWith(newSpecialize); + specialize->removeAndDeallocate(); + } + } + for (auto param : paramsToRemove) + param->removeAndDeallocate(); + } + } + } + return changed; + } +}; + +bool removeUnusedGenericParam(IRModule* module) +{ + RemoveUnusedGenericParamContext context = RemoveUnusedGenericParamContext(module); + return context.processModule(); +} + +} // namespace Slang |
