From 063cbeaaea2fb00a10c6058ea4a9632092772ea5 Mon Sep 17 00:00:00 2001 From: ArielG-NV <159081215+ArielG-NV@users.noreply.github.com> Date: Thu, 7 Aug 2025 00:22:22 -0700 Subject: Initial copy elision pass (#8042) Fixes #7574 Changes: * Add an initial (fairly simple) optimization pass which is able to eliminate redundant copies. * Our current existing optimizer passes remove redundant load/store very robustly, this pass will focus on other cases of copy elimination * Primary approach is to make all functions which are `in T` and `T` is trivial to copy into a `__constref T`. We then (depending on scenario) manually insert a variable+load if a pass-by-reference is not possible; otherwise we pass by `constref`. * Added optimizations to eliminate redundant code which causes `constref` to fail to compile --------- Co-authored-by: Harsh Aggarwal Co-authored-by: Claude Co-authored-by: slangbot Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> --- .../slang-ir-transform-params-to-constref.cpp | 287 +++++++++++++++++++++ 1 file changed, 287 insertions(+) create mode 100644 source/slang/slang-ir-transform-params-to-constref.cpp (limited to 'source/slang/slang-ir-transform-params-to-constref.cpp') diff --git a/source/slang/slang-ir-transform-params-to-constref.cpp b/source/slang/slang-ir-transform-params-to-constref.cpp new file mode 100644 index 000000000..8f4bcd037 --- /dev/null +++ b/source/slang/slang-ir-transform-params-to-constref.cpp @@ -0,0 +1,287 @@ +#include "slang-ir-transform-params-to-constref.h" + +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir.h" + +namespace Slang +{ + +struct TransformParamsToConstRefContext +{ + IRModule* module; + DiagnosticSink* sink; + IRBuilder builder; + bool changed = false; + + TransformParamsToConstRefContext(IRModule* module, DiagnosticSink* sink) + : module(module), sink(sink), builder(module) + { + } + + // Check if a type should be transformed (struct, array, or other composite types) + bool shouldTransformParam(IRParam* param) + { + auto type = param->getDataType(); + if (!type) + return false; + + switch (type->getOp()) + { + case kIROp_StructType: + case kIROp_ArrayType: + case kIROp_UnsizedArrayType: + case kIROp_VectorType: + case kIROp_MatrixType: + case kIROp_TupleType: + case kIROp_CoopVectorType: + // valid type, continue to check + break; + default: + return false; + } + + return true; + } + + void rewriteParamUseSitesToSupportConstRefUsage(HashSet& updatedParams) + { + // Traverse the uses of our updated params to rewrite them. + // Assume a `in` parameter has been converted to a `constref` parameter. + for (auto param : updatedParams) + { + traverseUses( + param, + [&](IRUse* use) + { + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_FieldExtract: + { + // Transform the IRFieldExtract into a IRFieldAddress + auto fieldExtract = as(use->getUser()); + builder.setInsertBefore(fieldExtract); + auto fieldAddr = builder.emitFieldAddress( + fieldExtract->getBase(), + fieldExtract->getField()); + auto loadInst = builder.emitLoad(fieldAddr); + fieldExtract->replaceUsesWith(loadInst); + fieldExtract->removeAndDeallocate(); + break; + } + case kIROp_GetElement: + { + // Transform the IRGetElement into a IRGetElementPtr + auto getElement = as(use->getUser()); + + builder.setInsertBefore(getElement); + auto elemAddr = builder.emitElementAddress( + getElement->getBase(), + getElement->getIndex()); + auto loadInst = builder.emitLoad(elemAddr); + getElement->replaceUsesWith(loadInst); + getElement->removeAndDeallocate(); + break; + } + default: + { + // Insert a load before the user and replace the user with the load + builder.setInsertBefore(user); + auto loadInst = builder.emitLoad(param); + use->set(loadInst); + break; + } + } + }); + } + } + + // Update call sites to pass an address instead of value for each updated-param + void updateCallSites(IRFunc* func, HashSet& updatedParams) + { + // Find all calls which use `func`. + List callsToUpdate; + traverseUsers(func, [&](IRCall* call) { callsToUpdate.add(call); }); + + // Update each call site + for (auto call : callsToUpdate) + { + builder.setInsertBefore(call); + List newArgs; + + // Transform arguments to match the updated-parameter + IRParam* param = func->getFirstParam(); + UInt i = 0; + auto iterate = [&]() + { + param = param->getNextParam(); + i++; + }; + for (; param; iterate()) + { + auto arg = call->getArg(i); + if (!updatedParams.contains(param)) + { + newArgs.add(arg); + continue; + } + + auto tempVar = builder.emitVar(arg->getFullType()); + builder.emitStore(tempVar, arg); + newArgs.add(tempVar); + } + + // Create new call with updated arguments + auto newCall = builder.emitCallInst(call->getFullType(), func, newArgs); + call->replaceUsesWith(newCall); + call->removeAndDeallocate(); + } + } + + // Check if function should be excluded from transformation + bool shouldProcessFunction(IRFunc* func) + { + // Skip functions without definitions + if (!func->isDefinition()) + return false; + + // Skip if we find any of these decorations + for (auto decoration : func->getDecorations()) + { + // Skip functions with target intrinsic decorations. + // These functions cannot be properly legalized after + // transformation. + if (as(decoration)) + return false; + + // Skip entry-point and pseudo-entry-point functions + // since we cannot legalize the input parameters. + if (as(decoration) || as(decoration) || + as(decoration)) + return false; + } + + // Skip functions with `kIROp_GenericAsm` since + // these instructions inject target specific code + // using parameters in an unpredictable way, relying + // on assumptions that parameters do not change type. + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (!as(inst)) + continue; + return false; + } + } + + return true; + } + + // Process a single function + void processFunc(IRFunc* func) + { + HashSet updatedParams; + bool hasTransformedParams = false; + + // First pass: Transform parameter types + for (auto param = func->getFirstParam(); param; param = param->getNextParam()) + { + if (shouldTransformParam(param)) + { + // Our goal here is to transform `in T` parameters to const-ref. + // We are selective about what we will transform for a few reasons: + // 1. no reason to transform simple primitives like `int`. + // 2. not every type makes sense as constref. For example, `ParameterBlock`. + // 3. constref is not 100% stable, so we need to be selective on what we let + // transform into constref. + // + // This allows us to pass the address of variables directly into a function, + // giving us the choice to remove copies into a parameter. + auto paramType = param->getDataType(); + auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal); + param->setFullType(constRefType); + + hasTransformedParams = true; + changed = true; + updatedParams.add(param); + } + } + + if (!hasTransformedParams) + { + return; + } + + // Second pass: Update function body according to the new `constref` parameters + rewriteParamUseSitesToSupportConstRefUsage(updatedParams); + + // Third pass: Update call sites + updateCallSites(func, updatedParams); + } + + void addFuncsToCallListInTopologicalOrder( + IRFunc* root, + List& functionsToProcess, + HashSet& visitedCandidates) + { + // We added 'root' already, leave + if (visitedCandidates.contains(root)) + return; + + visitedCandidates.add(root); + + for (auto block : root->getBlocks()) + { + for (auto blockInst : block->getChildren()) + { + auto call = as(blockInst); + if (!call) + continue; + + auto callee = as(call->getCallee()); + if (!callee) + continue; + + addFuncsToCallListInTopologicalOrder(callee, functionsToProcess, visitedCandidates); + } + } + + if (!shouldProcessFunction(root)) + return; + functionsToProcess.add(root); + } + + SlangResult processModule() + { + // Collect all functions that need processing. + // Process all callee's before callers; otherwise we introduce bugs + + HashSet visitedCandidates; + List functionsToProcess; + for (auto inst = module->getModuleInst()->getFirstChild(); inst; inst = inst->getNextInst()) + { + auto func = as(inst); + if (!func) + continue; + addFuncsToCallListInTopologicalOrder(func, functionsToProcess, visitedCandidates); + } + + // Process each function + for (auto func : functionsToProcess) + { + processFunc(func); + } + + return SLANG_OK; + } +}; + +SlangResult transformParamsToConstRef(IRModule* module, DiagnosticSink* sink) +{ + TransformParamsToConstRefContext context(module, sink); + return context.processModule(); +} + +} // namespace Slang -- cgit v1.2.3