summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArielG-NV <159081215+ArielG-NV@users.noreply.github.com>2025-08-07 00:22:22 -0700
committerGitHub <noreply@github.com>2025-08-07 07:22:22 +0000
commit063cbeaaea2fb00a10c6058ea4a9632092772ea5 (patch)
treeb4412347d6c264c3b1a84ec971921a5e2fe76134
parent9e2685853033f4286feaf22d04a755a7395d95ce (diff)
Initial copy elision pass (#8042)
Fixes #7574 Changes: * Add an initial (fairly simple) optimization pass which is able to eliminate redundant copies. * Our current existing optimizer passes remove redundant load/store very robustly, this pass will focus on other cases of copy elimination * Primary approach is to make all functions which are `in T` and `T` is trivial to copy into a `__constref T`. We then (depending on scenario) manually insert a variable+load if a pass-by-reference is not possible; otherwise we pass by `constref`. * Added optimizations to eliminate redundant code which causes `constref` to fail to compile --------- Co-authored-by: Harsh Aggarwal <haaggarwal@nvidia.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: slangbot <ellieh+slangbot@nvidia.com> Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
-rw-r--r--prelude/slang-cpp-types.h14
-rw-r--r--prelude/slang-cuda-prelude.h39
-rw-r--r--source/slang/slang-emit-cpp.cpp20
-rw-r--r--source/slang/slang-emit.cpp7
-rw-r--r--source/slang/slang-ir-insts.h1
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp137
-rw-r--r--source/slang/slang-ir-transform-params-to-constref.cpp287
-rw-r--r--source/slang/slang-ir-transform-params-to-constref.h12
-rw-r--r--source/slang/slang-ir.cpp5
-rw-r--r--tests/cuda/copy-elision-this-1.slang28
-rw-r--r--tests/cuda/copy-elision-this-2.slang141
-rw-r--r--tests/language-feature/pointer/const-ref.slang8
12 files changed, 647 insertions, 52 deletions
diff --git a/prelude/slang-cpp-types.h b/prelude/slang-cpp-types.h
index 010ab8d6c..491438c80 100644
--- a/prelude/slang-cpp-types.h
+++ b/prelude/slang-cpp-types.h
@@ -440,7 +440,7 @@ struct Texture1D
texture->Sample(samplerState, &loc, &out, sizeof(out));
return out;
}
- T SampleLevel(SamplerState samplerState, float loc, float level)
+ T SampleLevel(SamplerState samplerState, float loc, float level) const
{
T out;
texture->SampleLevel(samplerState, &loc, level, &out, sizeof(out));
@@ -500,7 +500,7 @@ struct Texture2D
texture->Sample(samplerState, &loc.x, &out, sizeof(out));
return out;
}
- T SampleLevel(SamplerState samplerState, const float2& loc, float level)
+ T SampleLevel(SamplerState samplerState, const float2& loc, float level) const
{
T out;
texture->SampleLevel(samplerState, &loc.x, level, &out, sizeof(out));
@@ -566,7 +566,7 @@ struct Texture3D
texture->Sample(samplerState, &loc.x, &out, sizeof(out));
return out;
}
- T SampleLevel(SamplerState samplerState, const float3& loc, float level)
+ T SampleLevel(SamplerState samplerState, const float3& loc, float level) const
{
T out;
texture->SampleLevel(samplerState, &loc.x, level, &out, sizeof(out));
@@ -620,7 +620,7 @@ struct TextureCube
texture->Sample(samplerState, &loc.x, &out, sizeof(out));
return out;
}
- T SampleLevel(SamplerState samplerState, const float3& loc, float level)
+ T SampleLevel(SamplerState samplerState, const float3& loc, float level) const
{
T out;
texture->SampleLevel(samplerState, &loc.x, level, &out, sizeof(out));
@@ -680,7 +680,7 @@ struct Texture1DArray
texture->Sample(samplerState, &loc.x, &out, sizeof(out));
return out;
}
- T SampleLevel(SamplerState samplerState, const float2& loc, float level)
+ T SampleLevel(SamplerState samplerState, const float2& loc, float level) const
{
T out;
texture->SampleLevel(samplerState, &loc.x, level, &out, sizeof(out));
@@ -747,7 +747,7 @@ struct Texture2DArray
texture->Sample(samplerState, &loc.x, &out, sizeof(out));
return out;
}
- T SampleLevel(SamplerState samplerState, const float3& loc, float level)
+ T SampleLevel(SamplerState samplerState, const float3& loc, float level) const
{
T out;
texture->SampleLevel(samplerState, &loc.x, level, &out, sizeof(out));
@@ -808,7 +808,7 @@ struct TextureCubeArray
texture->Sample(samplerState, &loc.x, &out, sizeof(out));
return out;
}
- T SampleLevel(SamplerState samplerState, const float4& loc, float level)
+ T SampleLevel(SamplerState samplerState, const float4& loc, float level) const
{
T out;
texture->SampleLevel(samplerState, &loc.x, level, &out, sizeof(out));
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 178c12f5f..a66fa15cb 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -348,22 +348,22 @@ SLANG_VECTOR_GET_ELEMENT(ulonglong)
SLANG_VECTOR_GET_ELEMENT(float)
SLANG_VECTOR_GET_ELEMENT(double)
-#define SLANG_VECTOR_GET_ELEMENT_PTR(T) \
- SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(T##1 * x, int index) \
- { \
- return ((T*)(x)) + index; \
- } \
- SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(T##2 * x, int index) \
- { \
- return ((T*)(x)) + index; \
- } \
- SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(T##3 * x, int index) \
- { \
- return ((T*)(x)) + index; \
- } \
- SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(T##4 * x, int index) \
- { \
- return ((T*)(x)) + index; \
+#define SLANG_VECTOR_GET_ELEMENT_PTR(T) \
+ SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(const T##1 * x, int index) \
+ { \
+ return ((T*)(x)) + index; \
+ } \
+ SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(const T##2 * x, int index) \
+ { \
+ return ((T*)(x)) + index; \
+ } \
+ SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(const T##3 * x, int index) \
+ { \
+ return ((T*)(x)) + index; \
+ } \
+ SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(const T##4 * x, int index) \
+ { \
+ return ((T*)(x)) + index; \
}
SLANG_VECTOR_GET_ELEMENT_PTR(int)
SLANG_VECTOR_GET_ELEMENT_PTR(bool)
@@ -689,6 +689,11 @@ struct Matrix
{
return rows[index];
}
+
+ SLANG_FORCE_INLINE SLANG_CUDA_CALL const Vector<T, COLS>& operator[](size_t index) const
+ {
+ return rows[index];
+ }
};
@@ -2312,7 +2317,7 @@ struct StructuredBuffer
}
#ifndef SLANG_CUDA_STRUCTURED_BUFFER_NO_COUNT
- SLANG_CUDA_CALL void GetDimensions(uint32_t* outNumStructs, uint32_t* outStride)
+ SLANG_CUDA_CALL void GetDimensions(uint32_t* outNumStructs, uint32_t* outStride) const
{
*outNumStructs = uint32_t(count);
*outStride = uint32_t(sizeof(T));
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index ffc4b97ef..66829308d 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -307,6 +307,25 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S
case kIROp_PtrType:
case kIROp_ConstRefType:
{
+ // Special note on `constref` types and why they are not emitted
+ // as a `const` pointer:
+ //
+ // We currently do not propegate/manage "constness" for locals.
+ // This is important since it means that we rely on opimization
+ // passes to remove all temporary pointer-variables created from
+ // our constref, otherwise we will generate invalid code like
+ // `T* var = const_ptr` or `T* var = &const_ptr->member`.
+ //
+ // If emitting `constref` fails due to this error, it is likely
+ // a missing compiler-optimization.
+ //
+ // Additionally, for C++/CUDA, downstream methods are required
+ // to be `const` if we want to use const pointers. This is currently
+ // not handled robustly.
+ //
+ // Due to these cascading issues, we do not emit const and instead
+ // emit as a regular pointer for the time being.
+
auto elementType = (IRType*)type->getOperand(0);
SLANG_RETURN_ON_FAIL(calcTypeName(elementType, target, out));
out << "*";
@@ -599,6 +618,7 @@ CPPSourceEmitter::CPPSourceEmitter(const Desc& desc)
void CPPSourceEmitter::emitParamTypeImpl(IRType* type, String const& name)
{
+ // For use the CPP-specific emitType implementation
emitType(type, name);
}
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 405bca5a2..7d8f1438d 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -110,6 +110,7 @@
#include "slang-ir-strip-default-construct.h"
#include "slang-ir-strip-legalization-insts.h"
#include "slang-ir-synthesize-active-mask.h"
+#include "slang-ir-transform-params-to-constref.h"
#include "slang-ir-translate-global-varying-var.h"
#include "slang-ir-undo-param-copy.h"
#include "slang-ir-uniformity.h"
@@ -1714,6 +1715,12 @@ Result linkAndOptimizeIR(
// For CUDA/OptiX like targets, add our pass to replace inout parameter copies with direct
// pointers
undoParameterCopy(irModule);
+ // Transform struct parameters to use ConstRef for better performance
+ if (isCPUTarget(targetRequest) || isCUDATarget(targetRequest) ||
+ isMetalTarget(targetRequest))
+ {
+ transformParamsToConstRef(irModule, codeGenContext->getSink());
+ }
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "PARAMETER COPIES REPLACED WITH DIRECT POINTERS");
#endif
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 93979bdbe..c8604f4fa 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3729,6 +3729,7 @@ public:
IRInOutType* getInOutType(IRType* valueType);
IRRefType* getRefType(IRType* valueType, AddressSpace addrSpace);
IRConstRefType* getConstRefType(IRType* valueType);
+ IRConstRefType* getConstRefType(IRType* valueType, AddressSpace addrSpace);
IRPtrTypeBase* getPtrType(IROp op, IRType* valueType);
IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace);
IRPtrType* getPtrType(IROp op, IRType* valueType, IRInst* addressSpace);
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
index a6dac723e..1feab47dd 100644
--- a/source/slang/slang-ir-redundancy-removal.cpp
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -361,43 +361,132 @@ bool tryRemoveRedundantStore(IRGlobalValueWithCode* func, IRStore* store)
return false;
}
+// Checks if we can change or have a modified rootVar
+// at some point.
+bool isExternallyModifiableAddr(IRInst* rootVar)
+{
+ if (!rootVar)
+ return false;
+
+ auto ptr = as<IRConstRefType>(rootVar->getDataType());
+ if (!ptr)
+ return true;
+
+ // Only a UserPointer can potentially be modified and changed to point to a different address
+ // if constRef. This may happen from a different thread even if constref to the current thread.
+ auto addrSpace = ptr->getAddressSpace();
+ if (addrSpace == AddressSpace::UserPointer)
+ return true;
+
+ return false;
+}
+
+bool tryRemoveRedundantLoad(IRGlobalValueWithCode* func, IRLoad* load)
+{
+ bool changed = false;
+
+ // If the load is preceeded by a store without any side-effect insts
+ // in-between, remove the load.
+ for (auto prev = load->getPrevInst(); prev; prev = prev->getPrevInst())
+ {
+ if (auto store = as<IRStore>(prev))
+ {
+ if (store->getPtr() == load->getPtr())
+ {
+ auto value = store->getVal();
+ load->replaceUsesWith(value);
+ load->removeAndDeallocate();
+ changed = true;
+ break;
+ }
+ }
+
+ if (canInstHaveSideEffectAtAddress(func, prev, load->getPtr()))
+ {
+ break;
+ }
+ }
+
+ return changed;
+}
+
bool eliminateRedundantLoadStore(IRGlobalValueWithCode* func)
{
bool changed = false;
for (auto block : func->getBlocks())
{
- for (auto inst = block->getFirstInst(); inst;)
+ IRInst* nextInst = nullptr;
+ for (auto inst = block->getFirstInst(); inst; inst = nextInst)
{
- auto nextInst = inst->getNextInst();
+ nextInst = inst->getNextInst();
if (auto load = as<IRLoad>(inst))
{
- for (auto prev = inst->getPrevInst(); prev; prev = prev->getPrevInst())
- {
- if (auto store = as<IRStore>(prev))
- {
- if (store->getPtr() == load->getPtr())
- {
- // If the load is preceeded by a store without any side-effect insts
- // in-between, remove the load.
- auto value = store->getVal();
- load->replaceUsesWith(value);
- load->removeAndDeallocate();
- changed = true;
- break;
- }
- }
-
- if (canInstHaveSideEffectAtAddress(func, prev, load->getPtr()))
- {
- break;
- }
- }
+ changed |= tryRemoveRedundantLoad(func, load);
}
else if (auto store = as<IRStore>(inst))
{
changed |= tryRemoveRedundantStore(func, store);
}
- inst = nextInst;
+ else if (auto getElementPtr = as<IRGetElementPtr>(inst))
+ {
+ auto rootAddr = getRootAddr(getElementPtr);
+ if (isExternallyModifiableAddr(rootAddr))
+ continue;
+
+ // GetElement(Load(GetElementPtr(x)))) ==> Load(GetElementPtr(GetElementPtr(x)))
+ // The benefit is that any GetAddr(Load(...)) can then transitively be optimized
+ // out.
+ // This can only be done if we have no side-effects. `constref` never has
+ // single-invocation side-effects.
+ traverseUsers<IRLoad>(
+ getElementPtr,
+ [&](IRLoad* load)
+ {
+ traverseUsers<IRGetElement>(
+ load,
+ [&](IRGetElement* getElement)
+ {
+ IRBuilder builder(getElement);
+ builder.setInsertBefore(getElement);
+ auto newGetElementPtr = builder.emitElementAddress(
+ getElementPtr,
+ getElement->getIndex());
+ auto newLoad = builder.emitLoad(newGetElementPtr);
+ getElement->replaceUsesWith(newLoad);
+ changed = true;
+ });
+ });
+ }
+ else if (auto fieldAddress = as<IRFieldAddress>(inst))
+ {
+ auto rootAddr = getRootAddr(fieldAddress);
+ if (isExternallyModifiableAddr(rootAddr))
+ continue;
+
+ // ExtractField(Load(GetFieldAddr(x)))) ==> Load(GetFieldAddr(GetFieldAddr(x)))
+ // The benefit is that any GetAddr(Load(...)) can then transitively be optimized
+ // out.
+ // This can only be done if we have no side-effects. `constref` never has
+ // single-invocation side-effects.
+ traverseUsers<IRLoad>(
+ fieldAddress,
+ [&](IRLoad* load)
+ {
+ traverseUsers<IRFieldExtract>(
+ load,
+ [&](IRFieldExtract* fieldExtract)
+ {
+ IRBuilder builder(fieldExtract);
+ builder.setInsertBefore(fieldExtract);
+ auto newGetFieldAddress = builder.emitFieldAddress(
+ fieldAddress,
+ fieldExtract->getField());
+ auto newLoad = builder.emitLoad(newGetFieldAddress);
+ fieldExtract->replaceUsesWith(newLoad);
+ changed = true;
+ });
+ });
+ }
}
}
return changed;
diff --git a/source/slang/slang-ir-transform-params-to-constref.cpp b/source/slang/slang-ir-transform-params-to-constref.cpp
new file mode 100644
index 000000000..8f4bcd037
--- /dev/null
+++ b/source/slang/slang-ir-transform-params-to-constref.cpp
@@ -0,0 +1,287 @@
+#include "slang-ir-transform-params-to-constref.h"
+
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+
+struct TransformParamsToConstRefContext
+{
+ IRModule* module;
+ DiagnosticSink* sink;
+ IRBuilder builder;
+ bool changed = false;
+
+ TransformParamsToConstRefContext(IRModule* module, DiagnosticSink* sink)
+ : module(module), sink(sink), builder(module)
+ {
+ }
+
+ // Check if a type should be transformed (struct, array, or other composite types)
+ bool shouldTransformParam(IRParam* param)
+ {
+ auto type = param->getDataType();
+ if (!type)
+ return false;
+
+ switch (type->getOp())
+ {
+ case kIROp_StructType:
+ case kIROp_ArrayType:
+ case kIROp_UnsizedArrayType:
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ case kIROp_TupleType:
+ case kIROp_CoopVectorType:
+ // valid type, continue to check
+ break;
+ default:
+ return false;
+ }
+
+ return true;
+ }
+
+ void rewriteParamUseSitesToSupportConstRefUsage(HashSet<IRParam*>& updatedParams)
+ {
+ // Traverse the uses of our updated params to rewrite them.
+ // Assume a `in` parameter has been converted to a `constref` parameter.
+ for (auto param : updatedParams)
+ {
+ traverseUses(
+ param,
+ [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ switch (user->getOp())
+ {
+ case kIROp_FieldExtract:
+ {
+ // Transform the IRFieldExtract into a IRFieldAddress
+ auto fieldExtract = as<IRFieldExtract>(use->getUser());
+ builder.setInsertBefore(fieldExtract);
+ auto fieldAddr = builder.emitFieldAddress(
+ fieldExtract->getBase(),
+ fieldExtract->getField());
+ auto loadInst = builder.emitLoad(fieldAddr);
+ fieldExtract->replaceUsesWith(loadInst);
+ fieldExtract->removeAndDeallocate();
+ break;
+ }
+ case kIROp_GetElement:
+ {
+ // Transform the IRGetElement into a IRGetElementPtr
+ auto getElement = as<IRGetElement>(use->getUser());
+
+ builder.setInsertBefore(getElement);
+ auto elemAddr = builder.emitElementAddress(
+ getElement->getBase(),
+ getElement->getIndex());
+ auto loadInst = builder.emitLoad(elemAddr);
+ getElement->replaceUsesWith(loadInst);
+ getElement->removeAndDeallocate();
+ break;
+ }
+ default:
+ {
+ // Insert a load before the user and replace the user with the load
+ builder.setInsertBefore(user);
+ auto loadInst = builder.emitLoad(param);
+ use->set(loadInst);
+ break;
+ }
+ }
+ });
+ }
+ }
+
+ // Update call sites to pass an address instead of value for each updated-param
+ void updateCallSites(IRFunc* func, HashSet<IRParam*>& updatedParams)
+ {
+ // Find all calls which use `func`.
+ List<IRCall*> callsToUpdate;
+ traverseUsers<IRCall>(func, [&](IRCall* call) { callsToUpdate.add(call); });
+
+ // Update each call site
+ for (auto call : callsToUpdate)
+ {
+ builder.setInsertBefore(call);
+ List<IRInst*> newArgs;
+
+ // Transform arguments to match the updated-parameter
+ IRParam* param = func->getFirstParam();
+ UInt i = 0;
+ auto iterate = [&]()
+ {
+ param = param->getNextParam();
+ i++;
+ };
+ for (; param; iterate())
+ {
+ auto arg = call->getArg(i);
+ if (!updatedParams.contains(param))
+ {
+ newArgs.add(arg);
+ continue;
+ }
+
+ auto tempVar = builder.emitVar(arg->getFullType());
+ builder.emitStore(tempVar, arg);
+ newArgs.add(tempVar);
+ }
+
+ // Create new call with updated arguments
+ auto newCall = builder.emitCallInst(call->getFullType(), func, newArgs);
+ call->replaceUsesWith(newCall);
+ call->removeAndDeallocate();
+ }
+ }
+
+ // Check if function should be excluded from transformation
+ bool shouldProcessFunction(IRFunc* func)
+ {
+ // Skip functions without definitions
+ if (!func->isDefinition())
+ return false;
+
+ // Skip if we find any of these decorations
+ for (auto decoration : func->getDecorations())
+ {
+ // Skip functions with target intrinsic decorations.
+ // These functions cannot be properly legalized after
+ // transformation.
+ if (as<IRTargetIntrinsicDecoration>(decoration))
+ return false;
+
+ // Skip entry-point and pseudo-entry-point functions
+ // since we cannot legalize the input parameters.
+ if (as<IREntryPointDecoration>(decoration) || as<IRCudaKernelDecoration>(decoration) ||
+ as<IRAutoPyBindCudaDecoration>(decoration))
+ return false;
+ }
+
+ // Skip functions with `kIROp_GenericAsm` since
+ // these instructions inject target specific code
+ // using parameters in an unpredictable way, relying
+ // on assumptions that parameters do not change type.
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (!as<IRGenericAsm>(inst))
+ continue;
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ // Process a single function
+ void processFunc(IRFunc* func)
+ {
+ HashSet<IRParam*> updatedParams;
+ bool hasTransformedParams = false;
+
+ // First pass: Transform parameter types
+ for (auto param = func->getFirstParam(); param; param = param->getNextParam())
+ {
+ if (shouldTransformParam(param))
+ {
+ // Our goal here is to transform `in T` parameters to const-ref.
+ // We are selective about what we will transform for a few reasons:
+ // 1. no reason to transform simple primitives like `int`.
+ // 2. not every type makes sense as constref. For example, `ParameterBlock`.
+ // 3. constref is not 100% stable, so we need to be selective on what we let
+ // transform into constref.
+ //
+ // This allows us to pass the address of variables directly into a function,
+ // giving us the choice to remove copies into a parameter.
+ auto paramType = param->getDataType();
+ auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal);
+ param->setFullType(constRefType);
+
+ hasTransformedParams = true;
+ changed = true;
+ updatedParams.add(param);
+ }
+ }
+
+ if (!hasTransformedParams)
+ {
+ return;
+ }
+
+ // Second pass: Update function body according to the new `constref` parameters
+ rewriteParamUseSitesToSupportConstRefUsage(updatedParams);
+
+ // Third pass: Update call sites
+ updateCallSites(func, updatedParams);
+ }
+
+ void addFuncsToCallListInTopologicalOrder(
+ IRFunc* root,
+ List<IRFunc*>& functionsToProcess,
+ HashSet<IRFunc*>& visitedCandidates)
+ {
+ // We added 'root' already, leave
+ if (visitedCandidates.contains(root))
+ return;
+
+ visitedCandidates.add(root);
+
+ for (auto block : root->getBlocks())
+ {
+ for (auto blockInst : block->getChildren())
+ {
+ auto call = as<IRCall>(blockInst);
+ if (!call)
+ continue;
+
+ auto callee = as<IRFunc>(call->getCallee());
+ if (!callee)
+ continue;
+
+ addFuncsToCallListInTopologicalOrder(callee, functionsToProcess, visitedCandidates);
+ }
+ }
+
+ if (!shouldProcessFunction(root))
+ return;
+ functionsToProcess.add(root);
+ }
+
+ SlangResult processModule()
+ {
+ // Collect all functions that need processing.
+ // Process all callee's before callers; otherwise we introduce bugs
+
+ HashSet<IRFunc*> visitedCandidates;
+ List<IRFunc*> functionsToProcess;
+ for (auto inst = module->getModuleInst()->getFirstChild(); inst; inst = inst->getNextInst())
+ {
+ auto func = as<IRFunc>(inst);
+ if (!func)
+ continue;
+ addFuncsToCallListInTopologicalOrder(func, functionsToProcess, visitedCandidates);
+ }
+
+ // Process each function
+ for (auto func : functionsToProcess)
+ {
+ processFunc(func);
+ }
+
+ return SLANG_OK;
+ }
+};
+
+SlangResult transformParamsToConstRef(IRModule* module, DiagnosticSink* sink)
+{
+ TransformParamsToConstRefContext context(module, sink);
+ return context.processModule();
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-transform-params-to-constref.h b/source/slang/slang-ir-transform-params-to-constref.h
new file mode 100644
index 000000000..5bdf8a275
--- /dev/null
+++ b/source/slang/slang-ir-transform-params-to-constref.h
@@ -0,0 +1,12 @@
+// source\slang\slang-ir-transform-params-to-constref.h
+#pragma once
+
+#include "slang-ir.h"
+
+namespace Slang
+{
+class DiagnosticSink;
+
+SlangResult transformParamsToConstRef(IRModule* module, DiagnosticSink* sink);
+
+} // namespace Slang
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index b48dcc7e6..6b9273c15 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2943,6 +2943,11 @@ IRConstRefType* IRBuilder::getConstRefType(IRType* valueType)
return (IRConstRefType*)getPtrType(kIROp_ConstRefType, valueType);
}
+IRConstRefType* IRBuilder::getConstRefType(IRType* valueType, AddressSpace addrSpace)
+{
+ return (IRConstRefType*)getPtrType(kIROp_ConstRefType, valueType, addrSpace);
+}
+
IRSPIRVLiteralType* IRBuilder::getSPIRVLiteralType(IRType* type)
{
IRInst* operands[] = {type};
diff --git a/tests/cuda/copy-elision-this-1.slang b/tests/cuda/copy-elision-this-1.slang
new file mode 100644
index 000000000..295b45c73
--- /dev/null
+++ b/tests/cuda/copy-elision-this-1.slang
@@ -0,0 +1,28 @@
+//TEST:SIMPLE(filecheck=CUDA): -stage compute -entry computeMain -target cuda
+struct Data {
+ StructuredBuffer<float> input[2];
+ RWStructuredBuffer<float> output;
+ uint input_tensor_count;
+ StructuredBuffer<uint> index_buffer;
+ uint index_count;
+
+ // CUDA: fetch{{.*}}Data{{.*}}*{{.*}}this
+ float fetch(int buffer, int index)
+ {
+ return input[buffer][index];
+ }
+};
+
+ParameterBlock<Data> data;
+
+[shader("compute")]
+[numthreads(8, 8, 1)]
+void computeMain(uint3 tid: SV_DispatchThreadID)
+{
+ float result = 0.0;
+ for (int i = 0; i < data.index_count; ++i) {
+ uint buffer = data.index_buffer[i];
+ result += data.fetch(buffer, tid.x * 1024 + tid.y);
+ }
+ data.output[tid.x * 1024 + tid.y] = result;
+}
diff --git a/tests/cuda/copy-elision-this-2.slang b/tests/cuda/copy-elision-this-2.slang
new file mode 100644
index 000000000..60bb948c9
--- /dev/null
+++ b/tests/cuda/copy-elision-this-2.slang
@@ -0,0 +1,141 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=BUF): -cuda -compute
+//TEST:SIMPLE(filecheck=CUDA): -stage compute -entry computeMain -target cuda -O3
+
+struct Data
+{
+ int val;
+
+
+ __init(int val)
+ {
+ this.val = val;
+ }
+};
+
+struct DataWrapped
+{
+ Data field;
+ Data element[2];
+
+ __init(int val)
+ {
+ field.val = val;
+ element[0].val = val;
+ element[1].val = val;
+ }
+}
+
+//TEST_INPUT:uniform(data=[1]):name=globalData
+uniform Data globalData;
+
+//TEST_INPUT: set input = ubuffer(data=[1 2 3 4], stride=4)
+RWStructuredBuffer<int> input;
+
+//TEST_INPUT: set output = out ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0], stride=4)
+RWStructuredBuffer<int> output;
+
+// CUDA: addCopyElision{{.*}}Data{{.*}}*{{.*}}data
+int addCopyElision(Data data, int val)
+{
+ // ensure we do not introduce a temporary
+ // CUDA-NOT: Data{{.*}};
+ return data.val + val;
+}
+
+// CUDA: nested{{.*}}Data{{.*}}*{{.*}}data
+int nested(Data data, int val)
+{
+
+ return addCopyElision(data, val);
+}
+
+// CUDA: addCopyElision{{.*}}FixedArray{{.*}}*{{.*}}data
+int addCopyElision(int data[10], int val)
+{
+// ensure we do not introduce a temporary
+// CUDA-NOT: {{.*}}FixedArray{{.*}};
+ return data[1] + val;
+}
+
+// CUDA: nested{{.*}}Array{{.*}}*{{.*}}data
+int nested(int data[10], int val)
+{
+ return addCopyElision(data, val);
+}
+
+void modify(inout int data[10])
+{
+ data[1] = input[0];
+}
+// CUDA: notDirectlyUsingParam{{.*}}Array{{.*}}*{{.*}}data
+int notDirectlyUsingParam(int data[10], int val)
+{
+// ensure we create a temporary for the array
+// CUDA: FixedArray{{.*}};
+ modify(data);
+ return data[1] + val;
+}
+
+
+// CUDA:computeMain
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain()
+{
+
+ // struct
+ Data data = Data(input[0]);
+ int structVal = addCopyElision(data, input[1]);
+
+ // struct which is globalParam
+ int globalParamStructVal = addCopyElision(globalData, input[1]);
+
+ // passing nested struct
+ int nestedStructVal = nested(data, input[1]);
+
+ // field
+ DataWrapped dataWrapped = DataWrapped(input[0]);
+ int fieldVal = addCopyElision(dataWrapped.field, input[1]);
+
+ // element
+ int elementVal = addCopyElision(dataWrapped.element[0], input[1]);
+
+ // A non-variable
+ int nonVariableVal = addCopyElision(Data(input[0]), input[1]);
+
+ // array
+ int val[10];
+ val[1] = input[0];
+ int arrayVal = addCopyElision(val, input[1]);
+
+ // passing nested array
+ int nestedArrayVal = nested(val, input[1]);
+
+ // not directly using param
+ int notDirectlyUsingParamVal = notDirectlyUsingParam(val, input[1]);
+
+ output[0] =
+ structVal == 3 &&
+ globalParamStructVal == 3 &&
+ nestedStructVal == 3 &&
+ fieldVal == 3 &&
+ elementVal == 3 &&
+ nonVariableVal == 3 &&
+ arrayVal == 3 &&
+ nestedArrayVal == 3 &&
+ notDirectlyUsingParamVal == 3
+ ? 1 : 0;
+
+ // For debugging
+ //output[1] = structVal;
+ //output[2] = globalParamStructVal;
+ //output[3] = nestedStructVal;
+ //output[4] = fieldVal;
+ //output[5] = elementVal;
+ //output[6] = nonVariableVal;
+ //output[7] = arrayVal;
+ //output[8] = nestedArrayVal;
+ //output[9] = notDirectlyUsingParamVal;
+}
+
+//BUF: 1 \ No newline at end of file
diff --git a/tests/language-feature/pointer/const-ref.slang b/tests/language-feature/pointer/const-ref.slang
index f62fda697..06bb9dc07 100644
--- a/tests/language-feature/pointer/const-ref.slang
+++ b/tests/language-feature/pointer/const-ref.slang
@@ -3,7 +3,7 @@
//TEST:SIMPLE(filecheck=CHECK): -target cuda -entry computeMain -stage compute
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER): -slang -compute -output-using-type -shaderobj
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER): -vk -compute -output-using-type -shaderobj
-
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER): -cuda -compute -output-using-type -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer;
@@ -14,8 +14,8 @@ struct Thing
int bigArray[128];
// Check that we are not inserting local variables that are copies of `this` parameter.
-
- // CHECK: __device__ int Thing_getSum{{.*}}(Thing{{.*}} * this{{.*}})
+
+ // CHECK: __device__ int Thing_getSum{{.*}}Thing{{.*}}*{{.*}}this{{.*}})
// CHECK-NOT: Thing{{[a-zA-Z0-9_]*}} {{[a-zA-Z0-9_]+}}
// CHECK: }
[constref]
@@ -32,7 +32,7 @@ struct Thing
// Check that we are not inserting local variables that are copies of `thing` parameter.
-// CHECK: __device__ int test{{.*}}(Thing{{.*}} * thing{{.*}})
+// CHECK: __device__ int test{{.*}}Thing{{.*}}*{{.*}}thing{{.*}})
// CHECK-NOT: Thing{{[a-zA-Z0-9_]*}} {{[a-zA-Z0-9_]+}}
// CHECK: }