summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-transform-params-to-constref.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-transform-params-to-constref.cpp')
-rw-r--r--source/slang/slang-ir-transform-params-to-constref.cpp142
1 files changed, 116 insertions, 26 deletions
diff --git a/source/slang/slang-ir-transform-params-to-constref.cpp b/source/slang/slang-ir-transform-params-to-constref.cpp
index 9328a1de1..d34b3d25b 100644
--- a/source/slang/slang-ir-transform-params-to-constref.cpp
+++ b/source/slang/slang-ir-transform-params-to-constref.cpp
@@ -31,10 +31,7 @@ struct TransformParamsToConstRefContext
case kIROp_StructType:
case kIROp_ArrayType:
case kIROp_UnsizedArrayType:
- case kIROp_VectorType:
- case kIROp_MatrixType:
case kIROp_TupleType:
- case kIROp_CoopVectorType:
// valid type, continue to check
break;
default:
@@ -44,58 +41,121 @@ struct TransformParamsToConstRefContext
return true;
}
- void rewriteParamUseSitesToSupportConstRefUsage(HashSet<IRParam*>& updatedParams)
+ void rewriteValueUsesToAddrUses(IRInst* newAddrInst)
{
- // Traverse the uses of our updated params to rewrite them.
- // Assume a `in` parameter has been converted to a `constref` parameter.
- for (auto param : updatedParams)
+ HashSet<IRInst*> workListSet;
+ workListSet.add(newAddrInst);
+ List<IRInst*> workList;
+ workList.add(newAddrInst);
+ auto _addToWorkList = [&](IRInst* inst)
+ {
+ if (workListSet.add(inst))
+ workList.add(inst);
+ };
+ for (Index i = 0; i < workList.getCount(); i++)
{
+ auto inst = workList[i];
traverseUses(
- param,
+ inst,
[&](IRUse* use)
{
auto user = use->getUser();
+ if (workListSet.contains(user))
+ return;
switch (user->getOp())
{
case kIROp_FieldExtract:
{
// Transform the IRFieldExtract into a IRFieldAddress
+ if (!isUseBaseAddrOperand(use, user))
+ break;
auto fieldExtract = as<IRFieldExtract>(user);
builder.setInsertBefore(fieldExtract);
auto fieldAddr = builder.emitFieldAddress(
fieldExtract->getBase(),
fieldExtract->getField());
- auto loadInst = builder.emitLoad(fieldAddr);
- fieldExtract->replaceUsesWith(loadInst);
- fieldExtract->removeAndDeallocate();
- break;
+ fieldExtract->replaceUsesWith(fieldAddr);
+ _addToWorkList(fieldAddr);
+ return;
}
case kIROp_GetElement:
{
// Transform the IRGetElement into a IRGetElementPtr
+ if (!isUseBaseAddrOperand(use, user))
+ break;
auto getElement = as<IRGetElement>(user);
builder.setInsertBefore(getElement);
auto elemAddr = builder.emitElementAddress(
getElement->getBase(),
getElement->getIndex());
- auto loadInst = builder.emitLoad(elemAddr);
- getElement->replaceUsesWith(loadInst);
- getElement->removeAndDeallocate();
- break;
+ getElement->replaceUsesWith(elemAddr);
+ _addToWorkList(elemAddr);
+ return;
}
- default:
+ case kIROp_Store:
{
- // Insert a load before the user and replace the user with the load
- builder.setInsertBefore(user);
- auto loadInst = builder.emitLoad(param);
- use->set(loadInst);
+ // If the current value is being stored into a write-once temp var that
+ // is immediately passed into a constref location in a call, we can get
+ // rid of the temp var and replace it with `inst` directly.
+ // (such temp var can be introduced during `updateCallSites` when we
+ // were processing the callee.)
+ //
+ auto dest = as<IRStore>(user)->getPtr();
+ if (dest->findDecorationImpl(kIROp_TempCallArgImmutableVarDecoration))
+ {
+ user->removeAndDeallocate();
+ dest->replaceUsesWith(inst);
+ dest->removeAndDeallocate();
+ return;
+ }
break;
}
}
+ // Insert a load before the user and replace the user with the load
+ builder.setInsertBefore(user);
+ auto loadInst = builder.emitLoad(inst);
+ use->set(loadInst);
});
}
}
+ void rewriteParamUseSitesToSupportConstRefUsage(HashSet<IRParam*>& updatedParams)
+ {
+ // Traverse the uses of our updated params to rewrite them.
+ // Assume a `in` parameter has been converted to a `constref` parameter.
+ for (auto param : updatedParams)
+ {
+ rewriteValueUsesToAddrUses(param);
+ }
+ }
+
+ // Check if `load` is an `IRLoad(addr)` where `addr` is a immutable location.
+ IRInst* isLoadFromImmutableAddress(IRInst* load)
+ {
+ if (load->getOp() != kIROp_Load)
+ return nullptr;
+ auto addr = load->getOperand(0);
+ auto root = getRootAddr(addr);
+ if (!root)
+ return nullptr;
+ if (!root->getDataType())
+ return nullptr;
+ switch (root->getDataType()->getOp())
+ {
+ case kIROp_ConstantBufferType:
+ case kIROp_ConstRefType:
+ case kIROp_ParameterBlockType:
+ return addr;
+ default:
+ // Note that we should in general not assume a read-only StructuredBuffer or
+ // a pointer with read-only access as an immutable location due to potential aliasing.
+ // We could introduce a compiler flag to turn on optimizations on these buffer types
+ // assuming there is no aliasing.
+ break;
+ }
+ return nullptr;
+ }
+
// Update call sites to pass an address instead of value for each updated-param
void updateCallSites(IRFunc* func, HashSet<IRParam*>& updatedParams)
{
@@ -119,10 +179,19 @@ struct TransformParamsToConstRefContext
newArgs.add(arg);
continue;
}
-
- auto tempVar = builder.emitVar(arg->getFullType());
- builder.emitStore(tempVar, arg);
- newArgs.add(tempVar);
+ if (auto addr = isLoadFromImmutableAddress(arg))
+ {
+ // If existing argument is a load from an immutable buffer address,
+ // we can pass in the address as is, without making a temporary copy.
+ newArgs.add(addr);
+ }
+ else
+ {
+ auto tempVar = builder.emitVar(arg->getFullType());
+ builder.addDecoration(tempVar, kIROp_TempCallArgImmutableVarDecoration);
+ builder.emitStore(tempVar, arg);
+ newArgs.add(tempVar);
+ }
}
// Create new call with updated arguments
@@ -177,6 +246,25 @@ struct TransformParamsToConstRefContext
{
HashSet<IRParam*> updatedParams;
+ // If the function is used in any way that is not understood by the
+ // compiler, do not modify it.
+ // For example, if the function is used as callback, we must preserve
+ // its signature.
+ for (auto use = func->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (as<IRDecoration>(user))
+ continue;
+ if (auto call = as<IRCall>(user))
+ {
+ if (call->getCalleeUse() == use)
+ continue;
+ }
+ // If we reach here, we encountered a non-call use of the func,
+ // we will stop processing.
+ return;
+ }
+
// First pass: Transform parameter types
for (auto param = func->getFirstParam(); param; param = param->getNextParam())
{
@@ -192,7 +280,7 @@ struct TransformParamsToConstRefContext
// This allows us to pass the address of variables directly into a function,
// giving us the choice to remove copies into a parameter.
auto paramType = param->getDataType();
- auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal);
+ auto constRefType = builder.getConstRefType(paramType, AddressSpace::Generic);
param->setFullType(constRefType);
changed = true;
@@ -205,6 +293,8 @@ struct TransformParamsToConstRefContext
return;
}
+ fixUpFuncType(func);
+
// Second pass: Update function body according to the new `constref` parameters
rewriteParamUseSitesToSupportConstRefUsage(updatedParams);