summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-23 06:59:25 -0800
committerGitHub <noreply@github.com>2023-01-23 06:59:25 -0800
commit46a4d98baa1d43b33717b4377aefeeaf46b9c2ff (patch)
treec89f3a1c416330f859887d00f896b18bcc7488a5
parent263ca18ea516cfce43fda703c0a411aaf1938e42 (diff)
Full address insts elimination for backward autodiff. (#2604)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--build/visual-studio/slang/slang.vcxproj5
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters15
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-ir-address-analysis.cpp173
-rw-r--r--source/slang/slang-ir-address-analysis.h24
-rw-r--r--source/slang/slang-ir-autodiff-addr-inst-elimination.cpp476
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp18
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h53
-rw-r--r--source/slang/slang-ir-autodiff.cpp3
-rw-r--r--source/slang/slang-ir-autodiff.h7
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp82
-rw-r--r--source/slang/slang-ir-check-differentiability.h3
-rw-r--r--source/slang/slang-ir-dominators.cpp25
-rw-r--r--source/slang/slang-ir-dominators.h3
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp125
-rw-r--r--source/slang/slang-ir-redundancy-removal.h11
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp4
-rw-r--r--source/slang/slang-ir-single-return.cpp16
-rw-r--r--source/slang/slang-ir-single-return.h1
-rw-r--r--source/slang/slang-ir-ssa-simplification.cpp3
-rw-r--r--source/slang/slang-ir-util.cpp62
-rw-r--r--source/slang/slang-ir-util.h29
-rw-r--r--source/slang/slang-ir.cpp10
-rw-r--r--source/slang/slang-lower-to-ir.cpp2
-rw-r--r--tests/autodiff/reverse-struct-multi-write.slang48
-rw-r--r--tests/autodiff/reverse-struct-multi-write.slang.expected.txt6
-rw-r--r--tests/compute/half-texture.slang.glsl15
-rw-r--r--tests/compute/half-texture.slang.hlsl22
-rw-r--r--tests/cross-compile/precise-keyword.slang.glsl9
-rw-r--r--tests/cross-compile/precise-keyword.slang.hlsl8
-rw-r--r--tests/experimental/liveness/liveness-6.slang.expected33
-rw-r--r--tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected27
-rw-r--r--tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang.1.expected33
-rw-r--r--tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang.1.expected15
-rw-r--r--tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang.1.expected15
-rw-r--r--tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl8
-rw-r--r--tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl1
-rw-r--r--tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl6
40 files changed, 1230 insertions, 171 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index b50712cb1..da62d25f3 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -340,6 +340,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-hlsl-intrinsic-set.h" />
<ClInclude Include="..\..\..\source\slang\slang-image-format-defs.h" />
<ClInclude Include="..\..\..\source\slang\slang-intrinsic-expand.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-address-analysis.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-any-value-marshalling.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-augment-make-existential.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-fwd.h" />
@@ -404,6 +405,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-ir-missing-return.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-optix-entry-point-uniforms.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-peephole.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-redundancy-removal.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-restructure-scoping.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-restructure.h" />
@@ -521,8 +523,10 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\source\slang\slang-glsl-extension-tracker.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-hlsl-intrinsic-set.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-intrinsic-expand.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-address-analysis.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-any-value-marshalling.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-augment-make-existential.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-addr-inst-elimination.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-fwd.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-pairs.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-rev.cpp" />
@@ -582,6 +586,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\source\slang\slang-ir-missing-return.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-optix-entry-point-uniforms.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-peephole.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-redundancy-removal.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-restructure-scoping.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-restructure.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index 748654b98..4c61f48d9 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -126,6 +126,9 @@
<ClInclude Include="..\..\..\source\slang\slang-intrinsic-expand.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-address-analysis.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-any-value-marshalling.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -318,6 +321,9 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-peephole.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-redundancy-removal.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -665,12 +671,18 @@
<ClCompile Include="..\..\..\source\slang\slang-intrinsic-expand.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-address-analysis.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-any-value-marshalling.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-augment-make-existential.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-addr-inst-elimination.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-fwd.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -848,6 +860,9 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-peephole.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-redundancy-removal.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 4820c430f..0d4088d75 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -581,6 +581,8 @@ DIAGNOSTIC(41021, Error, differentiableFuncMustHaveOutput, "a differentiable fun
DIAGNOSTIC(41022, Error, differentiableFuncMustHaveInput, "a differentiable function must have at least one differentiable input.")
DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can only be called when argument is statically resolvable to a string literal")
+DIAGNOSTIC(41901, Error, unsupportedUseOfLValueForAutoDiff, "unsupported use of L-value for auto differentiation.")
+DIAGNOSTIC(41902, Error, cannotDifferentiateDynamicallyIndexedData, "cannot auto-differentiate mixed read/write access to dynamically indexed data in '$0'.")
//
// 5xxxx - Target code generation.
//
diff --git a/source/slang/slang-ir-address-analysis.cpp b/source/slang/slang-ir-address-analysis.cpp
new file mode 100644
index 000000000..aba59e1de
--- /dev/null
+++ b/source/slang/slang-ir-address-analysis.cpp
@@ -0,0 +1,173 @@
+#include "slang-ir-address-analysis.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+ void moveInstToEarliestPoint(IRDominatorTree* domTree, IRGlobalValueWithCode* func, IRInst* inst)
+ {
+ if (!as<IRBlock>(inst->getParent()))
+ return;
+ if (domTree->isUnreachable(as<IRBlock>(inst->getParent())))
+ return;
+
+ List<IRBlock*> blocks;
+ HashSet<IRInst*> operandInsts;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ operandInsts.Add(inst->getOperand(i));
+ auto parentBlock = as<IRBlock>(inst->getOperand(i)->getParent());
+ if (parentBlock)
+ {
+ if (!domTree->isUnreachable(parentBlock))
+ blocks.add(parentBlock);
+ }
+ }
+ {
+ operandInsts.Add(inst->getFullType());
+ auto parentBlock = as<IRBlock>(inst->getFullType()->getParent());
+ if (parentBlock)
+ {
+ if (!domTree->isUnreachable(parentBlock))
+ blocks.add(parentBlock);
+ }
+ }
+ // Find earliest block that is dominated by all operand blocks.
+ IRBlock* earliestBlock = as<IRBlock>(inst->getParent());
+ for (auto block : func->getBlocks())
+ {
+ bool dominated = true;
+ for (auto opBlock : blocks)
+ {
+ if (!domTree->dominates(opBlock, block))
+ {
+ dominated = false;
+ break;
+ }
+ }
+ if (dominated)
+ {
+ earliestBlock = block;
+ break;
+ }
+ }
+
+ if (!earliestBlock)
+ return;
+
+ IRInst* latestOperand = nullptr;
+ for (auto childInst : earliestBlock->getChildren())
+ {
+ if (operandInsts.Contains(childInst))
+ {
+ latestOperand = childInst;
+ }
+ }
+
+ if (!latestOperand || as<IRParam>(latestOperand))
+ inst->insertBefore(earliestBlock->getFirstOrdinaryInst());
+ else
+ inst->insertAfter(latestOperand);
+ }
+
+ AddressAccessInfo analyzeAddressUse(IRDominatorTree* dom, IRGlobalValueWithCode* func)
+ {
+ DeduplicateContext deduplicateContext;
+
+ AddressAccessInfo info;
+
+ // Deduplicate and move known address insts.
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst = block->getFirstChild(); inst;)
+ {
+ auto next = inst->getNextInst();
+ switch (inst->getOp())
+ {
+ case kIROp_Var:
+ {
+ RefPtr<AddressInfo> addrInfo = new AddressInfo();
+ addrInfo->addrInst = inst;
+ addrInfo->isConstant = true;
+ addrInfo->parentAddress = nullptr;
+ info.addressInfos[inst] = addrInfo;
+ }
+ break;
+ case kIROp_Param:
+ if (as<IRPtrTypeBase>(inst->getFullType()))
+ {
+ RefPtr<AddressInfo> addrInfo = new AddressInfo();
+ addrInfo->addrInst = inst;
+ addrInfo->isConstant = (block == func->getFirstBlock() ? true : false);
+ addrInfo->parentAddress = nullptr;
+ info.addressInfos[inst] = addrInfo;
+ }
+ break;
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ {
+ moveInstToEarliestPoint(dom, func, inst->getFullType());
+ moveInstToEarliestPoint(dom, func, inst);
+ auto deduplicated = deduplicateContext.deduplicate(inst, [func](IRInst* inst)
+ {
+ if (!inst->getParent())
+ return false;
+ if (inst->getParent()->getParent() != func)
+ return false;
+ switch (inst->getOp())
+ {
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ return true;
+ default:
+ return false;
+ }
+ });
+
+ if (deduplicated != inst)
+ {
+ SLANG_RELEASE_ASSERT(dom->dominates(
+ as<IRBlock>(deduplicated->getParent()),
+ as<IRBlock>(inst->getParent())));
+
+ inst->replaceUsesWith(deduplicated);
+ inst->removeAndDeallocate();
+ }
+ else
+ {
+ RefPtr<AddressInfo> addrInfo = new AddressInfo();
+ addrInfo->addrInst = inst;
+ if (inst->getOp() == kIROp_FieldAddress)
+ {
+ addrInfo->isConstant = true;
+ }
+ else
+ {
+ addrInfo->isConstant =
+ as<IRConstant>(inst->getOperand(1)) ? true : false;
+ }
+ info.addressInfos[inst] = addrInfo;
+ }
+ }
+ break;
+ }
+ inst = next;
+ }
+ }
+
+ // Construct address info tree.
+ for (auto& addr : info.addressInfos)
+ {
+ RefPtr<AddressInfo> parentInfo;
+ if (addr.Value->addrInst->getOperandCount() > 1 &&
+ info.addressInfos.TryGetValue(addr.Value->addrInst->getOperand(0), parentInfo))
+ {
+ addr.Value->parentAddress = parentInfo;
+ parentInfo->children.add(addr.Value);
+ if (!parentInfo->isConstant)
+ addr.Value->isConstant = false;
+ }
+ }
+ return info;
+ }
+}
diff --git a/source/slang/slang-ir-address-analysis.h b/source/slang/slang-ir-address-analysis.h
new file mode 100644
index 000000000..450e8b9eb
--- /dev/null
+++ b/source/slang/slang-ir-address-analysis.h
@@ -0,0 +1,24 @@
+// slang-ir-address-analysis.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-dominators.h"
+
+namespace Slang
+{
+ struct AddressInfo : public RefObject
+ {
+ IRInst* addrInst = nullptr;
+ AddressInfo* parentAddress = nullptr;
+ bool isConstant = false;
+ List<AddressInfo*> children;
+ };
+
+ struct AddressAccessInfo
+ {
+ OrderedDictionary<IRInst*, RefPtr<AddressInfo>> addressInfos;
+ };
+
+ // Gather info on all addresses used by `func`.
+ AddressAccessInfo analyzeAddressUse(IRDominatorTree* domTree, IRGlobalValueWithCode* func);
+}
diff --git a/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp b/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp
new file mode 100644
index 000000000..c60995595
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp
@@ -0,0 +1,476 @@
+#include "slang-ir-address-analysis.h"
+#include "slang-ir-autodiff-fwd.h"
+#include "slang-ir-autodiff-pairs.h"
+#include "slang-ir-autodiff-rev.h"
+#include "slang-ir-autodiff.h"
+#include "slang-ir-single-return.h"
+#include "slang-ir-ssa-simplification.h"
+#include "slang-ir-validate.h"
+
+namespace Slang
+{
+bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst);
+
+struct AddressInstEliminationContext
+{
+ OrderedDictionary<IRInst*, IRInst*> mapAddrInstToTempVar;
+
+ IRInst* _reconstructStruct(
+ IRBuilder& builder, IRStructType* type, IRInst* tempVar, List<AddressInfo*>& childAddrs)
+ {
+ List<IRInst*> args;
+ IRInst* loadedTempVar = nullptr;
+ for (auto child : type->getChildren())
+ {
+ if (auto field = as<IRStructField>(child))
+ {
+ IRInst* childVar = nullptr;
+ for (auto subAddr : childAddrs)
+ {
+ auto fieldAddrInst = cast<IRFieldAddress>(subAddr->addrInst);
+ if (fieldAddrInst->getField() == field->getKey())
+ {
+ mapAddrInstToTempVar.TryGetValue(subAddr->addrInst, childVar);
+ break;
+ }
+ }
+ if (childVar)
+ {
+ args.add(builder.emitLoad(childVar));
+ }
+ else
+ {
+ if (!loadedTempVar)
+ loadedTempVar = builder.emitLoad(tempVar);
+ args.add(builder.emitFieldExtract(
+ field->getFieldType(), loadedTempVar, field->getKey()));
+ }
+ }
+ }
+ return builder.emitMakeStruct(type, args);
+ }
+
+ IRInst* _reconstructArray(
+ IRBuilder& builder,
+ IRArrayType* type,
+ IRIntegerValue arraySize,
+ IRInst* tempVar,
+ List<AddressInfo*>& childAddrs)
+ {
+ IRInst* loadedTempVar = nullptr;
+ List<IRInst*> args;
+ for (IRIntegerValue index = 0; index < arraySize; index++)
+ {
+ IRInst* childVar = nullptr;
+ for (auto subAddr : childAddrs)
+ {
+ auto elementPtrInst = cast<IRGetElementPtr>(subAddr->addrInst);
+ auto elementIndex = as<IRIntLit>(elementPtrInst->getIndex());
+ if (elementIndex && elementIndex->getValue() == index)
+ {
+ mapAddrInstToTempVar.TryGetValue(subAddr->addrInst, childVar);
+ break;
+ }
+ }
+ if (childVar)
+ {
+ args.add(builder.emitLoad(childVar));
+ }
+ else
+ {
+ if (!loadedTempVar)
+ loadedTempVar = builder.emitLoad(tempVar);
+ args.add(builder.emitElementExtract(
+ type->getElementType(),
+ loadedTempVar,
+ builder.getIntValue(builder.getIntType(), index)));
+ }
+ }
+ return builder.emitMakeArray(type, args.getCount(), args.getBuffer());
+ }
+
+ void updateChildTempVarRecursive(
+ IRBuilder& builder,
+ AddressInfo* addr,
+ IRInst* val)
+ {
+ for (auto child : addr->children)
+ {
+ IRInst* childVar = nullptr;
+ if (mapAddrInstToTempVar.TryGetValue(child->addrInst, childVar))
+ {
+ switch (child->addrInst->getOp())
+ {
+ case kIROp_FieldAddress:
+ {
+ auto subVal = builder.emitFieldExtract(
+ cast<IRPtrTypeBase>(child->addrInst->getDataType())->getValueType(),
+ val,
+ child->addrInst->getOperand(1));
+ builder.emitStore(childVar, subVal);
+ updateChildTempVarRecursive(builder, child, subVal);
+ }
+ break;
+ case kIROp_GetElementPtr:
+ {
+ auto subVal = builder.emitElementExtract(
+ cast<IRPtrTypeBase>(child->addrInst->getDataType())->getValueType(),
+ val,
+ child->addrInst->getOperand(1));
+ builder.emitStore(childVar, subVal);
+ updateChildTempVarRecursive(builder, child, subVal);
+ }
+ break;
+ default:
+ {
+ }
+ break;
+ }
+ }
+ }
+ }
+
+ IRInst* getLoadedValue(
+ IRBuilder& builder,
+ AddressInfo* addr,
+ IRInst* tempVar)
+ {
+ if (addr->children.getCount())
+ {
+ // Reconstruct val.
+ auto type =
+ cast<IRPtrTypeBase>(unwrapAttributedType(tempVar->getFullType()))->getValueType();
+ switch (type->getOp())
+ {
+ case kIROp_StructType:
+ return _reconstructStruct(
+ builder, as<IRStructType>(type), tempVar, addr->children);
+ case kIROp_ArrayType:
+ {
+ auto arrayType = as<IRArrayType>(type);
+ auto size = as<IRIntLit>(arrayType->getElementCount());
+ if (!size || size->getValue() < 0)
+ {
+ // Unsupported array type.
+ }
+ else
+ {
+ return _reconstructArray(
+ builder,
+ arrayType,
+ size->getValue(),
+ tempVar,
+ addr->children);
+ }
+ }
+ break;
+ default:
+ // Unsupported address type.
+ break;
+ }
+ }
+ return builder.emitLoad(tempVar);
+ };
+
+ void updateParentTempVarRecursive(
+ IRBuilder& builder,
+ AddressInfo* addr)
+ {
+ for (auto parent = addr->parentAddress; parent; parent = parent->parentAddress)
+ {
+ IRInst* parentVar = nullptr;
+ if (mapAddrInstToTempVar.TryGetValue(parent->addrInst, parentVar))
+ {
+ auto val = getLoadedValue(builder, parent, parentVar);
+ builder.emitStore(parentVar, val);
+ }
+ }
+ }
+
+ String getAddrName(IRInst* addrInst)
+ {
+ StringBuilder sb;
+ List<IRInst*> bases;
+ bases.add(addrInst);
+ for (; addrInst;)
+ {
+ if (auto fieldAddr = as<IRFieldAddress>(addrInst))
+ bases.add(fieldAddr->getBase());
+ else if (auto index = as<IRGetElementPtr>(addrInst))
+ bases.add(index->getBase());
+ else
+ break;
+ }
+ for (Index i = bases.getCount() - 1; i >= 0; i--)
+ {
+ if (bases[i]->getOp() == kIROp_FieldAddress)
+ {
+ sb << ".";
+ auto field = bases[i]->getOperand(1);
+ auto nameDecor = field->findDecoration<IRNameHintDecoration>();
+ sb << (nameDecor ? nameDecor->getName() : UnownedStringSlice("<unknown>"));
+ }
+ else if (bases[i]->getOp() == kIROp_FieldAddress)
+ {
+ sb << "[";
+ auto index = bases[i]->getOperand(1);
+ auto nameDecor = index->findDecoration<IRNameHintDecoration>();
+ if (nameDecor)
+ {
+ sb << nameDecor->getName();
+ }
+ else if (auto intLit = as<IRIntLit>(index))
+ {
+ sb << intLit->getValue();
+ }
+ else
+ {
+ sb << "...";
+ }
+ sb << "]";
+ }
+ else
+ {
+ auto nameDecor = bases[i]->findDecoration<IRNameHintDecoration>();
+ sb << (nameDecor ? nameDecor->getName() : UnownedStringSlice("<unknown>"));
+ }
+ }
+ return sb.ProduceString();
+ }
+
+ SlangResult eliminateAddressInstsImpl(
+ SharedIRBuilder* sharedBuilder,
+ DifferentiableTypeConformanceContext& diffContext,
+ IRFunc* func,
+ DiagnosticSink* sink)
+ {
+ bool hasError = false;
+
+ if (!isSingleReturnFunc(func))
+ {
+ convertFuncToSingleReturnForm(func->getModule(), func);
+ }
+
+ IRBuilder builder(sharedBuilder);
+
+ auto dom = computeDominatorTree(func);
+ auto addrUse = analyzeAddressUse(dom, func);
+ List<AddressInfo*> workList;
+ HashSet<AddressInfo*> workListSet;
+
+ // Process leaf addresses first.
+ for (auto addr : addrUse.addressInfos)
+ {
+ if (addr.Value->children.getCount() == 0)
+ workList.add(addr.Value);
+ }
+
+ auto createTempVarForAddr = [&](IRInst* addrInst)
+ {
+ if (as<IRParam>(addrInst))
+ builder.setInsertAfter(as<IRBlock>(addrInst->getParent())->getLastParam());
+ else
+ builder.setInsertAfter(addrInst);
+ auto ptrType = as<IRPtrTypeBase>(addrInst->getFullType());
+ SLANG_RELEASE_ASSERT(ptrType);
+ auto tempVar = builder.emitVar(ptrType->getValueType());
+ mapAddrInstToTempVar[addrInst] = tempVar;
+ };
+
+ // In the first pass, we create temp vars for addresses with non-trivial access pattern.
+ for (Index workListIndex = 0; workListIndex < workList.getCount(); workListIndex++)
+ {
+ auto addr = workList[workListIndex];
+
+ if (!isDifferentiableType(diffContext, addr->addrInst->getDataType()))
+ continue;
+
+ List<IRUse*> readUses, writeUses, callUses, subAddrUses, unknownUses;
+
+ for (auto node = addr; node; node = node->parentAddress)
+ {
+ auto addrInst = node->addrInst;
+
+ for (auto use = addrInst->firstUse; use; use = use->nextUse)
+ {
+ if (as<IRDecoration>(use->getUser()))
+ continue;
+ switch (use->getUser()->getOp())
+ {
+ case kIROp_Load:
+ readUses.add(use);
+ break;
+ case kIROp_Store:
+ writeUses.add(use);
+ break;
+ case kIROp_Call:
+ callUses.add(use);
+ break;
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ if (node == addr)
+ subAddrUses.add(use);
+ break;
+ default:
+ unknownUses.add(use);
+ break;
+ }
+ }
+ }
+
+ if (unknownUses.getCount() != 0)
+ {
+ // Diagnose about unknown use.
+ sink->diagnose(
+ unknownUses.getFirst()->getUser(),
+ Diagnostics::unsupportedUseOfLValueForAutoDiff);
+ hasError = true;
+ continue;
+ }
+
+ if (addr->isConstant)
+ {
+ // Otherwise, the address must be a constant, and we need to create a temp var for
+ // it. The exception is when the variable is a temp var for a call.
+ if (callUses.getCount() == 1 && writeUses.getCount() <= 1 &&
+ readUses.getCount() <= 1)
+ {
+ if (writeUses.getCount() == 0)
+ continue;
+
+ // The uses must be in write->call->read order.
+ auto callUse = callUses.getFirst();
+ auto writeUse = writeUses.getFirst();
+ auto readUse = readUses.getCount() ? readUses.getFirst() : writeUse;
+ if (dom->dominates(writeUse->getUser(), callUse->getUser()) &&
+ dom->dominates(callUse->getUser(), readUse->getUser()))
+ {
+ continue;
+ }
+ }
+
+ // Create a temp var for the address and replace all uses of the address to the temp
+ // var.
+ createTempVarForAddr(addr->addrInst);
+ }
+ else
+ {
+ // This is a dynamic address. We can only allow at most one write access to it.
+ bool hasNonTrivialAccess = false;
+ if (readUses.getCount() + callUses.getCount() != 0 &&
+ writeUses.getCount() + callUses.getCount() > 1)
+ hasNonTrivialAccess = true;
+
+ if (hasNonTrivialAccess)
+ {
+ // Mixed use of a non-constant address is unsupported right now.
+ sink->diagnose(
+ addr->addrInst,
+ Diagnostics::cannotDifferentiateDynamicallyIndexedData,
+ getAddrName(addr->addrInst));
+ }
+ }
+ if (addr->parentAddress && workListSet.Add(addr->parentAddress))
+ workList.add(addr->parentAddress);
+ }
+
+ if (hasError)
+ return SLANG_FAIL;
+
+ // Actually replace addresses with temp vars.
+ for (auto addr : workList)
+ {
+ IRInst* tempVar = nullptr;
+ if (!mapAddrInstToTempVar.TryGetValue(addr->addrInst, tempVar))
+ continue;
+ for (auto use = addr->addrInst->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+ auto user = use->getUser();
+
+ builder.setInsertBefore(user);
+ switch (user->getOp())
+ {
+ case kIROp_Load:
+ use->set(tempVar);
+ break;
+ case kIROp_Store:
+ use->set(tempVar);
+ updateChildTempVarRecursive(
+ builder, addr, as<IRStore>(user)->getVal());
+ updateParentTempVarRecursive(builder, addr);
+ case kIROp_Call:
+ {
+ use->set(tempVar);
+ builder.setInsertAfter(user);
+ auto newVal = builder.emitLoad(tempVar);
+ updateChildTempVarRecursive(builder, addr, newVal);
+ updateParentTempVarRecursive(builder, addr);
+ }
+ break;
+ default:
+ use->set(tempVar);
+ break;
+ }
+ use = nextUse;
+ }
+ }
+
+ // Assign initial values to tempVar.
+ for (auto tempVar : mapAddrInstToTempVar)
+ {
+ builder.setInsertAfter(tempVar.Value);
+ IRInst* initVal = nullptr;
+ if (tempVar.Key->getOp() == kIROp_Var ||
+ tempVar.Key->getOp() == kIROp_Param && as<IROutType>(tempVar.Key->getFullType()))
+ {
+ initVal = builder.emitDefaultConstruct(
+ cast<IRPtrTypeBase>(tempVar.Key->getFullType())->getValueType());
+ }
+ else
+ {
+ initVal = builder.emitLoad(tempVar.Key);
+ }
+ builder.emitStore(tempVar.Value, initVal);
+ }
+
+ // Store final values to out parameters before exiting function.
+ IRInst* returnInst = nullptr;
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (inst->getOp() == kIROp_Return)
+ {
+ returnInst = inst;
+ break;
+ }
+ }
+ }
+ SLANG_RELEASE_ASSERT(returnInst);
+ builder.setInsertBefore(returnInst);
+ for (auto param : func->getParams())
+ {
+ IRInst* tempVar = nullptr;
+ if (mapAddrInstToTempVar.TryGetValue(param, tempVar))
+ {
+ auto val = builder.emitLoad(tempVar);
+ builder.emitStore(param, val);
+ }
+ }
+ if (hasError)
+ return SLANG_FAIL;
+ return SLANG_OK;
+ }
+};
+
+SlangResult eliminateAddressInsts(
+ SharedIRBuilder* sharedBuilder,
+ DifferentiableTypeConformanceContext& diffContext,
+ IRFunc* func,
+ DiagnosticSink* sink)
+{
+ AddressInstEliminationContext ctx;
+ return ctx.eliminateAddressInstsImpl(sharedBuilder, diffContext, func, sink);
+}
+} // namespace Slang
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index ee159b80b..3f3618b44 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1130,6 +1130,8 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_VectorReshape:
case kIROp_IntCast:
case kIROp_FloatCast:
+ case kIROp_MakeStruct:
+ case kIROp_MakeArray:
return transcribeConstruct(builder, origInst);
case kIROp_LookupWitness:
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index d3a6137c1..779a4f1a3 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -5,10 +5,9 @@
#include "slang-ir-eliminate-phis.h"
#include "slang-ir-util.h"
#include "slang-ir-inst-pass-base.h"
-
+#include "slang-ir-ssa-simplification.h"
#include "slang-ir-autodiff-fwd.h"
-
namespace Slang
{
IRFuncType* BackwardDiffTranscriberBase::differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermeidateType)
@@ -502,6 +501,17 @@ namespace Slang
stripDerivativeDecorations(primalFunc);
eliminateDeadCode(primalOuterParent);
+ // Perform preparation and simplification.
+ differentiableTypeConformanceContext.setFunc(primalFunc);
+ if (SLANG_FAILED(eliminateAddressInsts(
+ builder->getSharedBuilder(),
+ differentiableTypeConformanceContext,
+ primalFunc,
+ sink)))
+ return nullptr;
+
+ simplifyFunc(primalFunc);
+
// Forward transcribe the clone of the original func.
ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>(
autoDiffSharedContext->transcriberSet.forwardTranscriber);
@@ -567,7 +577,9 @@ namespace Slang
}
auto fwdDiffFunc = generateNewForwardDerivativeForFunc(&tempBuilder, primalFunc, diffPropagateFunc);
-
+ if (!fwdDiffFunc)
+ return;
+
// Split first block into a paramter block.
this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc));
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index ae1a5dd70..e799456bb 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1061,6 +1061,10 @@ struct DiffTransposePass
case kIROp_MakeVector:
return transposeMakeVector(builder, fwdInst, revValue);
+ case kIROp_MakeStruct:
+ return transposeMakeStruct(builder, fwdInst, revValue);
+ case kIROp_MakeArray:
+ return transposeMakeArray(builder, fwdInst, revValue);
case kIROp_Specialize:
case kIROp_unconditionalBranch:
@@ -1218,6 +1222,55 @@ struct DiffTransposePass
return TranspositionResult(gradients);
}
+ TranspositionResult transposeMakeStruct(IRBuilder* builder, IRInst* fwdMakeStruct, IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+ auto structType = cast<IRStructType>(fwdMakeStruct->getFullType());
+ UInt ii = 0;
+ for (auto field : structType->getFields())
+ {
+ auto gradAtField = builder->emitFieldExtract(
+ field->getFieldType(),
+ revValue,
+ field->getKey());
+ SLANG_RELEASE_ASSERT(ii < fwdMakeStruct->getOperandCount());
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeStruct->getOperand(ii),
+ gradAtField,
+ fwdMakeStruct));
+ ii++;
+ }
+
+ // (A = MakeStruct(F1, F2, F3)) -> [(dF1 += dA.F1), (dF2 += dA.F2), (dF3 += dA.F3)]
+ return TranspositionResult(gradients);
+ }
+
+ TranspositionResult transposeMakeArray(IRBuilder* builder, IRInst* fwdMakeArray, IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+ auto arrayType = cast<IRArrayType>(fwdMakeArray->getFullType());
+ auto arraySize = cast<IRIntLit>(arrayType->getElementCount());
+
+ for (UInt ii = 0; ii < (UInt)arraySize->getValue(); ii++)
+ {
+ auto gradAtField = builder->emitElementExtract(
+ arrayType->getElementType(),
+ revValue,
+ builder->getIntValue(builder->getIntType(), ii));
+ SLANG_RELEASE_ASSERT(ii < fwdMakeArray->getOperandCount());
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeArray->getOperand(ii),
+ gradAtField,
+ fwdMakeArray));
+ ii++;
+ }
+
+ // (A = MakeArray(F1, F2, F3)) -> [(dF1 += dA.F1), (dF2 += dA.F2), (dF3 += dA.F3)]
+ return TranspositionResult(gradients);
+ }
+
// Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr.
//
void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad)
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 363006f58..4d33d3743 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1,7 +1,10 @@
#include "slang-ir-autodiff.h"
+#include "slang-ir-address-analysis.h"
#include "slang-ir-autodiff-rev.h"
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-autodiff-pairs.h"
+#include "slang-ir-single-return.h"
+#include "slang-ir-ssa-simplification.h"
#include "slang-ir-validate.h"
namespace Slang
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 7479e4eee..cb767c20a 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -260,4 +260,11 @@ bool finalizeAutoDiffPass(IRModule* module);
void stripDerivativeDecorations(IRInst* inst);
bool isBackwardDifferentiableFunc(IRInst* func);
+
+SlangResult eliminateAddressInsts(
+ SharedIRBuilder* sharedBuilder,
+ DifferentiableTypeConformanceContext& diffContext,
+ IRFunc* func,
+ DiagnosticSink* sink);
+
};
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index cb7290036..67b7e92b0 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -5,12 +5,45 @@
namespace Slang
{
+bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst)
+{
+ HashSet<IRInst*> processedSet;
+ while (auto ptrType = as<IRPtrTypeBase>(typeInst))
+ {
+ typeInst = ptrType->getValueType();
+ if (!processedSet.Add(typeInst))
+ return false;
+ }
+ if (!typeInst)
+ return false;
+ switch (typeInst->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_DifferentialPairType:
+ return true;
+ default:
+ break;
+ }
+ if (context.lookUpConformanceForType(typeInst))
+ return true;
+ // Look for equivalent types.
+ for (auto type : context.differentiableWitnessDictionary)
+ {
+ if (isTypeEqual(type.Key, (IRType*)typeInst))
+ {
+ context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value;
+ return true;
+ }
+ }
+ return false;
+}
struct CheckDifferentiabilityPassContext : public InstPassBase
{
public:
DiagnosticSink* sink;
AutoDiffSharedContext sharedContext;
+ SharedIRBuilder* sharedBuilder;
enum DifferentiableLevel
{
@@ -18,8 +51,8 @@ public:
};
Dictionary<IRInst*, DifferentiableLevel> differentiableFunctions;
- CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink)
- : InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst())
+ CheckDifferentiabilityPassContext(SharedIRBuilder* inSharedBuilder, IRModule* inModule, DiagnosticSink* inSink)
+ : InstPassBase(inModule), sharedBuilder(inSharedBuilder), sink(inSink), sharedContext(inModule->getModuleInst())
{}
IRInst* getSpecializedVal(IRInst* inst)
@@ -161,39 +194,6 @@ public:
return false;
}
- bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst)
- {
- HashSet<IRInst*> processedSet;
- while (auto ptrType = as<IRPtrTypeBase>(typeInst))
- {
- typeInst = ptrType->getValueType();
- if (!processedSet.Add(typeInst))
- return false;
- }
- if (!typeInst)
- return false;
- switch (typeInst->getOp())
- {
- case kIROp_FloatType:
- case kIROp_DifferentialPairType:
- return true;
- default:
- break;
- }
- if (context.lookUpConformanceForType(typeInst))
- return true;
- // Look for equivalent types.
- for (auto type : context.differentiableWitnessDictionary)
- {
- if (isTypeEqual(type.Key, (IRType*)typeInst))
- {
- context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value;
- return true;
- }
- }
- return false;
- }
-
int getParamIndexInBlock(IRParam* paramInst)
{
auto block = as<IRBlock>(paramInst->getParent());
@@ -228,6 +228,14 @@ public:
DifferentiableTypeConformanceContext diffTypeContext(&sharedContext);
diffTypeContext.setFunc(funcInst);
+ if (isBackwardDifferentiableFunc(funcInst) && !funcInst->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
+ {
+ if (auto func = as<IRFunc>(funcInst))
+ {
+ if (SLANG_FAILED(eliminateAddressInsts(sharedBuilder, diffTypeContext, func, sink)))
+ return;
+ }
+ }
HashSet<IRInst*> produceDiffSet;
HashSet<IRInst*> expectDiffSet;
@@ -468,9 +476,9 @@ public:
}
};
-void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink)
+void checkAutoDiffUsages(SharedIRBuilder* sharedBuilder, IRModule* module, DiagnosticSink* sink)
{
- CheckDifferentiabilityPassContext context(module, sink);
+ CheckDifferentiabilityPassContext context(sharedBuilder, module, sink);
context.processModule();
}
diff --git a/source/slang/slang-ir-check-differentiability.h b/source/slang/slang-ir-check-differentiability.h
index 735a918c9..16ae16b6f 100644
--- a/source/slang/slang-ir-check-differentiability.h
+++ b/source/slang/slang-ir-check-differentiability.h
@@ -7,8 +7,9 @@ namespace Slang
{
struct IRModule;
class DiagnosticSink;
+struct SharedIRBuilder;
// Check all auto diff usages are valid.
-void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink);
+void checkAutoDiffUsages(SharedIRBuilder* sharedBuilder, IRModule* module, DiagnosticSink* sink);
} // namespace Slang
diff --git a/source/slang/slang-ir-dominators.cpp b/source/slang/slang-ir-dominators.cpp
index 72b156228..1ffa7ba5d 100644
--- a/source/slang/slang-ir-dominators.cpp
+++ b/source/slang/slang-ir-dominators.cpp
@@ -86,6 +86,31 @@ bool IRDominatorTree::dominates(IRBlock* dominator, IRBlock* dominated)
return properlyDominates(dominator, dominated);
}
+bool IRDominatorTree::dominates(IRInst* dominator, IRInst* dominated)
+{
+ auto dominatorBlock = as<IRBlock>(dominator);
+ if (!dominatorBlock)
+ dominatorBlock = as<IRBlock>(dominator->getParent());
+
+ auto dominatedBlock = as<IRBlock>(dominated);
+ if (!dominatedBlock)
+ dominatedBlock = as<IRBlock>(dominated->getParent());
+
+ if (dominatorBlock == dominatedBlock)
+ {
+ for (auto inst = dominator; inst; inst = inst->getNextInst())
+ {
+ if (inst == dominated)
+ return true;
+ }
+ return false;
+ }
+ else
+ {
+ return dominates(dominatorBlock, dominatedBlock);
+ }
+}
+
IRBlock* IRDominatorTree::getImmediateDominator(IRBlock* block)
{
// An unreachable block has no immediate dominator.
diff --git a/source/slang/slang-ir-dominators.h b/source/slang/slang-ir-dominators.h
index be01830b0..1fb12c89e 100644
--- a/source/slang/slang-ir-dominators.h
+++ b/source/slang/slang-ir-dominators.h
@@ -7,6 +7,7 @@ namespace Slang
{
struct IRBlock;
struct IRGlobalValueWithCode;
+ struct IRInst;
/// The computed dominator tree for an IR control flow graph.
struct IRDominatorTree : public RefObject
@@ -22,6 +23,8 @@ namespace Slang
///
bool dominates(IRBlock* dominator, IRBlock* dominated);
+ bool dominates(IRInst* dominator, IRInst* dominated);
+
/// Does the first block properly dominate the second?
///
/// Block A properly dominates block B iff A dominates B
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 68afbbb95..134a45bf5 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -305,7 +305,6 @@ INST(GetOptionalValue, getOptionalValue, 1, 0)
INST(OptionalHasValue, optionalHasValue, 1, 0)
INST(MakeOptionalValue, makeOptionalValue, 1, 0)
INST(MakeOptionalNone, makeOptionalNone, 1, 0)
-INST(DifferentialBottomValue, differentialBottomVal, 0, 0)
INST(Call, call, 1, 0)
INST(RTTIObject, rtti_object, 0, 0)
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
new file mode 100644
index 000000000..a57bfce3e
--- /dev/null
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -0,0 +1,125 @@
+#include "slang-ir-redundancy-removal.h"
+#include "slang-ir-dominators.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+
+struct RedundancyRemovalContext
+{
+ RefPtr<IRDominatorTree> dom;
+ bool removeRedundancyInBlock(DeduplicateContext& deduplicateContext, IRBlock* block)
+ {
+ bool result = false;
+ for (auto instP : block->getChildren())
+ {
+ auto resultInst = deduplicateContext.deduplicate(instP, [&](IRInst* inst)
+ {
+ auto parentBlock = as<IRBlock>(inst->getParent());
+ if (!parentBlock)
+ return false;
+ if (dom->isUnreachable(parentBlock))
+ return false;
+
+ switch (inst->getOp())
+ {
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_Module:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_Not:
+ case kIROp_FieldExtract:
+ case kIROp_FieldAddress:
+ case kIROp_GetElement:
+ case kIROp_GetElementPtr:
+ case kIROp_LookupWitness:
+ case kIROp_Specialize:
+ case kIROp_OptionalHasValue:
+ case kIROp_GetOptionalValue:
+ case kIROp_MakeOptionalValue:
+ case kIROp_MakeTuple:
+ case kIROp_GetTupleElement:
+ case kIROp_MakeStruct:
+ case kIROp_MakeArray:
+ case kIROp_MakeVector:
+ case kIROp_MakeMatrix:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_MakeVectorFromScalar:
+ case kIROp_swizzle:
+ case kIROp_MatrixReshape:
+ case kIROp_MakeString:
+ case kIROp_MakeResultError:
+ case kIROp_MakeResultValue:
+ case kIROp_GetResultError:
+ case kIROp_GetResultValue:
+ case kIROp_CastFloatToInt:
+ case kIROp_CastIntToFloat:
+ case kIROp_CastIntToPtr:
+ case kIROp_CastPtrToBool:
+ case kIROp_CastPtrToInt:
+ case kIROp_BitAnd:
+ case kIROp_BitNot:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_BitCast:
+ case kIROp_Reinterpret:
+ case kIROp_Greater:
+ case kIROp_Less:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ case kIROp_Neq:
+ case kIROp_Eql:
+ return true;
+ default:
+ return false;
+ }
+ });
+ if (resultInst != instP)
+ result = true;
+ }
+ for (auto child : dom->getImmediatelyDominatedBlocks(block))
+ {
+ DeduplicateContext subContext;
+ subContext.deduplicateMap = deduplicateContext.deduplicateMap;
+ result |= removeRedundancyInBlock(subContext, child);
+ }
+ return result;
+ }
+};
+
+bool removeRedundancy(IRModule* module)
+{
+ bool changed = false;
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto genericInst = as<IRGeneric>(inst))
+ {
+ removeRedundancyInFunc(genericInst);
+ inst = findGenericReturnVal(genericInst);
+ }
+ if (auto func = as<IRFunc>(inst))
+ {
+ changed |= removeRedundancyInFunc(func);
+ }
+ }
+ return changed;
+}
+
+bool removeRedundancyInFunc(IRGlobalValueWithCode* func)
+{
+ auto root = func->getFirstBlock();
+ if (!root)
+ return false;
+
+ RedundancyRemovalContext context;
+ context.dom = computeDominatorTree(func);
+ DeduplicateContext deduplicateCtx;
+ return context.removeRedundancyInBlock(deduplicateCtx, root);
+}
+
+}
diff --git a/source/slang/slang-ir-redundancy-removal.h b/source/slang/slang-ir-redundancy-removal.h
new file mode 100644
index 000000000..26b265e77
--- /dev/null
+++ b/source/slang/slang-ir-redundancy-removal.h
@@ -0,0 +1,11 @@
+// slang-ir-redundancy-removal.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+ struct IRGlobalValueWithCode;
+
+ bool removeRedundancy(IRModule* module);
+ bool removeRedundancyInFunc(IRGlobalValueWithCode* func);
+}
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index 1e247d1d9..54a1f7e08 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -196,6 +196,10 @@ bool simplifyCFG(IRModule* module)
bool changed = false;
for (auto inst : module->getGlobalInsts())
{
+ if (auto genericInst = as<IRGeneric>(inst))
+ {
+ inst = findGenericReturnVal(genericInst);
+ }
if (auto func = as<IRFunc>(inst))
{
changed |= processFunc(func);
diff --git a/source/slang/slang-ir-single-return.cpp b/source/slang/slang-ir-single-return.cpp
index f76e35040..30e933133 100644
--- a/source/slang/slang-ir-single-return.cpp
+++ b/source/slang/slang-ir-single-return.cpp
@@ -91,4 +91,20 @@ void convertFuncToSingleReturnForm(IRModule* irModule, IRGlobalValueWithCode* fu
context.processFunc(func);
}
+bool isSingleReturnFunc(IRGlobalValueWithCode* func)
+{
+ int returnCount = 0;
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (inst->getOp() == kIROp_Return)
+ {
+ returnCount++;
+ }
+ }
+ }
+ return returnCount <= 1;
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-single-return.h b/source/slang/slang-ir-single-return.h
index 2ddfa280b..bb186634d 100644
--- a/source/slang/slang-ir-single-return.h
+++ b/source/slang/slang-ir-single-return.h
@@ -9,4 +9,5 @@ namespace Slang
// Convert the CFG of `func` to have only a single `return` at the end.
void convertFuncToSingleReturnForm(IRModule* module, IRGlobalValueWithCode* func);
+ bool isSingleReturnFunc(IRGlobalValueWithCode* func);
}
diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp
index fd5f41f49..f06fafcb3 100644
--- a/source/slang/slang-ir-ssa-simplification.cpp
+++ b/source/slang/slang-ir-ssa-simplification.cpp
@@ -9,6 +9,7 @@
#include "slang-ir-hoist-constants.h"
#include "slang-ir-deduplicate-generic-children.h"
#include "slang-ir-remove-unused-generic-param.h"
+#include "slang-ir-redundancy-removal.h"
namespace Slang
{
@@ -26,6 +27,7 @@ namespace Slang
changed |= deduplicateGenericChildren(module);
changed |= applySparseConditionalConstantPropagation(module);
changed |= peepholeOptimize(module);
+ changed |= removeRedundancy(module);
changed |= simplifyCFG(module);
// Note: we disregard the `changed` state from dead code elimination pass since
@@ -49,6 +51,7 @@ namespace Slang
changed = false;
changed |= applySparseConditionalConstantPropagation(func);
changed |= peepholeOptimize(func);
+ changed |= removeRedundancyInFunc(func);
changed |= simplifyCFG(func);
// Note: we disregard the `changed` state from dead code elimination pass since
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 881f041c0..319a23989 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -219,12 +219,20 @@ void moveInstChildren(IRInst* dest, IRInst* src)
}
}
+String dumpIRToString(IRInst* root)
+{
+ StringBuilder sb;
+ StringWriter writer(&sb, Slang::WriterFlag::AutoFlush);
+ dumpIR(root, IRDumpOptions(), nullptr, &writer);
+ return sb.ToString();
+}
+
struct GenericChildrenMigrationContextImpl
{
IRCloneEnv cloneEnv;
IRGeneric* srcGeneric;
IRGeneric* dstGeneric;
- Dictionary<IRInstKey, IRInst*> deduplicateMap;
+ DeduplicateContext deduplicateContext;
void init(IRGeneric* genericSrc, IRGeneric* genericDst, IRInst* insertBefore)
{
@@ -251,42 +259,34 @@ struct GenericChildrenMigrationContextImpl
inst = inst->getNextInst())
{
IRInstKey key = { inst };
- deduplicateMap.AddIfNotExists(key, inst);
+ deduplicateContext.deduplicateMap.AddIfNotExists(key, inst);
}
}
}
IRInst* deduplicate(IRInst* value)
{
- if (!value) return nullptr;
- if (value->getParent() != dstGeneric->getFirstBlock())
- return value;
- switch (value->getOp())
- {
- case kIROp_Param:
- case kIROp_StructType:
- case kIROp_StructKey:
- case kIROp_InterfaceType:
- case kIROp_ClassType:
- case kIROp_Func:
- case kIROp_Generic:
- return value;
- default:
- break;
- }
- if (as<IRConstant>(value))
- return value;
-
- for (UInt i = 0; i < value->getOperandCount(); i++)
- {
- value->setOperand(i, deduplicate(value->getOperand(i)));
- }
- value->setFullType((IRType*)deduplicate(value->getFullType()));
- IRInstKey key = { value };
- if (auto newValue = deduplicateMap.TryGetValue(key))
- return *newValue;
- deduplicateMap[key] = value;
- return value;
+ return deduplicateContext.deduplicate(value, [this](IRInst* inst)
+ {
+ if (inst->getParent() != dstGeneric->getFirstBlock())
+ return false;
+ switch (inst->getOp())
+ {
+ case kIROp_Param:
+ case kIROp_StructType:
+ case kIROp_StructKey:
+ case kIROp_InterfaceType:
+ case kIROp_ClassType:
+ case kIROp_Func:
+ case kIROp_Generic:
+ return false;
+ default:
+ break;
+ }
+ if (as<IRConstant>(inst))
+ return false;
+ return true;
+ });
}
IRInst* cloneInst(IRBuilder* builder, IRInst* src)
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 92446138f..a250fc6a6 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -5,6 +5,7 @@
// This file contains utility functions for operating with Slang IR.
//
#include "slang-ir.h"
+#include "slang-ir-insts.h"
namespace Slang
{
@@ -32,6 +33,32 @@ public:
IRInst* cloneInst(IRBuilder* builder, IRInst* src);
};
+
+struct DeduplicateContext
+{
+ Dictionary<IRInstKey, IRInst*> deduplicateMap;
+
+ template<typename TFunc>
+ IRInst* deduplicate(IRInst* value, const TFunc& shouldDeduplicate)
+ {
+ if (!value) return nullptr;
+ if (!shouldDeduplicate(value))
+ return value;
+ IRInstKey key = { value };
+ if (auto newValue = deduplicateMap.TryGetValue(key))
+ return *newValue;
+ for (UInt i = 0; i < value->getOperandCount(); i++)
+ {
+ value->setOperand(i, deduplicate(value->getOperand(i), shouldDeduplicate));
+ }
+ value->setFullType((IRType*)deduplicate(value->getFullType(), shouldDeduplicate));
+ if (auto newValue = deduplicateMap.TryGetValue(key))
+ return *newValue;
+ deduplicateMap[key] = value;
+ return value;
+ }
+};
+
bool isPtrToClassType(IRInst* type);
bool isPtrToArrayType(IRInst* type);
@@ -126,6 +153,8 @@ inline IRInst* unwrapAttributedType(IRInst* type)
return type;
}
+String dumpIRToString(IRInst* root);
+
}
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index e400d0a17..b79221900 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -1982,7 +1982,6 @@ namespace Slang
return getStringSlice() == rhs->getStringSlice();
}
case kIROp_VoidLit:
- case kIROp_DifferentialBottomValue:
{
return true;
}
@@ -2025,7 +2024,6 @@ namespace Slang
return combineHash(code, Slang::getHashCode(slice.begin(), slice.getLength()));
}
case kIROp_VoidLit:
- case kIROp_DifferentialBottomValue:
{
return code;
}
@@ -2110,14 +2108,6 @@ namespace Slang
irValue->value.ptrVal = keyInst.value.ptrVal;
break;
}
- case kIROp_DifferentialBottomValue:
- {
- const size_t instSize = prefixSize + sizeof(void*);
- irValue = static_cast<IRConstant*>(
- _createInst(instSize, keyInst.getFullType(), keyInst.getOp()));
- irValue->value.ptrVal = nullptr;
- break;
- }
case kIROp_StringLit:
{
const UnownedStringSlice slice = keyInst.getStringSlice();
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index ec51c7bfa..d0527eef8 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9157,7 +9157,7 @@ RefPtr<IRModule> generateIRForTranslationUnit(
checkForMissingReturns(module, compileRequest->getSink());
// Check for invalid differentiable function body.
- checkAutoDiffUsages(module, compileRequest->getSink());
+ checkAutoDiffUsages(sharedBuilder, module, compileRequest->getSink());
// The "mandatory" optimization passes may make use of the
// `IRHighLevelDeclDecoration` type to relate IR instructions
diff --git a/tests/autodiff/reverse-struct-multi-write.slang b/tests/autodiff/reverse-struct-multi-write.slang
new file mode 100644
index 000000000..dd12c7d3d
--- /dev/null
+++ b/tests/autodiff/reverse-struct-multi-write.slang
@@ -0,0 +1,48 @@
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+struct A : IDifferentiable
+{
+ float x;
+ float y;
+};
+
+[BackwardDifferentiable]
+A f(A a)
+{
+ // Read/writes to local struct variables won't be SSA'd out by default.
+ // The backward diff preparation pass will kick in to create temp vars for them.
+ A aout;
+ aout.y = 2 * a.x;
+ aout.y = aout.y + 2 * a.x;
+ aout.x = aout.y + 5 * a.x;
+
+ // The result should be equivalent to:
+ /*
+ A aout;
+ var tmp = 2 * a.x;
+ tmp = tmp + 2 * a.x;
+ aout.y = tmp;
+ aout.x = tmp + 5 * a.x;
+ */
+ return aout;
+
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ A a = {1.0, 2.0};
+
+ var dpa = diffPair(a);
+
+ A.Differential dout = {1.0, 1.0};
+
+ __bwd_diff(f)(dpa, dout);
+ outputBuffer[0] = dpa.d.x; // Expect: 13
+ outputBuffer[1] = dpa.d.y; // Expect: 0
+}
diff --git a/tests/autodiff/reverse-struct-multi-write.slang.expected.txt b/tests/autodiff/reverse-struct-multi-write.slang.expected.txt
new file mode 100644
index 000000000..403f2ffd4
--- /dev/null
+++ b/tests/autodiff/reverse-struct-multi-write.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+13.000000
+0.000000
+0.000000
+0.000000
+0.000000
diff --git a/tests/compute/half-texture.slang.glsl b/tests/compute/half-texture.slang.glsl
index 88f585378..0eccccaaf 100644
--- a/tests/compute/half-texture.slang.glsl
+++ b/tests/compute/half-texture.slang.glsl
@@ -21,20 +21,23 @@ layout(std430, binding = 0) buffer _S1 {
int _data[];
} outputBuffer_0;
-layout(local_size_x = 4, local_size_y = 4, local_size_z = 1) in;void main()
+layout(local_size_x = 4, local_size_y = 4, local_size_z = 1) in;
+void main()
{
ivec2 pos_0 = ivec2(gl_GlobalInvocationID.xy);
const float _S2 = 1.00000000000000000000 / 3.00000000000000000000;
- ivec2 pos2_0 = ivec2(3 - pos_0.y, 3 - pos_0.x);
+ int _S3 = pos_0.y;
+ int _S4 = pos_0.x;
+ ivec2 pos2_0 = ivec2(3 - _S3, 3 - _S4);
float16_t h_0 = (float16_t(imageLoad((halfTexture_0), ivec2((uvec2(pos2_0)))).x));
f16vec2 h2_0 = (f16vec2(imageLoad((halfTexture2_0), ivec2((uvec2(pos2_0)))).xy));
f16vec4 h4_0 = (f16vec4(imageLoad((halfTexture4_0), ivec2((uvec2(pos2_0))))));
- imageStore((halfTexture_0), ivec2((uvec2(pos_0))), f16vec4(h2_0.x + h2_0.y, float16_t(0), float16_t(0), float16_t(0)));
- imageStore((halfTexture2_0), ivec2((uvec2(pos_0))), f16vec4(h4_0.xy, float16_t(0), float16_t(0)));
- imageStore((halfTexture4_0), ivec2((uvec2(pos_0))), f16vec4(h2_0, h_0, h_0));
+ imageStore((halfTexture_0), ivec2((uvec2(pos_0))), f16vec4(h2_0.x + h2_0.y, float16_t(0), float16_t(0), float16_t(0)));
+ imageStore((halfTexture2_0), ivec2((uvec2(pos_0))), f16vec4(h4_0.xy, float16_t(0), float16_t(0)));
+ imageStore((halfTexture4_0), ivec2((uvec2(pos_0))), f16vec4(h2_0, h_0, h_0));
- int index_0 = pos_0.x + pos_0.y * 4;
+ int index_0 = _S4 + _S3 * 4;
((outputBuffer_0)._data[(uint(index_0))]) = index_0;
return;
diff --git a/tests/compute/half-texture.slang.hlsl b/tests/compute/half-texture.slang.hlsl
index c606703a4..2d04ee17f 100644
--- a/tests/compute/half-texture.slang.hlsl
+++ b/tests/compute/half-texture.slang.hlsl
@@ -8,19 +8,21 @@ RWStructuredBuffer<int > outputBuffer_0 : register(u0);
[shader("compute")][numthreads(4, 4, 1)]
void computeMain(uint3 dispatchThreadID_0 : SV_DISPATCHTHREADID)
{
- int2 pos_0 = (int2) dispatchThreadID_0.xy;
+ int2 pos_0 = int2(dispatchThreadID_0.xy);
float _S1 = 1.00000000000000000000 / 3.00000000000000000000;
- int2 pos2_0 = int2(int(3) - pos_0.y, int(3) - pos_0.x);
+ int _S2 = pos_0.y;
+ int _S3 = pos_0.x;
+ int2 pos2_0 = int2(int(3) - _S2, int(3) - _S3);
- half h_0 = halfTexture_0[(uint2) pos2_0];
- vector<half,2> h2_0 = halfTexture2_0[(uint2) pos2_0];
- vector<half,4> h4_0 = halfTexture4_0[(uint2) pos2_0];
+ half h_0 = halfTexture_0[uint2(pos2_0)];
+ vector<half, 2> h2_0 = halfTexture2_0[uint2(pos2_0)];
+ vector<half, 4> h4_0 = halfTexture4_0[uint2(pos2_0)];
- halfTexture_0[(uint2) pos_0] = h2_0.x + h2_0.y;
- halfTexture2_0[(uint2) pos_0] = h4_0.xy;
- halfTexture4_0[(uint2) pos_0] = vector<half,4>(h2_0, h_0, h_0);
+ halfTexture_0[uint2(pos_0)] = h2_0.x + h2_0.y;
+ halfTexture2_0[uint2(pos_0)] = h4_0.xy;
+ halfTexture4_0[uint2(pos_0)] = vector<half, 4>(h2_0, h_0, h_0);
- int index_0 = pos_0.x + pos_0.y * int(4);
- outputBuffer_0[(uint) index_0] = index_0;
+ int index_0 = _S3 + _S2 * int(4);
+ outputBuffer_0[uint(index_0)] = index_0;
return;
}
diff --git a/tests/cross-compile/precise-keyword.slang.glsl b/tests/cross-compile/precise-keyword.slang.glsl
index 17fed739e..027a8eb3b 100644
--- a/tests/cross-compile/precise-keyword.slang.glsl
+++ b/tests/cross-compile/precise-keyword.slang.glsl
@@ -11,15 +11,18 @@ in vec2 _S2;
void main()
{
+ float _S3 = _S2.x;
+
precise float z_0;
- if(_S2.x > float(0))
+ if(_S3 > 0.00000000000000000000)
{
- z_0 = _S2.x * _S2.y + _S2.x;
+ z_0 = _S3 * _S2.y + _S3;
}
else
{
- z_0 = _S2.y * _S2.x + _S2.y;
+ float _S4 = _S2.y;
+ z_0 = _S4 * _S3 + _S4;
}
_S1 = vec4(z_0);
return;
diff --git a/tests/cross-compile/precise-keyword.slang.hlsl b/tests/cross-compile/precise-keyword.slang.hlsl
index 54017868b..7a07fdc5e 100644
--- a/tests/cross-compile/precise-keyword.slang.hlsl
+++ b/tests/cross-compile/precise-keyword.slang.hlsl
@@ -3,15 +3,17 @@
float4 main(float2 v_0 : V) : SV_TARGET
{
+ float _S1 = v_0.x;
precise float z_0;
- if(v_0.x > (float) 0)
+ if (_S1 > 0.00000000000000000000)
{
- z_0 = v_0.x * v_0.y + v_0.x;
+ z_0 = _S1 * v_0.y + _S1;
}
else
{
- z_0 = v_0.y * v_0.x + v_0.y;
+ float _S2 = v_0.y;
+ z_0 = _S2 * _S1 + _S2;
}
return (float4) z_0;
diff --git a/tests/experimental/liveness/liveness-6.slang.expected b/tests/experimental/liveness/liveness-6.slang.expected
index ac1894f95..26a537330 100644
--- a/tests/experimental/liveness/liveness-6.slang.expected
+++ b/tests/experimental/liveness/liveness-6.slang.expected
@@ -60,15 +60,16 @@ int calcThing_0(int offset_0)
i_0 = i_0 + 1;
}
livenessEnd_0(i_0, 0);
- int _S3 = another_0[k_0 & 1];
- int _S4 = total_0;
+ int _S3 = k_0 & 1;
+ int _S4 = another_0[_S3];
+ int _S5 = total_0;
livenessEnd_0(total_0, 0);
- int total_1 = _S4 + _S3;
- int _S5 = arr_0[k_0 & 1];
+ int total_1 = _S5 + _S4;
+ int _S6 = arr_0[_S3];
livenessEnd_1(arr_0, 0);
- int total_2 = total_1 + _S5;
- int _S6 = (k_0 + 7) % 5;
- if(_S6 == 4)
+ int total_2 = total_1 + _S6;
+ int _S7 = (k_0 + 7) % 5;
+ if(_S7 == 4)
{
livenessEnd_0(k_0, 0);
livenessEnd_1(another_0, 0);
@@ -83,32 +84,32 @@ int calcThing_0(int offset_0)
int total_3;
if(total_0 > 4)
{
- int _S7 = total_0;
+ int _S8 = total_0;
livenessEnd_0(total_0, 0);
- int _S8 = - _S7;
+ int _S9 = - _S8;
livenessStart_1(total_3, 0);
- total_3 = _S8;
+ total_3 = _S9;
}
else
{
- int _S9 = total_0;
+ int _S10 = total_0;
livenessEnd_0(total_0, 0);
livenessStart_1(total_3, 0);
- total_3 = _S9;
+ total_3 = _S10;
}
return total_3;
}
-layout(std430, binding = 0) buffer _S10 {
+layout(std430, binding = 0) buffer _S11 {
int _data[];
} outputBuffer_0;
layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in;
void main()
{
int index_0 = int(gl_GlobalInvocationID.x);
- uint _S11 = uint(index_0);
- int _S12 = calcThing_0(index_0);
- ((outputBuffer_0)._data[(_S11)]) = _S12;
+ uint _S12 = uint(index_0);
+ int _S13 = calcThing_0(index_0);
+ ((outputBuffer_0)._data[(_S12)]) = _S13;
return;
}
diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected
index 8fc391feb..15221b921 100644
--- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected
+++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected
@@ -57,15 +57,16 @@ uint calcValue_0(hitObjectNV hit_0)
uint hitKind_0 = (hitObjectGetHitKindNV((hit_0)));
uint r_1 = 0U + hitKind_0 + instanceIndex_0 + instanceID_0 + geometryIndex_0 + primitiveIndex_0;
RayDesc_0 ray_1 = HitObject_GetRayDesc_0(hit_0);
- uint r_2 = r_1 + uint(ray_1.TMin_0 > 0.00000000000000000000) + uint(ray_1.TMax_0 < ray_1.TMin_0);
+ float _S6 = ray_1.TMin_0;
+ uint r_2 = r_1 + uint(_S6 > 0.00000000000000000000) + uint(ray_1.TMax_0 < _S6);
SomeValues_0 objSomeValues_0 = HitObject_GetAttributes_0(hit_0);
r_0 = r_2 + uint(objSomeValues_0.a_0);
}
else
{
- bool _S6 = (hitObjectIsMissNV((hit_0)));
+ bool _S7 = (hitObjectIsMissNV((hit_0)));
uint r_3;
- if(_S6)
+ if(_S7)
{
r_3 = 1U;
}
@@ -78,29 +79,29 @@ uint calcValue_0(hitObjectNV hit_0)
return r_0;
}
-layout(std430, binding = 1) buffer _S7 {
+layout(std430, binding = 1) buffer _S8 {
uint _data[];
} outputBuffer_0;
void main()
{
- uvec3 _S8 = ((gl_LaunchIDEXT));
- ivec2 launchID_0 = ivec2(_S8.xy);
- uvec3 _S9 = ((gl_LaunchSizeEXT));
+ uvec3 _S9 = ((gl_LaunchIDEXT));
+ ivec2 launchID_0 = ivec2(_S9.xy);
+ uvec3 _S10 = ((gl_LaunchSizeEXT));
int idx_0 = launchID_0.x;
RayDesc_0 ray_2;
ray_2.Origin_0 = vec3(float(idx_0), 0.00000000000000000000, 0.00000000000000000000);
ray_2.TMin_0 = 0.00999999977648258209;
ray_2.Direction_0 = vec3(0.00000000000000000000, 1.00000000000000000000, 0.00000000000000000000);
ray_2.TMax_0 = 10000.00000000000000000000;
- RayDesc_0 _S10 = ray_2;
+ RayDesc_0 _S11 = ray_2;
hitObjectNV hitObj_0;
- hitObjectRecordHitWithIndexNV(hitObj_0, scene_0, int(uint(idx_0)), int(uint(idx_0 * 2)), int(uint(idx_0 * 3)), 0U, 0U, _S10.Origin_0, _S10.TMin_0, _S10.Direction_0, _S10.TMax_0, (0));
+ hitObjectRecordHitWithIndexNV(hitObj_0, scene_0, int(uint(idx_0)), int(uint(idx_0 * 2)), int(uint(idx_0 * 3)), 0U, 0U, _S11.Origin_0, _S11.TMin_0, _S11.Direction_0, _S11.TMax_0, (0));
uint r_4 = calcValue_0(hitObj_0);
- RayDesc_0 _S11 = ray_2;
+ RayDesc_0 _S12 = ray_2;
hitObjectNV hitObj_1;
- hitObjectRecordHitNV(hitObj_1, scene_0, int(uint(idx_0)), int(uint(idx_0 * 3)), int(uint(idx_0 * 2)), 0U, 0U, 4U, _S11.Origin_0, _S11.TMin_0, _S11.Direction_0, _S11.TMax_0, (0));
- uint _S12 = calcValue_0(hitObj_1);
- uint r_5 = r_4 + _S12;
+ hitObjectRecordHitNV(hitObj_1, scene_0, int(uint(idx_0)), int(uint(idx_0 * 3)), int(uint(idx_0 * 2)), 0U, 0U, 4U, _S12.Origin_0, _S12.TMin_0, _S12.Direction_0, _S12.TMax_0, (0));
+ uint _S13 = calcValue_0(hitObj_1);
+ uint r_5 = r_4 + _S13;
((outputBuffer_0)._data[(uint(idx_0))]) = r_5;
return;
}
diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang.1.expected b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang.1.expected
index 90223115b..f250c1c92 100644
--- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang.1.expected
+++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang.1.expected
@@ -79,36 +79,37 @@ void main()
ivec2 launchID_0 = ivec2(_S3.xy);
uvec3 _S4 = ((gl_LaunchSizeEXT));
int idx_0 = launchID_0.x;
- SomeValues_0 someValues_0 = { idx_0, float(idx_0) * 2.00000000000000000000 };
+ float _S5 = float(idx_0);
+ SomeValues_0 someValues_0 = { idx_0, _S5 * 2.00000000000000000000 };
RayDesc_0 ray_0;
- ray_0.Origin_0 = vec3(float(idx_0), 0.00000000000000000000, 0.00000000000000000000);
+ ray_0.Origin_0 = vec3(_S5, 0.00000000000000000000, 0.00000000000000000000);
ray_0.TMin_0 = 0.00999999977648258209;
ray_0.Direction_0 = vec3(0.00000000000000000000, 1.00000000000000000000, 0.00000000000000000000);
ray_0.TMax_0 = 10000.00000000000000000000;
- RayDesc_0 _S5 = ray_0;
+ RayDesc_0 _S6 = ray_0;
p_0 = someValues_0;
hitObjectNV hitObj_0;
- hitObjectTraceRayNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S5.Origin_0, _S5.TMin_0, _S5.Direction_0, _S5.TMax_0, (0));
+ hitObjectTraceRayNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S6.Origin_0, _S6.TMin_0, _S6.Direction_0, _S6.TMax_0, (0));
uint r_1 = calcValue_0(hitObj_0);
reorderThreadNV(hitObj_0);
SomeValues_0 otherValues_0;
- SomeValues_0 _S6 = { idx_0 * -1, float(idx_0) * 4.00000000000000000000 };
- otherValues_0 = _S6;
+ SomeValues_0 _S7 = { idx_0 * -1, _S5 * 4.00000000000000000000 };
+ otherValues_0 = _S7;
HitObject_Invoke_0(scene_0, hitObj_0, otherValues_0);
- uint _S7 = calcValue_0(hitObj_0);
- uint r_2 = r_1 + _S7;
+ uint _S8 = calcValue_0(hitObj_0);
+ uint r_2 = r_1 + _S8;
reorderThreadNV(hitObj_0, uint(idx_0 & 3), 2U);
- SomeValues_0 _S8 = { idx_0 * -2, float(idx_0) * 8.00000000000000000000 };
- otherValues_0 = _S8;
+ SomeValues_0 _S9 = { idx_0 * -2, _S5 * 8.00000000000000000000 };
+ otherValues_0 = _S9;
HitObject_Invoke_0(scene_0, hitObj_0, otherValues_0);
- uint _S9 = calcValue_0(hitObj_0);
- uint r_3 = r_2 + _S9;
+ uint _S10 = calcValue_0(hitObj_0);
+ uint r_3 = r_2 + _S10;
reorderThreadNV(uint(idx_0 & 1), 1U);
- SomeValues_0 _S10 = { idx_0 * -4, float(idx_0) * 16.00000000000000000000 };
- otherValues_0 = _S10;
+ SomeValues_0 _S11 = { idx_0 * -4, _S5 * 16.00000000000000000000 };
+ otherValues_0 = _S11;
HitObject_Invoke_0(scene_0, hitObj_0, otherValues_0);
- uint _S11 = calcValue_0(hitObj_0);
- uint r_4 = r_3 + _S11;
+ uint _S12 = calcValue_0(hitObj_0);
+ uint r_4 = r_3 + _S12;
((outputBuffer_0)._data[(uint(idx_0))]) = r_4;
return;
}
diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang.1.expected b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang.1.expected
index a86dc6aa7..f6f6f132d 100644
--- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang.1.expected
+++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang.1.expected
@@ -70,19 +70,20 @@ void main()
int idx_0 = launchID_0.x;
int _S5 = idx_0 / 4;
float currentTime_0 = float(_S5);
- SomeValues_0 someValues_0 = { idx_0, float(idx_0) * 2.00000000000000000000 };
+ float _S6 = float(idx_0);
+ SomeValues_0 someValues_0 = { idx_0, _S6 * 2.00000000000000000000 };
RayDesc_0 ray_0;
- ray_0.Origin_0 = vec3(float(idx_0), 0.00000000000000000000, 0.00000000000000000000);
+ ray_0.Origin_0 = vec3(_S6, 0.00000000000000000000, 0.00000000000000000000);
ray_0.TMin_0 = 0.00999999977648258209;
ray_0.Direction_0 = vec3(0.00000000000000000000, 1.00000000000000000000, 0.00000000000000000000);
ray_0.TMax_0 = 10000.00000000000000000000;
- RayDesc_0 _S6 = ray_0;
+ RayDesc_0 _S7 = ray_0;
p_0 = someValues_0;
hitObjectNV hitObj_0;
- hitObjectTraceRayMotionNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S6.Origin_0, _S6.TMin_0, _S6.Direction_0, _S6.TMax_0, currentTime_0, (0));
- uint _S7 = uint(idx_0);
- uint _S8 = calcValue_0(hitObj_0);
- ((outputBuffer_0)._data[(_S7)]) = _S8;
+ hitObjectTraceRayMotionNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S7.Origin_0, _S7.TMin_0, _S7.Direction_0, _S7.TMax_0, currentTime_0, (0));
+ uint _S8 = uint(idx_0);
+ uint _S9 = calcValue_0(hitObj_0);
+ ((outputBuffer_0)._data[(_S8)]) = _S9;
return;
}
diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang.1.expected b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang.1.expected
index 38ddbf233..16099b5e2 100644
--- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang.1.expected
+++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang.1.expected
@@ -67,19 +67,20 @@ void main()
ivec2 launchID_0 = ivec2(_S3.xy);
uvec3 _S4 = ((gl_LaunchSizeEXT));
int idx_0 = launchID_0.x;
- SomeValues_0 someValues_0 = { idx_0, float(idx_0) * 2.00000000000000000000 };
+ float _S5 = float(idx_0);
+ SomeValues_0 someValues_0 = { idx_0, _S5 * 2.00000000000000000000 };
RayDesc_0 ray_0;
- ray_0.Origin_0 = vec3(float(idx_0), 0.00000000000000000000, 0.00000000000000000000);
+ ray_0.Origin_0 = vec3(_S5, 0.00000000000000000000, 0.00000000000000000000);
ray_0.TMin_0 = 0.00999999977648258209;
ray_0.Direction_0 = vec3(0.00000000000000000000, 1.00000000000000000000, 0.00000000000000000000);
ray_0.TMax_0 = 10000.00000000000000000000;
- RayDesc_0 _S5 = ray_0;
+ RayDesc_0 _S6 = ray_0;
p_0 = someValues_0;
hitObjectNV hitObj_0;
- hitObjectTraceRayNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S5.Origin_0, _S5.TMin_0, _S5.Direction_0, _S5.TMax_0, (0));
- uint _S6 = uint(idx_0);
- uint _S7 = calcValue_0(hitObj_0);
- ((outputBuffer_0)._data[(_S6)]) = _S7;
+ hitObjectTraceRayNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S6.Origin_0, _S6.TMin_0, _S6.Direction_0, _S6.TMax_0, (0));
+ uint _S7 = uint(idx_0);
+ uint _S8 = calcValue_0(hitObj_0);
+ ((outputBuffer_0)._data[(_S7)]) = _S8;
return;
}
diff --git a/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl b/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl
index 1818b7789..84eba46f0 100644
--- a/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl
+++ b/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl
@@ -19,11 +19,13 @@ void main()
{
beginInvocationInterlockARB();
- vec4 _S3 = (imageLoad((entryPointParams_texture_0), ivec2((uvec2(_S1.xy)))));
- imageStore((entryPointParams_texture_0), ivec2((uvec2(_S1.xy))), _S3 + _S1);
+ vec2 _S3 = _S1.xy;
+
+ vec4 _S4 = (imageLoad((entryPointParams_texture_0), ivec2((uvec2(_S3)))));
+ imageStore((entryPointParams_texture_0), ivec2((uvec2(_S3))), _S4 + _S1);
endInvocationInterlockARB();
- _S2 = _S3;
+ _S2 = _S4;
return;
}
diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl
index 1da5f4f8a..864f44eb3 100644
--- a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl
+++ b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl
@@ -14,6 +14,7 @@ out vec4 _S2;
void main()
{
+ uvec2 _S3 = uvec2(0U, 0U);
_S2 = gl_BaryCoordNV.x * ((_S1)[(0U)]) + gl_BaryCoordNV.y * ((_S1)[(1U)]) + gl_BaryCoordNV.z * ((_S1)[(2U)]);
return;
}
diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl
index 257b334bf..ce23492c9 100644
--- a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl
+++ b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl
@@ -8,7 +8,7 @@ void main(
vector<float,3> bary_0 : SV_BARYCENTRICS,
out vector<float,4> result_0 : SV_TARGET)
{
- result_0 = bary_0.x * GetAttributeAtVertex(color_0, (uint) int(0))
- + bary_0.y * GetAttributeAtVertex(color_0, (uint) int(1))
- + bary_0.z * GetAttributeAtVertex(color_0, (uint) int(2));
+ result_0 = bary_0.x * GetAttributeAtVertex(color_0, 0U)
+ + bary_0.y * GetAttributeAtVertex(color_0, 1U)
+ + bary_0.z * GetAttributeAtVertex(color_0, 2U);
}