summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-29 17:05:07 -0700
committerGitHub <noreply@github.com>2023-03-29 17:05:07 -0700
commit082c48d96c5f8f6b4f560d705fe731da14409cb4 (patch)
treefe9860aea3326cd321365bc5530a917fcef94718 /source
parenta862f5b7007ef50b5def30506f0cea138b73c710 (diff)
Update checkpoint policy to make obvious recompute decisions. (#2753)
* Update checkpoint policy to make obvious recompute decisions. Also adds an optimization to fold updateElement chains on the same array or struct into a single makeArray or makeStruct. * Bug fixes around array types with different int typed count. * change test. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-expr.cpp7
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp208
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.h10
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h14
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp168
-rw-r--r--source/slang/slang-ir-autodiff.cpp34
-rw-r--r--source/slang/slang-ir-autodiff.h1
-rw-r--r--source/slang/slang-ir-peephole.cpp110
-rw-r--r--source/slang/slang-ir.cpp53
-rw-r--r--source/slang/slang-lower-to-ir.cpp2
10 files changed, 410 insertions, 197 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index cfcb15269..a14ed38d8 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1546,7 +1546,12 @@ namespace Slang
// it is possible that we are referring to a generic value param
if (auto declRefExpr = expr.as<DeclRefExpr>())
{
- auto declRef = getDeclRef(m_astBuilder, declRefExpr);
+ auto checkedExpr = as<DeclRefExpr>(CheckTerm(expr.getExpr()));
+ if (!checkedExpr)
+ return nullptr;
+
+ SubstExpr<DeclRefExpr> substExpr(checkedExpr, expr.getSubsts());
+ auto declRef = getDeclRef(m_astBuilder, substExpr);
if (auto genericValParamRef = declRef.as<GenericValueParamDecl>())
{
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 793a8ff07..04d5560d9 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -23,7 +23,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal
HashSet<IRUse*> processedUses;
HashSet<IRUse*> usesToReplace;
-
+
auto addPrimalOperandsToWorkList = [&](IRInst* inst)
{
UIndex opIndex = 0;
@@ -144,7 +144,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal
if (auto var = as<IRVar>(result.instToRecompute))
{
IRUse* storeUse = findUniqueStoredVal(var);
- if (!storeUse)
+ if (storeUse)
workList.add(storeUse);
}
else
@@ -635,40 +635,216 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
return hoistInfo;
}
-void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode*)
+void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func)
{
- // Do nothing.. This is an (almost) always-store policy.
+ domTree = computeDominatorTree(func);
return;
}
-HoistResult DefaultCheckpointPolicy::classify(IRUse* use)
+static bool doesInstHaveDiffUse(IRInst* inst)
{
- // Store all that we can.. by default, classify will only be called on relevant differential
- // uses (or on uses in a 'recompute' inst)
- //
- if (auto var = as<IRVar>(use->get()))
+ bool hasDiffUser = false;
+
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (isDiffInst(user))
+ {
+ // Ignore uses that is a return or MakeDiffPair
+ switch (user->getOp())
+ {
+ case kIROp_Return:
+ continue;
+ case kIROp_MakeDifferentialPair:
+ if (!user->hasMoreThanOneUse() && user->firstUse &&
+ user->firstUse->getUser()->getOp() == kIROp_Return)
+ continue;
+ break;
+ default:
+ break;
+ }
+ hasDiffUser = true;
+ break;
+ }
+ }
+
+ return hasDiffUser;
+}
+
+static bool doesInstHaveStore(IRInst* inst)
+{
+ SLANG_RELEASE_ASSERT(as<IRPtrTypeBase>(inst->getDataType()));
+
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ if (as<IRStore>(use->getUser()))
+ return true;
+
+ if (as<IRPtrTypeBase>(use->getUser()->getDataType()))
+ {
+ if (doesInstHaveStore(use->getUser()))
+ return true;
+ }
+ }
+
+ return false;
+}
+
+static bool isIntermediateContextType(IRType* type)
+{
+ switch (type->getOp())
+ {
+ case kIROp_BackwardDiffIntermediateContextType:
+ return true;
+ case kIROp_PtrType:
+ return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType());
+ case kIROp_ArrayType:
+ return isIntermediateContextType(as<IRArrayType>(type)->getElementType());
+ }
+
+ return false;
+}
+
+static bool shouldStoreVar(IRVar* var)
+{
+ // Always store intermediate context var.
+ if (auto typeDecor = var->findDecoration<IRBackwardDerivativePrimalContextDecoration>())
{
+ // If we are specializing a callee's intermediate context with types that can't be stored,
+ // we can't store the entire context.
if (auto spec = as<IRSpecialize>(as<IRPtrTypeBase>(var->getDataType())->getValueType()))
{
for (UInt i = 0; i < spec->getArgCount(); i++)
{
if (!canTypeBeStored(spec->getArg(i)->getDataType()))
- return HoistResult::recompute(use->get());
+ return false;
}
- return HoistResult::store(use->get());
}
- else // if (canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType()));
+ return true;
+ }
+
+ if (isIntermediateContextType(var->getDataType()))
+ {
+ return true;
+ }
+
+ // For now the store policy is simple, we use two conditions:
+ // 1. Is the var used in a differential block and,
+ // 2. Does the var have a store
+ //
+
+ return (doesInstHaveDiffUse(var) && doesInstHaveStore(var) && canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType()));
+}
+
+static bool shouldStoreInst(IRInst* inst)
+{
+ if (!inst->getDataType())
+ {
+ return false;
+ }
+
+ if (!canTypeBeStored(inst->getDataType()))
+ return false;
+
+ // Never store certain opcodes.
+ switch (inst->getOp())
+ {
+ case kIROp_CastFloatToInt:
+ case kIROp_CastIntToFloat:
+ case kIROp_IntCast:
+ case kIROp_FloatCast:
+ case kIROp_MakeVectorFromScalar:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_Reinterpret:
+ case kIROp_BitCast:
+ case kIROp_DefaultConstruct:
+ case kIROp_MakeStruct:
+ case kIROp_MakeTuple:
+ case kIROp_MakeArray:
+ case kIROp_MakeArrayFromElement:
+ case kIROp_MakeDifferentialPair:
+ case kIROp_MakeOptionalNone:
+ case kIROp_MakeOptionalValue:
+ case kIROp_DifferentialPairGetDifferential:
+ case kIROp_DifferentialPairGetPrimal:
+ case kIROp_ExtractExistentialValue:
+ case kIROp_ExtractExistentialType:
+ case kIROp_ExtractExistentialWitnessTable:
+ case kIROp_undefined:
+ return false;
+ case kIROp_GetElement:
+ case kIROp_FieldExtract:
+ case kIROp_swizzle:
+ case kIROp_UpdateElement:
+ case kIROp_OptionalHasValue:
+ case kIROp_GetOptionalValue:
+ case kIROp_MatrixReshape:
+ case kIROp_VectorReshape:
+ // If the operand is already stored, don't store the result of these insts.
+ if (inst->getOperand(0)->findDecoration<IRPrimalValueStructKeyDecoration>())
{
- return HoistResult::store(use->get());
+ return false;
}
+ break;
+ default:
+ break;
+ }
+
+ // Only store if the inst has differential inst user.
+ bool hasDiffUser = doesInstHaveDiffUse(inst);
+ if (!hasDiffUser)
+ return false;
+
+ return true;
+}
+
+bool canRecompute(IRDominatorTree* domTree, IRUse* use)
+{
+ auto param = as<IRParam>(use->get());
+ if (!param)
+ return true;
+ auto paramBlock = as<IRBlock>(param->getParent());
+ for (auto predecessor : paramBlock->getPredecessors())
+ {
+ // If we hit this, the checkpoint policy is trying to recompute
+ // values across a loop region boundary (we don't currently support this,
+ // and in general this is quite inefficient in both compute & memory)
+ //
+ if (domTree->dominates(paramBlock, predecessor))
+ {
+ return false;
+ }
+ }
+ return true;
+}
+
+HoistResult DefaultCheckpointPolicy::classify(IRUse* use)
+{
+ // Store all that we can.. by default, classify will only be called on relevant differential
+ // uses (or on uses in a 'recompute' inst)
+ //
+ if (auto var = as<IRVar>(use->get()))
+ {
+ if (shouldStoreVar(var))
+ return HoistResult::store(var);
+ else
+ return HoistResult::recompute(var);
}
else
{
- if (canTypeBeStored(use->get()->getDataType()))
+ if (shouldStoreInst(use->get()))
return HoistResult::store(use->get());
else
- return HoistResult::recompute(use->get());
+ {
+ // We may not be able to recompute due to limitations of
+ // the unzip pass. If so we will store the result.
+ if (canRecompute(domTree, use))
+ return HoistResult::recompute(use->get());
+
+ // The fallback is to store.
+ return HoistResult::store(use->get());
+ }
}
}
-}; \ No newline at end of file
+};
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h
index dc85942f6..bd2575172 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.h
+++ b/source/slang/slang-ir-autodiff-primal-hoist.h
@@ -218,7 +218,7 @@ namespace Slang
class AutodiffCheckpointPolicyBase : public RefObject
{
- public:
+ public:
AutodiffCheckpointPolicyBase(IRModule* module) : module(module)
{ }
@@ -233,14 +233,14 @@ namespace Slang
virtual HoistResult classify(IRUse* diffBlockUse) = 0;
- protected:
+ protected:
IRModule* module;
};
class DefaultCheckpointPolicy : public AutodiffCheckpointPolicyBase
{
- public:
+ public:
DefaultCheckpointPolicy(IRModule* module)
: AutodiffCheckpointPolicyBase(module)
@@ -248,6 +248,8 @@ namespace Slang
virtual void preparePolicy(IRGlobalValueWithCode* func);
virtual HoistResult classify(IRUse* use);
+
+ RefPtr<IRDominatorTree> domTree;
};
RefPtr<HoistedPrimalsInfo> applyCheckpointSet(
@@ -261,4 +263,4 @@ namespace Slang
IRGlobalValueWithCode* func,
Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlockInfo);
-}; \ No newline at end of file
+};
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 91a1601fb..8c005a5c6 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1621,10 +1621,18 @@ struct DiffTransposePass
if (auto varToHoist = as<IRVar>(inst))
{
varToHoist->insertBefore(varBlock->getFirstOrdinaryInst());
- inst = findUniqueStoredVal(varToHoist)->getUser();
- SLANG_ASSERT(inst);
+ auto uniqueStoreUse = findUniqueStoredVal(varToHoist);
+ if (uniqueStoreUse)
+ {
+ inst = uniqueStoreUse->getUser();
+ SLANG_ASSERT(inst);
- defBlock = getBlock(inst);
+ defBlock = getBlock(inst);
+ }
+ else
+ {
+ defBlock = getBlock(inst);
+ }
}
else
{
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index af7792748..53f0cbba2 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -91,18 +91,6 @@ struct ExtractPrimalFuncContext
return newFuncType;
}
- bool isDiffInst(IRInst* inst)
- {
- if (inst->findDecoration<IRDifferentialInstDecoration>() ||
- inst->findDecoration<IRMixedDifferentialInstDecoration>())
- return true;
-
- if (auto block = as<IRBlock>(inst->getParent()))
- return isDiffInst(block);
-
- return false;
- }
-
IRInst* insertIntoReturnBlock(IRBuilder& builder, IRInst* inst)
{
if (!isDiffInst(inst))
@@ -130,162 +118,6 @@ struct ExtractPrimalFuncContext
}
}
- bool doesInstHaveDiffUse(IRInst* inst)
- {
- bool hasDiffUser = false;
-
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- auto user = use->getUser();
- if (isDiffInst(user))
- {
- // Ignore uses that is a return or MakeDiffPair
- switch (user->getOp())
- {
- case kIROp_Return:
- continue;
- case kIROp_MakeDifferentialPair:
- if (!user->hasMoreThanOneUse() && user->firstUse &&
- user->firstUse->getUser()->getOp() == kIROp_Return)
- continue;
- break;
- default:
- break;
- }
- hasDiffUser = true;
- break;
- }
- }
-
- return hasDiffUser;
- }
-
- bool doesInstHaveStore(IRInst* inst)
- {
- SLANG_RELEASE_ASSERT(as<IRPtrTypeBase>(inst->getDataType()));
-
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- if (as<IRStore>(use->getUser()))
- return true;
-
- if (as<IRPtrTypeBase>(use->getUser()->getDataType()))
- {
- if (doesInstHaveStore(use->getUser()))
- return true;
- }
- }
-
- return false;
- }
-
- bool isIntermediateContextType(IRType* type)
- {
- switch (type->getOp())
- {
- case kIROp_BackwardDiffIntermediateContextType:
- return true;
- case kIROp_PtrType:
- return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType());
- case kIROp_ArrayType:
- return isIntermediateContextType(as<IRArrayType>(type)->getElementType());
- }
-
- return false;
- }
-
- bool shouldStoreVar(IRVar* var)
- {
- // Always store intermediate context var.
- if (auto typeDecor = var->findDecoration<IRBackwardDerivativePrimalContextDecoration>())
- {
- // If we are specializing a callee's intermediate context with types that can't be stored,
- // we can't store the entire context.
- if (auto spec = as<IRSpecialize>(as<IRPtrTypeBase>(var->getDataType())->getValueType()))
- {
- for (UInt i = 0; i < spec->getArgCount(); i++)
- {
- if (!canTypeBeStored(spec->getArg(i)->getDataType()))
- return false;
- }
- }
- return true;
- }
-
- if (isIntermediateContextType(var->getDataType()))
- {
- return true;
- }
-
- // For now the store policy is simple, we use two conditions:
- // 1. Is the var used in a differential block and,
- // 2. Does the var have a store
- //
-
- return (doesInstHaveDiffUse(var) && doesInstHaveStore(var) && canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType()));
- }
-
- bool shouldStoreInst(IRInst* inst)
- {
- if (!inst->getDataType())
- {
- return false;
- }
-
- if (!canTypeBeStored(inst->getDataType()))
- return false;
-
- // Never store certain opcodes.
- switch (inst->getOp())
- {
- case kIROp_CastFloatToInt:
- case kIROp_CastIntToFloat:
- case kIROp_IntCast:
- case kIROp_FloatCast:
- case kIROp_MakeVectorFromScalar:
- case kIROp_MakeMatrixFromScalar:
- case kIROp_Reinterpret:
- case kIROp_BitCast:
- case kIROp_DefaultConstruct:
- case kIROp_MakeStruct:
- case kIROp_MakeTuple:
- case kIROp_MakeArray:
- case kIROp_MakeArrayFromElement:
- case kIROp_MakeDifferentialPair:
- case kIROp_MakeOptionalNone:
- case kIROp_MakeOptionalValue:
- case kIROp_DifferentialPairGetDifferential:
- case kIROp_DifferentialPairGetPrimal:
- case kIROp_ExtractExistentialValue:
- case kIROp_ExtractExistentialType:
- case kIROp_ExtractExistentialWitnessTable:
- return false;
- case kIROp_GetElement:
- case kIROp_FieldExtract:
- case kIROp_swizzle:
- case kIROp_UpdateElement:
- case kIROp_OptionalHasValue:
- case kIROp_GetOptionalValue:
- case kIROp_MatrixReshape:
- case kIROp_VectorReshape:
- // If the operand is already stored, don't store the result of these insts.
- if (inst->getOperand(0)->findDecoration<IRPrimalValueStructKeyDecoration>())
- {
- return false;
- }
- break;
- default:
- break;
- }
-
- // Only store if the inst has differential inst user.
- bool hasDiffUser = doesInstHaveDiffUse(inst);
- if (!hasDiffUser)
- return false;
-
- return true;
- }
-
IRStructField* addIntermediateContextField(IRInst* type, IRInst* intermediateOutput)
{
IRBuilder genTypeBuilder(module);
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 10c751d52..024d31fd8 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -282,15 +282,29 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
IRBuilder* builder, IRType* originalPairType)
{
IRInst* result = nullptr;
- if (pairTypeCache.TryGetValue(originalPairType, result))
- return result;
auto pairType = as<IRDifferentialPairTypeBase>(originalPairType);
if (!pairType)
+ return originalPairType;
+
+ // We make our type cache keyed on the primal type, not the pair type.
+ // This is because there may be duplicate pair types for the same
+ // primal type but different witness tables, and we don't want to treat
+ // them as distinct.
+ // We might want to consider making witness tables part of IR
+ // deduplication (make them HOISTABLE insts), but that is a bigger
+ // change. Another alternative is to make the witness operand of
+ // `IRDifferentialPairTypeBase` be child instead of an operand
+ // so that it is not considered part of the type for deduplication
+ // purposes.
+
+ auto primalType = pairType->getValueType();
+ if (pairTypeCache.TryGetValue(primalType, result))
+ return result;
+ if (!pairType)
{
result = originalPairType;
return result;
}
- auto primalType = pairType->getValueType();
if (as<IRParam>(primalType))
{
result = nullptr;
@@ -301,7 +315,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
if (!diffType)
return result;
result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
- pairTypeCache.Add(originalPairType, result);
+ pairTypeCache.Add(primalType, result);
return result;
}
@@ -1820,4 +1834,16 @@ bool isDerivativeContextVar(IRVar* var)
return var->findDecoration<IRBackwardDerivativePrimalContextDecoration>();
}
+bool isDiffInst(IRInst* inst)
+{
+ if (inst->findDecoration<IRDifferentialInstDecoration>() ||
+ inst->findDecoration<IRMixedDifferentialInstDecoration>())
+ return true;
+
+ if (auto block = as<IRBlock>(inst->getParent()))
+ return isDiffInst(block);
+
+ return false;
+}
+
}
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index da0cdc755..167aa2357 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -346,5 +346,6 @@ IRUse* findUniqueStoredVal(IRVar* var);
bool isDerivativeContextVar(IRVar* var);
+bool isDiffInst(IRInst* inst);
};
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index 65b5d2f45..ab3f0ceab 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -43,6 +43,9 @@ struct PeepholeContext : InstPassBase
chainKey.reverse();
if (auto updateInst = as<IRUpdateElement>(chainNode))
{
+ // If we see an extract(updateElement(x, accessChain, val), accessChain), then
+ // we can replace the inst with val.
+
if (updateInst->getAccessKeyCount() > (UInt)chainKey.getCount())
return false;
@@ -96,6 +99,8 @@ struct PeepholeContext : InstPassBase
}
else if (isAccessChainNotEqual)
{
+ // If we see an extract(updateElement(x, accessChain, val), accessChain2), where accessChain!=accessChain2,
+ // then we can replace the inst with extract(x, accessChain2).
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto newInst = builder.emitElementExtract(updateInst->getOldValue(), chainKey.getArrayView());
@@ -445,12 +450,65 @@ struct PeepholeContext : InstPassBase
changed = true;
}
}
+ else
+ {
+ // Check if the updated value is a chain of `updateElement` instructions that
+ // updates every element in the same array, and if so we can replace the
+ // whole chain with a single `makeArray` instruction.
+ auto arrayType = as<IRArrayType>(inst->getDataType());
+ if (!arrayType) break;
+ auto arraySize = as<IRIntLit>(arrayType->getElementCount());
+ if (!arraySize) break;
+
+ List<IRInst*> args;
+ args.setCount((UInt)arraySize->getValue());
+ for (Index i = 0; i < args.getCount(); i++)
+ args[i] = nullptr;
+
+ for (auto updateElement = updateInst; updateElement;
+ updateElement = as<IRUpdateElement>(updateElement->getOldValue()))
+ {
+ auto subKey = updateElement->getAccessKey(0);
+ auto subConstIndex = as<IRIntLit>(subKey);
+ if (!subConstIndex)
+ break;
+ auto index = (Index)subConstIndex->getValue();
+ if (index >= args.getCount())
+ break;
+ // If we have already seen an update for this index, then we can't
+ // override it with an earlier update.
+ if (args[index])
+ continue;
+ args[index] = updateElement->getElementValue();
+ }
+
+ bool isComplete = true;
+ for (auto arg : args)
+ {
+ if (!arg)
+ {
+ isComplete = false;
+ break;
+ }
+ }
+ if (isComplete)
+ {
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+ auto makeArray = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());
+ inst->replaceUsesWith(makeArray);
+ maybeRemoveOldInst(inst);
+ changed = true;
+ }
+ }
}
else if (auto structKey = as<IRStructKey>(key))
{
auto oldVal = inst->getOperand(0);
if (oldVal->getOp() == kIROp_MakeStruct)
{
+ // If we see updateElement(makeStruct(...), structKey, ...), we can
+ // replace it with a makeStruct that has the updated value.
auto structType = as<IRStructType>(inst->getDataType());
if (!structType) break;
List<IRInst*> args;
@@ -484,6 +542,58 @@ struct PeepholeContext : InstPassBase
changed = true;
}
}
+ else
+ {
+ // Check if the updated `oldVal` is a chain of updateElement insts that assigns
+ // values to every field of the struct, if so, we can just emit a makeStruct instead.
+ Dictionary<IRStructKey*, IRInst*> mapFieldKeyToVal;
+ for (auto updateElement = as<IRUpdateElement>(inst); updateElement;
+ updateElement = as<IRUpdateElement>(updateElement->getOldValue()))
+ {
+ if (updateElement->getAccessKeyCount() != 1)
+ break;
+ auto subStructKey = as<IRStructKey>(updateElement->getAccessKey(0));
+ if (!subStructKey)
+ break;
+
+ // If the key already exists, it means there is already a later update at this key.
+ // We need to be careful not to override it with an earlier value.
+ // AddIfNotExists will ensure this does not happen.
+ mapFieldKeyToVal.AddIfNotExists(
+ subStructKey, updateElement->getElementValue());
+ }
+
+ // Check if every field of the struct has a value assigned to it,
+ // while build up arguments for makeStruct inst at the same time.
+ auto structType = as<IRStructType>(inst->getDataType());
+ if (!structType) break;
+ List<IRInst*> args;
+ bool isComplete = true;
+ for (auto field : structType->getFields())
+ {
+ IRInst* arg = nullptr;
+ if (mapFieldKeyToVal.TryGetValue(field->getKey(), arg))
+ {
+ args.add(arg);
+ }
+ else
+ {
+ isComplete = false;
+ break;
+ }
+ }
+
+ if (!isComplete) break;
+
+ // Create a makeStruct inst using args.
+
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+ auto makeStruct = builder.emitMakeStruct(structType, (UInt)args.getCount(), args.getBuffer());
+ inst->replaceUsesWith(makeStruct);
+ maybeRemoveOldInst(inst);
+ changed = true;
+ }
}
}
break;
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index a4d7840d3..16926a742 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2389,6 +2389,25 @@ namespace Slang
capabilitySetType, kIROp_CapabilitySet, args.getCount(), args.getBuffer());
}
+ static void canonicalizeInstOperands(IRBuilder& builder, IRInst* inst)
+ {
+ // For Array types, we always want to make sure its element count
+ // has an int32_t type. We will convert all other int types to int32_t
+ // to avoid things like float[8] and float[8U] being distinct types.
+ if (inst->getOp() == kIROp_ArrayType)
+ {
+ IRInst* elementCount = inst->getOperand(1);
+ if (auto intLit = as<IRIntLit>(elementCount))
+ {
+ if (intLit->getDataType()->getOp() != kIROp_IntType)
+ {
+ IRInst* newElementCount = builder.getIntValue(builder.getIntType(), intLit->getValue());
+ inst->getOperands()[1].usedValue = newElementCount;
+ }
+ }
+ }
+ }
+
IRInst* IRBuilder::_findOrEmitHoistableInst(
IRType* type,
IROp op,
@@ -2448,6 +2467,8 @@ namespace Slang
}
}
+ canonicalizeInstOperands(*this, inst);
+
// Find or add the key/inst
{
IRInstKey key = { inst };
@@ -2515,6 +2536,38 @@ namespace Slang
UInt operandCount,
IRInst* const* operands)
{
+ switch (op)
+ {
+ case kIROp_ArrayType:
+ {
+ ShortList<IRInst*, 2> newOperands;
+ newOperands.addRange(operands, operandCount);
+
+ // If elementCount does not have int type, then we always cast
+ // it to an int type, to avoid having to deal with the
+ // possibility that an array<int, 2> and an array<int, 2U> are
+ // treated as distinct types.
+ if (operandCount < 2) break;
+ auto elementCount = operands[1];
+ if (elementCount->getFullType() && elementCount->getFullType()->getOp() != kIROp_IntType)
+ {
+ auto intLit = as<IRIntLit>(elementCount);
+ if (intLit)
+ elementCount = getIntValue(getIntType(), intLit->getValue());
+ else
+ elementCount = emitIntrinsicInst(getIntType(), kIROp_IntCast, 1, &elementCount);
+ }
+ newOperands[1] = elementCount;
+ return (IRType*)createIntrinsicInst(
+ nullptr,
+ op,
+ operandCount,
+ newOperands.getArrayView().getBuffer());
+ }
+ default:
+ break;
+ }
+
return (IRType*)createIntrinsicInst(
nullptr,
op,
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 6d4e50463..3affcff44 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -6446,7 +6446,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));
auto irSubType = lowerType(subContext, subType);
- irWitnessTable->setOperand(0, irSubType);
+ irWitnessTable->setConcreteType(irSubType);
// TODO(JS):
// Should the mangled name take part in obfuscation if enabled?