summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-25 17:27:40 -0800
committerGitHub <noreply@github.com>2023-01-25 17:27:40 -0800
commit1f4c7cab13341c2e9d48df2b01ed2c048c17c152 (patch)
treeed85dda63e1c939cf474961b965b7cc1883940bb /source
parentaa6814be1f7dea20597ae34d477e79e53d4a543f (diff)
Unify UpdateField and UpdateElement with access chain. (#2611)
* Unify UpdateField and UpdateElement with access chain. * Fix warnings. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit-c-like.cpp65
-rw-r--r--source/slang/slang-ir-addr-inst-elimination.cpp61
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp88
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h2
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp42
-rw-r--r--source/slang/slang-ir-autodiff-rev.h2
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h44
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp1
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp33
-rw-r--r--source/slang/slang-ir-constexpr.cpp1
-rw-r--r--source/slang/slang-ir-inst-defs.h7
-rw-r--r--source/slang/slang-ir-insts.h52
-rw-r--r--source/slang/slang-ir-peephole.cpp190
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp1
-rw-r--r--source/slang/slang-ir.cpp85
15 files changed, 351 insertions, 323 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index ffb469b9d..160585e26 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -1087,7 +1087,6 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
// Never fold these, because their result cannot be computed
// as a sub-expression (they must be emitted as a declaration
// or statement).
- case kIROp_UpdateField:
case kIROp_UpdateElement:
case kIROp_DefaultConstruct:
return false;
@@ -2487,43 +2486,45 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst)
auto ii = (IRUpdateElement*)inst;
auto subscriptOuter = getInfo(EmitOp::General);
auto subscriptPrec = getInfo(EmitOp::Postfix);
- auto arraySize = as<IRIntLit>(as<IRArrayType>(inst->getDataType())->getElementCount());
- SLANG_RELEASE_ASSERT(arraySize);
emitInstResultDecl(inst);
- m_writer->emit("{");
- for (UInt i = 0; i < (UInt)arraySize->getValue(); i++)
+ if (auto arrayType = as<IRArrayType>(inst->getDataType()))
{
- if (i > 0)
- m_writer->emit(", ");
- emitOperand(ii->getOldValue(), leftSide(subscriptOuter, subscriptPrec));
- m_writer->emit("[");
- m_writer->emit(i);
- m_writer->emit("]");
+ auto arraySize = as<IRIntLit>(arrayType->getElementCount());
+ SLANG_RELEASE_ASSERT(arraySize);
+ m_writer->emit("{");
+ for (UInt i = 0; i < (UInt)arraySize->getValue(); i++)
+ {
+ if (i > 0)
+ m_writer->emit(", ");
+ emitOperand(ii->getOldValue(), leftSide(subscriptOuter, subscriptPrec));
+ m_writer->emit("[");
+ m_writer->emit(i);
+ m_writer->emit("]");
+ }
+ m_writer->emit("}");
+ }
+ else
+ {
+ emitOperand(ii->getOldValue(), getInfo(EmitOp::General));
}
-
- m_writer->emit("}");
- m_writer->emit(";\n");
-
- emitOperand(ii, leftSide(subscriptOuter, subscriptPrec));
- m_writer->emit("[");
- emitOperand(ii->getIndex(), getInfo(EmitOp::General));
- m_writer->emit("] = ");
- emitOperand(ii->getElementValue(), getInfo(EmitOp::General));
- m_writer->emit(";\n");
- }
- break;
- case kIROp_UpdateField:
- {
- auto ii = (IRUpdateField*)inst;
- emitInstResultDecl(inst);
- emitOperand(ii->getOldValue(), getInfo(EmitOp::General));
m_writer->emit(";\n");
- auto subscriptOuter = getInfo(EmitOp::General);
- auto subscriptPrec = getInfo(EmitOp::Postfix);
emitOperand(ii, leftSide(subscriptOuter, subscriptPrec));
- m_writer->emit(".");
- m_writer->emit(getName(ii->getFieldKey()));
+ for (UInt i = 0; i < ii->getAccessKeyCount(); i++)
+ {
+ auto key = ii->getAccessKey(i);
+ if (as<IRStructKey>(key))
+ {
+ m_writer->emit(".");
+ m_writer->emit(getName(key));
+ }
+ else
+ {
+ m_writer->emit("[");
+ emitOperand(key, getInfo(EmitOp::General));
+ m_writer->emit("]");
+ }
+ }
m_writer->emit(" = ");
emitOperand(ii->getElementValue(), getInfo(EmitOp::General));
m_writer->emit(";\n");
diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp
index 877be1406..a5e0e0a4e 100644
--- a/source/slang/slang-ir-addr-inst-elimination.cpp
+++ b/source/slang/slang-ir-addr-inst-elimination.cpp
@@ -36,72 +36,29 @@ struct AddressInstEliminationContext
void storeValue(IRBuilder& builder, IRInst* addr, IRInst* val)
{
- List<IRInst*> baseAddrs;
+ List<IRInst*> accessChain;
for (auto inst = addr; inst;)
{
switch (inst->getOp())
{
default:
- baseAddrs.add(inst);
+ accessChain.add(inst);
goto endLoop;
case kIROp_GetElementPtr:
case kIROp_FieldAddress:
- baseAddrs.add(inst);
+ accessChain.add(inst->getOperand(1));
inst = inst->getOperand(0);
break;
}
}
endLoop:;
- List<IRInst*> values;
- values.setCount(baseAddrs.getCount());
- if (values.getCount() > 1)
- {
- IRInst* currentVal = builder.emitLoad(baseAddrs.getLast());
- values.getLast() = currentVal;
- for (Index i = baseAddrs.getCount() - 2; i >= 1; i--)
- {
- auto inst = baseAddrs[i];
- switch (inst->getOp())
- {
- default:
- sink->diagnose(inst->sourceLoc, Diagnostics::unsupportedUseOfLValueForAutoDiff);
- return;
- case kIROp_GetElementPtr:
- case kIROp_FieldAddress:
- {
- IRInst* args[] = { currentVal, inst->getOperand(1) };
- currentVal = builder.emitIntrinsicInst(
- cast<IRPtrTypeBase>(inst->getFullType())->getValueType(),
- (inst->getOp() == kIROp_GetElementPtr ? kIROp_GetElement : kIROp_FieldExtract),
- 2,
- args);
- values[i] = currentVal;
- }
- break;
- }
- }
- }
- values[0] = val;
- for (Index i = 1; i < values.getCount(); i++)
- {
- auto inst = baseAddrs[i - 1];
- switch (inst->getOp())
- {
- case kIROp_GetElementPtr:
- case kIROp_FieldAddress:
- {
- IRInst* args[] = {values[i], inst->getOperand(1), values[i - 1]};
- values[i] = builder.emitIntrinsicInst(
- values[i]->getFullType(),
- (inst->getOp() == kIROp_GetElementPtr ? kIROp_UpdateElement : kIROp_UpdateField),
- 3,
- args);
- }
- break;
- }
- }
- builder.emitStore(baseAddrs.getLast(), values.getLast());
+ auto lastAddr = accessChain.getLast();
+ auto lastVal = builder.emitLoad(lastAddr);
+ accessChain.removeLast();
+ accessChain.reverse();
+ auto update = builder.emitUpdateElement(lastVal, accessChain, val);
+ builder.emitStore(lastAddr, update);
}
void transformLoadAddr(IRUse* use)
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index f5fa17fae..58c8aae93 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -760,71 +760,42 @@ InstPair ForwardDiffTranscriber::transcribeGetElement(IRBuilder* builder, IRInst
return InstPair(primalGetElementPtr, diffGetElementPtr);
}
-InstPair ForwardDiffTranscriber::transcribeUpdateField(IRBuilder* builder, IRInst* originalInst)
+InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst)
{
- auto updateInst = as<IRUpdateField>(originalInst);
+ auto updateInst = as<IRUpdateElement>(originalInst);
IRInst* origBase = updateInst->getOldValue();
auto primalBase = findOrTranscribePrimalInst(builder, origBase);
- auto field = updateInst->getFieldKey();
- auto primalVal = findOrTranscribePrimalInst(builder, updateInst->getElementValue());
- auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>();
- auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType());
-
- IRInst* primalOperands[] = { primalBase, field, primalVal };
- IRInst* primalUpdateField = builder->emitIntrinsicInst(
- primalType,
- originalInst->getOp(),
- 3,
- primalOperands);
-
- if (!derivativeRefDecor)
+ List<IRInst*> primalAccessChain;
+ for (UInt i = 0; i < updateInst->getAccessKeyCount(); i++)
{
- return InstPair(primalUpdateField, nullptr);
+ auto originalKey = updateInst->getAccessKey(i);
+ auto primalKey = findOrTranscribePrimalInst(builder, originalKey);
+ primalAccessChain.add(primalKey);
}
+ auto origVal = updateInst->getElementValue();
+ auto primalVal = findOrTranscribePrimalInst(builder, origVal);
- IRInst* diffUpdateField = nullptr;
+ IRInst* primalUpdateField =
+ builder->emitUpdateElement(primalBase, primalAccessChain, primalVal);
- if (auto diffType = differentiateType(builder, originalInst->getDataType()))
+ IRInst* diffUpdateElement = nullptr;
+ List<IRInst*> diffAccessChain;
+ for (auto key : primalAccessChain)
{
- if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
+ if (as<IRStructKey>(key))
{
- if (auto diffVal = findOrTranscribeDiffInst(builder, updateInst->getElementValue()))
- {
- auto primalElementType = primalVal->getDataType();
-
- IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey(), diffVal, primalElementType };
- diffUpdateField = builder->emitIntrinsicInst(
- diffType,
- originalInst->getOp(),
- 4,
- diffOperands);
- }
+ auto decor = key->findDecoration<IRDerivativeMemberDecoration>();
+ if (decor)
+ diffAccessChain.add(decor->getDerivativeMemberStructKey());
+ else
+ return InstPair(primalUpdateField, nullptr);
+ }
+ else
+ {
+ diffAccessChain.add(key);
}
}
- return InstPair(primalUpdateField, diffUpdateField);
-}
-
-InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst)
-{
- auto updateInst = as<IRUpdateElement>(originalInst);
-
- IRInst* origBase = updateInst->getOldValue();
- auto primalBase = findOrTranscribePrimalInst(builder, origBase);
- auto primalIndex = findOrTranscribePrimalInst(builder, updateInst->getIndex());
- auto origVal = updateInst->getElementValue();
- auto primalVal = findOrTranscribePrimalInst(builder, origVal);
- auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType());
-
- IRInst* primalOperands[] = { primalBase, primalIndex, primalVal };
- IRInst* primalUpdateField = builder->emitIntrinsicInst(
- primalType,
- originalInst->getOp(),
- 3,
- primalOperands);
-
- IRInst* diffUpdateElement = nullptr;
-
if (auto diffType = differentiateType(builder, originalInst->getDataType()))
{
if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
@@ -833,12 +804,9 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI
{
auto primalElementType = primalVal->getDataType();
- IRInst* diffOperands[] = { diffBase, primalIndex, diffVal, primalElementType };
- diffUpdateElement = builder->emitIntrinsicInst(
- diffType,
- originalInst->getOp(),
- 4,
- diffOperands);
+ diffUpdateElement = builder->emitUpdateElement(
+ diffBase, diffAccessChain, diffVal);
+ builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType);
}
}
}
@@ -1249,8 +1217,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
return transcribeByPassthrough(builder, origInst);
case kIROp_UpdateElement:
return transcribeUpdateElement(builder, origInst);
- case kIROp_UpdateField:
- return transcribeUpdateField(builder, origInst);
case kIROp_unconditionalBranch:
return transcribeControlFlow(builder, origInst);
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 53577f40e..e595191a3 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -63,8 +63,6 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr);
- InstPair transcribeUpdateField(IRBuilder* builder, IRInst* originalInst);
-
InstPair transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst);
InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop);
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 000921c7e..fce2043eb 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -7,6 +7,9 @@
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-autodiff-fwd.h"
+#include "slang-ir-single-return.h"
+#include "slang-ir-addr-inst-elimination.h"
+#include "slang-ir-eliminate-multilevel-break.h"
namespace Slang
{
@@ -483,6 +486,39 @@ namespace Slang
builder.emitBranch(firstBlock);
}
+ struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy
+ {
+ DifferentiableTypeConformanceContext* diffTypeContext;
+
+ virtual bool shouldConvertAddrInst(IRInst* addrInst) override
+ {
+ if (isDifferentiableType(*diffTypeContext, addrInst->getDataType()))
+ return true;
+ return false;
+ }
+ };
+
+ SlangResult BackwardDiffTranscriberBase::prepareFuncForBackwardDiff(IRFunc* func)
+ {
+ DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext);
+ diffTypeContext.setFunc(func);
+
+ if (!isSingleReturnFunc(func))
+ {
+ convertFuncToSingleReturnForm(func->getModule(), func);
+ }
+ eliminateMultiLevelBreakForFunc(func->getModule(), func);
+
+ AutoDiffAddressConversionPolicy cvtPolicty;
+ cvtPolicty.diffTypeContext = &diffTypeContext;
+ auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink);
+ if (SLANG_SUCCEEDED(result))
+ {
+ simplifyFunc(func);
+ }
+ return result;
+ }
+
// Create a copy of originalFunc's forward derivative in the same generic context (if any) of
// `diffPropagateFunc`.
IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc(
@@ -501,8 +537,10 @@ namespace Slang
stripDerivativeDecorations(primalFunc);
eliminateDeadCode(primalOuterParent);
- // Perform simplification.
- simplifyFunc(primalFunc);
+ // Perform required transformations and simplifications on the original func to make it
+ // reversible.
+ if (SLANG_FAILED(prepareFuncForBackwardDiff(primalFunc)))
+ return diffPropagateFunc;
// Forward transcribe the clone of the original func.
ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>(
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index 7aa6c2441..f789089b0 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -85,6 +85,8 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize);
+ SlangResult prepareFuncForBackwardDiff(IRFunc* func);
+
IRFunc* generateNewForwardDerivativeForFunc(IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc);
void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc);
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 51dcd9f45..0d45c6a84 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1065,9 +1065,6 @@ struct DiffTransposePass
case kIROp_UpdateElement:
return transposeUpdateElement(builder, fwdInst, revValue);
- case kIROp_UpdateField:
- return transposeUpdateField(builder, fwdInst, revValue);
-
case kIROp_Specialize:
case kIROp_unconditionalBranch:
case kIROp_conditionalBranch:
@@ -1312,20 +1309,22 @@ struct DiffTransposePass
auto updateInst = as<IRUpdateElement>(fwdUpdate);
List<RevGradient> gradients;
- auto arrayType = cast<IRArrayType>(fwdUpdate->getFullType());
- auto revElement = builder->emitElementExtract(arrayType->getElementType(), revValue, updateInst->getIndex());
+ auto accessChain = updateInst->getAccessChain();
+ auto revElement = builder->emitElementExtract(revValue, accessChain.getArrayView());
gradients.add(RevGradient(
RevGradient::Flavor::Simple,
updateInst->getElementValue(),
revElement,
fwdUpdate));
- auto primalElementType = updateInst->getPrimalElementType();
- auto diffZero = emitDZeroOfDiffInstType(builder, (IRType*)primalElementType);
+ auto primalElementTypeDecor = updateInst->findDecoration<IRPrimalElementTypeDecoration>();
+ SLANG_RELEASE_ASSERT(primalElementTypeDecor);
+
+ auto diffZero = emitDZeroOfDiffInstType(builder, (IRType*)primalElementTypeDecor->getPrimalElementType());
SLANG_ASSERT(diffZero);
auto revRest = builder->emitUpdateElement(
revValue,
- updateInst->getIndex(),
+ accessChain,
diffZero);
gradients.add(RevGradient(
RevGradient::Flavor::Simple,
@@ -1336,35 +1335,6 @@ struct DiffTransposePass
return TranspositionResult(gradients);
}
- TranspositionResult transposeUpdateField(IRBuilder* builder, IRInst* fwdUpdate, IRInst* revValue)
- {
- auto updateInst = as<IRUpdateField>(fwdUpdate);
-
- List<RevGradient> gradients;
- IRType* fieldType = updateInst->getElementValue()->getFullType();
- auto revElement = builder->emitFieldExtract(fieldType, revValue, updateInst->getFieldKey());
- gradients.add(RevGradient(
- RevGradient::Flavor::Simple,
- updateInst->getElementValue(),
- revElement,
- fwdUpdate));
-
- auto primalElementType = updateInst->getPrimalElementType();
- auto diffZero = emitDZeroOfDiffInstType(builder, (IRType*)primalElementType);
- SLANG_ASSERT(diffZero);
- auto revRest = builder->emitUpdateField(
- revValue,
- updateInst->getFieldKey(),
- diffZero);
- gradients.add(RevGradient(
- RevGradient::Flavor::Simple,
- updateInst->getOldValue(),
- revRest,
- fwdUpdate));
- // (A = UpdateField(s, fieldKey, V)) -> [(dV += dA.fieldKey, d_s += UpdateField(revValue, fieldKey, 0)]
- return TranspositionResult(gradients);
- }
-
// Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr.
//
void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad)
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 8d9a01b75..44cb2aa09 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -196,7 +196,6 @@ struct ExtractPrimalFuncContext
case kIROp_FieldExtract:
case kIROp_swizzle:
case kIROp_UpdateElement:
- case kIROp_UpdateField:
case kIROp_OptionalHasValue:
case kIROp_GetOptionalValue:
case kIROp_MatrixReshape:
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 8cefa6a04..ce3e563f5 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -2,8 +2,6 @@
#include "slang-ir-autodiff.h"
#include "slang-ir-inst-pass-base.h"
-#include "slang-ir-single-return.h"
-#include "slang-ir-addr-inst-elimination.h"
namespace Slang
{
@@ -177,29 +175,6 @@ public:
return false;
}
- struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy
- {
- DifferentiableTypeConformanceContext* diffTypeContext;
-
- virtual bool shouldConvertAddrInst(IRInst* addrInst) override
- {
- if (isDifferentiableType(*diffTypeContext, addrInst->getDataType()))
- return true;
- return false;
- }
- };
-
- SlangResult prepareFuncForAutoDiff(DifferentiableTypeConformanceContext& diffTypeContext, IRFunc* func)
- {
- if (!isSingleReturnFunc(func))
- {
- convertFuncToSingleReturnForm(func->getModule(), func);
- }
- AutoDiffAddressConversionPolicy cvtPolicty;
- cvtPolicty.diffTypeContext = &diffTypeContext;
- return eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink);
- }
-
void processFunc(IRGlobalValueWithCode* funcInst)
{
if (!_isFuncMarkedForAutoDiff(funcInst))
@@ -209,14 +184,6 @@ public:
DifferentiableTypeConformanceContext diffTypeContext(&sharedContext);
diffTypeContext.setFunc(funcInst);
- if (isBackwardDifferentiableFunc(funcInst) && !funcInst->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
- {
- if (auto func = as<IRFunc>(funcInst))
- {
- if (SLANG_FAILED(prepareFuncForAutoDiff(diffTypeContext, func)))
- return;
- }
- }
HashSet<IRInst*> produceDiffSet;
HashSet<IRInst*> expectDiffSet;
diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp
index a8cdb5cca..2bbb9618c 100644
--- a/source/slang/slang-ir-constexpr.cpp
+++ b/source/slang/slang-ir-constexpr.cpp
@@ -119,7 +119,6 @@ bool opCanBeConstExpr(IROp op)
case kIROp_swizzle:
case kIROp_GetElement:
case kIROp_FieldExtract:
- case kIROp_UpdateField:
case kIROp_UpdateElement:
case kIROp_ExtractExistentialType:
case kIROp_ExtractExistentialValue:
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 817edaa83..6b6b3924a 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -311,8 +311,7 @@ INST(Call, call, 1, 0)
INST(RTTIObject, rtti_object, 0, 0)
INST(Alloca, alloca, 1, 0)
-INST(UpdateElement, updateElement, 3, 0)
-INST(UpdateField, updateField, 3, 0)
+INST(UpdateElement, updateElement, 2, 0)
INST(PackAnyValue, packAnyValue, 1, 0)
INST(UnpackAnyValue, unpackAnyValue, 1, 0)
@@ -762,6 +761,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// in an intermediary struct for reuse in backward propagation phase.
INST(PrimalValueStructKeyDecoration, primalValueKey, 1, 0)
+ /// Used by the auto-diff pass to mark the primal element type of an
+ /// forward-differentiated updateElement inst.
+ INST(PrimalElementTypeDecoration, primalElementType, 1, 0)
+
/// Used by the auto-diff pass to hold a reference to a
/// differential member of a type in its associated differential type.
INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 10c490f3c..405df4073 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -689,6 +689,18 @@ struct IRPrimalValueStructKeyDecoration : IRDecoration
IRStructKey* getStructKey() { return as<IRStructKey>(getOperand(0)); }
};
+struct IRPrimalElementTypeDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_PrimalElementTypeDecoration
+ };
+
+ IR_LEAF_ISA(PrimalElementTypeDecoration)
+
+ IRInst* getPrimalElementType() { return getOperand(0); }
+};
+
struct IRMixedDifferentialInstDecoration : IRDecoration
{
enum
@@ -2170,28 +2182,15 @@ struct IRUpdateElement : IRInst
IR_LEAF_ISA(UpdateElement)
IRInst* getOldValue() { return getOperand(0); }
- IRInst* getIndex() { return getOperand(1); }
- IRInst* getElementValue() { return getOperand(2); }
- IRInst* getPrimalElementType()
- {
- if (getOperandCount() != 4)
- return nullptr;
- return getOperand(3);
- }
-};
-
-struct IRUpdateField : IRInst
-{
- IR_LEAF_ISA(UpdateField)
-
- IRInst* getOldValue() { return getOperand(0); }
- IRInst* getFieldKey() { return getOperand(1); }
- IRInst* getElementValue() { return getOperand(2); }
- IRInst* getPrimalElementType()
+ IRInst* getElementValue() { return getOperand(1); }
+ IRInst* getAccessKey(UInt index) { return getOperand(2 + index); }
+ UInt getAccessKeyCount() { return getOperandCount() - 2; }
+ List<IRInst*> getAccessChain()
{
- if (getOperandCount() != 4)
- return nullptr;
- return getOperand(3);
+ List<IRInst*> result;
+ for (UInt i = 0; i < getAccessKeyCount(); i++)
+ result.add(getAccessKey(i));
+ return result;
}
};
@@ -2798,6 +2797,7 @@ public:
IRInst* addDifferentiableTypeDictionaryDecoration(IRInst* target);
IRInst* addPrimalValueStructKeyDecoration(IRInst* target, IRStructKey* key);
+ IRInst* addPrimalElementTypeDecoration(IRInst* target, IRInst* type);
// Add a differentiable type entry to the appropriate dictionary.
IRInst* addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness);
@@ -3148,13 +3148,21 @@ public:
IRInst* base,
IRInst* index);
+ IRInst* emitElementExtract(
+ IRInst* base,
+ IRInst* index);
+
+ IRInst* emitElementExtract(
+ IRInst* base,
+ const ArrayView<IRInst*>& accessChain);
+
IRInst* emitElementAddress(
IRType* type,
IRInst* basePtr,
IRInst* index);
IRInst* emitUpdateElement(IRInst* base, IRInst* index, IRInst* newElement);
- IRInst* emitUpdateField(IRInst* base, IRInst* fieldKey, IRInst* newFieldVal);
+ IRInst* emitUpdateElement(IRInst* base, const List<IRInst*>& accessChain, IRInst* newElement);
IRInst* emitGetAddress(
IRType* type,
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index 16f6cd9b9..fd0b4577a 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -11,6 +11,91 @@ struct PeepholeContext : InstPassBase
bool changed = false;
+ bool tryFoldElementExtractFromUpdateInst(IRInst* inst)
+ {
+ bool isAccessChainEqual = false;
+ bool isAccessChainNotEqual = false;
+ List<IRInst*> chainKey;
+ IRInst* chainNode = inst;
+ for (;;)
+ {
+ switch (chainNode->getOp())
+ {
+ case kIROp_FieldExtract:
+ case kIROp_GetElement:
+ chainKey.add(chainNode->getOperand(1));
+ chainNode = chainNode->getOperand(0);
+ continue;
+ }
+ break;
+ }
+ chainKey.reverse();
+ if (auto updateInst = as<IRUpdateElement>(chainNode))
+ {
+ if (updateInst->getAccessKeyCount() > (UInt)chainKey.getCount())
+ return false;
+
+ isAccessChainEqual = true;
+ for (UInt i = 0; i < (UInt)chainKey.getCount(); i++)
+ {
+ if (updateInst->getAccessKey(i) != chainKey[i])
+ {
+ isAccessChainEqual = false;
+ if (as<IRStructKey>(chainKey[i]))
+ {
+ isAccessChainNotEqual = true;
+ break;
+ }
+ else
+ {
+ if (auto constIndex1 = as<IRIntLit>(updateInst->getAccessKey(i)))
+ {
+ if (auto constIndex2 = as<IRIntLit>(chainKey[i]))
+ {
+ if (constIndex1->getValue() != constIndex2->getValue())
+ {
+ isAccessChainNotEqual = true;
+ break;
+ }
+ }
+ }
+ }
+ }
+ }
+ if (isAccessChainEqual)
+ {
+ auto remainingKeys = chainKey.getArrayView(
+ updateInst->getAccessKeyCount(),
+ chainKey.getCount() - updateInst->getAccessKeyCount());
+ if (remainingKeys.getCount() == 0)
+ {
+ inst->replaceUsesWith(updateInst->getElementValue());
+ inst->removeAndDeallocate();
+ return true;
+ }
+ else if (remainingKeys.getCount() > 0)
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(inst);
+ auto newValue = builder.emitElementExtract(updateInst->getElementValue(), remainingKeys);
+ inst->replaceUsesWith(newValue);
+ inst->removeAndDeallocate();
+ return true;
+ }
+ }
+ else if (isAccessChainNotEqual)
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(inst);
+ auto newInst = builder.emitElementExtract(updateInst->getOldValue(), chainKey.getArrayView());
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ return true;
+ }
+ }
+ return false;
+ }
+
void processInst(IRInst* inst)
{
switch (inst->getOp())
@@ -84,19 +169,9 @@ struct PeepholeContext : InstPassBase
}
}
}
- else if (auto updateField = as<IRUpdateField>(inst->getOperand(0)))
+ else
{
- if (inst->getOperand(1) == updateField->getFieldKey())
- {
- inst->replaceUsesWith(updateField->getElementValue());
- inst->removeAndDeallocate();
- changed = true;
- }
- else
- {
- inst->setOperand(0, updateField->getOldValue());
- changed = true;
- }
+ changed = tryFoldElementExtractFromUpdateInst(inst);
}
break;
case kIROp_GetElement:
@@ -119,32 +194,18 @@ struct PeepholeContext : InstPassBase
inst->removeAndDeallocate();
changed = true;
}
- else if (auto updateElement = as<IRUpdateElement>(inst->getOperand(0)))
+ else
{
- if (inst->getOperand(1) == updateElement->getIndex())
- {
- inst->replaceUsesWith(updateElement->getElementValue());
- inst->removeAndDeallocate();
- changed = true;
- }
- else if (auto constIndex1 = as<IRIntLit>(inst->getOperand(1)))
- {
- if (auto constIndex2 = as<IRIntLit>(updateElement->getIndex()))
- {
- // If we can determine that the indices does not match,
- // then reduce the original value operand to before the update.
- if (constIndex1->getValue() != constIndex2->getValue())
- {
- inst->setOperand(0, updateElement->getOldValue());
- changed = true;
- }
- }
- }
+ changed = tryFoldElementExtractFromUpdateInst(inst);
}
break;
case kIROp_UpdateElement:
{
- if (auto constIndex = as<IRIntLit>(inst->getOperand(1)))
+ auto updateInst = as<IRUpdateElement>(inst);
+ if (updateInst->getAccessKeyCount() != 1)
+ break;
+ auto key = updateInst->getAccessKey(0);
+ if (auto constIndex = as<IRIntLit>(key))
{
auto oldVal = inst->getOperand(0);
if (oldVal->getOp() == kIROp_MakeArray ||
@@ -179,44 +240,43 @@ struct PeepholeContext : InstPassBase
}
}
}
- }
- break;
- case kIROp_UpdateField:
- {
- auto oldVal = inst->getOperand(0);
- if (oldVal->getOp() == kIROp_MakeStruct)
+ else if (auto structKey = as<IRStructKey>(key))
{
- auto structType = as<IRStructType>(inst->getDataType());
- if (!structType) break;
- List<IRInst*> args;
- UInt i = 0;
- bool isValid = true;
- for (auto field : structType->getFields())
+ auto oldVal = inst->getOperand(0);
+ if (oldVal->getOp() == kIROp_MakeStruct)
{
- IRInst* arg = nullptr;
- if (i < oldVal->getOperandCount())
- arg = oldVal->getOperand(i);
- if (field->getKey() == inst->getOperand(1))
- arg = inst->getOperand(2);
- if (arg)
+ auto structType = as<IRStructType>(inst->getDataType());
+ if (!structType) break;
+ List<IRInst*> args;
+ UInt i = 0;
+ bool isValid = true;
+ for (auto field : structType->getFields())
{
- args.add(arg);
+ IRInst* arg = nullptr;
+ if (i < oldVal->getOperandCount())
+ arg = oldVal->getOperand(i);
+ if (field->getKey() == inst->getOperand(1))
+ arg = inst->getOperand(2);
+ if (arg)
+ {
+ args.add(arg);
+ }
+ else
+ {
+ isValid = false;
+ break;
+ }
+ i++;
}
- else
+ if (isValid)
{
- isValid = false;
- break;
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(inst);
+ auto makeStruct = builder.emitMakeStruct(structType, (UInt)args.getCount(), args.getBuffer());
+ inst->replaceUsesWith(makeStruct);
+ inst->removeAndDeallocate();
+ changed = true;
}
- i++;
- }
- if (isValid)
- {
- IRBuilder builder(&sharedBuilderStorage);
- builder.setInsertBefore(inst);
- auto makeStruct = builder.emitMakeStruct(structType, (UInt)args.getCount(), args.getBuffer());
- inst->replaceUsesWith(makeStruct);
- inst->removeAndDeallocate();
- changed = true;
}
}
}
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
index bcf0907df..9bd681115 100644
--- a/source/slang/slang-ir-redundancy-removal.cpp
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -38,7 +38,6 @@ struct RedundancyRemovalContext
case kIROp_GetElement:
case kIROp_GetElementPtr:
case kIROp_UpdateElement:
- case kIROp_UpdateField:
case kIROp_LookupWitness:
case kIROp_Specialize:
case kIROp_OptionalHasValue:
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 2960d942c..0434ff682 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3872,6 +3872,11 @@ namespace Slang
return addDecoration(target, kIROp_PrimalValueStructKeyDecoration, key);
}
+ IRInst* IRBuilder::addPrimalElementTypeDecoration(IRInst* target, IRInst* type)
+ {
+ return addDecoration(target, kIROp_PrimalElementTypeDecoration, type);
+ }
+
RefPtr<IRModule> IRModule::create(Session* session)
{
RefPtr<IRModule> module = new IRModule(session);
@@ -4355,6 +4360,65 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitElementExtract(
+ IRInst* base,
+ IRInst* index)
+ {
+ IRType* type = nullptr;
+ if (auto arrayType = as<IRArrayType>(base->getDataType()))
+ {
+ type = arrayType->getElementType();
+ }
+ else if (auto vectorType = as<IRVectorType>(base->getDataType()))
+ {
+ type = vectorType->getElementType();
+ }
+ else if (auto matrixType = as<IRMatrixType>(base->getDataType()))
+ {
+ type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount());
+ }
+ SLANG_RELEASE_ASSERT(type);
+ auto inst = createInst<IRFieldAddress>(
+ this,
+ kIROp_GetElement,
+ type,
+ base,
+ index);
+
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitElementExtract(
+ IRInst* base,
+ const ArrayView<IRInst*>& accessChain)
+ {
+ for (auto access : accessChain)
+ {
+ IRType* resultType = nullptr;
+ if (auto structKey = as<IRStructKey>(access))
+ {
+ auto structType = as<IRStructType>(base->getDataType());
+ SLANG_RELEASE_ASSERT(structType);
+ for (auto field : structType->getFields())
+ {
+ if (field->getKey() == structKey)
+ {
+ resultType = field->getFieldType();
+ break;
+ }
+ }
+ SLANG_RELEASE_ASSERT(resultType);
+ base = emitFieldExtract(resultType, base, structKey);
+ }
+ else
+ {
+ base = emitElementExtract(base, access);
+ }
+ }
+ return base;
+ }
+
IRInst* IRBuilder::emitElementAddress(
IRType* type,
IRInst* basePtr,
@@ -4378,23 +4442,21 @@ namespace Slang
kIROp_UpdateElement,
base->getFullType(),
base,
- index,
- newElement);
+ newElement,
+ index);
addInst(inst);
return inst;
}
- IRInst* IRBuilder::emitUpdateField(IRInst* base, IRInst* fieldKey, IRInst* newFieldVal)
+ IRInst* IRBuilder::emitUpdateElement(IRInst* base, const List<IRInst*>& accessChain, IRInst* newElement)
{
- auto inst = createInst<IRUpdateField>(
- this,
- kIROp_UpdateField,
- base->getFullType(),
- base,
- fieldKey,
- newFieldVal);
-
+ List<IRInst*> args;
+ args.add(base);
+ args.add(newElement);
+ args.addRange(accessChain);
+ auto inst = createInst<IRUpdateElement>(
+ this, kIROp_UpdateElement, base->getFullType(), (Int)args.getCount(), args.getBuffer());
addInst(inst);
return inst;
}
@@ -6663,7 +6725,6 @@ namespace Slang
case kIROp_GetElement:
case kIROp_GetElementPtr:
case kIROp_UpdateElement:
- case kIROp_UpdateField:
case kIROp_MeshOutputRef:
case kIROp_MakeVectorFromScalar:
case kIROp_swizzle: