summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-24 22:16:21 -0800
committerGitHub <noreply@github.com>2023-01-24 22:16:21 -0800
commit951ad25e0a9c3b0089c6b996b8e821ac93cf5766 (patch)
tree7bed99484204611a4669d7c2c11019795e37f7cb /source
parenta3b0eff62e59f3a05461bf3edee5e100e804e4d5 (diff)
Reimplement address elimination. (#2605)
* Reimplement address elimination pass. * Fix error. * Update test references. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-builder.h34
-rw-r--r--source/slang/slang-ast-synthesis.cpp6
-rw-r--r--source/slang/slang-check-decl.cpp2
-rw-r--r--source/slang/slang-check-expr.cpp75
-rw-r--r--source/slang/slang-check-impl.h4
-rw-r--r--source/slang/slang-emit-c-like.cpp68
-rw-r--r--source/slang/slang-ir-addr-inst-elimination.cpp216
-rw-r--r--source/slang/slang-ir-addr-inst-elimination.h21
-rw-r--r--source/slang/slang-ir-autodiff-addr-inst-elimination.cpp476
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp109
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h6
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp10
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp14
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h229
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp3
-rw-r--r--source/slang/slang-ir-autodiff.cpp40
-rw-r--r--source/slang/slang-ir-autodiff.h34
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp97
-rw-r--r--source/slang/slang-ir-constexpr.cpp3
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-insts.h37
-rw-r--r--source/slang/slang-ir-peephole.cpp153
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp5
-rw-r--r--source/slang/slang-ir-util.cpp25
-rw-r--r--source/slang/slang-ir-util.h3
-rw-r--r--source/slang/slang-ir.cpp62
26 files changed, 1158 insertions, 578 deletions
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index 72d8ec50a..bc95ed63d 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -419,6 +419,40 @@ protected:
};
+struct ValSet
+{
+ struct ValItem
+ {
+ Val* val = nullptr;
+ ValItem() = default;
+ ValItem(Val* v) : val(v) {}
+
+ HashCode getHashCode()
+ {
+ return val ? val->getHashCode() : 0;
+ }
+ bool operator==(ValItem other)
+ {
+ if (val == other.val)
+ return true;
+ if (val)
+ return val->equalsVal(other.val);
+ else if (other.val)
+ return other.val->equalsVal(val);
+ return false;
+ }
+ };
+ HashSet<ValItem> set;
+ bool add(Val* val)
+ {
+ return set.Add(ValItem(val));
+ }
+ bool contains(Val* val)
+ {
+ return set.Contains(ValItem(val));
+ }
+};
+
} // namespace Slang
#endif
diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp
index e5a9ff75f..bab651ed9 100644
--- a/source/slang/slang-ast-synthesis.cpp
+++ b/source/slang/slang-ast-synthesis.cpp
@@ -29,8 +29,11 @@ Expr* ASTSynthesizer::emitPostfixExpr(UnownedStringSlice operatorToken, Expr* ba
ForStmt* ASTSynthesizer::emitFor(Expr* initVal, Expr* finalVal, VarDecl* &outIndexVar)
{
+ auto parentStmt = getCurrentScope().m_parentSeqStmt;
+ auto seqStmt = m_builder->create<SeqStmt>();
auto scopeDecl = pushVarScope()->containerDecl;
auto stmt = m_builder->create<ForStmt>();
+ stmt->statement = seqStmt;
stmt->scopeDecl = (ScopeDecl*)scopeDecl;
auto declStmt = emitVarDeclStmt(nullptr, m_namePool->getName("S_synth_loop_index"), initVal);
stmt->initialStatement = declStmt;
@@ -38,7 +41,8 @@ ForStmt* ASTSynthesizer::emitFor(Expr* initVal, Expr* finalVal, VarDecl* &outInd
auto predicateExpr = emitBinaryExpr(UnownedStringSlice("<"), emitVarExpr(outIndexVar), finalVal);
stmt->predicateExpression = predicateExpr;
stmt->sideEffectExpression = emitPrefixExpr(UnownedStringSlice("++"), emitVarExpr(outIndexVar));
- getCurrentScope().m_parentSeqStmt->stmts.add(stmt);
+ parentStmt->stmts.add(stmt);
+ m_scopeStack.getLast().m_parentSeqStmt = seqStmt;
return stmt;
}
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 7b5f85b60..5add89312 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -3308,7 +3308,6 @@ namespace Slang
synth.popScope();
if (!assignStmt)
return nullptr;
- forStmt->statement = assignStmt;
return forStmt;
}
@@ -4962,6 +4961,7 @@ namespace Slang
auto thisType = calcThisType(parentDeclRef);
maybeRegisterDifferentiableType(m_astBuilder, thisType);
}
+ completeDifferentiableTypeDictionary();
m_parentDifferentiableAttr = oldAttr;
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 43124b535..2853c1eb9 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -922,7 +922,6 @@ namespace Slang
}
}
-
void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type)
{
if (!builder->isDifferentiableInterfaceAvailable())
@@ -962,6 +961,80 @@ namespace Slang
}
}
+ void SemanticsVisitor::maybeRegisterDifferentiableTypeRecursive(ASTBuilder* builder, Type* type, ValSet& workingSet)
+ {
+ if (workingSet.contains(type))
+ return;
+
+ if (!builder->isDifferentiableInterfaceAvailable())
+ {
+ return;
+ }
+
+ if (!m_parentDifferentiableAttr)
+ {
+ return;
+ }
+
+ workingSet.add(type);
+
+ // Check for special cases such as PtrTypeBase<T> or Array<T>
+ // This could potentially be handled later by simply defining extensions
+ // for Ptr<T:IDifferentiable> etc..
+ //
+ if (auto ptrType = as<PtrTypeBase>(type))
+ {
+ maybeRegisterDifferentiableTypeRecursive(builder, ptrType->getValueType(), workingSet);
+ return;
+ }
+
+ if (auto arrayType = as<ArrayExpressionType>(type))
+ {
+ maybeRegisterDifferentiableTypeRecursive(builder, arrayType->baseType, workingSet);
+ return;
+ }
+
+ if (auto declRefType = as<DeclRefType>(type))
+ {
+ if (auto subtypeWitness = as<SubtypeWitness>(
+ tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface())))
+ {
+ registerDifferentiableType((DeclRefType*)type, subtypeWitness);
+ if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>())
+ {
+ foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member)
+ {
+ auto subType = m_astBuilder->getOrCreateDeclRefType(member.getDecl(), nullptr);
+ maybeRegisterDifferentiableTypeRecursive(m_astBuilder, subType, workingSet);
+ });
+ foreachDirectOrExtensionMemberOfType<VarDeclBase>(this, aggTypeDeclRef, [&](DeclRef<VarDeclBase> member)
+ {
+ auto fieldType = getType(m_astBuilder, member);
+ maybeRegisterDifferentiableTypeRecursive(m_astBuilder, fieldType, workingSet);
+ });
+ }
+ }
+ return;
+ }
+ }
+
+ void SemanticsVisitor::completeDifferentiableTypeDictionary()
+ {
+ ValSet workingSet;
+ for (auto type : m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness)
+ {
+ if (auto aggTypeDeclRef = type.Key.as<AggTypeDecl>())
+ {
+ maybeRegisterDifferentiableTypeRecursive(
+ m_astBuilder,
+ m_astBuilder->getOrCreateDeclRefType(
+ aggTypeDeclRef.getDecl(), aggTypeDeclRef.substitutions),
+ workingSet);
+ }
+ }
+ }
+
+
Expr* SemanticsVisitor::CheckTerm(Expr* term)
{
auto checkedTerm = _CheckTerm(term);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 6099febb5..1b59094e2 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -298,7 +298,7 @@ namespace Slang
bool isDifferentiableFunc(FunctionDeclBase* func);
bool isBackwardDifferentiableFunc(FunctionDeclBase* func);
FunctionDifferentiableLevel getFuncDifferentiableLevel(FunctionDeclBase* func);
-
+
private:
/// Mapping from type declarations to the known extensiosn that apply to them
Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>> m_mapTypeDeclToCandidateExtensions;
@@ -760,6 +760,8 @@ namespace Slang
/// describing the relationship.
///
void registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness);
+ void maybeRegisterDifferentiableTypeRecursive(ASTBuilder* builder, Type* type, ValSet& workingSet);
+ void completeDifferentiableTypeDictionary();
// Construct the differential for 'type', if it exists.
Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc);
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index fbcd0dbbd..ffb469b9d 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -1087,6 +1087,8 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
// Never fold these, because their result cannot be computed
// as a sub-expression (they must be emitted as a declaration
// or statement).
+ case kIROp_UpdateField:
+ case kIROp_UpdateElement:
case kIROp_DefaultConstruct:
return false;
@@ -1135,6 +1137,7 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
case kIROp_MakeStruct:
case kIROp_MakeArray:
case kIROp_swizzleSet:
+ case kIROp_MakeArrayFromElement:
return false;
}
@@ -2190,7 +2193,24 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
m_writer->emit(" }");
}
break;
+ case kIROp_MakeArrayFromElement:
+ {
+ // TODO: initializer-list syntax may not always
+ // be appropriate, depending on the context
+ // of the expression.
+ m_writer->emit("{ ");
+ UInt argCount =
+ (UInt)cast<IRIntLit>(cast<IRArrayType>(inst->getDataType())->getElementCount())
+ ->getValue();
+ for (UInt aa = 0; aa < argCount; ++aa)
+ {
+ if (aa != 0) m_writer->emit(", ");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ }
+ m_writer->emit(" }");
+ }
+ break;
case kIROp_BitCast:
{
// Note: we are currently emitting casts as plain old
@@ -2461,6 +2481,54 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst)
m_writer->emit(";\n");
}
break;
+
+ case kIROp_UpdateElement:
+ {
+ auto ii = (IRUpdateElement*)inst;
+ auto subscriptOuter = getInfo(EmitOp::General);
+ auto subscriptPrec = getInfo(EmitOp::Postfix);
+ auto arraySize = as<IRIntLit>(as<IRArrayType>(inst->getDataType())->getElementCount());
+ SLANG_RELEASE_ASSERT(arraySize);
+ emitInstResultDecl(inst);
+ m_writer->emit("{");
+ for (UInt i = 0; i < (UInt)arraySize->getValue(); i++)
+ {
+ if (i > 0)
+ m_writer->emit(", ");
+ emitOperand(ii->getOldValue(), leftSide(subscriptOuter, subscriptPrec));
+ m_writer->emit("[");
+ m_writer->emit(i);
+ m_writer->emit("]");
+ }
+
+ m_writer->emit("}");
+ m_writer->emit(";\n");
+
+ emitOperand(ii, leftSide(subscriptOuter, subscriptPrec));
+ m_writer->emit("[");
+ emitOperand(ii->getIndex(), getInfo(EmitOp::General));
+ m_writer->emit("] = ");
+ emitOperand(ii->getElementValue(), getInfo(EmitOp::General));
+ m_writer->emit(";\n");
+ }
+ break;
+ case kIROp_UpdateField:
+ {
+ auto ii = (IRUpdateField*)inst;
+ emitInstResultDecl(inst);
+ emitOperand(ii->getOldValue(), getInfo(EmitOp::General));
+ m_writer->emit(";\n");
+
+ auto subscriptOuter = getInfo(EmitOp::General);
+ auto subscriptPrec = getInfo(EmitOp::Postfix);
+ emitOperand(ii, leftSide(subscriptOuter, subscriptPrec));
+ m_writer->emit(".");
+ m_writer->emit(getName(ii->getFieldKey()));
+ m_writer->emit(" = ");
+ emitOperand(ii->getElementValue(), getInfo(EmitOp::General));
+ m_writer->emit(";\n");
+ }
+ break;
}
}
diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp
new file mode 100644
index 000000000..877be1406
--- /dev/null
+++ b/source/slang/slang-ir-addr-inst-elimination.cpp
@@ -0,0 +1,216 @@
+#include "slang-ir-addr-inst-elimination.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+
+// Rewrites address load/store into value extract/updates to allow SSA transform to apply to struct and array elements.
+// For example,
+// load(elementPtr(arr, 1)) ==> elementExtract(load(arr), 1)
+// store(fieldAddr(s, field_key), val) ==> store(s, updateField(load(s), fieldKey, val))
+// After this transform, all address operands of `load` and `store` insts will be either a var or a param.
+
+struct AddressInstEliminationContext
+{
+ SharedIRBuilder* sharedBuilder;
+ DiagnosticSink* sink;
+
+ IRInst* getValue(IRBuilder& builder, IRInst* addr)
+ {
+ switch (addr->getOp())
+ {
+ default:
+ return builder.emitLoad(addr);
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ {
+ IRInst* args[] = {getValue(builder, addr->getOperand(0)), addr->getOperand(1)};
+ return builder.emitIntrinsicInst(
+ cast<IRPtrTypeBase>(addr->getFullType())->getValueType(),
+ (addr->getOp() == kIROp_GetElementPtr ? kIROp_GetElement : kIROp_FieldExtract),
+ 2,
+ args);
+ }
+ }
+ }
+
+ void storeValue(IRBuilder& builder, IRInst* addr, IRInst* val)
+ {
+ List<IRInst*> baseAddrs;
+
+ for (auto inst = addr; inst;)
+ {
+ switch (inst->getOp())
+ {
+ default:
+ baseAddrs.add(inst);
+ goto endLoop;
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ baseAddrs.add(inst);
+ inst = inst->getOperand(0);
+ break;
+ }
+ }
+ endLoop:;
+ List<IRInst*> values;
+ values.setCount(baseAddrs.getCount());
+ if (values.getCount() > 1)
+ {
+ IRInst* currentVal = builder.emitLoad(baseAddrs.getLast());
+ values.getLast() = currentVal;
+ for (Index i = baseAddrs.getCount() - 2; i >= 1; i--)
+ {
+ auto inst = baseAddrs[i];
+ switch (inst->getOp())
+ {
+ default:
+ sink->diagnose(inst->sourceLoc, Diagnostics::unsupportedUseOfLValueForAutoDiff);
+ return;
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ {
+ IRInst* args[] = { currentVal, inst->getOperand(1) };
+ currentVal = builder.emitIntrinsicInst(
+ cast<IRPtrTypeBase>(inst->getFullType())->getValueType(),
+ (inst->getOp() == kIROp_GetElementPtr ? kIROp_GetElement : kIROp_FieldExtract),
+ 2,
+ args);
+ values[i] = currentVal;
+ }
+ break;
+ }
+ }
+ }
+ values[0] = val;
+ for (Index i = 1; i < values.getCount(); i++)
+ {
+ auto inst = baseAddrs[i - 1];
+ switch (inst->getOp())
+ {
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ {
+ IRInst* args[] = {values[i], inst->getOperand(1), values[i - 1]};
+ values[i] = builder.emitIntrinsicInst(
+ values[i]->getFullType(),
+ (inst->getOp() == kIROp_GetElementPtr ? kIROp_UpdateElement : kIROp_UpdateField),
+ 3,
+ args);
+ }
+ break;
+ }
+ }
+ builder.emitStore(baseAddrs.getLast(), values.getLast());
+ }
+
+ void transformLoadAddr(IRUse* use)
+ {
+ auto addr = use->get();
+ auto load = as<IRLoad>(use->getUser());
+
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(use->getUser());
+ auto value = getValue(builder, addr);
+ load->replaceUsesWith(value);
+ load->removeAndDeallocate();
+ }
+
+ void transformStoreAddr(IRUse* use)
+ {
+ auto addr = use->get();
+ auto store = as<IRStore>(use->getUser());
+
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(use->getUser());
+ storeValue(builder, addr, store->getVal());
+ store->removeAndDeallocate();
+ }
+
+ void transformCallAddr(IRUse* use)
+ {
+ auto addr = use->get();
+ auto call = as<IRCall>(use->getUser());
+
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(call);
+ auto tempVar = builder.emitVar(cast<IRPtrTypeBase>(addr->getFullType())->getValueType());
+ builder.emitStore(tempVar, getValue(builder, addr));
+ builder.setInsertAfter(call);
+ storeValue(builder, addr, builder.emitLoad(tempVar));
+ use->set(tempVar);
+ }
+
+ SlangResult eliminateAddressInstsImpl(
+ SharedIRBuilder* inSharedBuilder,
+ AddressConversionPolicy* policy,
+ IRFunc* func,
+ DiagnosticSink* inSink)
+ {
+ sharedBuilder = inSharedBuilder;
+ sink = inSink;
+
+ IRBuilder builder(sharedBuilder);
+
+ List<IRInst*> workList;
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (as<IRPtrTypeBase>(inst->getDataType()))
+ {
+ workList.add(inst);
+ }
+ }
+ }
+
+ for (Index workListIndex = 0; workListIndex < workList.getCount(); workListIndex++)
+ {
+ auto addrInst = workList[workListIndex];
+
+ if (!policy->shouldConvertAddrInst(addrInst))
+ continue;
+
+ for (auto use = addrInst->firstUse; use; )
+ {
+ if (as<IRDecoration>(use->getUser()))
+ continue;
+
+ auto nextUse = use->nextUse;
+
+ switch (use->getUser()->getOp())
+ {
+ case kIROp_Load:
+ transformLoadAddr(use);
+ break;
+ case kIROp_Store:
+ transformStoreAddr(use);
+ break;
+ case kIROp_Call:
+ transformCallAddr(use);
+ break;
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ break;
+ default:
+ sink->diagnose(use->getUser()->sourceLoc, Diagnostics::unsupportedUseOfLValueForAutoDiff);
+ break;
+ }
+ use = nextUse;
+ }
+ }
+
+ return SLANG_OK;
+ }
+};
+
+SlangResult eliminateAddressInsts(
+ SharedIRBuilder* sharedBuilder,
+ AddressConversionPolicy* policy,
+ IRFunc* func,
+ DiagnosticSink* sink)
+{
+ AddressInstEliminationContext ctx;
+ return ctx.eliminateAddressInstsImpl(sharedBuilder, policy, func, sink);
+}
+} // namespace Slang
diff --git a/source/slang/slang-ir-addr-inst-elimination.h b/source/slang/slang-ir-addr-inst-elimination.h
new file mode 100644
index 000000000..53e5628f3
--- /dev/null
+++ b/source/slang/slang-ir-addr-inst-elimination.h
@@ -0,0 +1,21 @@
+// slang-ir-addr-inst-elimination.h
+#pragma once
+
+#include "slang-ir.h"
+
+namespace Slang
+{
+struct SharedIRBuilder;
+class DiagnosticSink;
+
+struct AddressConversionPolicy
+{
+ virtual bool shouldConvertAddrInst(IRInst* addrInst) = 0;
+};
+SlangResult eliminateAddressInsts(
+ SharedIRBuilder* sharedBuilder,
+ AddressConversionPolicy* policy,
+ IRFunc* func,
+ DiagnosticSink* sink);
+
+}
diff --git a/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp b/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp
deleted file mode 100644
index c60995595..000000000
--- a/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp
+++ /dev/null
@@ -1,476 +0,0 @@
-#include "slang-ir-address-analysis.h"
-#include "slang-ir-autodiff-fwd.h"
-#include "slang-ir-autodiff-pairs.h"
-#include "slang-ir-autodiff-rev.h"
-#include "slang-ir-autodiff.h"
-#include "slang-ir-single-return.h"
-#include "slang-ir-ssa-simplification.h"
-#include "slang-ir-validate.h"
-
-namespace Slang
-{
-bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst);
-
-struct AddressInstEliminationContext
-{
- OrderedDictionary<IRInst*, IRInst*> mapAddrInstToTempVar;
-
- IRInst* _reconstructStruct(
- IRBuilder& builder, IRStructType* type, IRInst* tempVar, List<AddressInfo*>& childAddrs)
- {
- List<IRInst*> args;
- IRInst* loadedTempVar = nullptr;
- for (auto child : type->getChildren())
- {
- if (auto field = as<IRStructField>(child))
- {
- IRInst* childVar = nullptr;
- for (auto subAddr : childAddrs)
- {
- auto fieldAddrInst = cast<IRFieldAddress>(subAddr->addrInst);
- if (fieldAddrInst->getField() == field->getKey())
- {
- mapAddrInstToTempVar.TryGetValue(subAddr->addrInst, childVar);
- break;
- }
- }
- if (childVar)
- {
- args.add(builder.emitLoad(childVar));
- }
- else
- {
- if (!loadedTempVar)
- loadedTempVar = builder.emitLoad(tempVar);
- args.add(builder.emitFieldExtract(
- field->getFieldType(), loadedTempVar, field->getKey()));
- }
- }
- }
- return builder.emitMakeStruct(type, args);
- }
-
- IRInst* _reconstructArray(
- IRBuilder& builder,
- IRArrayType* type,
- IRIntegerValue arraySize,
- IRInst* tempVar,
- List<AddressInfo*>& childAddrs)
- {
- IRInst* loadedTempVar = nullptr;
- List<IRInst*> args;
- for (IRIntegerValue index = 0; index < arraySize; index++)
- {
- IRInst* childVar = nullptr;
- for (auto subAddr : childAddrs)
- {
- auto elementPtrInst = cast<IRGetElementPtr>(subAddr->addrInst);
- auto elementIndex = as<IRIntLit>(elementPtrInst->getIndex());
- if (elementIndex && elementIndex->getValue() == index)
- {
- mapAddrInstToTempVar.TryGetValue(subAddr->addrInst, childVar);
- break;
- }
- }
- if (childVar)
- {
- args.add(builder.emitLoad(childVar));
- }
- else
- {
- if (!loadedTempVar)
- loadedTempVar = builder.emitLoad(tempVar);
- args.add(builder.emitElementExtract(
- type->getElementType(),
- loadedTempVar,
- builder.getIntValue(builder.getIntType(), index)));
- }
- }
- return builder.emitMakeArray(type, args.getCount(), args.getBuffer());
- }
-
- void updateChildTempVarRecursive(
- IRBuilder& builder,
- AddressInfo* addr,
- IRInst* val)
- {
- for (auto child : addr->children)
- {
- IRInst* childVar = nullptr;
- if (mapAddrInstToTempVar.TryGetValue(child->addrInst, childVar))
- {
- switch (child->addrInst->getOp())
- {
- case kIROp_FieldAddress:
- {
- auto subVal = builder.emitFieldExtract(
- cast<IRPtrTypeBase>(child->addrInst->getDataType())->getValueType(),
- val,
- child->addrInst->getOperand(1));
- builder.emitStore(childVar, subVal);
- updateChildTempVarRecursive(builder, child, subVal);
- }
- break;
- case kIROp_GetElementPtr:
- {
- auto subVal = builder.emitElementExtract(
- cast<IRPtrTypeBase>(child->addrInst->getDataType())->getValueType(),
- val,
- child->addrInst->getOperand(1));
- builder.emitStore(childVar, subVal);
- updateChildTempVarRecursive(builder, child, subVal);
- }
- break;
- default:
- {
- }
- break;
- }
- }
- }
- }
-
- IRInst* getLoadedValue(
- IRBuilder& builder,
- AddressInfo* addr,
- IRInst* tempVar)
- {
- if (addr->children.getCount())
- {
- // Reconstruct val.
- auto type =
- cast<IRPtrTypeBase>(unwrapAttributedType(tempVar->getFullType()))->getValueType();
- switch (type->getOp())
- {
- case kIROp_StructType:
- return _reconstructStruct(
- builder, as<IRStructType>(type), tempVar, addr->children);
- case kIROp_ArrayType:
- {
- auto arrayType = as<IRArrayType>(type);
- auto size = as<IRIntLit>(arrayType->getElementCount());
- if (!size || size->getValue() < 0)
- {
- // Unsupported array type.
- }
- else
- {
- return _reconstructArray(
- builder,
- arrayType,
- size->getValue(),
- tempVar,
- addr->children);
- }
- }
- break;
- default:
- // Unsupported address type.
- break;
- }
- }
- return builder.emitLoad(tempVar);
- };
-
- void updateParentTempVarRecursive(
- IRBuilder& builder,
- AddressInfo* addr)
- {
- for (auto parent = addr->parentAddress; parent; parent = parent->parentAddress)
- {
- IRInst* parentVar = nullptr;
- if (mapAddrInstToTempVar.TryGetValue(parent->addrInst, parentVar))
- {
- auto val = getLoadedValue(builder, parent, parentVar);
- builder.emitStore(parentVar, val);
- }
- }
- }
-
- String getAddrName(IRInst* addrInst)
- {
- StringBuilder sb;
- List<IRInst*> bases;
- bases.add(addrInst);
- for (; addrInst;)
- {
- if (auto fieldAddr = as<IRFieldAddress>(addrInst))
- bases.add(fieldAddr->getBase());
- else if (auto index = as<IRGetElementPtr>(addrInst))
- bases.add(index->getBase());
- else
- break;
- }
- for (Index i = bases.getCount() - 1; i >= 0; i--)
- {
- if (bases[i]->getOp() == kIROp_FieldAddress)
- {
- sb << ".";
- auto field = bases[i]->getOperand(1);
- auto nameDecor = field->findDecoration<IRNameHintDecoration>();
- sb << (nameDecor ? nameDecor->getName() : UnownedStringSlice("<unknown>"));
- }
- else if (bases[i]->getOp() == kIROp_FieldAddress)
- {
- sb << "[";
- auto index = bases[i]->getOperand(1);
- auto nameDecor = index->findDecoration<IRNameHintDecoration>();
- if (nameDecor)
- {
- sb << nameDecor->getName();
- }
- else if (auto intLit = as<IRIntLit>(index))
- {
- sb << intLit->getValue();
- }
- else
- {
- sb << "...";
- }
- sb << "]";
- }
- else
- {
- auto nameDecor = bases[i]->findDecoration<IRNameHintDecoration>();
- sb << (nameDecor ? nameDecor->getName() : UnownedStringSlice("<unknown>"));
- }
- }
- return sb.ProduceString();
- }
-
- SlangResult eliminateAddressInstsImpl(
- SharedIRBuilder* sharedBuilder,
- DifferentiableTypeConformanceContext& diffContext,
- IRFunc* func,
- DiagnosticSink* sink)
- {
- bool hasError = false;
-
- if (!isSingleReturnFunc(func))
- {
- convertFuncToSingleReturnForm(func->getModule(), func);
- }
-
- IRBuilder builder(sharedBuilder);
-
- auto dom = computeDominatorTree(func);
- auto addrUse = analyzeAddressUse(dom, func);
- List<AddressInfo*> workList;
- HashSet<AddressInfo*> workListSet;
-
- // Process leaf addresses first.
- for (auto addr : addrUse.addressInfos)
- {
- if (addr.Value->children.getCount() == 0)
- workList.add(addr.Value);
- }
-
- auto createTempVarForAddr = [&](IRInst* addrInst)
- {
- if (as<IRParam>(addrInst))
- builder.setInsertAfter(as<IRBlock>(addrInst->getParent())->getLastParam());
- else
- builder.setInsertAfter(addrInst);
- auto ptrType = as<IRPtrTypeBase>(addrInst->getFullType());
- SLANG_RELEASE_ASSERT(ptrType);
- auto tempVar = builder.emitVar(ptrType->getValueType());
- mapAddrInstToTempVar[addrInst] = tempVar;
- };
-
- // In the first pass, we create temp vars for addresses with non-trivial access pattern.
- for (Index workListIndex = 0; workListIndex < workList.getCount(); workListIndex++)
- {
- auto addr = workList[workListIndex];
-
- if (!isDifferentiableType(diffContext, addr->addrInst->getDataType()))
- continue;
-
- List<IRUse*> readUses, writeUses, callUses, subAddrUses, unknownUses;
-
- for (auto node = addr; node; node = node->parentAddress)
- {
- auto addrInst = node->addrInst;
-
- for (auto use = addrInst->firstUse; use; use = use->nextUse)
- {
- if (as<IRDecoration>(use->getUser()))
- continue;
- switch (use->getUser()->getOp())
- {
- case kIROp_Load:
- readUses.add(use);
- break;
- case kIROp_Store:
- writeUses.add(use);
- break;
- case kIROp_Call:
- callUses.add(use);
- break;
- case kIROp_GetElementPtr:
- case kIROp_FieldAddress:
- if (node == addr)
- subAddrUses.add(use);
- break;
- default:
- unknownUses.add(use);
- break;
- }
- }
- }
-
- if (unknownUses.getCount() != 0)
- {
- // Diagnose about unknown use.
- sink->diagnose(
- unknownUses.getFirst()->getUser(),
- Diagnostics::unsupportedUseOfLValueForAutoDiff);
- hasError = true;
- continue;
- }
-
- if (addr->isConstant)
- {
- // Otherwise, the address must be a constant, and we need to create a temp var for
- // it. The exception is when the variable is a temp var for a call.
- if (callUses.getCount() == 1 && writeUses.getCount() <= 1 &&
- readUses.getCount() <= 1)
- {
- if (writeUses.getCount() == 0)
- continue;
-
- // The uses must be in write->call->read order.
- auto callUse = callUses.getFirst();
- auto writeUse = writeUses.getFirst();
- auto readUse = readUses.getCount() ? readUses.getFirst() : writeUse;
- if (dom->dominates(writeUse->getUser(), callUse->getUser()) &&
- dom->dominates(callUse->getUser(), readUse->getUser()))
- {
- continue;
- }
- }
-
- // Create a temp var for the address and replace all uses of the address to the temp
- // var.
- createTempVarForAddr(addr->addrInst);
- }
- else
- {
- // This is a dynamic address. We can only allow at most one write access to it.
- bool hasNonTrivialAccess = false;
- if (readUses.getCount() + callUses.getCount() != 0 &&
- writeUses.getCount() + callUses.getCount() > 1)
- hasNonTrivialAccess = true;
-
- if (hasNonTrivialAccess)
- {
- // Mixed use of a non-constant address is unsupported right now.
- sink->diagnose(
- addr->addrInst,
- Diagnostics::cannotDifferentiateDynamicallyIndexedData,
- getAddrName(addr->addrInst));
- }
- }
- if (addr->parentAddress && workListSet.Add(addr->parentAddress))
- workList.add(addr->parentAddress);
- }
-
- if (hasError)
- return SLANG_FAIL;
-
- // Actually replace addresses with temp vars.
- for (auto addr : workList)
- {
- IRInst* tempVar = nullptr;
- if (!mapAddrInstToTempVar.TryGetValue(addr->addrInst, tempVar))
- continue;
- for (auto use = addr->addrInst->firstUse; use;)
- {
- auto nextUse = use->nextUse;
- auto user = use->getUser();
-
- builder.setInsertBefore(user);
- switch (user->getOp())
- {
- case kIROp_Load:
- use->set(tempVar);
- break;
- case kIROp_Store:
- use->set(tempVar);
- updateChildTempVarRecursive(
- builder, addr, as<IRStore>(user)->getVal());
- updateParentTempVarRecursive(builder, addr);
- case kIROp_Call:
- {
- use->set(tempVar);
- builder.setInsertAfter(user);
- auto newVal = builder.emitLoad(tempVar);
- updateChildTempVarRecursive(builder, addr, newVal);
- updateParentTempVarRecursive(builder, addr);
- }
- break;
- default:
- use->set(tempVar);
- break;
- }
- use = nextUse;
- }
- }
-
- // Assign initial values to tempVar.
- for (auto tempVar : mapAddrInstToTempVar)
- {
- builder.setInsertAfter(tempVar.Value);
- IRInst* initVal = nullptr;
- if (tempVar.Key->getOp() == kIROp_Var ||
- tempVar.Key->getOp() == kIROp_Param && as<IROutType>(tempVar.Key->getFullType()))
- {
- initVal = builder.emitDefaultConstruct(
- cast<IRPtrTypeBase>(tempVar.Key->getFullType())->getValueType());
- }
- else
- {
- initVal = builder.emitLoad(tempVar.Key);
- }
- builder.emitStore(tempVar.Value, initVal);
- }
-
- // Store final values to out parameters before exiting function.
- IRInst* returnInst = nullptr;
- for (auto block : func->getBlocks())
- {
- for (auto inst : block->getChildren())
- {
- if (inst->getOp() == kIROp_Return)
- {
- returnInst = inst;
- break;
- }
- }
- }
- SLANG_RELEASE_ASSERT(returnInst);
- builder.setInsertBefore(returnInst);
- for (auto param : func->getParams())
- {
- IRInst* tempVar = nullptr;
- if (mapAddrInstToTempVar.TryGetValue(param, tempVar))
- {
- auto val = builder.emitLoad(tempVar);
- builder.emitStore(param, val);
- }
- }
- if (hasError)
- return SLANG_FAIL;
- return SLANG_OK;
- }
-};
-
-SlangResult eliminateAddressInsts(
- SharedIRBuilder* sharedBuilder,
- DifferentiableTypeConformanceContext& diffContext,
- IRFunc* func,
- DiagnosticSink* sink)
-{
- AddressInstEliminationContext ctx;
- return ctx.eliminateAddressInstsImpl(sharedBuilder, diffContext, func, sink);
-}
-} // namespace Slang
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 3f3618b44..f5fa17fae 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -53,6 +53,21 @@ String ForwardDiffTranscriber::getJVPVarName(IRInst* origVar)
return String("");
}
+InstPair ForwardDiffTranscriber::transcribeUndefined(IRBuilder* builder, IRInst* origInst)
+{
+ auto primalVal = maybeCloneForPrimalInst(builder, origInst);
+
+ if (IRType* diffType = differentiateType(builder, origInst->getFullType()))
+ {
+ auto dzero = getDifferentialZeroOfType(builder, origInst->getFullType());
+ if (dzero)
+ {
+ return InstPair(primalVal, dzero);
+ }
+ }
+ return InstPair(primalVal, nullptr);
+}
+
InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar)
{
if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType()))
@@ -745,6 +760,91 @@ InstPair ForwardDiffTranscriber::transcribeGetElement(IRBuilder* builder, IRInst
return InstPair(primalGetElementPtr, diffGetElementPtr);
}
+InstPair ForwardDiffTranscriber::transcribeUpdateField(IRBuilder* builder, IRInst* originalInst)
+{
+ auto updateInst = as<IRUpdateField>(originalInst);
+
+ IRInst* origBase = updateInst->getOldValue();
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto field = updateInst->getFieldKey();
+ auto primalVal = findOrTranscribePrimalInst(builder, updateInst->getElementValue());
+ auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>();
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType());
+
+ IRInst* primalOperands[] = { primalBase, field, primalVal };
+ IRInst* primalUpdateField = builder->emitIntrinsicInst(
+ primalType,
+ originalInst->getOp(),
+ 3,
+ primalOperands);
+
+ if (!derivativeRefDecor)
+ {
+ return InstPair(primalUpdateField, nullptr);
+ }
+
+ IRInst* diffUpdateField = nullptr;
+
+ if (auto diffType = differentiateType(builder, originalInst->getDataType()))
+ {
+ if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
+ {
+ if (auto diffVal = findOrTranscribeDiffInst(builder, updateInst->getElementValue()))
+ {
+ auto primalElementType = primalVal->getDataType();
+
+ IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey(), diffVal, primalElementType };
+ diffUpdateField = builder->emitIntrinsicInst(
+ diffType,
+ originalInst->getOp(),
+ 4,
+ diffOperands);
+ }
+ }
+ }
+ return InstPair(primalUpdateField, diffUpdateField);
+}
+
+InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst)
+{
+ auto updateInst = as<IRUpdateElement>(originalInst);
+
+ IRInst* origBase = updateInst->getOldValue();
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto primalIndex = findOrTranscribePrimalInst(builder, updateInst->getIndex());
+ auto origVal = updateInst->getElementValue();
+ auto primalVal = findOrTranscribePrimalInst(builder, origVal);
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType());
+
+ IRInst* primalOperands[] = { primalBase, primalIndex, primalVal };
+ IRInst* primalUpdateField = builder->emitIntrinsicInst(
+ primalType,
+ originalInst->getOp(),
+ 3,
+ primalOperands);
+
+ IRInst* diffUpdateElement = nullptr;
+
+ if (auto diffType = differentiateType(builder, originalInst->getDataType()))
+ {
+ if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
+ {
+ if (auto diffVal = findOrTranscribeDiffInst(builder, origVal))
+ {
+ auto primalElementType = primalVal->getDataType();
+
+ IRInst* diffOperands[] = { diffBase, primalIndex, diffVal, primalElementType };
+ diffUpdateElement = builder->emitIntrinsicInst(
+ diffType,
+ originalInst->getOp(),
+ 4,
+ diffOperands);
+ }
+ }
+ }
+ return InstPair(primalUpdateField, diffUpdateElement);
+}
+
InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* origLoop)
{
// The loop comes with three blocks.. we just need to transcribe each one
@@ -1132,6 +1232,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_FloatCast:
case kIROp_MakeStruct:
case kIROp_MakeArray:
+ case kIROp_MakeArrayFromElement:
return transcribeConstruct(builder, origInst);
case kIROp_LookupWitness:
@@ -1146,7 +1247,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_MakeVectorFromScalar:
case kIROp_MakeTuple:
return transcribeByPassthrough(builder, origInst);
-
+ case kIROp_UpdateElement:
+ return transcribeUpdateElement(builder, origInst);
+ case kIROp_UpdateField:
+ return transcribeUpdateField(builder, origInst);
case kIROp_unconditionalBranch:
return transcribeControlFlow(builder, origInst);
@@ -1194,6 +1298,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
return transcribeExtractExistentialWitnessTable(builder, origInst);
case kIROp_WrapExistential:
return transcribeWrapExistential(builder, origInst);
+ case kIROp_undefined:
+ return transcribeUndefined(builder, origInst);
+
case kIROp_CreateExistentialObject:
// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
// so we treat this inst as non differentiable.
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index b09b57974..f8186a96e 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -23,6 +23,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
//
String makeDiffPairName(IRInst* origVar);
+ InstPair transcribeUndefined(IRBuilder* builder, IRInst* origInst);
+
InstPair transcribeVar(IRBuilder* builder, IRVar* origVar);
InstPair transcribeBinaryArith(IRBuilder* builder, IRInst* origArith);
@@ -61,6 +63,10 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr);
+ InstPair transcribeUpdateField(IRBuilder* builder, IRInst* originalInst);
+
+ InstPair transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst);
+
InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop);
InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse);
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 779a4f1a3..000921c7e 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -501,15 +501,7 @@ namespace Slang
stripDerivativeDecorations(primalFunc);
eliminateDeadCode(primalOuterParent);
- // Perform preparation and simplification.
- differentiableTypeConformanceContext.setFunc(primalFunc);
- if (SLANG_FAILED(eliminateAddressInsts(
- builder->getSharedBuilder(),
- differentiableTypeConformanceContext,
- primalFunc,
- sink)))
- return nullptr;
-
+ // Perform simplification.
simplifyFunc(primalFunc);
// Forward transcribe the clone of the original func.
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index cfbc9638a..89adbe6a0 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -624,6 +624,20 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
}
+
+ if (auto arrayType = as<IRArrayType>(primalType))
+ {
+ auto diffElementType =
+ (IRType*)differentiableTypeConformanceContext.getDifferentialForType(
+ builder, arrayType->getElementType());
+ SLANG_RELEASE_ASSERT(diffElementType);
+ auto diffArrayType = builder->getArrayType(diffElementType, arrayType->getElementCount());
+ auto diffElementZero = getDifferentialZeroOfType(builder, arrayType->getElementType());
+ auto result = builder->emitMakeArrayFromElement(diffArrayType, diffElementZero);
+ builder->markInstAsDifferential(result, primalType);
+ return result;
+ }
+
// Since primalType has a corresponding differential type, we can lookup the
// definition for zero().
auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index e799456bb..51dcd9f45 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -278,21 +278,12 @@ struct DiffTransposePass
auto primalType = tryGetPrimalTypeFromDiffInst(fwdInst);
auto diffType = fwdInst->getDataType();
- auto zeroMethod = diffTypeContext.getZeroMethodForType(
- &tempVarBuilder,
- primalType);
-
- SLANG_ASSERT(zeroMethod);
+ auto zero = emitDZeroOfDiffInstType(&tempVarBuilder, primalType);
// Emit a var in the top-level differential block to hold the gradient,
// and initialize it.
auto tempRevVar = tempVarBuilder.emitVar(diffType);
- auto diffZero = tempVarBuilder.emitCallInst(
- diffType,
- zeroMethod,
- List<IRInst*>());
- tempVarBuilder.emitStore(tempRevVar, diffZero);
-
+ tempVarBuilder.emitStore(tempRevVar, zero);
revAccumulatorVarMap[fwdInst] = tempRevVar;
return tempRevVar;
@@ -1044,6 +1035,9 @@ struct DiffTransposePass
case kIROp_FieldExtract:
return transposeFieldExtract(builder, as<IRFieldExtract>(fwdInst), revValue);
+ case kIROp_GetElement:
+ return transposeGetElement(builder, as<IRGetElement>(fwdInst), revValue);
+
case kIROp_Return:
return transposeReturn(builder, as<IRReturn>(fwdInst), revValue);
@@ -1065,6 +1059,14 @@ struct DiffTransposePass
return transposeMakeStruct(builder, fwdInst, revValue);
case kIROp_MakeArray:
return transposeMakeArray(builder, fwdInst, revValue);
+ case kIROp_MakeArrayFromElement:
+ return transposeMakeArrayFromElement(builder, fwdInst, revValue);
+
+ case kIROp_UpdateElement:
+ return transposeUpdateElement(builder, fwdInst, revValue);
+
+ case kIROp_UpdateField:
+ return transposeUpdateField(builder, fwdInst, revValue);
case kIROp_Specialize:
case kIROp_unconditionalBranch:
@@ -1170,6 +1172,17 @@ struct DiffTransposePass
fwdExtract)));
}
+ TranspositionResult transposeGetElement(IRBuilder*, IRGetElement* fwdGetElement, IRInst* revValue)
+ {
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::GetElement,
+ fwdGetElement->getBase(),
+ revValue,
+ fwdGetElement)));
+ }
+
TranspositionResult transposeMakePair(IRBuilder*, IRMakeDifferentialPair* fwdMakePair, IRInst* revValue)
{
// Even though makePair returns a pair of (primal, differential)
@@ -1250,27 +1263,108 @@ struct DiffTransposePass
{
List<RevGradient> gradients;
auto arrayType = cast<IRArrayType>(fwdMakeArray->getFullType());
- auto arraySize = cast<IRIntLit>(arrayType->getElementCount());
- for (UInt ii = 0; ii < (UInt)arraySize->getValue(); ii++)
+ for (UInt ii = 0; ii < fwdMakeArray->getOperandCount(); ii++)
{
auto gradAtField = builder->emitElementExtract(
arrayType->getElementType(),
revValue,
builder->getIntValue(builder->getIntType(), ii));
- SLANG_RELEASE_ASSERT(ii < fwdMakeArray->getOperandCount());
gradients.add(RevGradient(
RevGradient::Flavor::Simple,
fwdMakeArray->getOperand(ii),
gradAtField,
fwdMakeArray));
- ii++;
}
// (A = MakeArray(F1, F2, F3)) -> [(dF1 += dA.F1), (dF2 += dA.F2), (dF3 += dA.F3)]
return TranspositionResult(gradients);
}
+ TranspositionResult transposeMakeArrayFromElement(IRBuilder* builder, IRInst* fwdMakeArrayFromElement, IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+ auto arrayType = cast<IRArrayType>(fwdMakeArrayFromElement->getFullType());
+ auto arraySize = cast<IRIntLit>(arrayType->getElementCount());
+ SLANG_RELEASE_ASSERT(arraySize);
+ // TODO: if arraySize is a generic value, we can't statically expand things here.
+ // In that case we probably need another opcode e.g. `Sum(arrayValue)` that can be expand
+ // later in the pipeline when `arrayValue` becomes a known value.
+ for (UInt ii = 0; ii < (UInt)arraySize->getValue(); ii++)
+ {
+ auto gradAtField = builder->emitElementExtract(
+ arrayType->getElementType(),
+ revValue,
+ builder->getIntValue(builder->getIntType(), ii));
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeArrayFromElement->getOperand(0),
+ gradAtField,
+ fwdMakeArrayFromElement));
+ }
+
+ // (A = MakeArrayFromElement(E)) -> [(dE += dA.F1), (dE += dA.F2), (dE += dA.F3)]
+ return TranspositionResult(gradients);
+ }
+
+ TranspositionResult transposeUpdateElement(IRBuilder* builder, IRInst* fwdUpdate, IRInst* revValue)
+ {
+ auto updateInst = as<IRUpdateElement>(fwdUpdate);
+
+ List<RevGradient> gradients;
+ auto arrayType = cast<IRArrayType>(fwdUpdate->getFullType());
+ auto revElement = builder->emitElementExtract(arrayType->getElementType(), revValue, updateInst->getIndex());
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ updateInst->getElementValue(),
+ revElement,
+ fwdUpdate));
+
+ auto primalElementType = updateInst->getPrimalElementType();
+ auto diffZero = emitDZeroOfDiffInstType(builder, (IRType*)primalElementType);
+ SLANG_ASSERT(diffZero);
+ auto revRest = builder->emitUpdateElement(
+ revValue,
+ updateInst->getIndex(),
+ diffZero);
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ updateInst->getOldValue(),
+ revRest,
+ fwdUpdate));
+ // (A = UpdateElement(arr, index, V)) -> [(dV += dA[index], d_arr += UpdateElement(revValue, index, 0)]
+ return TranspositionResult(gradients);
+ }
+
+ TranspositionResult transposeUpdateField(IRBuilder* builder, IRInst* fwdUpdate, IRInst* revValue)
+ {
+ auto updateInst = as<IRUpdateField>(fwdUpdate);
+
+ List<RevGradient> gradients;
+ IRType* fieldType = updateInst->getElementValue()->getFullType();
+ auto revElement = builder->emitFieldExtract(fieldType, revValue, updateInst->getFieldKey());
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ updateInst->getElementValue(),
+ revElement,
+ fwdUpdate));
+
+ auto primalElementType = updateInst->getPrimalElementType();
+ auto diffZero = emitDZeroOfDiffInstType(builder, (IRType*)primalElementType);
+ SLANG_ASSERT(diffZero);
+ auto revRest = builder->emitUpdateField(
+ revValue,
+ updateInst->getFieldKey(),
+ diffZero);
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ updateInst->getOldValue(),
+ revRest,
+ fwdUpdate));
+ // (A = UpdateField(s, fieldKey, V)) -> [(dV += dA.fieldKey, d_s += UpdateField(revValue, fieldKey, 0)]
+ return TranspositionResult(gradients);
+ }
+
// Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr.
//
void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad)
@@ -1548,11 +1642,80 @@ struct DiffTransposePass
case RevGradient::Flavor::FieldExtract:
return materializeFieldExtractGradients(builder, aggPrimalType, gradients);
+ case RevGradient::Flavor::GetElement:
+ return materializeGetElementGradients(builder, aggPrimalType, gradients);
+
default:
SLANG_ASSERT_FAILURE("Unhandled gradient flavor for materialization");
}
}
+ RevGradient materializeGetElementGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
+ {
+ // Setup a temporary variable to aggregate gradients.
+ // TODO: We can extend this later to grab an existing ptr to allow aggregation of
+ // gradients across blocks without constructing new variables.
+ // Looking up an existing pointer could also allow chained accesses like x.a.b[1] to directly
+ // write into the specific sub-field that is affected without constructing intermediate vars.
+ //
+ auto revGradVar = builder->emitVar(
+ (IRType*)diffTypeContext.getDifferentialForType(builder, aggPrimalType));
+
+ // Initialize with T.dzero()
+ auto zeroValueInst = emitDZeroOfDiffInstType(builder, aggPrimalType);
+
+ builder->emitStore(revGradVar, zeroValueInst);
+
+ OrderedDictionary<IRInst*, List<RevGradient>> bucketedGradients;
+ for (auto gradient : gradients)
+ {
+ // Grab the element affected by this gradient.
+ auto getElementInst = as<IRGetElement>(gradient.fwdGradInst);
+ SLANG_ASSERT(getElementInst);
+
+ auto index = getElementInst->getIndex();
+ SLANG_ASSERT(index);
+
+ if (!bucketedGradients.ContainsKey(index))
+ {
+ bucketedGradients[index] = List<RevGradient>();
+ }
+
+ bucketedGradients[index].GetValue().add(RevGradient(
+ RevGradient::Flavor::Simple,
+ gradient.targetInst,
+ gradient.revGradInst,
+ gradient.fwdGradInst
+ ));
+
+ }
+
+ for (auto pair : bucketedGradients)
+ {
+ auto subGrads = pair.Value;
+
+ auto primalType = tryGetPrimalTypeFromDiffInst(subGrads[0].fwdGradInst);
+
+ SLANG_ASSERT(primalType);
+
+ // Construct address to this field in revGradVar.
+ auto revGradTargetAddress = builder->emitElementAddress(
+ builder->getPtrType(subGrads[0].revGradInst->getDataType()),
+ revGradVar,
+ pair.Key);
+
+ builder->emitStore(revGradTargetAddress, emitAggregateValue(builder, primalType, subGrads));
+ }
+
+ // Load the entire var and return it.
+ return RevGradient(
+ RevGradient::Flavor::Simple,
+ gradients[0].targetInst,
+ builder->emitLoad(revGradVar),
+ nullptr);
+ }
+
+
RevGradient materializeFieldExtractGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
{
// Setup a temporary variable to aggregate gradients.
@@ -1569,7 +1732,7 @@ struct DiffTransposePass
builder->emitStore(revGradVar, zeroValueInst);
- Dictionary<IRStructKey*, List<RevGradient>> bucketedGradients;
+ OrderedDictionary<IRStructKey*, List<RevGradient>> bucketedGradients;
for (auto gradient : gradients)
{
// Grab the field affected by this gradient.
@@ -1601,7 +1764,7 @@ struct DiffTransposePass
SLANG_ASSERT(primalType);
- // Consruct address to this field in revGradVar.
+ // Construct address to this field in revGradVar.
auto revGradTargetAddress = builder->emitFieldAddress(
builder->getPtrType(subGrads[0].revGradInst->getDataType()),
revGradVar,
@@ -1734,6 +1897,14 @@ struct DiffTransposePass
IRInst* emitDZeroOfDiffInstType(IRBuilder* builder, IRType* primalType)
{
+ if (auto arrayType = as<IRArrayType>(primalType))
+ {
+ auto diffElementType = (IRType*)diffTypeContext.getDifferentialForType(builder, arrayType->getElementType());
+ SLANG_RELEASE_ASSERT(diffElementType);
+ auto diffArrayType = builder->getArrayType(diffElementType, arrayType->getElementCount());
+ auto diffElementZero = emitDZeroOfDiffInstType(builder, arrayType->getElementType());
+ return builder->emitMakeArrayFromElement(diffArrayType, diffElementZero);
+ }
auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType);
// Should exist.
@@ -1747,6 +1918,30 @@ struct DiffTransposePass
IRInst* emitDAddOfDiffInstType(IRBuilder* builder, IRType* primalType, IRInst* op1, IRInst* op2)
{
+ if (auto arrayType = as<IRArrayType>(primalType))
+ {
+ auto diffElementType = (IRType*)diffTypeContext.getDifferentialForType(builder, arrayType->getElementType());
+ SLANG_RELEASE_ASSERT(diffElementType);
+ auto arraySize = arrayType->getElementCount();
+ if (auto constArraySize = as<IRIntLit>(arraySize))
+ {
+ List<IRInst*> args;
+ for (IRIntegerValue i = 0; i < constArraySize->getValue(); i++)
+ {
+ auto index = builder->getIntValue(builder->getIntType(), i);
+ auto op1Val = builder->emitElementExtract(diffElementType, op1, index);
+ auto op2Val = builder->emitElementExtract(diffElementType, op2, index);
+ args.add(emitDAddOfDiffInstType(builder, arrayType->getElementType(), op1Val, op2Val));
+ }
+ auto diffArrayType = builder->getArrayType(diffElementType, arrayType->getElementCount());
+ return builder->emitMakeArray(diffArrayType, (UInt)args.getCount(), args.getBuffer());
+ }
+ else
+ {
+ // TODO: insert a runtime loop here.
+ SLANG_UNIMPLEMENTED_X("dadd of dynamic array.");
+ }
+ }
auto addMethod = diffTypeContext.getAddMethodForType(builder, primalType);
// Should exist.
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 640041ecf..8d9a01b75 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -185,6 +185,7 @@ struct ExtractPrimalFuncContext
case kIROp_MakeStruct:
case kIROp_MakeTuple:
case kIROp_MakeArray:
+ case kIROp_MakeArrayFromElement:
case kIROp_MakeDifferentialPair:
case kIROp_MakeOptionalNone:
case kIROp_MakeOptionalValue:
@@ -194,6 +195,8 @@ struct ExtractPrimalFuncContext
case kIROp_GetElement:
case kIROp_FieldExtract:
case kIROp_swizzle:
+ case kIROp_UpdateElement:
+ case kIROp_UpdateField:
case kIROp_OptionalHasValue:
case kIROp_GetOptionalValue:
case kIROp_MatrixReshape:
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 4d33d3743..ae8359251 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -532,6 +532,46 @@ void stripNoDiffTypeAttribute(IRModule* module)
pass.processModule();
}
+bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst)
+{
+ HashSet<IRInst*> processedSet;
+ for (;typeInst;)
+ {
+ if (as<IRArrayTypeBase>(typeInst) || as<IRPtrTypeBase>(typeInst))
+ {
+ typeInst = typeInst->getOperand(0);
+ if (!processedSet.Add(typeInst))
+ return false;
+ }
+ else
+ {
+ break;
+ }
+ }
+ if (!typeInst)
+ return false;
+ switch (typeInst->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_DifferentialPairType:
+ return true;
+ default:
+ break;
+ }
+ if (context.lookUpConformanceForType(typeInst))
+ return true;
+ // Look for equivalent types.
+ for (auto type : context.differentiableWitnessDictionary)
+ {
+ if (isTypeEqual(type.Key, (IRType*)typeInst))
+ {
+ context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value;
+ return true;
+ }
+ }
+ return false;
+}
+
struct AutoDiffPass : public InstPassBase
{
DiagnosticSink* getSink()
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index cb767c20a..6da4ea6a6 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -175,20 +175,40 @@ struct DifferentiableTypeConformanceContext
case kIROp_DoubleType:
case kIROp_VectorType:
return origType;
+ case kIROp_ArrayType:
+ {
+ auto diffElementType = (IRType*)getDifferentialForType(
+ builder, as<IRArrayType>(origType)->getElementType());
+ if (!diffElementType)
+ return nullptr;
+ return builder->getArrayType(
+ diffElementType,
+ as<IRArrayType>(origType)->getElementCount());
+ }
+ default:
+ return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey);
}
- return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey);
}
IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
{
- return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey);
+ auto result = lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey);
+ if (result && !result->findDecoration<IRNoSideEffectDecoration>())
+ {
+ builder->addDecoration(result, kIROp_NoSideEffectDecoration);
+ }
+ return result;
}
IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
{
- return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey);
+ auto result = lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey);
+ if (result && !result->findDecoration<IRNoSideEffectDecoration>())
+ {
+ builder->addDecoration(result, kIROp_NoSideEffectDecoration);
+ }
+ return result;
}
-
};
struct DifferentialPairTypeBuilder
@@ -261,10 +281,6 @@ void stripDerivativeDecorations(IRInst* inst);
bool isBackwardDifferentiableFunc(IRInst* func);
-SlangResult eliminateAddressInsts(
- SharedIRBuilder* sharedBuilder,
- DifferentiableTypeConformanceContext& diffContext,
- IRFunc* func,
- DiagnosticSink* sink);
+bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst);
};
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 67b7e92b0..f8d70c8ed 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -2,40 +2,32 @@
#include "slang-ir-autodiff.h"
#include "slang-ir-inst-pass-base.h"
+#include "slang-ir-single-return.h"
+#include "slang-ir-addr-inst-elimination.h"
namespace Slang
{
-bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst)
+IRInst* getSpecializedVal(IRInst* inst)
{
- HashSet<IRInst*> processedSet;
- while (auto ptrType = as<IRPtrTypeBase>(typeInst))
+ int loopLimit = 1024;
+ while (inst && inst->getOp() == kIROp_Specialize)
{
- typeInst = ptrType->getValueType();
- if (!processedSet.Add(typeInst))
- return false;
- }
- if (!typeInst)
- return false;
- switch (typeInst->getOp())
- {
- case kIROp_FloatType:
- case kIROp_DifferentialPairType:
- return true;
- default:
- break;
- }
- if (context.lookUpConformanceForType(typeInst))
- return true;
- // Look for equivalent types.
- for (auto type : context.differentiableWitnessDictionary)
- {
- if (isTypeEqual(type.Key, (IRType*)typeInst))
- {
- context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value;
- return true;
- }
+ inst = as<IRSpecialize>(inst)->getBase();
+ loopLimit--;
+ if (loopLimit == 0)
+ return inst;
}
- return false;
+ return inst;
+}
+
+IRInst* getLeafFunc(IRInst* func)
+{
+ func = getSpecializedVal(func);
+ if (!func)
+ return nullptr;
+ if (auto genericFunc = as<IRGeneric>(func))
+ return findInnerMostGenericReturnVal(genericFunc);
+ return func;
}
struct CheckDifferentiabilityPassContext : public InstPassBase
@@ -55,29 +47,6 @@ public:
: InstPassBase(inModule), sharedBuilder(inSharedBuilder), sink(inSink), sharedContext(inModule->getModuleInst())
{}
- IRInst* getSpecializedVal(IRInst* inst)
- {
- int loopLimit = 1024;
- while (inst && inst->getOp() == kIROp_Specialize)
- {
- inst = as<IRSpecialize>(inst)->getBase();
- loopLimit--;
- if (loopLimit == 0)
- return inst;
- }
- return inst;
- }
-
- IRInst* getLeafFunc(IRInst* func)
- {
- func = getSpecializedVal(func);
- if (!func)
- return nullptr;
- if (auto genericFunc = as<IRGeneric>(func))
- return findInnerMostGenericReturnVal(genericFunc);
- return func;
- }
-
bool _isFuncMarkedForAutoDiff(IRInst* func)
{
func = getLeafFunc(func);
@@ -219,6 +188,30 @@ public:
}
return false;
}
+
+ struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy
+ {
+ DifferentiableTypeConformanceContext* diffTypeContext;
+
+ virtual bool shouldConvertAddrInst(IRInst* addrInst) override
+ {
+ if (isDifferentiableType(*diffTypeContext, addrInst->getDataType()))
+ return true;
+ return false;
+ }
+ };
+
+ SlangResult prepareFuncForAutoDiff(DifferentiableTypeConformanceContext& diffTypeContext, IRFunc* func)
+ {
+ if (!isSingleReturnFunc(func))
+ {
+ convertFuncToSingleReturnForm(func->getModule(), func);
+ }
+ AutoDiffAddressConversionPolicy cvtPolicty;
+ cvtPolicty.diffTypeContext = &diffTypeContext;
+ return eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink);
+ }
+
void processFunc(IRGlobalValueWithCode* funcInst)
{
if (!_isFuncMarkedForAutoDiff(funcInst))
@@ -232,7 +225,7 @@ public:
{
if (auto func = as<IRFunc>(funcInst))
{
- if (SLANG_FAILED(eliminateAddressInsts(sharedBuilder, diffTypeContext, func, sink)))
+ if (SLANG_FAILED(prepareFuncForAutoDiff(diffTypeContext, func)))
return;
}
}
diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp
index 0fedf695d..a8cdb5cca 100644
--- a/source/slang/slang-ir-constexpr.cpp
+++ b/source/slang/slang-ir-constexpr.cpp
@@ -115,9 +115,12 @@ bool opCanBeConstExpr(IROp op)
case kIROp_MakeString:
case kIROp_MakeUInt64:
case kIROp_MakeArray:
+ case kIROp_MakeArrayFromElement:
case kIROp_swizzle:
case kIROp_GetElement:
case kIROp_FieldExtract:
+ case kIROp_UpdateField:
+ case kIROp_UpdateElement:
case kIROp_ExtractExistentialType:
case kIROp_ExtractExistentialValue:
case kIROp_ExtractExistentialWitnessTable:
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 134a45bf5..c2a1886fb 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -293,6 +293,7 @@ INST(MakeMatrixFromScalar, makeMatrixFromScalar, 1, 0)
INST(MatrixReshape, matrixReshape, 1, 0)
INST(VectorReshape, vectorReshape, 1, 0)
INST(MakeArray, makeArray, 0, 0)
+INST(MakeArrayFromElement, makeArrayFromElement, 1, 0)
INST(MakeStruct, makeStruct, 0, 0)
INST(MakeTuple, makeTuple, 0, 0)
INST(GetTupleElement, getTupleElement, 2, 0)
@@ -310,6 +311,9 @@ INST(Call, call, 1, 0)
INST(RTTIObject, rtti_object, 0, 0)
INST(Alloca, alloca, 1, 0)
+INST(UpdateElement, updateElement, 3, 0)
+INST(UpdateField, updateField, 3, 0)
+
INST(PackAnyValue, packAnyValue, 1, 0)
INST(UnpackAnyValue, unpackAnyValue, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 22da763b3..4887b1c79 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2185,6 +2185,36 @@ struct IRDifferentialPairGetPrimal : IRInst
IRInst* getBase() { return getOperand(0); }
};
+struct IRUpdateElement : IRInst
+{
+ IR_LEAF_ISA(UpdateElement)
+
+ IRInst* getOldValue() { return getOperand(0); }
+ IRInst* getIndex() { return getOperand(1); }
+ IRInst* getElementValue() { return getOperand(2); }
+ IRInst* getPrimalElementType()
+ {
+ if (getOperandCount() != 4)
+ return nullptr;
+ return getOperand(3);
+ }
+};
+
+struct IRUpdateField : IRInst
+{
+ IR_LEAF_ISA(UpdateField)
+
+ IRInst* getOldValue() { return getOperand(0); }
+ IRInst* getFieldKey() { return getOperand(1); }
+ IRInst* getElementValue() { return getOperand(2); }
+ IRInst* getPrimalElementType()
+ {
+ if (getOperandCount() != 4)
+ return nullptr;
+ return getOperand(3);
+ }
+};
+
// Constructs an `Result<T,E>` value from an error code.
struct IRMakeResultError : IRInst
{
@@ -2925,6 +2955,10 @@ public:
UInt argCount,
IRInst* const* args);
+ IRInst* emitMakeArrayFromElement(
+ IRType* type,
+ IRInst* element);
+
IRInst* emitMakeStruct(
IRType* type,
UInt argCount,
@@ -3139,6 +3173,9 @@ public:
IRInst* basePtr,
IRInst* index);
+ IRInst* emitUpdateElement(IRInst* base, IRInst* index, IRInst* newElement);
+ IRInst* emitUpdateField(IRInst* base, IRInst* fieldKey, IRInst* newFieldVal);
+
IRInst* emitGetAddress(
IRType* type,
IRInst* value);
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index 21e17b546..16f6cd9b9 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -69,7 +69,7 @@ struct PeepholeContext : InstPassBase
Index i = 0;
for (auto sfield : structType->getFields())
{
- if (sfield == field)
+ if (sfield->getKey() == field)
{
fieldIndex = i;
break;
@@ -84,6 +84,142 @@ struct PeepholeContext : InstPassBase
}
}
}
+ else if (auto updateField = as<IRUpdateField>(inst->getOperand(0)))
+ {
+ if (inst->getOperand(1) == updateField->getFieldKey())
+ {
+ inst->replaceUsesWith(updateField->getElementValue());
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ else
+ {
+ inst->setOperand(0, updateField->getOldValue());
+ changed = true;
+ }
+ }
+ break;
+ case kIROp_GetElement:
+ if (inst->getOperand(0)->getOp() == kIROp_MakeArray)
+ {
+ auto index = as<IRIntLit>(as<IRGetElement>(inst)->getIndex());
+ if (!index)
+ break;
+ auto opCount = inst->getOperand(0)->getOperandCount();
+ if ((UInt)index->getValue() < opCount)
+ {
+ inst->replaceUsesWith(inst->getOperand(0)->getOperand((UInt)index->getValue()));
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ else if (inst->getOperand(0)->getOp() == kIROp_MakeArrayFromElement)
+ {
+ inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ else if (auto updateElement = as<IRUpdateElement>(inst->getOperand(0)))
+ {
+ if (inst->getOperand(1) == updateElement->getIndex())
+ {
+ inst->replaceUsesWith(updateElement->getElementValue());
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ else if (auto constIndex1 = as<IRIntLit>(inst->getOperand(1)))
+ {
+ if (auto constIndex2 = as<IRIntLit>(updateElement->getIndex()))
+ {
+ // If we can determine that the indices does not match,
+ // then reduce the original value operand to before the update.
+ if (constIndex1->getValue() != constIndex2->getValue())
+ {
+ inst->setOperand(0, updateElement->getOldValue());
+ changed = true;
+ }
+ }
+ }
+ }
+ break;
+ case kIROp_UpdateElement:
+ {
+ if (auto constIndex = as<IRIntLit>(inst->getOperand(1)))
+ {
+ auto oldVal = inst->getOperand(0);
+ if (oldVal->getOp() == kIROp_MakeArray ||
+ oldVal->getOp() == kIROp_MakeArrayFromElement)
+ {
+ auto arrayType = as<IRArrayType>(inst->getDataType());
+ if (!arrayType) break;
+ auto arraySize = as<IRIntLit>(arrayType->getElementCount());
+ if (!arraySize) break;
+ List<IRInst*> args;
+ for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
+ {
+ IRInst* arg = nullptr;
+ if (i < (IRIntegerValue)oldVal->getOperandCount())
+ arg = oldVal->getOperand((UInt)i);
+ else if (oldVal->getOperandCount() != 0)
+ arg = oldVal->getOperand(0);
+ else
+ break;
+ if (i == (IRIntegerValue)constIndex->getValue())
+ arg = inst->getOperand(2);
+ args.add(arg);
+ }
+ if (args.getCount() == arraySize->getValue())
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(inst);
+ auto makeArray = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());
+ inst->replaceUsesWith(makeArray);
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ }
+ }
+ break;
+ case kIROp_UpdateField:
+ {
+ auto oldVal = inst->getOperand(0);
+ if (oldVal->getOp() == kIROp_MakeStruct)
+ {
+ auto structType = as<IRStructType>(inst->getDataType());
+ if (!structType) break;
+ List<IRInst*> args;
+ UInt i = 0;
+ bool isValid = true;
+ for (auto field : structType->getFields())
+ {
+ IRInst* arg = nullptr;
+ if (i < oldVal->getOperandCount())
+ arg = oldVal->getOperand(i);
+ if (field->getKey() == inst->getOperand(1))
+ arg = inst->getOperand(2);
+ if (arg)
+ {
+ args.add(arg);
+ }
+ else
+ {
+ isValid = false;
+ break;
+ }
+ i++;
+ }
+ if (isValid)
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(inst);
+ auto makeStruct = builder.emitMakeStruct(structType, (UInt)args.getCount(), args.getBuffer());
+ inst->replaceUsesWith(makeStruct);
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ }
break;
case kIROp_CastPtrToBool:
{
@@ -246,10 +382,17 @@ struct PeepholeContext : InstPassBase
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->init(module);
sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
-
- changed = false;
- processChildInsts(func, [this](IRInst* inst) { processInst(inst); });
- return changed;
+ bool result = false;
+ for (;;)
+ {
+ changed = false;
+ processChildInsts(func, [this](IRInst* inst) { processInst(inst); });
+ if (changed)
+ result = true;
+ else
+ break;
+ }
+ return result;
}
bool processModule()
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
index a57bfce3e..bcf0907df 100644
--- a/source/slang/slang-ir-redundancy-removal.cpp
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -37,6 +37,8 @@ struct RedundancyRemovalContext
case kIROp_FieldAddress:
case kIROp_GetElement:
case kIROp_GetElementPtr:
+ case kIROp_UpdateElement:
+ case kIROp_UpdateField:
case kIROp_LookupWitness:
case kIROp_Specialize:
case kIROp_OptionalHasValue:
@@ -46,6 +48,7 @@ struct RedundancyRemovalContext
case kIROp_GetTupleElement:
case kIROp_MakeStruct:
case kIROp_MakeArray:
+ case kIROp_MakeArrayFromElement:
case kIROp_MakeVector:
case kIROp_MakeMatrix:
case kIROp_MakeMatrixFromScalar:
@@ -75,6 +78,8 @@ struct RedundancyRemovalContext
case kIROp_Neq:
case kIROp_Eql:
return true;
+ case kIROp_Call:
+ return isPureFunctionalCall(as<IRCall>(inst));
default:
return false;
}
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 319a23989..3ffbb75f7 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -227,6 +227,31 @@ String dumpIRToString(IRInst* root)
return sb.ToString();
}
+bool isPureFunctionalCall(IRCall* call)
+{
+ auto callee = getResolvedInstForDecorations(call->getCallee());
+ if (callee->findDecoration<IRReadNoneDecoration>())
+ {
+ return true;
+ }
+ if (callee->findDecoration<IRNoSideEffectDecoration>())
+ {
+ // If the function has no side effect and is not writing to any outputs,
+ // we can safely treat the call as a normal inst.
+ bool hasOutArg = false;
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ if (as<IRPtrTypeBase>(call->getArg(i)->getDataType()))
+ {
+ hasOutArg = true;
+ break;
+ }
+ }
+ return !hasOutArg;
+ }
+ return false;
+}
+
struct GenericChildrenMigrationContextImpl
{
IRCloneEnv cloneEnv;
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index a250fc6a6..26ad4bc68 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -155,6 +155,9 @@ inline IRInst* unwrapAttributedType(IRInst* type)
String dumpIRToString(IRInst* root);
+// Returns whether a call insts can be treated as a pure functional inst
+// (no writes to memory, no side effects).
+bool isPureFunctionalCall(IRCall* callInst);
}
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index b79221900..2960d942c 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -1810,6 +1810,24 @@ namespace Slang
}
template<typename T>
+ static T* createInst(
+ IRBuilder* builder,
+ IROp op,
+ IRType* type,
+ IRInst* arg1,
+ IRInst* arg2,
+ IRInst* arg3)
+ {
+ IRInst* args[] = { arg1, arg2, arg3 };
+ return createInstImpl<T>(
+ builder,
+ op,
+ type,
+ 3,
+ &args[0]);
+ }
+
+ template<typename T>
static T* createInstWithTrailingArgs(
IRBuilder* builder,
IROp op,
@@ -3768,6 +3786,13 @@ namespace Slang
return emitIntrinsicInst(type, kIROp_MakeArray, argCount, args);
}
+ IRInst* IRBuilder::emitMakeArrayFromElement(
+ IRType* type,
+ IRInst* element)
+ {
+ return emitIntrinsicInst(type, kIROp_MakeArrayFromElement, 1, &element);
+ }
+
IRInst* IRBuilder::emitMakeStruct(
IRType* type,
UInt argCount,
@@ -4346,6 +4371,34 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitUpdateElement(IRInst* base, IRInst* index, IRInst* newElement)
+ {
+ auto inst = createInst<IRUpdateElement>(
+ this,
+ kIROp_UpdateElement,
+ base->getFullType(),
+ base,
+ index,
+ newElement);
+
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitUpdateField(IRInst* base, IRInst* fieldKey, IRInst* newFieldVal)
+ {
+ auto inst = createInst<IRUpdateField>(
+ this,
+ kIROp_UpdateField,
+ base->getFullType(),
+ base,
+ fieldKey,
+ newFieldVal);
+
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitGetAddress(
IRType* type,
IRInst* value)
@@ -6545,11 +6598,7 @@ namespace Slang
// common subexpression elimination, etc.
//
auto call = cast<IRCall>(this);
- auto callee = getResolvedInstForDecorations(call->getCallee());
- if(callee->findDecoration<IRReadNoneDecoration>())
- {
- return false;
- }
+ return !isPureFunctionalCall(call);
}
break;
@@ -6592,6 +6641,7 @@ namespace Slang
case kIROp_MatrixReshape:
case kIROp_VectorReshape:
case kIROp_MakeArray:
+ case kIROp_MakeArrayFromElement:
case kIROp_MakeStruct:
case kIROp_MakeString:
case kIROp_getNativeStr:
@@ -6612,6 +6662,8 @@ namespace Slang
case kIROp_FieldAddress:
case kIROp_GetElement:
case kIROp_GetElementPtr:
+ case kIROp_UpdateElement:
+ case kIROp_UpdateField:
case kIROp_MeshOutputRef:
case kIROp_MakeVectorFromScalar:
case kIROp_swizzle: