diff options
| author | ArielG-NV <159081215+ArielG-NV@users.noreply.github.com> | 2025-08-07 00:22:22 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-08-07 07:22:22 +0000 |
| commit | 063cbeaaea2fb00a10c6058ea4a9632092772ea5 (patch) | |
| tree | b4412347d6c264c3b1a84ec971921a5e2fe76134 /source/slang | |
| parent | 9e2685853033f4286feaf22d04a755a7395d95ce (diff) | |
Initial copy elision pass (#8042)
Fixes #7574
Changes:
* Add an initial (fairly simple) optimization pass which is able to
eliminate redundant copies.
* Our current existing optimizer passes remove redundant load/store very
robustly, this pass will focus on other cases of copy elimination
* Primary approach is to make all functions which are `in T` and `T` is
trivial to copy into a `__constref T`. We then (depending on scenario)
manually insert a variable+load if a pass-by-reference is not possible;
otherwise we pass by `constref`.
* Added optimizations to eliminate redundant code which causes
`constref` to fail to compile
---------
Co-authored-by: Harsh Aggarwal <haaggarwal@nvidia.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: slangbot <ellieh+slangbot@nvidia.com>
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-redundancy-removal.cpp | 137 | ||||
| -rw-r--r-- | source/slang/slang-ir-transform-params-to-constref.cpp | 287 | ||||
| -rw-r--r-- | source/slang/slang-ir-transform-params-to-constref.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 5 |
7 files changed, 445 insertions, 24 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index ffc4b97ef..66829308d 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -307,6 +307,25 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S case kIROp_PtrType: case kIROp_ConstRefType: { + // Special note on `constref` types and why they are not emitted + // as a `const` pointer: + // + // We currently do not propegate/manage "constness" for locals. + // This is important since it means that we rely on opimization + // passes to remove all temporary pointer-variables created from + // our constref, otherwise we will generate invalid code like + // `T* var = const_ptr` or `T* var = &const_ptr->member`. + // + // If emitting `constref` fails due to this error, it is likely + // a missing compiler-optimization. + // + // Additionally, for C++/CUDA, downstream methods are required + // to be `const` if we want to use const pointers. This is currently + // not handled robustly. + // + // Due to these cascading issues, we do not emit const and instead + // emit as a regular pointer for the time being. + auto elementType = (IRType*)type->getOperand(0); SLANG_RETURN_ON_FAIL(calcTypeName(elementType, target, out)); out << "*"; @@ -599,6 +618,7 @@ CPPSourceEmitter::CPPSourceEmitter(const Desc& desc) void CPPSourceEmitter::emitParamTypeImpl(IRType* type, String const& name) { + // For use the CPP-specific emitType implementation emitType(type, name); } diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 405bca5a2..7d8f1438d 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -110,6 +110,7 @@ #include "slang-ir-strip-default-construct.h" #include "slang-ir-strip-legalization-insts.h" #include "slang-ir-synthesize-active-mask.h" +#include "slang-ir-transform-params-to-constref.h" #include "slang-ir-translate-global-varying-var.h" #include "slang-ir-undo-param-copy.h" #include "slang-ir-uniformity.h" @@ -1714,6 +1715,12 @@ Result linkAndOptimizeIR( // For CUDA/OptiX like targets, add our pass to replace inout parameter copies with direct // pointers undoParameterCopy(irModule); + // Transform struct parameters to use ConstRef for better performance + if (isCPUTarget(targetRequest) || isCUDATarget(targetRequest) || + isMetalTarget(targetRequest)) + { + transformParamsToConstRef(irModule, codeGenContext->getSink()); + } #if 0 dumpIRIfEnabled(codeGenContext, irModule, "PARAMETER COPIES REPLACED WITH DIRECT POINTERS"); #endif diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 93979bdbe..c8604f4fa 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3729,6 +3729,7 @@ public: IRInOutType* getInOutType(IRType* valueType); IRRefType* getRefType(IRType* valueType, AddressSpace addrSpace); IRConstRefType* getConstRefType(IRType* valueType); + IRConstRefType* getConstRefType(IRType* valueType, AddressSpace addrSpace); IRPtrTypeBase* getPtrType(IROp op, IRType* valueType); IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace); IRPtrType* getPtrType(IROp op, IRType* valueType, IRInst* addressSpace); diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index a6dac723e..1feab47dd 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -361,43 +361,132 @@ bool tryRemoveRedundantStore(IRGlobalValueWithCode* func, IRStore* store) return false; } +// Checks if we can change or have a modified rootVar +// at some point. +bool isExternallyModifiableAddr(IRInst* rootVar) +{ + if (!rootVar) + return false; + + auto ptr = as<IRConstRefType>(rootVar->getDataType()); + if (!ptr) + return true; + + // Only a UserPointer can potentially be modified and changed to point to a different address + // if constRef. This may happen from a different thread even if constref to the current thread. + auto addrSpace = ptr->getAddressSpace(); + if (addrSpace == AddressSpace::UserPointer) + return true; + + return false; +} + +bool tryRemoveRedundantLoad(IRGlobalValueWithCode* func, IRLoad* load) +{ + bool changed = false; + + // If the load is preceeded by a store without any side-effect insts + // in-between, remove the load. + for (auto prev = load->getPrevInst(); prev; prev = prev->getPrevInst()) + { + if (auto store = as<IRStore>(prev)) + { + if (store->getPtr() == load->getPtr()) + { + auto value = store->getVal(); + load->replaceUsesWith(value); + load->removeAndDeallocate(); + changed = true; + break; + } + } + + if (canInstHaveSideEffectAtAddress(func, prev, load->getPtr())) + { + break; + } + } + + return changed; +} + bool eliminateRedundantLoadStore(IRGlobalValueWithCode* func) { bool changed = false; for (auto block : func->getBlocks()) { - for (auto inst = block->getFirstInst(); inst;) + IRInst* nextInst = nullptr; + for (auto inst = block->getFirstInst(); inst; inst = nextInst) { - auto nextInst = inst->getNextInst(); + nextInst = inst->getNextInst(); if (auto load = as<IRLoad>(inst)) { - for (auto prev = inst->getPrevInst(); prev; prev = prev->getPrevInst()) - { - if (auto store = as<IRStore>(prev)) - { - if (store->getPtr() == load->getPtr()) - { - // If the load is preceeded by a store without any side-effect insts - // in-between, remove the load. - auto value = store->getVal(); - load->replaceUsesWith(value); - load->removeAndDeallocate(); - changed = true; - break; - } - } - - if (canInstHaveSideEffectAtAddress(func, prev, load->getPtr())) - { - break; - } - } + changed |= tryRemoveRedundantLoad(func, load); } else if (auto store = as<IRStore>(inst)) { changed |= tryRemoveRedundantStore(func, store); } - inst = nextInst; + else if (auto getElementPtr = as<IRGetElementPtr>(inst)) + { + auto rootAddr = getRootAddr(getElementPtr); + if (isExternallyModifiableAddr(rootAddr)) + continue; + + // GetElement(Load(GetElementPtr(x)))) ==> Load(GetElementPtr(GetElementPtr(x))) + // The benefit is that any GetAddr(Load(...)) can then transitively be optimized + // out. + // This can only be done if we have no side-effects. `constref` never has + // single-invocation side-effects. + traverseUsers<IRLoad>( + getElementPtr, + [&](IRLoad* load) + { + traverseUsers<IRGetElement>( + load, + [&](IRGetElement* getElement) + { + IRBuilder builder(getElement); + builder.setInsertBefore(getElement); + auto newGetElementPtr = builder.emitElementAddress( + getElementPtr, + getElement->getIndex()); + auto newLoad = builder.emitLoad(newGetElementPtr); + getElement->replaceUsesWith(newLoad); + changed = true; + }); + }); + } + else if (auto fieldAddress = as<IRFieldAddress>(inst)) + { + auto rootAddr = getRootAddr(fieldAddress); + if (isExternallyModifiableAddr(rootAddr)) + continue; + + // ExtractField(Load(GetFieldAddr(x)))) ==> Load(GetFieldAddr(GetFieldAddr(x))) + // The benefit is that any GetAddr(Load(...)) can then transitively be optimized + // out. + // This can only be done if we have no side-effects. `constref` never has + // single-invocation side-effects. + traverseUsers<IRLoad>( + fieldAddress, + [&](IRLoad* load) + { + traverseUsers<IRFieldExtract>( + load, + [&](IRFieldExtract* fieldExtract) + { + IRBuilder builder(fieldExtract); + builder.setInsertBefore(fieldExtract); + auto newGetFieldAddress = builder.emitFieldAddress( + fieldAddress, + fieldExtract->getField()); + auto newLoad = builder.emitLoad(newGetFieldAddress); + fieldExtract->replaceUsesWith(newLoad); + changed = true; + }); + }); + } } } return changed; diff --git a/source/slang/slang-ir-transform-params-to-constref.cpp b/source/slang/slang-ir-transform-params-to-constref.cpp new file mode 100644 index 000000000..8f4bcd037 --- /dev/null +++ b/source/slang/slang-ir-transform-params-to-constref.cpp @@ -0,0 +1,287 @@ +#include "slang-ir-transform-params-to-constref.h" + +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir.h" + +namespace Slang +{ + +struct TransformParamsToConstRefContext +{ + IRModule* module; + DiagnosticSink* sink; + IRBuilder builder; + bool changed = false; + + TransformParamsToConstRefContext(IRModule* module, DiagnosticSink* sink) + : module(module), sink(sink), builder(module) + { + } + + // Check if a type should be transformed (struct, array, or other composite types) + bool shouldTransformParam(IRParam* param) + { + auto type = param->getDataType(); + if (!type) + return false; + + switch (type->getOp()) + { + case kIROp_StructType: + case kIROp_ArrayType: + case kIROp_UnsizedArrayType: + case kIROp_VectorType: + case kIROp_MatrixType: + case kIROp_TupleType: + case kIROp_CoopVectorType: + // valid type, continue to check + break; + default: + return false; + } + + return true; + } + + void rewriteParamUseSitesToSupportConstRefUsage(HashSet<IRParam*>& updatedParams) + { + // Traverse the uses of our updated params to rewrite them. + // Assume a `in` parameter has been converted to a `constref` parameter. + for (auto param : updatedParams) + { + traverseUses( + param, + [&](IRUse* use) + { + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_FieldExtract: + { + // Transform the IRFieldExtract into a IRFieldAddress + auto fieldExtract = as<IRFieldExtract>(use->getUser()); + builder.setInsertBefore(fieldExtract); + auto fieldAddr = builder.emitFieldAddress( + fieldExtract->getBase(), + fieldExtract->getField()); + auto loadInst = builder.emitLoad(fieldAddr); + fieldExtract->replaceUsesWith(loadInst); + fieldExtract->removeAndDeallocate(); + break; + } + case kIROp_GetElement: + { + // Transform the IRGetElement into a IRGetElementPtr + auto getElement = as<IRGetElement>(use->getUser()); + + builder.setInsertBefore(getElement); + auto elemAddr = builder.emitElementAddress( + getElement->getBase(), + getElement->getIndex()); + auto loadInst = builder.emitLoad(elemAddr); + getElement->replaceUsesWith(loadInst); + getElement->removeAndDeallocate(); + break; + } + default: + { + // Insert a load before the user and replace the user with the load + builder.setInsertBefore(user); + auto loadInst = builder.emitLoad(param); + use->set(loadInst); + break; + } + } + }); + } + } + + // Update call sites to pass an address instead of value for each updated-param + void updateCallSites(IRFunc* func, HashSet<IRParam*>& updatedParams) + { + // Find all calls which use `func`. + List<IRCall*> callsToUpdate; + traverseUsers<IRCall>(func, [&](IRCall* call) { callsToUpdate.add(call); }); + + // Update each call site + for (auto call : callsToUpdate) + { + builder.setInsertBefore(call); + List<IRInst*> newArgs; + + // Transform arguments to match the updated-parameter + IRParam* param = func->getFirstParam(); + UInt i = 0; + auto iterate = [&]() + { + param = param->getNextParam(); + i++; + }; + for (; param; iterate()) + { + auto arg = call->getArg(i); + if (!updatedParams.contains(param)) + { + newArgs.add(arg); + continue; + } + + auto tempVar = builder.emitVar(arg->getFullType()); + builder.emitStore(tempVar, arg); + newArgs.add(tempVar); + } + + // Create new call with updated arguments + auto newCall = builder.emitCallInst(call->getFullType(), func, newArgs); + call->replaceUsesWith(newCall); + call->removeAndDeallocate(); + } + } + + // Check if function should be excluded from transformation + bool shouldProcessFunction(IRFunc* func) + { + // Skip functions without definitions + if (!func->isDefinition()) + return false; + + // Skip if we find any of these decorations + for (auto decoration : func->getDecorations()) + { + // Skip functions with target intrinsic decorations. + // These functions cannot be properly legalized after + // transformation. + if (as<IRTargetIntrinsicDecoration>(decoration)) + return false; + + // Skip entry-point and pseudo-entry-point functions + // since we cannot legalize the input parameters. + if (as<IREntryPointDecoration>(decoration) || as<IRCudaKernelDecoration>(decoration) || + as<IRAutoPyBindCudaDecoration>(decoration)) + return false; + } + + // Skip functions with `kIROp_GenericAsm` since + // these instructions inject target specific code + // using parameters in an unpredictable way, relying + // on assumptions that parameters do not change type. + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (!as<IRGenericAsm>(inst)) + continue; + return false; + } + } + + return true; + } + + // Process a single function + void processFunc(IRFunc* func) + { + HashSet<IRParam*> updatedParams; + bool hasTransformedParams = false; + + // First pass: Transform parameter types + for (auto param = func->getFirstParam(); param; param = param->getNextParam()) + { + if (shouldTransformParam(param)) + { + // Our goal here is to transform `in T` parameters to const-ref. + // We are selective about what we will transform for a few reasons: + // 1. no reason to transform simple primitives like `int`. + // 2. not every type makes sense as constref. For example, `ParameterBlock`. + // 3. constref is not 100% stable, so we need to be selective on what we let + // transform into constref. + // + // This allows us to pass the address of variables directly into a function, + // giving us the choice to remove copies into a parameter. + auto paramType = param->getDataType(); + auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal); + param->setFullType(constRefType); + + hasTransformedParams = true; + changed = true; + updatedParams.add(param); + } + } + + if (!hasTransformedParams) + { + return; + } + + // Second pass: Update function body according to the new `constref` parameters + rewriteParamUseSitesToSupportConstRefUsage(updatedParams); + + // Third pass: Update call sites + updateCallSites(func, updatedParams); + } + + void addFuncsToCallListInTopologicalOrder( + IRFunc* root, + List<IRFunc*>& functionsToProcess, + HashSet<IRFunc*>& visitedCandidates) + { + // We added 'root' already, leave + if (visitedCandidates.contains(root)) + return; + + visitedCandidates.add(root); + + for (auto block : root->getBlocks()) + { + for (auto blockInst : block->getChildren()) + { + auto call = as<IRCall>(blockInst); + if (!call) + continue; + + auto callee = as<IRFunc>(call->getCallee()); + if (!callee) + continue; + + addFuncsToCallListInTopologicalOrder(callee, functionsToProcess, visitedCandidates); + } + } + + if (!shouldProcessFunction(root)) + return; + functionsToProcess.add(root); + } + + SlangResult processModule() + { + // Collect all functions that need processing. + // Process all callee's before callers; otherwise we introduce bugs + + HashSet<IRFunc*> visitedCandidates; + List<IRFunc*> functionsToProcess; + for (auto inst = module->getModuleInst()->getFirstChild(); inst; inst = inst->getNextInst()) + { + auto func = as<IRFunc>(inst); + if (!func) + continue; + addFuncsToCallListInTopologicalOrder(func, functionsToProcess, visitedCandidates); + } + + // Process each function + for (auto func : functionsToProcess) + { + processFunc(func); + } + + return SLANG_OK; + } +}; + +SlangResult transformParamsToConstRef(IRModule* module, DiagnosticSink* sink) +{ + TransformParamsToConstRefContext context(module, sink); + return context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-transform-params-to-constref.h b/source/slang/slang-ir-transform-params-to-constref.h new file mode 100644 index 000000000..5bdf8a275 --- /dev/null +++ b/source/slang/slang-ir-transform-params-to-constref.h @@ -0,0 +1,12 @@ +// source\slang\slang-ir-transform-params-to-constref.h +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ +class DiagnosticSink; + +SlangResult transformParamsToConstRef(IRModule* module, DiagnosticSink* sink); + +} // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index b48dcc7e6..6b9273c15 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2943,6 +2943,11 @@ IRConstRefType* IRBuilder::getConstRefType(IRType* valueType) return (IRConstRefType*)getPtrType(kIROp_ConstRefType, valueType); } +IRConstRefType* IRBuilder::getConstRefType(IRType* valueType, AddressSpace addrSpace) +{ + return (IRConstRefType*)getPtrType(kIROp_ConstRefType, valueType, addrSpace); +} + IRSPIRVLiteralType* IRBuilder::getSPIRVLiteralType(IRType* type) { IRInst* operands[] = {type}; |
