summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-ir-address-analysis.cpp173
-rw-r--r--source/slang/slang-ir-address-analysis.h24
-rw-r--r--source/slang/slang-ir-autodiff-addr-inst-elimination.cpp476
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp18
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h53
-rw-r--r--source/slang/slang-ir-autodiff.cpp3
-rw-r--r--source/slang/slang-ir-autodiff.h7
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp82
-rw-r--r--source/slang/slang-ir-check-differentiability.h3
-rw-r--r--source/slang/slang-ir-dominators.cpp25
-rw-r--r--source/slang/slang-ir-dominators.h3
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp125
-rw-r--r--source/slang/slang-ir-redundancy-removal.h11
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp4
-rw-r--r--source/slang/slang-ir-single-return.cpp16
-rw-r--r--source/slang/slang-ir-single-return.h1
-rw-r--r--source/slang/slang-ir-ssa-simplification.cpp3
-rw-r--r--source/slang/slang-ir-util.cpp62
-rw-r--r--source/slang/slang-ir-util.h29
-rw-r--r--source/slang/slang-ir.cpp10
-rw-r--r--source/slang/slang-lower-to-ir.cpp2
24 files changed, 1051 insertions, 84 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 4820c430f..0d4088d75 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -581,6 +581,8 @@ DIAGNOSTIC(41021, Error, differentiableFuncMustHaveOutput, "a differentiable fun
DIAGNOSTIC(41022, Error, differentiableFuncMustHaveInput, "a differentiable function must have at least one differentiable input.")
DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can only be called when argument is statically resolvable to a string literal")
+DIAGNOSTIC(41901, Error, unsupportedUseOfLValueForAutoDiff, "unsupported use of L-value for auto differentiation.")
+DIAGNOSTIC(41902, Error, cannotDifferentiateDynamicallyIndexedData, "cannot auto-differentiate mixed read/write access to dynamically indexed data in '$0'.")
//
// 5xxxx - Target code generation.
//
diff --git a/source/slang/slang-ir-address-analysis.cpp b/source/slang/slang-ir-address-analysis.cpp
new file mode 100644
index 000000000..aba59e1de
--- /dev/null
+++ b/source/slang/slang-ir-address-analysis.cpp
@@ -0,0 +1,173 @@
+#include "slang-ir-address-analysis.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+ void moveInstToEarliestPoint(IRDominatorTree* domTree, IRGlobalValueWithCode* func, IRInst* inst)
+ {
+ if (!as<IRBlock>(inst->getParent()))
+ return;
+ if (domTree->isUnreachable(as<IRBlock>(inst->getParent())))
+ return;
+
+ List<IRBlock*> blocks;
+ HashSet<IRInst*> operandInsts;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ operandInsts.Add(inst->getOperand(i));
+ auto parentBlock = as<IRBlock>(inst->getOperand(i)->getParent());
+ if (parentBlock)
+ {
+ if (!domTree->isUnreachable(parentBlock))
+ blocks.add(parentBlock);
+ }
+ }
+ {
+ operandInsts.Add(inst->getFullType());
+ auto parentBlock = as<IRBlock>(inst->getFullType()->getParent());
+ if (parentBlock)
+ {
+ if (!domTree->isUnreachable(parentBlock))
+ blocks.add(parentBlock);
+ }
+ }
+ // Find earliest block that is dominated by all operand blocks.
+ IRBlock* earliestBlock = as<IRBlock>(inst->getParent());
+ for (auto block : func->getBlocks())
+ {
+ bool dominated = true;
+ for (auto opBlock : blocks)
+ {
+ if (!domTree->dominates(opBlock, block))
+ {
+ dominated = false;
+ break;
+ }
+ }
+ if (dominated)
+ {
+ earliestBlock = block;
+ break;
+ }
+ }
+
+ if (!earliestBlock)
+ return;
+
+ IRInst* latestOperand = nullptr;
+ for (auto childInst : earliestBlock->getChildren())
+ {
+ if (operandInsts.Contains(childInst))
+ {
+ latestOperand = childInst;
+ }
+ }
+
+ if (!latestOperand || as<IRParam>(latestOperand))
+ inst->insertBefore(earliestBlock->getFirstOrdinaryInst());
+ else
+ inst->insertAfter(latestOperand);
+ }
+
+ AddressAccessInfo analyzeAddressUse(IRDominatorTree* dom, IRGlobalValueWithCode* func)
+ {
+ DeduplicateContext deduplicateContext;
+
+ AddressAccessInfo info;
+
+ // Deduplicate and move known address insts.
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst = block->getFirstChild(); inst;)
+ {
+ auto next = inst->getNextInst();
+ switch (inst->getOp())
+ {
+ case kIROp_Var:
+ {
+ RefPtr<AddressInfo> addrInfo = new AddressInfo();
+ addrInfo->addrInst = inst;
+ addrInfo->isConstant = true;
+ addrInfo->parentAddress = nullptr;
+ info.addressInfos[inst] = addrInfo;
+ }
+ break;
+ case kIROp_Param:
+ if (as<IRPtrTypeBase>(inst->getFullType()))
+ {
+ RefPtr<AddressInfo> addrInfo = new AddressInfo();
+ addrInfo->addrInst = inst;
+ addrInfo->isConstant = (block == func->getFirstBlock() ? true : false);
+ addrInfo->parentAddress = nullptr;
+ info.addressInfos[inst] = addrInfo;
+ }
+ break;
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ {
+ moveInstToEarliestPoint(dom, func, inst->getFullType());
+ moveInstToEarliestPoint(dom, func, inst);
+ auto deduplicated = deduplicateContext.deduplicate(inst, [func](IRInst* inst)
+ {
+ if (!inst->getParent())
+ return false;
+ if (inst->getParent()->getParent() != func)
+ return false;
+ switch (inst->getOp())
+ {
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ return true;
+ default:
+ return false;
+ }
+ });
+
+ if (deduplicated != inst)
+ {
+ SLANG_RELEASE_ASSERT(dom->dominates(
+ as<IRBlock>(deduplicated->getParent()),
+ as<IRBlock>(inst->getParent())));
+
+ inst->replaceUsesWith(deduplicated);
+ inst->removeAndDeallocate();
+ }
+ else
+ {
+ RefPtr<AddressInfo> addrInfo = new AddressInfo();
+ addrInfo->addrInst = inst;
+ if (inst->getOp() == kIROp_FieldAddress)
+ {
+ addrInfo->isConstant = true;
+ }
+ else
+ {
+ addrInfo->isConstant =
+ as<IRConstant>(inst->getOperand(1)) ? true : false;
+ }
+ info.addressInfos[inst] = addrInfo;
+ }
+ }
+ break;
+ }
+ inst = next;
+ }
+ }
+
+ // Construct address info tree.
+ for (auto& addr : info.addressInfos)
+ {
+ RefPtr<AddressInfo> parentInfo;
+ if (addr.Value->addrInst->getOperandCount() > 1 &&
+ info.addressInfos.TryGetValue(addr.Value->addrInst->getOperand(0), parentInfo))
+ {
+ addr.Value->parentAddress = parentInfo;
+ parentInfo->children.add(addr.Value);
+ if (!parentInfo->isConstant)
+ addr.Value->isConstant = false;
+ }
+ }
+ return info;
+ }
+}
diff --git a/source/slang/slang-ir-address-analysis.h b/source/slang/slang-ir-address-analysis.h
new file mode 100644
index 000000000..450e8b9eb
--- /dev/null
+++ b/source/slang/slang-ir-address-analysis.h
@@ -0,0 +1,24 @@
+// slang-ir-address-analysis.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-dominators.h"
+
+namespace Slang
+{
+ struct AddressInfo : public RefObject
+ {
+ IRInst* addrInst = nullptr;
+ AddressInfo* parentAddress = nullptr;
+ bool isConstant = false;
+ List<AddressInfo*> children;
+ };
+
+ struct AddressAccessInfo
+ {
+ OrderedDictionary<IRInst*, RefPtr<AddressInfo>> addressInfos;
+ };
+
+ // Gather info on all addresses used by `func`.
+ AddressAccessInfo analyzeAddressUse(IRDominatorTree* domTree, IRGlobalValueWithCode* func);
+}
diff --git a/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp b/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp
new file mode 100644
index 000000000..c60995595
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp
@@ -0,0 +1,476 @@
+#include "slang-ir-address-analysis.h"
+#include "slang-ir-autodiff-fwd.h"
+#include "slang-ir-autodiff-pairs.h"
+#include "slang-ir-autodiff-rev.h"
+#include "slang-ir-autodiff.h"
+#include "slang-ir-single-return.h"
+#include "slang-ir-ssa-simplification.h"
+#include "slang-ir-validate.h"
+
+namespace Slang
+{
+bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst);
+
+struct AddressInstEliminationContext
+{
+ OrderedDictionary<IRInst*, IRInst*> mapAddrInstToTempVar;
+
+ IRInst* _reconstructStruct(
+ IRBuilder& builder, IRStructType* type, IRInst* tempVar, List<AddressInfo*>& childAddrs)
+ {
+ List<IRInst*> args;
+ IRInst* loadedTempVar = nullptr;
+ for (auto child : type->getChildren())
+ {
+ if (auto field = as<IRStructField>(child))
+ {
+ IRInst* childVar = nullptr;
+ for (auto subAddr : childAddrs)
+ {
+ auto fieldAddrInst = cast<IRFieldAddress>(subAddr->addrInst);
+ if (fieldAddrInst->getField() == field->getKey())
+ {
+ mapAddrInstToTempVar.TryGetValue(subAddr->addrInst, childVar);
+ break;
+ }
+ }
+ if (childVar)
+ {
+ args.add(builder.emitLoad(childVar));
+ }
+ else
+ {
+ if (!loadedTempVar)
+ loadedTempVar = builder.emitLoad(tempVar);
+ args.add(builder.emitFieldExtract(
+ field->getFieldType(), loadedTempVar, field->getKey()));
+ }
+ }
+ }
+ return builder.emitMakeStruct(type, args);
+ }
+
+ IRInst* _reconstructArray(
+ IRBuilder& builder,
+ IRArrayType* type,
+ IRIntegerValue arraySize,
+ IRInst* tempVar,
+ List<AddressInfo*>& childAddrs)
+ {
+ IRInst* loadedTempVar = nullptr;
+ List<IRInst*> args;
+ for (IRIntegerValue index = 0; index < arraySize; index++)
+ {
+ IRInst* childVar = nullptr;
+ for (auto subAddr : childAddrs)
+ {
+ auto elementPtrInst = cast<IRGetElementPtr>(subAddr->addrInst);
+ auto elementIndex = as<IRIntLit>(elementPtrInst->getIndex());
+ if (elementIndex && elementIndex->getValue() == index)
+ {
+ mapAddrInstToTempVar.TryGetValue(subAddr->addrInst, childVar);
+ break;
+ }
+ }
+ if (childVar)
+ {
+ args.add(builder.emitLoad(childVar));
+ }
+ else
+ {
+ if (!loadedTempVar)
+ loadedTempVar = builder.emitLoad(tempVar);
+ args.add(builder.emitElementExtract(
+ type->getElementType(),
+ loadedTempVar,
+ builder.getIntValue(builder.getIntType(), index)));
+ }
+ }
+ return builder.emitMakeArray(type, args.getCount(), args.getBuffer());
+ }
+
+ void updateChildTempVarRecursive(
+ IRBuilder& builder,
+ AddressInfo* addr,
+ IRInst* val)
+ {
+ for (auto child : addr->children)
+ {
+ IRInst* childVar = nullptr;
+ if (mapAddrInstToTempVar.TryGetValue(child->addrInst, childVar))
+ {
+ switch (child->addrInst->getOp())
+ {
+ case kIROp_FieldAddress:
+ {
+ auto subVal = builder.emitFieldExtract(
+ cast<IRPtrTypeBase>(child->addrInst->getDataType())->getValueType(),
+ val,
+ child->addrInst->getOperand(1));
+ builder.emitStore(childVar, subVal);
+ updateChildTempVarRecursive(builder, child, subVal);
+ }
+ break;
+ case kIROp_GetElementPtr:
+ {
+ auto subVal = builder.emitElementExtract(
+ cast<IRPtrTypeBase>(child->addrInst->getDataType())->getValueType(),
+ val,
+ child->addrInst->getOperand(1));
+ builder.emitStore(childVar, subVal);
+ updateChildTempVarRecursive(builder, child, subVal);
+ }
+ break;
+ default:
+ {
+ }
+ break;
+ }
+ }
+ }
+ }
+
+ IRInst* getLoadedValue(
+ IRBuilder& builder,
+ AddressInfo* addr,
+ IRInst* tempVar)
+ {
+ if (addr->children.getCount())
+ {
+ // Reconstruct val.
+ auto type =
+ cast<IRPtrTypeBase>(unwrapAttributedType(tempVar->getFullType()))->getValueType();
+ switch (type->getOp())
+ {
+ case kIROp_StructType:
+ return _reconstructStruct(
+ builder, as<IRStructType>(type), tempVar, addr->children);
+ case kIROp_ArrayType:
+ {
+ auto arrayType = as<IRArrayType>(type);
+ auto size = as<IRIntLit>(arrayType->getElementCount());
+ if (!size || size->getValue() < 0)
+ {
+ // Unsupported array type.
+ }
+ else
+ {
+ return _reconstructArray(
+ builder,
+ arrayType,
+ size->getValue(),
+ tempVar,
+ addr->children);
+ }
+ }
+ break;
+ default:
+ // Unsupported address type.
+ break;
+ }
+ }
+ return builder.emitLoad(tempVar);
+ };
+
+ void updateParentTempVarRecursive(
+ IRBuilder& builder,
+ AddressInfo* addr)
+ {
+ for (auto parent = addr->parentAddress; parent; parent = parent->parentAddress)
+ {
+ IRInst* parentVar = nullptr;
+ if (mapAddrInstToTempVar.TryGetValue(parent->addrInst, parentVar))
+ {
+ auto val = getLoadedValue(builder, parent, parentVar);
+ builder.emitStore(parentVar, val);
+ }
+ }
+ }
+
+ String getAddrName(IRInst* addrInst)
+ {
+ StringBuilder sb;
+ List<IRInst*> bases;
+ bases.add(addrInst);
+ for (; addrInst;)
+ {
+ if (auto fieldAddr = as<IRFieldAddress>(addrInst))
+ bases.add(fieldAddr->getBase());
+ else if (auto index = as<IRGetElementPtr>(addrInst))
+ bases.add(index->getBase());
+ else
+ break;
+ }
+ for (Index i = bases.getCount() - 1; i >= 0; i--)
+ {
+ if (bases[i]->getOp() == kIROp_FieldAddress)
+ {
+ sb << ".";
+ auto field = bases[i]->getOperand(1);
+ auto nameDecor = field->findDecoration<IRNameHintDecoration>();
+ sb << (nameDecor ? nameDecor->getName() : UnownedStringSlice("<unknown>"));
+ }
+ else if (bases[i]->getOp() == kIROp_FieldAddress)
+ {
+ sb << "[";
+ auto index = bases[i]->getOperand(1);
+ auto nameDecor = index->findDecoration<IRNameHintDecoration>();
+ if (nameDecor)
+ {
+ sb << nameDecor->getName();
+ }
+ else if (auto intLit = as<IRIntLit>(index))
+ {
+ sb << intLit->getValue();
+ }
+ else
+ {
+ sb << "...";
+ }
+ sb << "]";
+ }
+ else
+ {
+ auto nameDecor = bases[i]->findDecoration<IRNameHintDecoration>();
+ sb << (nameDecor ? nameDecor->getName() : UnownedStringSlice("<unknown>"));
+ }
+ }
+ return sb.ProduceString();
+ }
+
+ SlangResult eliminateAddressInstsImpl(
+ SharedIRBuilder* sharedBuilder,
+ DifferentiableTypeConformanceContext& diffContext,
+ IRFunc* func,
+ DiagnosticSink* sink)
+ {
+ bool hasError = false;
+
+ if (!isSingleReturnFunc(func))
+ {
+ convertFuncToSingleReturnForm(func->getModule(), func);
+ }
+
+ IRBuilder builder(sharedBuilder);
+
+ auto dom = computeDominatorTree(func);
+ auto addrUse = analyzeAddressUse(dom, func);
+ List<AddressInfo*> workList;
+ HashSet<AddressInfo*> workListSet;
+
+ // Process leaf addresses first.
+ for (auto addr : addrUse.addressInfos)
+ {
+ if (addr.Value->children.getCount() == 0)
+ workList.add(addr.Value);
+ }
+
+ auto createTempVarForAddr = [&](IRInst* addrInst)
+ {
+ if (as<IRParam>(addrInst))
+ builder.setInsertAfter(as<IRBlock>(addrInst->getParent())->getLastParam());
+ else
+ builder.setInsertAfter(addrInst);
+ auto ptrType = as<IRPtrTypeBase>(addrInst->getFullType());
+ SLANG_RELEASE_ASSERT(ptrType);
+ auto tempVar = builder.emitVar(ptrType->getValueType());
+ mapAddrInstToTempVar[addrInst] = tempVar;
+ };
+
+ // In the first pass, we create temp vars for addresses with non-trivial access pattern.
+ for (Index workListIndex = 0; workListIndex < workList.getCount(); workListIndex++)
+ {
+ auto addr = workList[workListIndex];
+
+ if (!isDifferentiableType(diffContext, addr->addrInst->getDataType()))
+ continue;
+
+ List<IRUse*> readUses, writeUses, callUses, subAddrUses, unknownUses;
+
+ for (auto node = addr; node; node = node->parentAddress)
+ {
+ auto addrInst = node->addrInst;
+
+ for (auto use = addrInst->firstUse; use; use = use->nextUse)
+ {
+ if (as<IRDecoration>(use->getUser()))
+ continue;
+ switch (use->getUser()->getOp())
+ {
+ case kIROp_Load:
+ readUses.add(use);
+ break;
+ case kIROp_Store:
+ writeUses.add(use);
+ break;
+ case kIROp_Call:
+ callUses.add(use);
+ break;
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ if (node == addr)
+ subAddrUses.add(use);
+ break;
+ default:
+ unknownUses.add(use);
+ break;
+ }
+ }
+ }
+
+ if (unknownUses.getCount() != 0)
+ {
+ // Diagnose about unknown use.
+ sink->diagnose(
+ unknownUses.getFirst()->getUser(),
+ Diagnostics::unsupportedUseOfLValueForAutoDiff);
+ hasError = true;
+ continue;
+ }
+
+ if (addr->isConstant)
+ {
+ // Otherwise, the address must be a constant, and we need to create a temp var for
+ // it. The exception is when the variable is a temp var for a call.
+ if (callUses.getCount() == 1 && writeUses.getCount() <= 1 &&
+ readUses.getCount() <= 1)
+ {
+ if (writeUses.getCount() == 0)
+ continue;
+
+ // The uses must be in write->call->read order.
+ auto callUse = callUses.getFirst();
+ auto writeUse = writeUses.getFirst();
+ auto readUse = readUses.getCount() ? readUses.getFirst() : writeUse;
+ if (dom->dominates(writeUse->getUser(), callUse->getUser()) &&
+ dom->dominates(callUse->getUser(), readUse->getUser()))
+ {
+ continue;
+ }
+ }
+
+ // Create a temp var for the address and replace all uses of the address to the temp
+ // var.
+ createTempVarForAddr(addr->addrInst);
+ }
+ else
+ {
+ // This is a dynamic address. We can only allow at most one write access to it.
+ bool hasNonTrivialAccess = false;
+ if (readUses.getCount() + callUses.getCount() != 0 &&
+ writeUses.getCount() + callUses.getCount() > 1)
+ hasNonTrivialAccess = true;
+
+ if (hasNonTrivialAccess)
+ {
+ // Mixed use of a non-constant address is unsupported right now.
+ sink->diagnose(
+ addr->addrInst,
+ Diagnostics::cannotDifferentiateDynamicallyIndexedData,
+ getAddrName(addr->addrInst));
+ }
+ }
+ if (addr->parentAddress && workListSet.Add(addr->parentAddress))
+ workList.add(addr->parentAddress);
+ }
+
+ if (hasError)
+ return SLANG_FAIL;
+
+ // Actually replace addresses with temp vars.
+ for (auto addr : workList)
+ {
+ IRInst* tempVar = nullptr;
+ if (!mapAddrInstToTempVar.TryGetValue(addr->addrInst, tempVar))
+ continue;
+ for (auto use = addr->addrInst->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+ auto user = use->getUser();
+
+ builder.setInsertBefore(user);
+ switch (user->getOp())
+ {
+ case kIROp_Load:
+ use->set(tempVar);
+ break;
+ case kIROp_Store:
+ use->set(tempVar);
+ updateChildTempVarRecursive(
+ builder, addr, as<IRStore>(user)->getVal());
+ updateParentTempVarRecursive(builder, addr);
+ case kIROp_Call:
+ {
+ use->set(tempVar);
+ builder.setInsertAfter(user);
+ auto newVal = builder.emitLoad(tempVar);
+ updateChildTempVarRecursive(builder, addr, newVal);
+ updateParentTempVarRecursive(builder, addr);
+ }
+ break;
+ default:
+ use->set(tempVar);
+ break;
+ }
+ use = nextUse;
+ }
+ }
+
+ // Assign initial values to tempVar.
+ for (auto tempVar : mapAddrInstToTempVar)
+ {
+ builder.setInsertAfter(tempVar.Value);
+ IRInst* initVal = nullptr;
+ if (tempVar.Key->getOp() == kIROp_Var ||
+ tempVar.Key->getOp() == kIROp_Param && as<IROutType>(tempVar.Key->getFullType()))
+ {
+ initVal = builder.emitDefaultConstruct(
+ cast<IRPtrTypeBase>(tempVar.Key->getFullType())->getValueType());
+ }
+ else
+ {
+ initVal = builder.emitLoad(tempVar.Key);
+ }
+ builder.emitStore(tempVar.Value, initVal);
+ }
+
+ // Store final values to out parameters before exiting function.
+ IRInst* returnInst = nullptr;
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (inst->getOp() == kIROp_Return)
+ {
+ returnInst = inst;
+ break;
+ }
+ }
+ }
+ SLANG_RELEASE_ASSERT(returnInst);
+ builder.setInsertBefore(returnInst);
+ for (auto param : func->getParams())
+ {
+ IRInst* tempVar = nullptr;
+ if (mapAddrInstToTempVar.TryGetValue(param, tempVar))
+ {
+ auto val = builder.emitLoad(tempVar);
+ builder.emitStore(param, val);
+ }
+ }
+ if (hasError)
+ return SLANG_FAIL;
+ return SLANG_OK;
+ }
+};
+
+SlangResult eliminateAddressInsts(
+ SharedIRBuilder* sharedBuilder,
+ DifferentiableTypeConformanceContext& diffContext,
+ IRFunc* func,
+ DiagnosticSink* sink)
+{
+ AddressInstEliminationContext ctx;
+ return ctx.eliminateAddressInstsImpl(sharedBuilder, diffContext, func, sink);
+}
+} // namespace Slang
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index ee159b80b..3f3618b44 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1130,6 +1130,8 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_VectorReshape:
case kIROp_IntCast:
case kIROp_FloatCast:
+ case kIROp_MakeStruct:
+ case kIROp_MakeArray:
return transcribeConstruct(builder, origInst);
case kIROp_LookupWitness:
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index d3a6137c1..779a4f1a3 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -5,10 +5,9 @@
#include "slang-ir-eliminate-phis.h"
#include "slang-ir-util.h"
#include "slang-ir-inst-pass-base.h"
-
+#include "slang-ir-ssa-simplification.h"
#include "slang-ir-autodiff-fwd.h"
-
namespace Slang
{
IRFuncType* BackwardDiffTranscriberBase::differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermeidateType)
@@ -502,6 +501,17 @@ namespace Slang
stripDerivativeDecorations(primalFunc);
eliminateDeadCode(primalOuterParent);
+ // Perform preparation and simplification.
+ differentiableTypeConformanceContext.setFunc(primalFunc);
+ if (SLANG_FAILED(eliminateAddressInsts(
+ builder->getSharedBuilder(),
+ differentiableTypeConformanceContext,
+ primalFunc,
+ sink)))
+ return nullptr;
+
+ simplifyFunc(primalFunc);
+
// Forward transcribe the clone of the original func.
ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>(
autoDiffSharedContext->transcriberSet.forwardTranscriber);
@@ -567,7 +577,9 @@ namespace Slang
}
auto fwdDiffFunc = generateNewForwardDerivativeForFunc(&tempBuilder, primalFunc, diffPropagateFunc);
-
+ if (!fwdDiffFunc)
+ return;
+
// Split first block into a paramter block.
this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc));
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index ae1a5dd70..e799456bb 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1061,6 +1061,10 @@ struct DiffTransposePass
case kIROp_MakeVector:
return transposeMakeVector(builder, fwdInst, revValue);
+ case kIROp_MakeStruct:
+ return transposeMakeStruct(builder, fwdInst, revValue);
+ case kIROp_MakeArray:
+ return transposeMakeArray(builder, fwdInst, revValue);
case kIROp_Specialize:
case kIROp_unconditionalBranch:
@@ -1218,6 +1222,55 @@ struct DiffTransposePass
return TranspositionResult(gradients);
}
+ TranspositionResult transposeMakeStruct(IRBuilder* builder, IRInst* fwdMakeStruct, IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+ auto structType = cast<IRStructType>(fwdMakeStruct->getFullType());
+ UInt ii = 0;
+ for (auto field : structType->getFields())
+ {
+ auto gradAtField = builder->emitFieldExtract(
+ field->getFieldType(),
+ revValue,
+ field->getKey());
+ SLANG_RELEASE_ASSERT(ii < fwdMakeStruct->getOperandCount());
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeStruct->getOperand(ii),
+ gradAtField,
+ fwdMakeStruct));
+ ii++;
+ }
+
+ // (A = MakeStruct(F1, F2, F3)) -> [(dF1 += dA.F1), (dF2 += dA.F2), (dF3 += dA.F3)]
+ return TranspositionResult(gradients);
+ }
+
+ TranspositionResult transposeMakeArray(IRBuilder* builder, IRInst* fwdMakeArray, IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+ auto arrayType = cast<IRArrayType>(fwdMakeArray->getFullType());
+ auto arraySize = cast<IRIntLit>(arrayType->getElementCount());
+
+ for (UInt ii = 0; ii < (UInt)arraySize->getValue(); ii++)
+ {
+ auto gradAtField = builder->emitElementExtract(
+ arrayType->getElementType(),
+ revValue,
+ builder->getIntValue(builder->getIntType(), ii));
+ SLANG_RELEASE_ASSERT(ii < fwdMakeArray->getOperandCount());
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeArray->getOperand(ii),
+ gradAtField,
+ fwdMakeArray));
+ ii++;
+ }
+
+ // (A = MakeArray(F1, F2, F3)) -> [(dF1 += dA.F1), (dF2 += dA.F2), (dF3 += dA.F3)]
+ return TranspositionResult(gradients);
+ }
+
// Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr.
//
void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad)
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 363006f58..4d33d3743 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1,7 +1,10 @@
#include "slang-ir-autodiff.h"
+#include "slang-ir-address-analysis.h"
#include "slang-ir-autodiff-rev.h"
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-autodiff-pairs.h"
+#include "slang-ir-single-return.h"
+#include "slang-ir-ssa-simplification.h"
#include "slang-ir-validate.h"
namespace Slang
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 7479e4eee..cb767c20a 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -260,4 +260,11 @@ bool finalizeAutoDiffPass(IRModule* module);
void stripDerivativeDecorations(IRInst* inst);
bool isBackwardDifferentiableFunc(IRInst* func);
+
+SlangResult eliminateAddressInsts(
+ SharedIRBuilder* sharedBuilder,
+ DifferentiableTypeConformanceContext& diffContext,
+ IRFunc* func,
+ DiagnosticSink* sink);
+
};
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index cb7290036..67b7e92b0 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -5,12 +5,45 @@
namespace Slang
{
+bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst)
+{
+ HashSet<IRInst*> processedSet;
+ while (auto ptrType = as<IRPtrTypeBase>(typeInst))
+ {
+ typeInst = ptrType->getValueType();
+ if (!processedSet.Add(typeInst))
+ return false;
+ }
+ if (!typeInst)
+ return false;
+ switch (typeInst->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_DifferentialPairType:
+ return true;
+ default:
+ break;
+ }
+ if (context.lookUpConformanceForType(typeInst))
+ return true;
+ // Look for equivalent types.
+ for (auto type : context.differentiableWitnessDictionary)
+ {
+ if (isTypeEqual(type.Key, (IRType*)typeInst))
+ {
+ context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value;
+ return true;
+ }
+ }
+ return false;
+}
struct CheckDifferentiabilityPassContext : public InstPassBase
{
public:
DiagnosticSink* sink;
AutoDiffSharedContext sharedContext;
+ SharedIRBuilder* sharedBuilder;
enum DifferentiableLevel
{
@@ -18,8 +51,8 @@ public:
};
Dictionary<IRInst*, DifferentiableLevel> differentiableFunctions;
- CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink)
- : InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst())
+ CheckDifferentiabilityPassContext(SharedIRBuilder* inSharedBuilder, IRModule* inModule, DiagnosticSink* inSink)
+ : InstPassBase(inModule), sharedBuilder(inSharedBuilder), sink(inSink), sharedContext(inModule->getModuleInst())
{}
IRInst* getSpecializedVal(IRInst* inst)
@@ -161,39 +194,6 @@ public:
return false;
}
- bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst)
- {
- HashSet<IRInst*> processedSet;
- while (auto ptrType = as<IRPtrTypeBase>(typeInst))
- {
- typeInst = ptrType->getValueType();
- if (!processedSet.Add(typeInst))
- return false;
- }
- if (!typeInst)
- return false;
- switch (typeInst->getOp())
- {
- case kIROp_FloatType:
- case kIROp_DifferentialPairType:
- return true;
- default:
- break;
- }
- if (context.lookUpConformanceForType(typeInst))
- return true;
- // Look for equivalent types.
- for (auto type : context.differentiableWitnessDictionary)
- {
- if (isTypeEqual(type.Key, (IRType*)typeInst))
- {
- context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value;
- return true;
- }
- }
- return false;
- }
-
int getParamIndexInBlock(IRParam* paramInst)
{
auto block = as<IRBlock>(paramInst->getParent());
@@ -228,6 +228,14 @@ public:
DifferentiableTypeConformanceContext diffTypeContext(&sharedContext);
diffTypeContext.setFunc(funcInst);
+ if (isBackwardDifferentiableFunc(funcInst) && !funcInst->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
+ {
+ if (auto func = as<IRFunc>(funcInst))
+ {
+ if (SLANG_FAILED(eliminateAddressInsts(sharedBuilder, diffTypeContext, func, sink)))
+ return;
+ }
+ }
HashSet<IRInst*> produceDiffSet;
HashSet<IRInst*> expectDiffSet;
@@ -468,9 +476,9 @@ public:
}
};
-void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink)
+void checkAutoDiffUsages(SharedIRBuilder* sharedBuilder, IRModule* module, DiagnosticSink* sink)
{
- CheckDifferentiabilityPassContext context(module, sink);
+ CheckDifferentiabilityPassContext context(sharedBuilder, module, sink);
context.processModule();
}
diff --git a/source/slang/slang-ir-check-differentiability.h b/source/slang/slang-ir-check-differentiability.h
index 735a918c9..16ae16b6f 100644
--- a/source/slang/slang-ir-check-differentiability.h
+++ b/source/slang/slang-ir-check-differentiability.h
@@ -7,8 +7,9 @@ namespace Slang
{
struct IRModule;
class DiagnosticSink;
+struct SharedIRBuilder;
// Check all auto diff usages are valid.
-void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink);
+void checkAutoDiffUsages(SharedIRBuilder* sharedBuilder, IRModule* module, DiagnosticSink* sink);
} // namespace Slang
diff --git a/source/slang/slang-ir-dominators.cpp b/source/slang/slang-ir-dominators.cpp
index 72b156228..1ffa7ba5d 100644
--- a/source/slang/slang-ir-dominators.cpp
+++ b/source/slang/slang-ir-dominators.cpp
@@ -86,6 +86,31 @@ bool IRDominatorTree::dominates(IRBlock* dominator, IRBlock* dominated)
return properlyDominates(dominator, dominated);
}
+bool IRDominatorTree::dominates(IRInst* dominator, IRInst* dominated)
+{
+ auto dominatorBlock = as<IRBlock>(dominator);
+ if (!dominatorBlock)
+ dominatorBlock = as<IRBlock>(dominator->getParent());
+
+ auto dominatedBlock = as<IRBlock>(dominated);
+ if (!dominatedBlock)
+ dominatedBlock = as<IRBlock>(dominated->getParent());
+
+ if (dominatorBlock == dominatedBlock)
+ {
+ for (auto inst = dominator; inst; inst = inst->getNextInst())
+ {
+ if (inst == dominated)
+ return true;
+ }
+ return false;
+ }
+ else
+ {
+ return dominates(dominatorBlock, dominatedBlock);
+ }
+}
+
IRBlock* IRDominatorTree::getImmediateDominator(IRBlock* block)
{
// An unreachable block has no immediate dominator.
diff --git a/source/slang/slang-ir-dominators.h b/source/slang/slang-ir-dominators.h
index be01830b0..1fb12c89e 100644
--- a/source/slang/slang-ir-dominators.h
+++ b/source/slang/slang-ir-dominators.h
@@ -7,6 +7,7 @@ namespace Slang
{
struct IRBlock;
struct IRGlobalValueWithCode;
+ struct IRInst;
/// The computed dominator tree for an IR control flow graph.
struct IRDominatorTree : public RefObject
@@ -22,6 +23,8 @@ namespace Slang
///
bool dominates(IRBlock* dominator, IRBlock* dominated);
+ bool dominates(IRInst* dominator, IRInst* dominated);
+
/// Does the first block properly dominate the second?
///
/// Block A properly dominates block B iff A dominates B
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 68afbbb95..134a45bf5 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -305,7 +305,6 @@ INST(GetOptionalValue, getOptionalValue, 1, 0)
INST(OptionalHasValue, optionalHasValue, 1, 0)
INST(MakeOptionalValue, makeOptionalValue, 1, 0)
INST(MakeOptionalNone, makeOptionalNone, 1, 0)
-INST(DifferentialBottomValue, differentialBottomVal, 0, 0)
INST(Call, call, 1, 0)
INST(RTTIObject, rtti_object, 0, 0)
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
new file mode 100644
index 000000000..a57bfce3e
--- /dev/null
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -0,0 +1,125 @@
+#include "slang-ir-redundancy-removal.h"
+#include "slang-ir-dominators.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+
+struct RedundancyRemovalContext
+{
+ RefPtr<IRDominatorTree> dom;
+ bool removeRedundancyInBlock(DeduplicateContext& deduplicateContext, IRBlock* block)
+ {
+ bool result = false;
+ for (auto instP : block->getChildren())
+ {
+ auto resultInst = deduplicateContext.deduplicate(instP, [&](IRInst* inst)
+ {
+ auto parentBlock = as<IRBlock>(inst->getParent());
+ if (!parentBlock)
+ return false;
+ if (dom->isUnreachable(parentBlock))
+ return false;
+
+ switch (inst->getOp())
+ {
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_Module:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_Not:
+ case kIROp_FieldExtract:
+ case kIROp_FieldAddress:
+ case kIROp_GetElement:
+ case kIROp_GetElementPtr:
+ case kIROp_LookupWitness:
+ case kIROp_Specialize:
+ case kIROp_OptionalHasValue:
+ case kIROp_GetOptionalValue:
+ case kIROp_MakeOptionalValue:
+ case kIROp_MakeTuple:
+ case kIROp_GetTupleElement:
+ case kIROp_MakeStruct:
+ case kIROp_MakeArray:
+ case kIROp_MakeVector:
+ case kIROp_MakeMatrix:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_MakeVectorFromScalar:
+ case kIROp_swizzle:
+ case kIROp_MatrixReshape:
+ case kIROp_MakeString:
+ case kIROp_MakeResultError:
+ case kIROp_MakeResultValue:
+ case kIROp_GetResultError:
+ case kIROp_GetResultValue:
+ case kIROp_CastFloatToInt:
+ case kIROp_CastIntToFloat:
+ case kIROp_CastIntToPtr:
+ case kIROp_CastPtrToBool:
+ case kIROp_CastPtrToInt:
+ case kIROp_BitAnd:
+ case kIROp_BitNot:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_BitCast:
+ case kIROp_Reinterpret:
+ case kIROp_Greater:
+ case kIROp_Less:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ case kIROp_Neq:
+ case kIROp_Eql:
+ return true;
+ default:
+ return false;
+ }
+ });
+ if (resultInst != instP)
+ result = true;
+ }
+ for (auto child : dom->getImmediatelyDominatedBlocks(block))
+ {
+ DeduplicateContext subContext;
+ subContext.deduplicateMap = deduplicateContext.deduplicateMap;
+ result |= removeRedundancyInBlock(subContext, child);
+ }
+ return result;
+ }
+};
+
+bool removeRedundancy(IRModule* module)
+{
+ bool changed = false;
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto genericInst = as<IRGeneric>(inst))
+ {
+ removeRedundancyInFunc(genericInst);
+ inst = findGenericReturnVal(genericInst);
+ }
+ if (auto func = as<IRFunc>(inst))
+ {
+ changed |= removeRedundancyInFunc(func);
+ }
+ }
+ return changed;
+}
+
+bool removeRedundancyInFunc(IRGlobalValueWithCode* func)
+{
+ auto root = func->getFirstBlock();
+ if (!root)
+ return false;
+
+ RedundancyRemovalContext context;
+ context.dom = computeDominatorTree(func);
+ DeduplicateContext deduplicateCtx;
+ return context.removeRedundancyInBlock(deduplicateCtx, root);
+}
+
+}
diff --git a/source/slang/slang-ir-redundancy-removal.h b/source/slang/slang-ir-redundancy-removal.h
new file mode 100644
index 000000000..26b265e77
--- /dev/null
+++ b/source/slang/slang-ir-redundancy-removal.h
@@ -0,0 +1,11 @@
+// slang-ir-redundancy-removal.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+ struct IRGlobalValueWithCode;
+
+ bool removeRedundancy(IRModule* module);
+ bool removeRedundancyInFunc(IRGlobalValueWithCode* func);
+}
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index 1e247d1d9..54a1f7e08 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -196,6 +196,10 @@ bool simplifyCFG(IRModule* module)
bool changed = false;
for (auto inst : module->getGlobalInsts())
{
+ if (auto genericInst = as<IRGeneric>(inst))
+ {
+ inst = findGenericReturnVal(genericInst);
+ }
if (auto func = as<IRFunc>(inst))
{
changed |= processFunc(func);
diff --git a/source/slang/slang-ir-single-return.cpp b/source/slang/slang-ir-single-return.cpp
index f76e35040..30e933133 100644
--- a/source/slang/slang-ir-single-return.cpp
+++ b/source/slang/slang-ir-single-return.cpp
@@ -91,4 +91,20 @@ void convertFuncToSingleReturnForm(IRModule* irModule, IRGlobalValueWithCode* fu
context.processFunc(func);
}
+bool isSingleReturnFunc(IRGlobalValueWithCode* func)
+{
+ int returnCount = 0;
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (inst->getOp() == kIROp_Return)
+ {
+ returnCount++;
+ }
+ }
+ }
+ return returnCount <= 1;
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-single-return.h b/source/slang/slang-ir-single-return.h
index 2ddfa280b..bb186634d 100644
--- a/source/slang/slang-ir-single-return.h
+++ b/source/slang/slang-ir-single-return.h
@@ -9,4 +9,5 @@ namespace Slang
// Convert the CFG of `func` to have only a single `return` at the end.
void convertFuncToSingleReturnForm(IRModule* module, IRGlobalValueWithCode* func);
+ bool isSingleReturnFunc(IRGlobalValueWithCode* func);
}
diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp
index fd5f41f49..f06fafcb3 100644
--- a/source/slang/slang-ir-ssa-simplification.cpp
+++ b/source/slang/slang-ir-ssa-simplification.cpp
@@ -9,6 +9,7 @@
#include "slang-ir-hoist-constants.h"
#include "slang-ir-deduplicate-generic-children.h"
#include "slang-ir-remove-unused-generic-param.h"
+#include "slang-ir-redundancy-removal.h"
namespace Slang
{
@@ -26,6 +27,7 @@ namespace Slang
changed |= deduplicateGenericChildren(module);
changed |= applySparseConditionalConstantPropagation(module);
changed |= peepholeOptimize(module);
+ changed |= removeRedundancy(module);
changed |= simplifyCFG(module);
// Note: we disregard the `changed` state from dead code elimination pass since
@@ -49,6 +51,7 @@ namespace Slang
changed = false;
changed |= applySparseConditionalConstantPropagation(func);
changed |= peepholeOptimize(func);
+ changed |= removeRedundancyInFunc(func);
changed |= simplifyCFG(func);
// Note: we disregard the `changed` state from dead code elimination pass since
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 881f041c0..319a23989 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -219,12 +219,20 @@ void moveInstChildren(IRInst* dest, IRInst* src)
}
}
+String dumpIRToString(IRInst* root)
+{
+ StringBuilder sb;
+ StringWriter writer(&sb, Slang::WriterFlag::AutoFlush);
+ dumpIR(root, IRDumpOptions(), nullptr, &writer);
+ return sb.ToString();
+}
+
struct GenericChildrenMigrationContextImpl
{
IRCloneEnv cloneEnv;
IRGeneric* srcGeneric;
IRGeneric* dstGeneric;
- Dictionary<IRInstKey, IRInst*> deduplicateMap;
+ DeduplicateContext deduplicateContext;
void init(IRGeneric* genericSrc, IRGeneric* genericDst, IRInst* insertBefore)
{
@@ -251,42 +259,34 @@ struct GenericChildrenMigrationContextImpl
inst = inst->getNextInst())
{
IRInstKey key = { inst };
- deduplicateMap.AddIfNotExists(key, inst);
+ deduplicateContext.deduplicateMap.AddIfNotExists(key, inst);
}
}
}
IRInst* deduplicate(IRInst* value)
{
- if (!value) return nullptr;
- if (value->getParent() != dstGeneric->getFirstBlock())
- return value;
- switch (value->getOp())
- {
- case kIROp_Param:
- case kIROp_StructType:
- case kIROp_StructKey:
- case kIROp_InterfaceType:
- case kIROp_ClassType:
- case kIROp_Func:
- case kIROp_Generic:
- return value;
- default:
- break;
- }
- if (as<IRConstant>(value))
- return value;
-
- for (UInt i = 0; i < value->getOperandCount(); i++)
- {
- value->setOperand(i, deduplicate(value->getOperand(i)));
- }
- value->setFullType((IRType*)deduplicate(value->getFullType()));
- IRInstKey key = { value };
- if (auto newValue = deduplicateMap.TryGetValue(key))
- return *newValue;
- deduplicateMap[key] = value;
- return value;
+ return deduplicateContext.deduplicate(value, [this](IRInst* inst)
+ {
+ if (inst->getParent() != dstGeneric->getFirstBlock())
+ return false;
+ switch (inst->getOp())
+ {
+ case kIROp_Param:
+ case kIROp_StructType:
+ case kIROp_StructKey:
+ case kIROp_InterfaceType:
+ case kIROp_ClassType:
+ case kIROp_Func:
+ case kIROp_Generic:
+ return false;
+ default:
+ break;
+ }
+ if (as<IRConstant>(inst))
+ return false;
+ return true;
+ });
}
IRInst* cloneInst(IRBuilder* builder, IRInst* src)
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 92446138f..a250fc6a6 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -5,6 +5,7 @@
// This file contains utility functions for operating with Slang IR.
//
#include "slang-ir.h"
+#include "slang-ir-insts.h"
namespace Slang
{
@@ -32,6 +33,32 @@ public:
IRInst* cloneInst(IRBuilder* builder, IRInst* src);
};
+
+struct DeduplicateContext
+{
+ Dictionary<IRInstKey, IRInst*> deduplicateMap;
+
+ template<typename TFunc>
+ IRInst* deduplicate(IRInst* value, const TFunc& shouldDeduplicate)
+ {
+ if (!value) return nullptr;
+ if (!shouldDeduplicate(value))
+ return value;
+ IRInstKey key = { value };
+ if (auto newValue = deduplicateMap.TryGetValue(key))
+ return *newValue;
+ for (UInt i = 0; i < value->getOperandCount(); i++)
+ {
+ value->setOperand(i, deduplicate(value->getOperand(i), shouldDeduplicate));
+ }
+ value->setFullType((IRType*)deduplicate(value->getFullType(), shouldDeduplicate));
+ if (auto newValue = deduplicateMap.TryGetValue(key))
+ return *newValue;
+ deduplicateMap[key] = value;
+ return value;
+ }
+};
+
bool isPtrToClassType(IRInst* type);
bool isPtrToArrayType(IRInst* type);
@@ -126,6 +153,8 @@ inline IRInst* unwrapAttributedType(IRInst* type)
return type;
}
+String dumpIRToString(IRInst* root);
+
}
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index e400d0a17..b79221900 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -1982,7 +1982,6 @@ namespace Slang
return getStringSlice() == rhs->getStringSlice();
}
case kIROp_VoidLit:
- case kIROp_DifferentialBottomValue:
{
return true;
}
@@ -2025,7 +2024,6 @@ namespace Slang
return combineHash(code, Slang::getHashCode(slice.begin(), slice.getLength()));
}
case kIROp_VoidLit:
- case kIROp_DifferentialBottomValue:
{
return code;
}
@@ -2110,14 +2108,6 @@ namespace Slang
irValue->value.ptrVal = keyInst.value.ptrVal;
break;
}
- case kIROp_DifferentialBottomValue:
- {
- const size_t instSize = prefixSize + sizeof(void*);
- irValue = static_cast<IRConstant*>(
- _createInst(instSize, keyInst.getFullType(), keyInst.getOp()));
- irValue->value.ptrVal = nullptr;
- break;
- }
case kIROp_StringLit:
{
const UnownedStringSlice slice = keyInst.getStringSlice();
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index ec51c7bfa..d0527eef8 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9157,7 +9157,7 @@ RefPtr<IRModule> generateIRForTranslationUnit(
checkForMissingReturns(module, compileRequest->getSink());
// Check for invalid differentiable function body.
- checkAutoDiffUsages(module, compileRequest->getSink());
+ checkAutoDiffUsages(sharedBuilder, module, compileRequest->getSink());
// The "mandatory" optimization passes may make use of the
// `IRHighLevelDeclDecoration` type to relate IR instructions