summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorArielG-NV <159081215+ArielG-NV@users.noreply.github.com>2025-08-07 00:22:22 -0700
committerGitHub <noreply@github.com>2025-08-07 07:22:22 +0000
commit063cbeaaea2fb00a10c6058ea4a9632092772ea5 (patch)
treeb4412347d6c264c3b1a84ec971921a5e2fe76134 /source/slang
parent9e2685853033f4286feaf22d04a755a7395d95ce (diff)
Initial copy elision pass (#8042)
Fixes #7574 Changes: * Add an initial (fairly simple) optimization pass which is able to eliminate redundant copies. * Our current existing optimizer passes remove redundant load/store very robustly, this pass will focus on other cases of copy elimination * Primary approach is to make all functions which are `in T` and `T` is trivial to copy into a `__constref T`. We then (depending on scenario) manually insert a variable+load if a pass-by-reference is not possible; otherwise we pass by `constref`. * Added optimizations to eliminate redundant code which causes `constref` to fail to compile --------- Co-authored-by: Harsh Aggarwal <haaggarwal@nvidia.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: slangbot <ellieh+slangbot@nvidia.com> Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-emit-cpp.cpp20
-rw-r--r--source/slang/slang-emit.cpp7
-rw-r--r--source/slang/slang-ir-insts.h1
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp137
-rw-r--r--source/slang/slang-ir-transform-params-to-constref.cpp287
-rw-r--r--source/slang/slang-ir-transform-params-to-constref.h12
-rw-r--r--source/slang/slang-ir.cpp5
7 files changed, 445 insertions, 24 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index ffc4b97ef..66829308d 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -307,6 +307,25 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S
case kIROp_PtrType:
case kIROp_ConstRefType:
{
+ // Special note on `constref` types and why they are not emitted
+ // as a `const` pointer:
+ //
+ // We currently do not propegate/manage "constness" for locals.
+ // This is important since it means that we rely on opimization
+ // passes to remove all temporary pointer-variables created from
+ // our constref, otherwise we will generate invalid code like
+ // `T* var = const_ptr` or `T* var = &const_ptr->member`.
+ //
+ // If emitting `constref` fails due to this error, it is likely
+ // a missing compiler-optimization.
+ //
+ // Additionally, for C++/CUDA, downstream methods are required
+ // to be `const` if we want to use const pointers. This is currently
+ // not handled robustly.
+ //
+ // Due to these cascading issues, we do not emit const and instead
+ // emit as a regular pointer for the time being.
+
auto elementType = (IRType*)type->getOperand(0);
SLANG_RETURN_ON_FAIL(calcTypeName(elementType, target, out));
out << "*";
@@ -599,6 +618,7 @@ CPPSourceEmitter::CPPSourceEmitter(const Desc& desc)
void CPPSourceEmitter::emitParamTypeImpl(IRType* type, String const& name)
{
+ // For use the CPP-specific emitType implementation
emitType(type, name);
}
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 405bca5a2..7d8f1438d 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -110,6 +110,7 @@
#include "slang-ir-strip-default-construct.h"
#include "slang-ir-strip-legalization-insts.h"
#include "slang-ir-synthesize-active-mask.h"
+#include "slang-ir-transform-params-to-constref.h"
#include "slang-ir-translate-global-varying-var.h"
#include "slang-ir-undo-param-copy.h"
#include "slang-ir-uniformity.h"
@@ -1714,6 +1715,12 @@ Result linkAndOptimizeIR(
// For CUDA/OptiX like targets, add our pass to replace inout parameter copies with direct
// pointers
undoParameterCopy(irModule);
+ // Transform struct parameters to use ConstRef for better performance
+ if (isCPUTarget(targetRequest) || isCUDATarget(targetRequest) ||
+ isMetalTarget(targetRequest))
+ {
+ transformParamsToConstRef(irModule, codeGenContext->getSink());
+ }
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "PARAMETER COPIES REPLACED WITH DIRECT POINTERS");
#endif
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 93979bdbe..c8604f4fa 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3729,6 +3729,7 @@ public:
IRInOutType* getInOutType(IRType* valueType);
IRRefType* getRefType(IRType* valueType, AddressSpace addrSpace);
IRConstRefType* getConstRefType(IRType* valueType);
+ IRConstRefType* getConstRefType(IRType* valueType, AddressSpace addrSpace);
IRPtrTypeBase* getPtrType(IROp op, IRType* valueType);
IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace);
IRPtrType* getPtrType(IROp op, IRType* valueType, IRInst* addressSpace);
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
index a6dac723e..1feab47dd 100644
--- a/source/slang/slang-ir-redundancy-removal.cpp
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -361,43 +361,132 @@ bool tryRemoveRedundantStore(IRGlobalValueWithCode* func, IRStore* store)
return false;
}
+// Checks if we can change or have a modified rootVar
+// at some point.
+bool isExternallyModifiableAddr(IRInst* rootVar)
+{
+ if (!rootVar)
+ return false;
+
+ auto ptr = as<IRConstRefType>(rootVar->getDataType());
+ if (!ptr)
+ return true;
+
+ // Only a UserPointer can potentially be modified and changed to point to a different address
+ // if constRef. This may happen from a different thread even if constref to the current thread.
+ auto addrSpace = ptr->getAddressSpace();
+ if (addrSpace == AddressSpace::UserPointer)
+ return true;
+
+ return false;
+}
+
+bool tryRemoveRedundantLoad(IRGlobalValueWithCode* func, IRLoad* load)
+{
+ bool changed = false;
+
+ // If the load is preceeded by a store without any side-effect insts
+ // in-between, remove the load.
+ for (auto prev = load->getPrevInst(); prev; prev = prev->getPrevInst())
+ {
+ if (auto store = as<IRStore>(prev))
+ {
+ if (store->getPtr() == load->getPtr())
+ {
+ auto value = store->getVal();
+ load->replaceUsesWith(value);
+ load->removeAndDeallocate();
+ changed = true;
+ break;
+ }
+ }
+
+ if (canInstHaveSideEffectAtAddress(func, prev, load->getPtr()))
+ {
+ break;
+ }
+ }
+
+ return changed;
+}
+
bool eliminateRedundantLoadStore(IRGlobalValueWithCode* func)
{
bool changed = false;
for (auto block : func->getBlocks())
{
- for (auto inst = block->getFirstInst(); inst;)
+ IRInst* nextInst = nullptr;
+ for (auto inst = block->getFirstInst(); inst; inst = nextInst)
{
- auto nextInst = inst->getNextInst();
+ nextInst = inst->getNextInst();
if (auto load = as<IRLoad>(inst))
{
- for (auto prev = inst->getPrevInst(); prev; prev = prev->getPrevInst())
- {
- if (auto store = as<IRStore>(prev))
- {
- if (store->getPtr() == load->getPtr())
- {
- // If the load is preceeded by a store without any side-effect insts
- // in-between, remove the load.
- auto value = store->getVal();
- load->replaceUsesWith(value);
- load->removeAndDeallocate();
- changed = true;
- break;
- }
- }
-
- if (canInstHaveSideEffectAtAddress(func, prev, load->getPtr()))
- {
- break;
- }
- }
+ changed |= tryRemoveRedundantLoad(func, load);
}
else if (auto store = as<IRStore>(inst))
{
changed |= tryRemoveRedundantStore(func, store);
}
- inst = nextInst;
+ else if (auto getElementPtr = as<IRGetElementPtr>(inst))
+ {
+ auto rootAddr = getRootAddr(getElementPtr);
+ if (isExternallyModifiableAddr(rootAddr))
+ continue;
+
+ // GetElement(Load(GetElementPtr(x)))) ==> Load(GetElementPtr(GetElementPtr(x)))
+ // The benefit is that any GetAddr(Load(...)) can then transitively be optimized
+ // out.
+ // This can only be done if we have no side-effects. `constref` never has
+ // single-invocation side-effects.
+ traverseUsers<IRLoad>(
+ getElementPtr,
+ [&](IRLoad* load)
+ {
+ traverseUsers<IRGetElement>(
+ load,
+ [&](IRGetElement* getElement)
+ {
+ IRBuilder builder(getElement);
+ builder.setInsertBefore(getElement);
+ auto newGetElementPtr = builder.emitElementAddress(
+ getElementPtr,
+ getElement->getIndex());
+ auto newLoad = builder.emitLoad(newGetElementPtr);
+ getElement->replaceUsesWith(newLoad);
+ changed = true;
+ });
+ });
+ }
+ else if (auto fieldAddress = as<IRFieldAddress>(inst))
+ {
+ auto rootAddr = getRootAddr(fieldAddress);
+ if (isExternallyModifiableAddr(rootAddr))
+ continue;
+
+ // ExtractField(Load(GetFieldAddr(x)))) ==> Load(GetFieldAddr(GetFieldAddr(x)))
+ // The benefit is that any GetAddr(Load(...)) can then transitively be optimized
+ // out.
+ // This can only be done if we have no side-effects. `constref` never has
+ // single-invocation side-effects.
+ traverseUsers<IRLoad>(
+ fieldAddress,
+ [&](IRLoad* load)
+ {
+ traverseUsers<IRFieldExtract>(
+ load,
+ [&](IRFieldExtract* fieldExtract)
+ {
+ IRBuilder builder(fieldExtract);
+ builder.setInsertBefore(fieldExtract);
+ auto newGetFieldAddress = builder.emitFieldAddress(
+ fieldAddress,
+ fieldExtract->getField());
+ auto newLoad = builder.emitLoad(newGetFieldAddress);
+ fieldExtract->replaceUsesWith(newLoad);
+ changed = true;
+ });
+ });
+ }
}
}
return changed;
diff --git a/source/slang/slang-ir-transform-params-to-constref.cpp b/source/slang/slang-ir-transform-params-to-constref.cpp
new file mode 100644
index 000000000..8f4bcd037
--- /dev/null
+++ b/source/slang/slang-ir-transform-params-to-constref.cpp
@@ -0,0 +1,287 @@
+#include "slang-ir-transform-params-to-constref.h"
+
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+
+struct TransformParamsToConstRefContext
+{
+ IRModule* module;
+ DiagnosticSink* sink;
+ IRBuilder builder;
+ bool changed = false;
+
+ TransformParamsToConstRefContext(IRModule* module, DiagnosticSink* sink)
+ : module(module), sink(sink), builder(module)
+ {
+ }
+
+ // Check if a type should be transformed (struct, array, or other composite types)
+ bool shouldTransformParam(IRParam* param)
+ {
+ auto type = param->getDataType();
+ if (!type)
+ return false;
+
+ switch (type->getOp())
+ {
+ case kIROp_StructType:
+ case kIROp_ArrayType:
+ case kIROp_UnsizedArrayType:
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ case kIROp_TupleType:
+ case kIROp_CoopVectorType:
+ // valid type, continue to check
+ break;
+ default:
+ return false;
+ }
+
+ return true;
+ }
+
+ void rewriteParamUseSitesToSupportConstRefUsage(HashSet<IRParam*>& updatedParams)
+ {
+ // Traverse the uses of our updated params to rewrite them.
+ // Assume a `in` parameter has been converted to a `constref` parameter.
+ for (auto param : updatedParams)
+ {
+ traverseUses(
+ param,
+ [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ switch (user->getOp())
+ {
+ case kIROp_FieldExtract:
+ {
+ // Transform the IRFieldExtract into a IRFieldAddress
+ auto fieldExtract = as<IRFieldExtract>(use->getUser());
+ builder.setInsertBefore(fieldExtract);
+ auto fieldAddr = builder.emitFieldAddress(
+ fieldExtract->getBase(),
+ fieldExtract->getField());
+ auto loadInst = builder.emitLoad(fieldAddr);
+ fieldExtract->replaceUsesWith(loadInst);
+ fieldExtract->removeAndDeallocate();
+ break;
+ }
+ case kIROp_GetElement:
+ {
+ // Transform the IRGetElement into a IRGetElementPtr
+ auto getElement = as<IRGetElement>(use->getUser());
+
+ builder.setInsertBefore(getElement);
+ auto elemAddr = builder.emitElementAddress(
+ getElement->getBase(),
+ getElement->getIndex());
+ auto loadInst = builder.emitLoad(elemAddr);
+ getElement->replaceUsesWith(loadInst);
+ getElement->removeAndDeallocate();
+ break;
+ }
+ default:
+ {
+ // Insert a load before the user and replace the user with the load
+ builder.setInsertBefore(user);
+ auto loadInst = builder.emitLoad(param);
+ use->set(loadInst);
+ break;
+ }
+ }
+ });
+ }
+ }
+
+ // Update call sites to pass an address instead of value for each updated-param
+ void updateCallSites(IRFunc* func, HashSet<IRParam*>& updatedParams)
+ {
+ // Find all calls which use `func`.
+ List<IRCall*> callsToUpdate;
+ traverseUsers<IRCall>(func, [&](IRCall* call) { callsToUpdate.add(call); });
+
+ // Update each call site
+ for (auto call : callsToUpdate)
+ {
+ builder.setInsertBefore(call);
+ List<IRInst*> newArgs;
+
+ // Transform arguments to match the updated-parameter
+ IRParam* param = func->getFirstParam();
+ UInt i = 0;
+ auto iterate = [&]()
+ {
+ param = param->getNextParam();
+ i++;
+ };
+ for (; param; iterate())
+ {
+ auto arg = call->getArg(i);
+ if (!updatedParams.contains(param))
+ {
+ newArgs.add(arg);
+ continue;
+ }
+
+ auto tempVar = builder.emitVar(arg->getFullType());
+ builder.emitStore(tempVar, arg);
+ newArgs.add(tempVar);
+ }
+
+ // Create new call with updated arguments
+ auto newCall = builder.emitCallInst(call->getFullType(), func, newArgs);
+ call->replaceUsesWith(newCall);
+ call->removeAndDeallocate();
+ }
+ }
+
+ // Check if function should be excluded from transformation
+ bool shouldProcessFunction(IRFunc* func)
+ {
+ // Skip functions without definitions
+ if (!func->isDefinition())
+ return false;
+
+ // Skip if we find any of these decorations
+ for (auto decoration : func->getDecorations())
+ {
+ // Skip functions with target intrinsic decorations.
+ // These functions cannot be properly legalized after
+ // transformation.
+ if (as<IRTargetIntrinsicDecoration>(decoration))
+ return false;
+
+ // Skip entry-point and pseudo-entry-point functions
+ // since we cannot legalize the input parameters.
+ if (as<IREntryPointDecoration>(decoration) || as<IRCudaKernelDecoration>(decoration) ||
+ as<IRAutoPyBindCudaDecoration>(decoration))
+ return false;
+ }
+
+ // Skip functions with `kIROp_GenericAsm` since
+ // these instructions inject target specific code
+ // using parameters in an unpredictable way, relying
+ // on assumptions that parameters do not change type.
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (!as<IRGenericAsm>(inst))
+ continue;
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ // Process a single function
+ void processFunc(IRFunc* func)
+ {
+ HashSet<IRParam*> updatedParams;
+ bool hasTransformedParams = false;
+
+ // First pass: Transform parameter types
+ for (auto param = func->getFirstParam(); param; param = param->getNextParam())
+ {
+ if (shouldTransformParam(param))
+ {
+ // Our goal here is to transform `in T` parameters to const-ref.
+ // We are selective about what we will transform for a few reasons:
+ // 1. no reason to transform simple primitives like `int`.
+ // 2. not every type makes sense as constref. For example, `ParameterBlock`.
+ // 3. constref is not 100% stable, so we need to be selective on what we let
+ // transform into constref.
+ //
+ // This allows us to pass the address of variables directly into a function,
+ // giving us the choice to remove copies into a parameter.
+ auto paramType = param->getDataType();
+ auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal);
+ param->setFullType(constRefType);
+
+ hasTransformedParams = true;
+ changed = true;
+ updatedParams.add(param);
+ }
+ }
+
+ if (!hasTransformedParams)
+ {
+ return;
+ }
+
+ // Second pass: Update function body according to the new `constref` parameters
+ rewriteParamUseSitesToSupportConstRefUsage(updatedParams);
+
+ // Third pass: Update call sites
+ updateCallSites(func, updatedParams);
+ }
+
+ void addFuncsToCallListInTopologicalOrder(
+ IRFunc* root,
+ List<IRFunc*>& functionsToProcess,
+ HashSet<IRFunc*>& visitedCandidates)
+ {
+ // We added 'root' already, leave
+ if (visitedCandidates.contains(root))
+ return;
+
+ visitedCandidates.add(root);
+
+ for (auto block : root->getBlocks())
+ {
+ for (auto blockInst : block->getChildren())
+ {
+ auto call = as<IRCall>(blockInst);
+ if (!call)
+ continue;
+
+ auto callee = as<IRFunc>(call->getCallee());
+ if (!callee)
+ continue;
+
+ addFuncsToCallListInTopologicalOrder(callee, functionsToProcess, visitedCandidates);
+ }
+ }
+
+ if (!shouldProcessFunction(root))
+ return;
+ functionsToProcess.add(root);
+ }
+
+ SlangResult processModule()
+ {
+ // Collect all functions that need processing.
+ // Process all callee's before callers; otherwise we introduce bugs
+
+ HashSet<IRFunc*> visitedCandidates;
+ List<IRFunc*> functionsToProcess;
+ for (auto inst = module->getModuleInst()->getFirstChild(); inst; inst = inst->getNextInst())
+ {
+ auto func = as<IRFunc>(inst);
+ if (!func)
+ continue;
+ addFuncsToCallListInTopologicalOrder(func, functionsToProcess, visitedCandidates);
+ }
+
+ // Process each function
+ for (auto func : functionsToProcess)
+ {
+ processFunc(func);
+ }
+
+ return SLANG_OK;
+ }
+};
+
+SlangResult transformParamsToConstRef(IRModule* module, DiagnosticSink* sink)
+{
+ TransformParamsToConstRefContext context(module, sink);
+ return context.processModule();
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-transform-params-to-constref.h b/source/slang/slang-ir-transform-params-to-constref.h
new file mode 100644
index 000000000..5bdf8a275
--- /dev/null
+++ b/source/slang/slang-ir-transform-params-to-constref.h
@@ -0,0 +1,12 @@
+// source\slang\slang-ir-transform-params-to-constref.h
+#pragma once
+
+#include "slang-ir.h"
+
+namespace Slang
+{
+class DiagnosticSink;
+
+SlangResult transformParamsToConstRef(IRModule* module, DiagnosticSink* sink);
+
+} // namespace Slang
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index b48dcc7e6..6b9273c15 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2943,6 +2943,11 @@ IRConstRefType* IRBuilder::getConstRefType(IRType* valueType)
return (IRConstRefType*)getPtrType(kIROp_ConstRefType, valueType);
}
+IRConstRefType* IRBuilder::getConstRefType(IRType* valueType, AddressSpace addrSpace)
+{
+ return (IRConstRefType*)getPtrType(kIROp_ConstRefType, valueType, addrSpace);
+}
+
IRSPIRVLiteralType* IRBuilder::getSPIRVLiteralType(IRType* type)
{
IRInst* operands[] = {type};