diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-23 06:59:25 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-23 06:59:25 -0800 |
| commit | 46a4d98baa1d43b33717b4377aefeeaf46b9c2ff (patch) | |
| tree | c89f3a1c416330f859887d00f896b18bcc7488a5 | |
| parent | 263ca18ea516cfce43fda703c0a411aaf1938e42 (diff) | |
Full address insts elimination for backward autodiff. (#2604)
Co-authored-by: Yong He <yhe@nvidia.com>
40 files changed, 1230 insertions, 171 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index b50712cb1..da62d25f3 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -340,6 +340,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClInclude Include="..\..\..\source\slang\slang-hlsl-intrinsic-set.h" />
<ClInclude Include="..\..\..\source\slang\slang-image-format-defs.h" />
<ClInclude Include="..\..\..\source\slang\slang-intrinsic-expand.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-address-analysis.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-any-value-marshalling.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-augment-make-existential.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-fwd.h" />
@@ -404,6 +405,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClInclude Include="..\..\..\source\slang\slang-ir-missing-return.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-optix-entry-point-uniforms.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-peephole.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-redundancy-removal.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-restructure-scoping.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-restructure.h" />
@@ -521,8 +523,10 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClCompile Include="..\..\..\source\slang\slang-glsl-extension-tracker.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-hlsl-intrinsic-set.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-intrinsic-expand.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-address-analysis.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-any-value-marshalling.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-augment-make-existential.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-addr-inst-elimination.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-fwd.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-pairs.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-rev.cpp" />
@@ -582,6 +586,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClCompile Include="..\..\..\source\slang\slang-ir-missing-return.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-optix-entry-point-uniforms.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-peephole.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-redundancy-removal.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-restructure-scoping.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-restructure.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index 748654b98..4c61f48d9 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -126,6 +126,9 @@ <ClInclude Include="..\..\..\source\slang\slang-intrinsic-expand.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-address-analysis.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-any-value-marshalling.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -318,6 +321,9 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-peephole.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-redundancy-removal.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -665,12 +671,18 @@ <ClCompile Include="..\..\..\source\slang\slang-intrinsic-expand.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-address-analysis.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-any-value-marshalling.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-augment-make-existential.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-addr-inst-elimination.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-fwd.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -848,6 +860,9 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-peephole.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-redundancy-removal.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 4820c430f..0d4088d75 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -581,6 +581,8 @@ DIAGNOSTIC(41021, Error, differentiableFuncMustHaveOutput, "a differentiable fun DIAGNOSTIC(41022, Error, differentiableFuncMustHaveInput, "a differentiable function must have at least one differentiable input.") DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can only be called when argument is statically resolvable to a string literal") +DIAGNOSTIC(41901, Error, unsupportedUseOfLValueForAutoDiff, "unsupported use of L-value for auto differentiation.") +DIAGNOSTIC(41902, Error, cannotDifferentiateDynamicallyIndexedData, "cannot auto-differentiate mixed read/write access to dynamically indexed data in '$0'.") // // 5xxxx - Target code generation. // diff --git a/source/slang/slang-ir-address-analysis.cpp b/source/slang/slang-ir-address-analysis.cpp new file mode 100644 index 000000000..aba59e1de --- /dev/null +++ b/source/slang/slang-ir-address-analysis.cpp @@ -0,0 +1,173 @@ +#include "slang-ir-address-analysis.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" + +namespace Slang +{ + void moveInstToEarliestPoint(IRDominatorTree* domTree, IRGlobalValueWithCode* func, IRInst* inst) + { + if (!as<IRBlock>(inst->getParent())) + return; + if (domTree->isUnreachable(as<IRBlock>(inst->getParent()))) + return; + + List<IRBlock*> blocks; + HashSet<IRInst*> operandInsts; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + operandInsts.Add(inst->getOperand(i)); + auto parentBlock = as<IRBlock>(inst->getOperand(i)->getParent()); + if (parentBlock) + { + if (!domTree->isUnreachable(parentBlock)) + blocks.add(parentBlock); + } + } + { + operandInsts.Add(inst->getFullType()); + auto parentBlock = as<IRBlock>(inst->getFullType()->getParent()); + if (parentBlock) + { + if (!domTree->isUnreachable(parentBlock)) + blocks.add(parentBlock); + } + } + // Find earliest block that is dominated by all operand blocks. + IRBlock* earliestBlock = as<IRBlock>(inst->getParent()); + for (auto block : func->getBlocks()) + { + bool dominated = true; + for (auto opBlock : blocks) + { + if (!domTree->dominates(opBlock, block)) + { + dominated = false; + break; + } + } + if (dominated) + { + earliestBlock = block; + break; + } + } + + if (!earliestBlock) + return; + + IRInst* latestOperand = nullptr; + for (auto childInst : earliestBlock->getChildren()) + { + if (operandInsts.Contains(childInst)) + { + latestOperand = childInst; + } + } + + if (!latestOperand || as<IRParam>(latestOperand)) + inst->insertBefore(earliestBlock->getFirstOrdinaryInst()); + else + inst->insertAfter(latestOperand); + } + + AddressAccessInfo analyzeAddressUse(IRDominatorTree* dom, IRGlobalValueWithCode* func) + { + DeduplicateContext deduplicateContext; + + AddressAccessInfo info; + + // Deduplicate and move known address insts. + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstChild(); inst;) + { + auto next = inst->getNextInst(); + switch (inst->getOp()) + { + case kIROp_Var: + { + RefPtr<AddressInfo> addrInfo = new AddressInfo(); + addrInfo->addrInst = inst; + addrInfo->isConstant = true; + addrInfo->parentAddress = nullptr; + info.addressInfos[inst] = addrInfo; + } + break; + case kIROp_Param: + if (as<IRPtrTypeBase>(inst->getFullType())) + { + RefPtr<AddressInfo> addrInfo = new AddressInfo(); + addrInfo->addrInst = inst; + addrInfo->isConstant = (block == func->getFirstBlock() ? true : false); + addrInfo->parentAddress = nullptr; + info.addressInfos[inst] = addrInfo; + } + break; + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + { + moveInstToEarliestPoint(dom, func, inst->getFullType()); + moveInstToEarliestPoint(dom, func, inst); + auto deduplicated = deduplicateContext.deduplicate(inst, [func](IRInst* inst) + { + if (!inst->getParent()) + return false; + if (inst->getParent()->getParent() != func) + return false; + switch (inst->getOp()) + { + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + return true; + default: + return false; + } + }); + + if (deduplicated != inst) + { + SLANG_RELEASE_ASSERT(dom->dominates( + as<IRBlock>(deduplicated->getParent()), + as<IRBlock>(inst->getParent()))); + + inst->replaceUsesWith(deduplicated); + inst->removeAndDeallocate(); + } + else + { + RefPtr<AddressInfo> addrInfo = new AddressInfo(); + addrInfo->addrInst = inst; + if (inst->getOp() == kIROp_FieldAddress) + { + addrInfo->isConstant = true; + } + else + { + addrInfo->isConstant = + as<IRConstant>(inst->getOperand(1)) ? true : false; + } + info.addressInfos[inst] = addrInfo; + } + } + break; + } + inst = next; + } + } + + // Construct address info tree. + for (auto& addr : info.addressInfos) + { + RefPtr<AddressInfo> parentInfo; + if (addr.Value->addrInst->getOperandCount() > 1 && + info.addressInfos.TryGetValue(addr.Value->addrInst->getOperand(0), parentInfo)) + { + addr.Value->parentAddress = parentInfo; + parentInfo->children.add(addr.Value); + if (!parentInfo->isConstant) + addr.Value->isConstant = false; + } + } + return info; + } +} diff --git a/source/slang/slang-ir-address-analysis.h b/source/slang/slang-ir-address-analysis.h new file mode 100644 index 000000000..450e8b9eb --- /dev/null +++ b/source/slang/slang-ir-address-analysis.h @@ -0,0 +1,24 @@ +// slang-ir-address-analysis.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-dominators.h" + +namespace Slang +{ + struct AddressInfo : public RefObject + { + IRInst* addrInst = nullptr; + AddressInfo* parentAddress = nullptr; + bool isConstant = false; + List<AddressInfo*> children; + }; + + struct AddressAccessInfo + { + OrderedDictionary<IRInst*, RefPtr<AddressInfo>> addressInfos; + }; + + // Gather info on all addresses used by `func`. + AddressAccessInfo analyzeAddressUse(IRDominatorTree* domTree, IRGlobalValueWithCode* func); +} diff --git a/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp b/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp new file mode 100644 index 000000000..c60995595 --- /dev/null +++ b/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp @@ -0,0 +1,476 @@ +#include "slang-ir-address-analysis.h" +#include "slang-ir-autodiff-fwd.h" +#include "slang-ir-autodiff-pairs.h" +#include "slang-ir-autodiff-rev.h" +#include "slang-ir-autodiff.h" +#include "slang-ir-single-return.h" +#include "slang-ir-ssa-simplification.h" +#include "slang-ir-validate.h" + +namespace Slang +{ +bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst); + +struct AddressInstEliminationContext +{ + OrderedDictionary<IRInst*, IRInst*> mapAddrInstToTempVar; + + IRInst* _reconstructStruct( + IRBuilder& builder, IRStructType* type, IRInst* tempVar, List<AddressInfo*>& childAddrs) + { + List<IRInst*> args; + IRInst* loadedTempVar = nullptr; + for (auto child : type->getChildren()) + { + if (auto field = as<IRStructField>(child)) + { + IRInst* childVar = nullptr; + for (auto subAddr : childAddrs) + { + auto fieldAddrInst = cast<IRFieldAddress>(subAddr->addrInst); + if (fieldAddrInst->getField() == field->getKey()) + { + mapAddrInstToTempVar.TryGetValue(subAddr->addrInst, childVar); + break; + } + } + if (childVar) + { + args.add(builder.emitLoad(childVar)); + } + else + { + if (!loadedTempVar) + loadedTempVar = builder.emitLoad(tempVar); + args.add(builder.emitFieldExtract( + field->getFieldType(), loadedTempVar, field->getKey())); + } + } + } + return builder.emitMakeStruct(type, args); + } + + IRInst* _reconstructArray( + IRBuilder& builder, + IRArrayType* type, + IRIntegerValue arraySize, + IRInst* tempVar, + List<AddressInfo*>& childAddrs) + { + IRInst* loadedTempVar = nullptr; + List<IRInst*> args; + for (IRIntegerValue index = 0; index < arraySize; index++) + { + IRInst* childVar = nullptr; + for (auto subAddr : childAddrs) + { + auto elementPtrInst = cast<IRGetElementPtr>(subAddr->addrInst); + auto elementIndex = as<IRIntLit>(elementPtrInst->getIndex()); + if (elementIndex && elementIndex->getValue() == index) + { + mapAddrInstToTempVar.TryGetValue(subAddr->addrInst, childVar); + break; + } + } + if (childVar) + { + args.add(builder.emitLoad(childVar)); + } + else + { + if (!loadedTempVar) + loadedTempVar = builder.emitLoad(tempVar); + args.add(builder.emitElementExtract( + type->getElementType(), + loadedTempVar, + builder.getIntValue(builder.getIntType(), index))); + } + } + return builder.emitMakeArray(type, args.getCount(), args.getBuffer()); + } + + void updateChildTempVarRecursive( + IRBuilder& builder, + AddressInfo* addr, + IRInst* val) + { + for (auto child : addr->children) + { + IRInst* childVar = nullptr; + if (mapAddrInstToTempVar.TryGetValue(child->addrInst, childVar)) + { + switch (child->addrInst->getOp()) + { + case kIROp_FieldAddress: + { + auto subVal = builder.emitFieldExtract( + cast<IRPtrTypeBase>(child->addrInst->getDataType())->getValueType(), + val, + child->addrInst->getOperand(1)); + builder.emitStore(childVar, subVal); + updateChildTempVarRecursive(builder, child, subVal); + } + break; + case kIROp_GetElementPtr: + { + auto subVal = builder.emitElementExtract( + cast<IRPtrTypeBase>(child->addrInst->getDataType())->getValueType(), + val, + child->addrInst->getOperand(1)); + builder.emitStore(childVar, subVal); + updateChildTempVarRecursive(builder, child, subVal); + } + break; + default: + { + } + break; + } + } + } + } + + IRInst* getLoadedValue( + IRBuilder& builder, + AddressInfo* addr, + IRInst* tempVar) + { + if (addr->children.getCount()) + { + // Reconstruct val. + auto type = + cast<IRPtrTypeBase>(unwrapAttributedType(tempVar->getFullType()))->getValueType(); + switch (type->getOp()) + { + case kIROp_StructType: + return _reconstructStruct( + builder, as<IRStructType>(type), tempVar, addr->children); + case kIROp_ArrayType: + { + auto arrayType = as<IRArrayType>(type); + auto size = as<IRIntLit>(arrayType->getElementCount()); + if (!size || size->getValue() < 0) + { + // Unsupported array type. + } + else + { + return _reconstructArray( + builder, + arrayType, + size->getValue(), + tempVar, + addr->children); + } + } + break; + default: + // Unsupported address type. + break; + } + } + return builder.emitLoad(tempVar); + }; + + void updateParentTempVarRecursive( + IRBuilder& builder, + AddressInfo* addr) + { + for (auto parent = addr->parentAddress; parent; parent = parent->parentAddress) + { + IRInst* parentVar = nullptr; + if (mapAddrInstToTempVar.TryGetValue(parent->addrInst, parentVar)) + { + auto val = getLoadedValue(builder, parent, parentVar); + builder.emitStore(parentVar, val); + } + } + } + + String getAddrName(IRInst* addrInst) + { + StringBuilder sb; + List<IRInst*> bases; + bases.add(addrInst); + for (; addrInst;) + { + if (auto fieldAddr = as<IRFieldAddress>(addrInst)) + bases.add(fieldAddr->getBase()); + else if (auto index = as<IRGetElementPtr>(addrInst)) + bases.add(index->getBase()); + else + break; + } + for (Index i = bases.getCount() - 1; i >= 0; i--) + { + if (bases[i]->getOp() == kIROp_FieldAddress) + { + sb << "."; + auto field = bases[i]->getOperand(1); + auto nameDecor = field->findDecoration<IRNameHintDecoration>(); + sb << (nameDecor ? nameDecor->getName() : UnownedStringSlice("<unknown>")); + } + else if (bases[i]->getOp() == kIROp_FieldAddress) + { + sb << "["; + auto index = bases[i]->getOperand(1); + auto nameDecor = index->findDecoration<IRNameHintDecoration>(); + if (nameDecor) + { + sb << nameDecor->getName(); + } + else if (auto intLit = as<IRIntLit>(index)) + { + sb << intLit->getValue(); + } + else + { + sb << "..."; + } + sb << "]"; + } + else + { + auto nameDecor = bases[i]->findDecoration<IRNameHintDecoration>(); + sb << (nameDecor ? nameDecor->getName() : UnownedStringSlice("<unknown>")); + } + } + return sb.ProduceString(); + } + + SlangResult eliminateAddressInstsImpl( + SharedIRBuilder* sharedBuilder, + DifferentiableTypeConformanceContext& diffContext, + IRFunc* func, + DiagnosticSink* sink) + { + bool hasError = false; + + if (!isSingleReturnFunc(func)) + { + convertFuncToSingleReturnForm(func->getModule(), func); + } + + IRBuilder builder(sharedBuilder); + + auto dom = computeDominatorTree(func); + auto addrUse = analyzeAddressUse(dom, func); + List<AddressInfo*> workList; + HashSet<AddressInfo*> workListSet; + + // Process leaf addresses first. + for (auto addr : addrUse.addressInfos) + { + if (addr.Value->children.getCount() == 0) + workList.add(addr.Value); + } + + auto createTempVarForAddr = [&](IRInst* addrInst) + { + if (as<IRParam>(addrInst)) + builder.setInsertAfter(as<IRBlock>(addrInst->getParent())->getLastParam()); + else + builder.setInsertAfter(addrInst); + auto ptrType = as<IRPtrTypeBase>(addrInst->getFullType()); + SLANG_RELEASE_ASSERT(ptrType); + auto tempVar = builder.emitVar(ptrType->getValueType()); + mapAddrInstToTempVar[addrInst] = tempVar; + }; + + // In the first pass, we create temp vars for addresses with non-trivial access pattern. + for (Index workListIndex = 0; workListIndex < workList.getCount(); workListIndex++) + { + auto addr = workList[workListIndex]; + + if (!isDifferentiableType(diffContext, addr->addrInst->getDataType())) + continue; + + List<IRUse*> readUses, writeUses, callUses, subAddrUses, unknownUses; + + for (auto node = addr; node; node = node->parentAddress) + { + auto addrInst = node->addrInst; + + for (auto use = addrInst->firstUse; use; use = use->nextUse) + { + if (as<IRDecoration>(use->getUser())) + continue; + switch (use->getUser()->getOp()) + { + case kIROp_Load: + readUses.add(use); + break; + case kIROp_Store: + writeUses.add(use); + break; + case kIROp_Call: + callUses.add(use); + break; + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + if (node == addr) + subAddrUses.add(use); + break; + default: + unknownUses.add(use); + break; + } + } + } + + if (unknownUses.getCount() != 0) + { + // Diagnose about unknown use. + sink->diagnose( + unknownUses.getFirst()->getUser(), + Diagnostics::unsupportedUseOfLValueForAutoDiff); + hasError = true; + continue; + } + + if (addr->isConstant) + { + // Otherwise, the address must be a constant, and we need to create a temp var for + // it. The exception is when the variable is a temp var for a call. + if (callUses.getCount() == 1 && writeUses.getCount() <= 1 && + readUses.getCount() <= 1) + { + if (writeUses.getCount() == 0) + continue; + + // The uses must be in write->call->read order. + auto callUse = callUses.getFirst(); + auto writeUse = writeUses.getFirst(); + auto readUse = readUses.getCount() ? readUses.getFirst() : writeUse; + if (dom->dominates(writeUse->getUser(), callUse->getUser()) && + dom->dominates(callUse->getUser(), readUse->getUser())) + { + continue; + } + } + + // Create a temp var for the address and replace all uses of the address to the temp + // var. + createTempVarForAddr(addr->addrInst); + } + else + { + // This is a dynamic address. We can only allow at most one write access to it. + bool hasNonTrivialAccess = false; + if (readUses.getCount() + callUses.getCount() != 0 && + writeUses.getCount() + callUses.getCount() > 1) + hasNonTrivialAccess = true; + + if (hasNonTrivialAccess) + { + // Mixed use of a non-constant address is unsupported right now. + sink->diagnose( + addr->addrInst, + Diagnostics::cannotDifferentiateDynamicallyIndexedData, + getAddrName(addr->addrInst)); + } + } + if (addr->parentAddress && workListSet.Add(addr->parentAddress)) + workList.add(addr->parentAddress); + } + + if (hasError) + return SLANG_FAIL; + + // Actually replace addresses with temp vars. + for (auto addr : workList) + { + IRInst* tempVar = nullptr; + if (!mapAddrInstToTempVar.TryGetValue(addr->addrInst, tempVar)) + continue; + for (auto use = addr->addrInst->firstUse; use;) + { + auto nextUse = use->nextUse; + auto user = use->getUser(); + + builder.setInsertBefore(user); + switch (user->getOp()) + { + case kIROp_Load: + use->set(tempVar); + break; + case kIROp_Store: + use->set(tempVar); + updateChildTempVarRecursive( + builder, addr, as<IRStore>(user)->getVal()); + updateParentTempVarRecursive(builder, addr); + case kIROp_Call: + { + use->set(tempVar); + builder.setInsertAfter(user); + auto newVal = builder.emitLoad(tempVar); + updateChildTempVarRecursive(builder, addr, newVal); + updateParentTempVarRecursive(builder, addr); + } + break; + default: + use->set(tempVar); + break; + } + use = nextUse; + } + } + + // Assign initial values to tempVar. + for (auto tempVar : mapAddrInstToTempVar) + { + builder.setInsertAfter(tempVar.Value); + IRInst* initVal = nullptr; + if (tempVar.Key->getOp() == kIROp_Var || + tempVar.Key->getOp() == kIROp_Param && as<IROutType>(tempVar.Key->getFullType())) + { + initVal = builder.emitDefaultConstruct( + cast<IRPtrTypeBase>(tempVar.Key->getFullType())->getValueType()); + } + else + { + initVal = builder.emitLoad(tempVar.Key); + } + builder.emitStore(tempVar.Value, initVal); + } + + // Store final values to out parameters before exiting function. + IRInst* returnInst = nullptr; + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (inst->getOp() == kIROp_Return) + { + returnInst = inst; + break; + } + } + } + SLANG_RELEASE_ASSERT(returnInst); + builder.setInsertBefore(returnInst); + for (auto param : func->getParams()) + { + IRInst* tempVar = nullptr; + if (mapAddrInstToTempVar.TryGetValue(param, tempVar)) + { + auto val = builder.emitLoad(tempVar); + builder.emitStore(param, val); + } + } + if (hasError) + return SLANG_FAIL; + return SLANG_OK; + } +}; + +SlangResult eliminateAddressInsts( + SharedIRBuilder* sharedBuilder, + DifferentiableTypeConformanceContext& diffContext, + IRFunc* func, + DiagnosticSink* sink) +{ + AddressInstEliminationContext ctx; + return ctx.eliminateAddressInstsImpl(sharedBuilder, diffContext, func, sink); +} +} // namespace Slang diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index ee159b80b..3f3618b44 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1130,6 +1130,8 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_VectorReshape: case kIROp_IntCast: case kIROp_FloatCast: + case kIROp_MakeStruct: + case kIROp_MakeArray: return transcribeConstruct(builder, origInst); case kIROp_LookupWitness: diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index d3a6137c1..779a4f1a3 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -5,10 +5,9 @@ #include "slang-ir-eliminate-phis.h" #include "slang-ir-util.h" #include "slang-ir-inst-pass-base.h" - +#include "slang-ir-ssa-simplification.h" #include "slang-ir-autodiff-fwd.h" - namespace Slang { IRFuncType* BackwardDiffTranscriberBase::differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermeidateType) @@ -502,6 +501,17 @@ namespace Slang stripDerivativeDecorations(primalFunc); eliminateDeadCode(primalOuterParent); + // Perform preparation and simplification. + differentiableTypeConformanceContext.setFunc(primalFunc); + if (SLANG_FAILED(eliminateAddressInsts( + builder->getSharedBuilder(), + differentiableTypeConformanceContext, + primalFunc, + sink))) + return nullptr; + + simplifyFunc(primalFunc); + // Forward transcribe the clone of the original func. ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>( autoDiffSharedContext->transcriberSet.forwardTranscriber); @@ -567,7 +577,9 @@ namespace Slang } auto fwdDiffFunc = generateNewForwardDerivativeForFunc(&tempBuilder, primalFunc, diffPropagateFunc); - + if (!fwdDiffFunc) + return; + // Split first block into a paramter block. this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc)); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index ae1a5dd70..e799456bb 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1061,6 +1061,10 @@ struct DiffTransposePass case kIROp_MakeVector: return transposeMakeVector(builder, fwdInst, revValue); + case kIROp_MakeStruct: + return transposeMakeStruct(builder, fwdInst, revValue); + case kIROp_MakeArray: + return transposeMakeArray(builder, fwdInst, revValue); case kIROp_Specialize: case kIROp_unconditionalBranch: @@ -1218,6 +1222,55 @@ struct DiffTransposePass return TranspositionResult(gradients); } + TranspositionResult transposeMakeStruct(IRBuilder* builder, IRInst* fwdMakeStruct, IRInst* revValue) + { + List<RevGradient> gradients; + auto structType = cast<IRStructType>(fwdMakeStruct->getFullType()); + UInt ii = 0; + for (auto field : structType->getFields()) + { + auto gradAtField = builder->emitFieldExtract( + field->getFieldType(), + revValue, + field->getKey()); + SLANG_RELEASE_ASSERT(ii < fwdMakeStruct->getOperandCount()); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakeStruct->getOperand(ii), + gradAtField, + fwdMakeStruct)); + ii++; + } + + // (A = MakeStruct(F1, F2, F3)) -> [(dF1 += dA.F1), (dF2 += dA.F2), (dF3 += dA.F3)] + return TranspositionResult(gradients); + } + + TranspositionResult transposeMakeArray(IRBuilder* builder, IRInst* fwdMakeArray, IRInst* revValue) + { + List<RevGradient> gradients; + auto arrayType = cast<IRArrayType>(fwdMakeArray->getFullType()); + auto arraySize = cast<IRIntLit>(arrayType->getElementCount()); + + for (UInt ii = 0; ii < (UInt)arraySize->getValue(); ii++) + { + auto gradAtField = builder->emitElementExtract( + arrayType->getElementType(), + revValue, + builder->getIntValue(builder->getIntType(), ii)); + SLANG_RELEASE_ASSERT(ii < fwdMakeArray->getOperandCount()); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakeArray->getOperand(ii), + gradAtField, + fwdMakeArray)); + ii++; + } + + // (A = MakeArray(F1, F2, F3)) -> [(dF1 += dA.F1), (dF2 += dA.F2), (dF3 += dA.F3)] + return TranspositionResult(gradients); + } + // Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr. // void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad) diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 363006f58..4d33d3743 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1,7 +1,10 @@ #include "slang-ir-autodiff.h" +#include "slang-ir-address-analysis.h" #include "slang-ir-autodiff-rev.h" #include "slang-ir-autodiff-fwd.h" #include "slang-ir-autodiff-pairs.h" +#include "slang-ir-single-return.h" +#include "slang-ir-ssa-simplification.h" #include "slang-ir-validate.h" namespace Slang diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 7479e4eee..cb767c20a 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -260,4 +260,11 @@ bool finalizeAutoDiffPass(IRModule* module); void stripDerivativeDecorations(IRInst* inst); bool isBackwardDifferentiableFunc(IRInst* func); + +SlangResult eliminateAddressInsts( + SharedIRBuilder* sharedBuilder, + DifferentiableTypeConformanceContext& diffContext, + IRFunc* func, + DiagnosticSink* sink); + }; diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index cb7290036..67b7e92b0 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -5,12 +5,45 @@ namespace Slang { +bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst) +{ + HashSet<IRInst*> processedSet; + while (auto ptrType = as<IRPtrTypeBase>(typeInst)) + { + typeInst = ptrType->getValueType(); + if (!processedSet.Add(typeInst)) + return false; + } + if (!typeInst) + return false; + switch (typeInst->getOp()) + { + case kIROp_FloatType: + case kIROp_DifferentialPairType: + return true; + default: + break; + } + if (context.lookUpConformanceForType(typeInst)) + return true; + // Look for equivalent types. + for (auto type : context.differentiableWitnessDictionary) + { + if (isTypeEqual(type.Key, (IRType*)typeInst)) + { + context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value; + return true; + } + } + return false; +} struct CheckDifferentiabilityPassContext : public InstPassBase { public: DiagnosticSink* sink; AutoDiffSharedContext sharedContext; + SharedIRBuilder* sharedBuilder; enum DifferentiableLevel { @@ -18,8 +51,8 @@ public: }; Dictionary<IRInst*, DifferentiableLevel> differentiableFunctions; - CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink) - : InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst()) + CheckDifferentiabilityPassContext(SharedIRBuilder* inSharedBuilder, IRModule* inModule, DiagnosticSink* inSink) + : InstPassBase(inModule), sharedBuilder(inSharedBuilder), sink(inSink), sharedContext(inModule->getModuleInst()) {} IRInst* getSpecializedVal(IRInst* inst) @@ -161,39 +194,6 @@ public: return false; } - bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst) - { - HashSet<IRInst*> processedSet; - while (auto ptrType = as<IRPtrTypeBase>(typeInst)) - { - typeInst = ptrType->getValueType(); - if (!processedSet.Add(typeInst)) - return false; - } - if (!typeInst) - return false; - switch (typeInst->getOp()) - { - case kIROp_FloatType: - case kIROp_DifferentialPairType: - return true; - default: - break; - } - if (context.lookUpConformanceForType(typeInst)) - return true; - // Look for equivalent types. - for (auto type : context.differentiableWitnessDictionary) - { - if (isTypeEqual(type.Key, (IRType*)typeInst)) - { - context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value; - return true; - } - } - return false; - } - int getParamIndexInBlock(IRParam* paramInst) { auto block = as<IRBlock>(paramInst->getParent()); @@ -228,6 +228,14 @@ public: DifferentiableTypeConformanceContext diffTypeContext(&sharedContext); diffTypeContext.setFunc(funcInst); + if (isBackwardDifferentiableFunc(funcInst) && !funcInst->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) + { + if (auto func = as<IRFunc>(funcInst)) + { + if (SLANG_FAILED(eliminateAddressInsts(sharedBuilder, diffTypeContext, func, sink))) + return; + } + } HashSet<IRInst*> produceDiffSet; HashSet<IRInst*> expectDiffSet; @@ -468,9 +476,9 @@ public: } }; -void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink) +void checkAutoDiffUsages(SharedIRBuilder* sharedBuilder, IRModule* module, DiagnosticSink* sink) { - CheckDifferentiabilityPassContext context(module, sink); + CheckDifferentiabilityPassContext context(sharedBuilder, module, sink); context.processModule(); } diff --git a/source/slang/slang-ir-check-differentiability.h b/source/slang/slang-ir-check-differentiability.h index 735a918c9..16ae16b6f 100644 --- a/source/slang/slang-ir-check-differentiability.h +++ b/source/slang/slang-ir-check-differentiability.h @@ -7,8 +7,9 @@ namespace Slang { struct IRModule; class DiagnosticSink; +struct SharedIRBuilder; // Check all auto diff usages are valid. -void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink); +void checkAutoDiffUsages(SharedIRBuilder* sharedBuilder, IRModule* module, DiagnosticSink* sink); } // namespace Slang diff --git a/source/slang/slang-ir-dominators.cpp b/source/slang/slang-ir-dominators.cpp index 72b156228..1ffa7ba5d 100644 --- a/source/slang/slang-ir-dominators.cpp +++ b/source/slang/slang-ir-dominators.cpp @@ -86,6 +86,31 @@ bool IRDominatorTree::dominates(IRBlock* dominator, IRBlock* dominated) return properlyDominates(dominator, dominated); } +bool IRDominatorTree::dominates(IRInst* dominator, IRInst* dominated) +{ + auto dominatorBlock = as<IRBlock>(dominator); + if (!dominatorBlock) + dominatorBlock = as<IRBlock>(dominator->getParent()); + + auto dominatedBlock = as<IRBlock>(dominated); + if (!dominatedBlock) + dominatedBlock = as<IRBlock>(dominated->getParent()); + + if (dominatorBlock == dominatedBlock) + { + for (auto inst = dominator; inst; inst = inst->getNextInst()) + { + if (inst == dominated) + return true; + } + return false; + } + else + { + return dominates(dominatorBlock, dominatedBlock); + } +} + IRBlock* IRDominatorTree::getImmediateDominator(IRBlock* block) { // An unreachable block has no immediate dominator. diff --git a/source/slang/slang-ir-dominators.h b/source/slang/slang-ir-dominators.h index be01830b0..1fb12c89e 100644 --- a/source/slang/slang-ir-dominators.h +++ b/source/slang/slang-ir-dominators.h @@ -7,6 +7,7 @@ namespace Slang { struct IRBlock; struct IRGlobalValueWithCode; + struct IRInst; /// The computed dominator tree for an IR control flow graph. struct IRDominatorTree : public RefObject @@ -22,6 +23,8 @@ namespace Slang /// bool dominates(IRBlock* dominator, IRBlock* dominated); + bool dominates(IRInst* dominator, IRInst* dominated); + /// Does the first block properly dominate the second? /// /// Block A properly dominates block B iff A dominates B diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 68afbbb95..134a45bf5 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -305,7 +305,6 @@ INST(GetOptionalValue, getOptionalValue, 1, 0) INST(OptionalHasValue, optionalHasValue, 1, 0) INST(MakeOptionalValue, makeOptionalValue, 1, 0) INST(MakeOptionalNone, makeOptionalNone, 1, 0) -INST(DifferentialBottomValue, differentialBottomVal, 0, 0) INST(Call, call, 1, 0) INST(RTTIObject, rtti_object, 0, 0) diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp new file mode 100644 index 000000000..a57bfce3e --- /dev/null +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -0,0 +1,125 @@ +#include "slang-ir-redundancy-removal.h" +#include "slang-ir-dominators.h" +#include "slang-ir-util.h" + +namespace Slang +{ + +struct RedundancyRemovalContext +{ + RefPtr<IRDominatorTree> dom; + bool removeRedundancyInBlock(DeduplicateContext& deduplicateContext, IRBlock* block) + { + bool result = false; + for (auto instP : block->getChildren()) + { + auto resultInst = deduplicateContext.deduplicate(instP, [&](IRInst* inst) + { + auto parentBlock = as<IRBlock>(inst->getParent()); + if (!parentBlock) + return false; + if (dom->isUnreachable(parentBlock)) + return false; + + switch (inst->getOp()) + { + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_Module: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_And: + case kIROp_Or: + case kIROp_Not: + case kIROp_FieldExtract: + case kIROp_FieldAddress: + case kIROp_GetElement: + case kIROp_GetElementPtr: + case kIROp_LookupWitness: + case kIROp_Specialize: + case kIROp_OptionalHasValue: + case kIROp_GetOptionalValue: + case kIROp_MakeOptionalValue: + case kIROp_MakeTuple: + case kIROp_GetTupleElement: + case kIROp_MakeStruct: + case kIROp_MakeArray: + case kIROp_MakeVector: + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MakeVectorFromScalar: + case kIROp_swizzle: + case kIROp_MatrixReshape: + case kIROp_MakeString: + case kIROp_MakeResultError: + case kIROp_MakeResultValue: + case kIROp_GetResultError: + case kIROp_GetResultValue: + case kIROp_CastFloatToInt: + case kIROp_CastIntToFloat: + case kIROp_CastIntToPtr: + case kIROp_CastPtrToBool: + case kIROp_CastPtrToInt: + case kIROp_BitAnd: + case kIROp_BitNot: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_BitCast: + case kIROp_Reinterpret: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + case kIROp_Neq: + case kIROp_Eql: + return true; + default: + return false; + } + }); + if (resultInst != instP) + result = true; + } + for (auto child : dom->getImmediatelyDominatedBlocks(block)) + { + DeduplicateContext subContext; + subContext.deduplicateMap = deduplicateContext.deduplicateMap; + result |= removeRedundancyInBlock(subContext, child); + } + return result; + } +}; + +bool removeRedundancy(IRModule* module) +{ + bool changed = false; + for (auto inst : module->getGlobalInsts()) + { + if (auto genericInst = as<IRGeneric>(inst)) + { + removeRedundancyInFunc(genericInst); + inst = findGenericReturnVal(genericInst); + } + if (auto func = as<IRFunc>(inst)) + { + changed |= removeRedundancyInFunc(func); + } + } + return changed; +} + +bool removeRedundancyInFunc(IRGlobalValueWithCode* func) +{ + auto root = func->getFirstBlock(); + if (!root) + return false; + + RedundancyRemovalContext context; + context.dom = computeDominatorTree(func); + DeduplicateContext deduplicateCtx; + return context.removeRedundancyInBlock(deduplicateCtx, root); +} + +} diff --git a/source/slang/slang-ir-redundancy-removal.h b/source/slang/slang-ir-redundancy-removal.h new file mode 100644 index 000000000..26b265e77 --- /dev/null +++ b/source/slang/slang-ir-redundancy-removal.h @@ -0,0 +1,11 @@ +// slang-ir-redundancy-removal.h +#pragma once + +namespace Slang +{ + struct IRModule; + struct IRGlobalValueWithCode; + + bool removeRedundancy(IRModule* module); + bool removeRedundancyInFunc(IRGlobalValueWithCode* func); +} diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index 1e247d1d9..54a1f7e08 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -196,6 +196,10 @@ bool simplifyCFG(IRModule* module) bool changed = false; for (auto inst : module->getGlobalInsts()) { + if (auto genericInst = as<IRGeneric>(inst)) + { + inst = findGenericReturnVal(genericInst); + } if (auto func = as<IRFunc>(inst)) { changed |= processFunc(func); diff --git a/source/slang/slang-ir-single-return.cpp b/source/slang/slang-ir-single-return.cpp index f76e35040..30e933133 100644 --- a/source/slang/slang-ir-single-return.cpp +++ b/source/slang/slang-ir-single-return.cpp @@ -91,4 +91,20 @@ void convertFuncToSingleReturnForm(IRModule* irModule, IRGlobalValueWithCode* fu context.processFunc(func); } +bool isSingleReturnFunc(IRGlobalValueWithCode* func) +{ + int returnCount = 0; + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (inst->getOp() == kIROp_Return) + { + returnCount++; + } + } + } + return returnCount <= 1; +} + } // namespace Slang diff --git a/source/slang/slang-ir-single-return.h b/source/slang/slang-ir-single-return.h index 2ddfa280b..bb186634d 100644 --- a/source/slang/slang-ir-single-return.h +++ b/source/slang/slang-ir-single-return.h @@ -9,4 +9,5 @@ namespace Slang // Convert the CFG of `func` to have only a single `return` at the end. void convertFuncToSingleReturnForm(IRModule* module, IRGlobalValueWithCode* func); + bool isSingleReturnFunc(IRGlobalValueWithCode* func); } diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp index fd5f41f49..f06fafcb3 100644 --- a/source/slang/slang-ir-ssa-simplification.cpp +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -9,6 +9,7 @@ #include "slang-ir-hoist-constants.h" #include "slang-ir-deduplicate-generic-children.h" #include "slang-ir-remove-unused-generic-param.h" +#include "slang-ir-redundancy-removal.h" namespace Slang { @@ -26,6 +27,7 @@ namespace Slang changed |= deduplicateGenericChildren(module); changed |= applySparseConditionalConstantPropagation(module); changed |= peepholeOptimize(module); + changed |= removeRedundancy(module); changed |= simplifyCFG(module); // Note: we disregard the `changed` state from dead code elimination pass since @@ -49,6 +51,7 @@ namespace Slang changed = false; changed |= applySparseConditionalConstantPropagation(func); changed |= peepholeOptimize(func); + changed |= removeRedundancyInFunc(func); changed |= simplifyCFG(func); // Note: we disregard the `changed` state from dead code elimination pass since diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 881f041c0..319a23989 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -219,12 +219,20 @@ void moveInstChildren(IRInst* dest, IRInst* src) } } +String dumpIRToString(IRInst* root) +{ + StringBuilder sb; + StringWriter writer(&sb, Slang::WriterFlag::AutoFlush); + dumpIR(root, IRDumpOptions(), nullptr, &writer); + return sb.ToString(); +} + struct GenericChildrenMigrationContextImpl { IRCloneEnv cloneEnv; IRGeneric* srcGeneric; IRGeneric* dstGeneric; - Dictionary<IRInstKey, IRInst*> deduplicateMap; + DeduplicateContext deduplicateContext; void init(IRGeneric* genericSrc, IRGeneric* genericDst, IRInst* insertBefore) { @@ -251,42 +259,34 @@ struct GenericChildrenMigrationContextImpl inst = inst->getNextInst()) { IRInstKey key = { inst }; - deduplicateMap.AddIfNotExists(key, inst); + deduplicateContext.deduplicateMap.AddIfNotExists(key, inst); } } } IRInst* deduplicate(IRInst* value) { - if (!value) return nullptr; - if (value->getParent() != dstGeneric->getFirstBlock()) - return value; - switch (value->getOp()) - { - case kIROp_Param: - case kIROp_StructType: - case kIROp_StructKey: - case kIROp_InterfaceType: - case kIROp_ClassType: - case kIROp_Func: - case kIROp_Generic: - return value; - default: - break; - } - if (as<IRConstant>(value)) - return value; - - for (UInt i = 0; i < value->getOperandCount(); i++) - { - value->setOperand(i, deduplicate(value->getOperand(i))); - } - value->setFullType((IRType*)deduplicate(value->getFullType())); - IRInstKey key = { value }; - if (auto newValue = deduplicateMap.TryGetValue(key)) - return *newValue; - deduplicateMap[key] = value; - return value; + return deduplicateContext.deduplicate(value, [this](IRInst* inst) + { + if (inst->getParent() != dstGeneric->getFirstBlock()) + return false; + switch (inst->getOp()) + { + case kIROp_Param: + case kIROp_StructType: + case kIROp_StructKey: + case kIROp_InterfaceType: + case kIROp_ClassType: + case kIROp_Func: + case kIROp_Generic: + return false; + default: + break; + } + if (as<IRConstant>(inst)) + return false; + return true; + }); } IRInst* cloneInst(IRBuilder* builder, IRInst* src) diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 92446138f..a250fc6a6 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -5,6 +5,7 @@ // This file contains utility functions for operating with Slang IR. // #include "slang-ir.h" +#include "slang-ir-insts.h" namespace Slang { @@ -32,6 +33,32 @@ public: IRInst* cloneInst(IRBuilder* builder, IRInst* src); }; + +struct DeduplicateContext +{ + Dictionary<IRInstKey, IRInst*> deduplicateMap; + + template<typename TFunc> + IRInst* deduplicate(IRInst* value, const TFunc& shouldDeduplicate) + { + if (!value) return nullptr; + if (!shouldDeduplicate(value)) + return value; + IRInstKey key = { value }; + if (auto newValue = deduplicateMap.TryGetValue(key)) + return *newValue; + for (UInt i = 0; i < value->getOperandCount(); i++) + { + value->setOperand(i, deduplicate(value->getOperand(i), shouldDeduplicate)); + } + value->setFullType((IRType*)deduplicate(value->getFullType(), shouldDeduplicate)); + if (auto newValue = deduplicateMap.TryGetValue(key)) + return *newValue; + deduplicateMap[key] = value; + return value; + } +}; + bool isPtrToClassType(IRInst* type); bool isPtrToArrayType(IRInst* type); @@ -126,6 +153,8 @@ inline IRInst* unwrapAttributedType(IRInst* type) return type; } +String dumpIRToString(IRInst* root); + } #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index e400d0a17..b79221900 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -1982,7 +1982,6 @@ namespace Slang return getStringSlice() == rhs->getStringSlice(); } case kIROp_VoidLit: - case kIROp_DifferentialBottomValue: { return true; } @@ -2025,7 +2024,6 @@ namespace Slang return combineHash(code, Slang::getHashCode(slice.begin(), slice.getLength())); } case kIROp_VoidLit: - case kIROp_DifferentialBottomValue: { return code; } @@ -2110,14 +2108,6 @@ namespace Slang irValue->value.ptrVal = keyInst.value.ptrVal; break; } - case kIROp_DifferentialBottomValue: - { - const size_t instSize = prefixSize + sizeof(void*); - irValue = static_cast<IRConstant*>( - _createInst(instSize, keyInst.getFullType(), keyInst.getOp())); - irValue->value.ptrVal = nullptr; - break; - } case kIROp_StringLit: { const UnownedStringSlice slice = keyInst.getStringSlice(); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ec51c7bfa..d0527eef8 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -9157,7 +9157,7 @@ RefPtr<IRModule> generateIRForTranslationUnit( checkForMissingReturns(module, compileRequest->getSink()); // Check for invalid differentiable function body. - checkAutoDiffUsages(module, compileRequest->getSink()); + checkAutoDiffUsages(sharedBuilder, module, compileRequest->getSink()); // The "mandatory" optimization passes may make use of the // `IRHighLevelDeclDecoration` type to relate IR instructions diff --git a/tests/autodiff/reverse-struct-multi-write.slang b/tests/autodiff/reverse-struct-multi-write.slang new file mode 100644 index 000000000..dd12c7d3d --- /dev/null +++ b/tests/autodiff/reverse-struct-multi-write.slang @@ -0,0 +1,48 @@ + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct A : IDifferentiable +{ + float x; + float y; +}; + +[BackwardDifferentiable] +A f(A a) +{ + // Read/writes to local struct variables won't be SSA'd out by default. + // The backward diff preparation pass will kick in to create temp vars for them. + A aout; + aout.y = 2 * a.x; + aout.y = aout.y + 2 * a.x; + aout.x = aout.y + 5 * a.x; + + // The result should be equivalent to: + /* + A aout; + var tmp = 2 * a.x; + tmp = tmp + 2 * a.x; + aout.y = tmp; + aout.x = tmp + 5 * a.x; + */ + return aout; + +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + A a = {1.0, 2.0}; + + var dpa = diffPair(a); + + A.Differential dout = {1.0, 1.0}; + + __bwd_diff(f)(dpa, dout); + outputBuffer[0] = dpa.d.x; // Expect: 13 + outputBuffer[1] = dpa.d.y; // Expect: 0 +} diff --git a/tests/autodiff/reverse-struct-multi-write.slang.expected.txt b/tests/autodiff/reverse-struct-multi-write.slang.expected.txt new file mode 100644 index 000000000..403f2ffd4 --- /dev/null +++ b/tests/autodiff/reverse-struct-multi-write.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +13.000000 +0.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/compute/half-texture.slang.glsl b/tests/compute/half-texture.slang.glsl index 88f585378..0eccccaaf 100644 --- a/tests/compute/half-texture.slang.glsl +++ b/tests/compute/half-texture.slang.glsl @@ -21,20 +21,23 @@ layout(std430, binding = 0) buffer _S1 { int _data[]; } outputBuffer_0; -layout(local_size_x = 4, local_size_y = 4, local_size_z = 1) in;void main() +layout(local_size_x = 4, local_size_y = 4, local_size_z = 1) in; +void main() { ivec2 pos_0 = ivec2(gl_GlobalInvocationID.xy); const float _S2 = 1.00000000000000000000 / 3.00000000000000000000; - ivec2 pos2_0 = ivec2(3 - pos_0.y, 3 - pos_0.x); + int _S3 = pos_0.y; + int _S4 = pos_0.x; + ivec2 pos2_0 = ivec2(3 - _S3, 3 - _S4); float16_t h_0 = (float16_t(imageLoad((halfTexture_0), ivec2((uvec2(pos2_0)))).x)); f16vec2 h2_0 = (f16vec2(imageLoad((halfTexture2_0), ivec2((uvec2(pos2_0)))).xy)); f16vec4 h4_0 = (f16vec4(imageLoad((halfTexture4_0), ivec2((uvec2(pos2_0)))))); - imageStore((halfTexture_0), ivec2((uvec2(pos_0))), f16vec4(h2_0.x + h2_0.y, float16_t(0), float16_t(0), float16_t(0))); - imageStore((halfTexture2_0), ivec2((uvec2(pos_0))), f16vec4(h4_0.xy, float16_t(0), float16_t(0))); - imageStore((halfTexture4_0), ivec2((uvec2(pos_0))), f16vec4(h2_0, h_0, h_0)); + imageStore((halfTexture_0), ivec2((uvec2(pos_0))), f16vec4(h2_0.x + h2_0.y, float16_t(0), float16_t(0), float16_t(0))); + imageStore((halfTexture2_0), ivec2((uvec2(pos_0))), f16vec4(h4_0.xy, float16_t(0), float16_t(0))); + imageStore((halfTexture4_0), ivec2((uvec2(pos_0))), f16vec4(h2_0, h_0, h_0)); - int index_0 = pos_0.x + pos_0.y * 4; + int index_0 = _S4 + _S3 * 4; ((outputBuffer_0)._data[(uint(index_0))]) = index_0; return; diff --git a/tests/compute/half-texture.slang.hlsl b/tests/compute/half-texture.slang.hlsl index c606703a4..2d04ee17f 100644 --- a/tests/compute/half-texture.slang.hlsl +++ b/tests/compute/half-texture.slang.hlsl @@ -8,19 +8,21 @@ RWStructuredBuffer<int > outputBuffer_0 : register(u0); [shader("compute")][numthreads(4, 4, 1)] void computeMain(uint3 dispatchThreadID_0 : SV_DISPATCHTHREADID) { - int2 pos_0 = (int2) dispatchThreadID_0.xy; + int2 pos_0 = int2(dispatchThreadID_0.xy); float _S1 = 1.00000000000000000000 / 3.00000000000000000000; - int2 pos2_0 = int2(int(3) - pos_0.y, int(3) - pos_0.x); + int _S2 = pos_0.y; + int _S3 = pos_0.x; + int2 pos2_0 = int2(int(3) - _S2, int(3) - _S3); - half h_0 = halfTexture_0[(uint2) pos2_0]; - vector<half,2> h2_0 = halfTexture2_0[(uint2) pos2_0]; - vector<half,4> h4_0 = halfTexture4_0[(uint2) pos2_0]; + half h_0 = halfTexture_0[uint2(pos2_0)]; + vector<half, 2> h2_0 = halfTexture2_0[uint2(pos2_0)]; + vector<half, 4> h4_0 = halfTexture4_0[uint2(pos2_0)]; - halfTexture_0[(uint2) pos_0] = h2_0.x + h2_0.y; - halfTexture2_0[(uint2) pos_0] = h4_0.xy; - halfTexture4_0[(uint2) pos_0] = vector<half,4>(h2_0, h_0, h_0); + halfTexture_0[uint2(pos_0)] = h2_0.x + h2_0.y; + halfTexture2_0[uint2(pos_0)] = h4_0.xy; + halfTexture4_0[uint2(pos_0)] = vector<half, 4>(h2_0, h_0, h_0); - int index_0 = pos_0.x + pos_0.y * int(4); - outputBuffer_0[(uint) index_0] = index_0; + int index_0 = _S3 + _S2 * int(4); + outputBuffer_0[uint(index_0)] = index_0; return; } diff --git a/tests/cross-compile/precise-keyword.slang.glsl b/tests/cross-compile/precise-keyword.slang.glsl index 17fed739e..027a8eb3b 100644 --- a/tests/cross-compile/precise-keyword.slang.glsl +++ b/tests/cross-compile/precise-keyword.slang.glsl @@ -11,15 +11,18 @@ in vec2 _S2; void main() { + float _S3 = _S2.x; + precise float z_0; - if(_S2.x > float(0)) + if(_S3 > 0.00000000000000000000) { - z_0 = _S2.x * _S2.y + _S2.x; + z_0 = _S3 * _S2.y + _S3; } else { - z_0 = _S2.y * _S2.x + _S2.y; + float _S4 = _S2.y; + z_0 = _S4 * _S3 + _S4; } _S1 = vec4(z_0); return; diff --git a/tests/cross-compile/precise-keyword.slang.hlsl b/tests/cross-compile/precise-keyword.slang.hlsl index 54017868b..7a07fdc5e 100644 --- a/tests/cross-compile/precise-keyword.slang.hlsl +++ b/tests/cross-compile/precise-keyword.slang.hlsl @@ -3,15 +3,17 @@ float4 main(float2 v_0 : V) : SV_TARGET { + float _S1 = v_0.x; precise float z_0; - if(v_0.x > (float) 0) + if (_S1 > 0.00000000000000000000) { - z_0 = v_0.x * v_0.y + v_0.x; + z_0 = _S1 * v_0.y + _S1; } else { - z_0 = v_0.y * v_0.x + v_0.y; + float _S2 = v_0.y; + z_0 = _S2 * _S1 + _S2; } return (float4) z_0; diff --git a/tests/experimental/liveness/liveness-6.slang.expected b/tests/experimental/liveness/liveness-6.slang.expected index ac1894f95..26a537330 100644 --- a/tests/experimental/liveness/liveness-6.slang.expected +++ b/tests/experimental/liveness/liveness-6.slang.expected @@ -60,15 +60,16 @@ int calcThing_0(int offset_0) i_0 = i_0 + 1; } livenessEnd_0(i_0, 0); - int _S3 = another_0[k_0 & 1]; - int _S4 = total_0; + int _S3 = k_0 & 1; + int _S4 = another_0[_S3]; + int _S5 = total_0; livenessEnd_0(total_0, 0); - int total_1 = _S4 + _S3; - int _S5 = arr_0[k_0 & 1]; + int total_1 = _S5 + _S4; + int _S6 = arr_0[_S3]; livenessEnd_1(arr_0, 0); - int total_2 = total_1 + _S5; - int _S6 = (k_0 + 7) % 5; - if(_S6 == 4) + int total_2 = total_1 + _S6; + int _S7 = (k_0 + 7) % 5; + if(_S7 == 4) { livenessEnd_0(k_0, 0); livenessEnd_1(another_0, 0); @@ -83,32 +84,32 @@ int calcThing_0(int offset_0) int total_3; if(total_0 > 4) { - int _S7 = total_0; + int _S8 = total_0; livenessEnd_0(total_0, 0); - int _S8 = - _S7; + int _S9 = - _S8; livenessStart_1(total_3, 0); - total_3 = _S8; + total_3 = _S9; } else { - int _S9 = total_0; + int _S10 = total_0; livenessEnd_0(total_0, 0); livenessStart_1(total_3, 0); - total_3 = _S9; + total_3 = _S10; } return total_3; } -layout(std430, binding = 0) buffer _S10 { +layout(std430, binding = 0) buffer _S11 { int _data[]; } outputBuffer_0; layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; void main() { int index_0 = int(gl_GlobalInvocationID.x); - uint _S11 = uint(index_0); - int _S12 = calcThing_0(index_0); - ((outputBuffer_0)._data[(_S11)]) = _S12; + uint _S12 = uint(index_0); + int _S13 = calcThing_0(index_0); + ((outputBuffer_0)._data[(_S12)]) = _S13; return; } diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected index 8fc391feb..15221b921 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected @@ -57,15 +57,16 @@ uint calcValue_0(hitObjectNV hit_0) uint hitKind_0 = (hitObjectGetHitKindNV((hit_0))); uint r_1 = 0U + hitKind_0 + instanceIndex_0 + instanceID_0 + geometryIndex_0 + primitiveIndex_0; RayDesc_0 ray_1 = HitObject_GetRayDesc_0(hit_0); - uint r_2 = r_1 + uint(ray_1.TMin_0 > 0.00000000000000000000) + uint(ray_1.TMax_0 < ray_1.TMin_0); + float _S6 = ray_1.TMin_0; + uint r_2 = r_1 + uint(_S6 > 0.00000000000000000000) + uint(ray_1.TMax_0 < _S6); SomeValues_0 objSomeValues_0 = HitObject_GetAttributes_0(hit_0); r_0 = r_2 + uint(objSomeValues_0.a_0); } else { - bool _S6 = (hitObjectIsMissNV((hit_0))); + bool _S7 = (hitObjectIsMissNV((hit_0))); uint r_3; - if(_S6) + if(_S7) { r_3 = 1U; } @@ -78,29 +79,29 @@ uint calcValue_0(hitObjectNV hit_0) return r_0; } -layout(std430, binding = 1) buffer _S7 { +layout(std430, binding = 1) buffer _S8 { uint _data[]; } outputBuffer_0; void main() { - uvec3 _S8 = ((gl_LaunchIDEXT)); - ivec2 launchID_0 = ivec2(_S8.xy); - uvec3 _S9 = ((gl_LaunchSizeEXT)); + uvec3 _S9 = ((gl_LaunchIDEXT)); + ivec2 launchID_0 = ivec2(_S9.xy); + uvec3 _S10 = ((gl_LaunchSizeEXT)); int idx_0 = launchID_0.x; RayDesc_0 ray_2; ray_2.Origin_0 = vec3(float(idx_0), 0.00000000000000000000, 0.00000000000000000000); ray_2.TMin_0 = 0.00999999977648258209; ray_2.Direction_0 = vec3(0.00000000000000000000, 1.00000000000000000000, 0.00000000000000000000); ray_2.TMax_0 = 10000.00000000000000000000; - RayDesc_0 _S10 = ray_2; + RayDesc_0 _S11 = ray_2; hitObjectNV hitObj_0; - hitObjectRecordHitWithIndexNV(hitObj_0, scene_0, int(uint(idx_0)), int(uint(idx_0 * 2)), int(uint(idx_0 * 3)), 0U, 0U, _S10.Origin_0, _S10.TMin_0, _S10.Direction_0, _S10.TMax_0, (0)); + hitObjectRecordHitWithIndexNV(hitObj_0, scene_0, int(uint(idx_0)), int(uint(idx_0 * 2)), int(uint(idx_0 * 3)), 0U, 0U, _S11.Origin_0, _S11.TMin_0, _S11.Direction_0, _S11.TMax_0, (0)); uint r_4 = calcValue_0(hitObj_0); - RayDesc_0 _S11 = ray_2; + RayDesc_0 _S12 = ray_2; hitObjectNV hitObj_1; - hitObjectRecordHitNV(hitObj_1, scene_0, int(uint(idx_0)), int(uint(idx_0 * 3)), int(uint(idx_0 * 2)), 0U, 0U, 4U, _S11.Origin_0, _S11.TMin_0, _S11.Direction_0, _S11.TMax_0, (0)); - uint _S12 = calcValue_0(hitObj_1); - uint r_5 = r_4 + _S12; + hitObjectRecordHitNV(hitObj_1, scene_0, int(uint(idx_0)), int(uint(idx_0 * 3)), int(uint(idx_0 * 2)), 0U, 0U, 4U, _S12.Origin_0, _S12.TMin_0, _S12.Direction_0, _S12.TMax_0, (0)); + uint _S13 = calcValue_0(hitObj_1); + uint r_5 = r_4 + _S13; ((outputBuffer_0)._data[(uint(idx_0))]) = r_5; return; } diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang.1.expected b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang.1.expected index 90223115b..f250c1c92 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang.1.expected +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang.1.expected @@ -79,36 +79,37 @@ void main() ivec2 launchID_0 = ivec2(_S3.xy); uvec3 _S4 = ((gl_LaunchSizeEXT)); int idx_0 = launchID_0.x; - SomeValues_0 someValues_0 = { idx_0, float(idx_0) * 2.00000000000000000000 }; + float _S5 = float(idx_0); + SomeValues_0 someValues_0 = { idx_0, _S5 * 2.00000000000000000000 }; RayDesc_0 ray_0; - ray_0.Origin_0 = vec3(float(idx_0), 0.00000000000000000000, 0.00000000000000000000); + ray_0.Origin_0 = vec3(_S5, 0.00000000000000000000, 0.00000000000000000000); ray_0.TMin_0 = 0.00999999977648258209; ray_0.Direction_0 = vec3(0.00000000000000000000, 1.00000000000000000000, 0.00000000000000000000); ray_0.TMax_0 = 10000.00000000000000000000; - RayDesc_0 _S5 = ray_0; + RayDesc_0 _S6 = ray_0; p_0 = someValues_0; hitObjectNV hitObj_0; - hitObjectTraceRayNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S5.Origin_0, _S5.TMin_0, _S5.Direction_0, _S5.TMax_0, (0)); + hitObjectTraceRayNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S6.Origin_0, _S6.TMin_0, _S6.Direction_0, _S6.TMax_0, (0)); uint r_1 = calcValue_0(hitObj_0); reorderThreadNV(hitObj_0); SomeValues_0 otherValues_0; - SomeValues_0 _S6 = { idx_0 * -1, float(idx_0) * 4.00000000000000000000 }; - otherValues_0 = _S6; + SomeValues_0 _S7 = { idx_0 * -1, _S5 * 4.00000000000000000000 }; + otherValues_0 = _S7; HitObject_Invoke_0(scene_0, hitObj_0, otherValues_0); - uint _S7 = calcValue_0(hitObj_0); - uint r_2 = r_1 + _S7; + uint _S8 = calcValue_0(hitObj_0); + uint r_2 = r_1 + _S8; reorderThreadNV(hitObj_0, uint(idx_0 & 3), 2U); - SomeValues_0 _S8 = { idx_0 * -2, float(idx_0) * 8.00000000000000000000 }; - otherValues_0 = _S8; + SomeValues_0 _S9 = { idx_0 * -2, _S5 * 8.00000000000000000000 }; + otherValues_0 = _S9; HitObject_Invoke_0(scene_0, hitObj_0, otherValues_0); - uint _S9 = calcValue_0(hitObj_0); - uint r_3 = r_2 + _S9; + uint _S10 = calcValue_0(hitObj_0); + uint r_3 = r_2 + _S10; reorderThreadNV(uint(idx_0 & 1), 1U); - SomeValues_0 _S10 = { idx_0 * -4, float(idx_0) * 16.00000000000000000000 }; - otherValues_0 = _S10; + SomeValues_0 _S11 = { idx_0 * -4, _S5 * 16.00000000000000000000 }; + otherValues_0 = _S11; HitObject_Invoke_0(scene_0, hitObj_0, otherValues_0); - uint _S11 = calcValue_0(hitObj_0); - uint r_4 = r_3 + _S11; + uint _S12 = calcValue_0(hitObj_0); + uint r_4 = r_3 + _S12; ((outputBuffer_0)._data[(uint(idx_0))]) = r_4; return; } diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang.1.expected b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang.1.expected index a86dc6aa7..f6f6f132d 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang.1.expected +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang.1.expected @@ -70,19 +70,20 @@ void main() int idx_0 = launchID_0.x; int _S5 = idx_0 / 4; float currentTime_0 = float(_S5); - SomeValues_0 someValues_0 = { idx_0, float(idx_0) * 2.00000000000000000000 }; + float _S6 = float(idx_0); + SomeValues_0 someValues_0 = { idx_0, _S6 * 2.00000000000000000000 }; RayDesc_0 ray_0; - ray_0.Origin_0 = vec3(float(idx_0), 0.00000000000000000000, 0.00000000000000000000); + ray_0.Origin_0 = vec3(_S6, 0.00000000000000000000, 0.00000000000000000000); ray_0.TMin_0 = 0.00999999977648258209; ray_0.Direction_0 = vec3(0.00000000000000000000, 1.00000000000000000000, 0.00000000000000000000); ray_0.TMax_0 = 10000.00000000000000000000; - RayDesc_0 _S6 = ray_0; + RayDesc_0 _S7 = ray_0; p_0 = someValues_0; hitObjectNV hitObj_0; - hitObjectTraceRayMotionNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S6.Origin_0, _S6.TMin_0, _S6.Direction_0, _S6.TMax_0, currentTime_0, (0)); - uint _S7 = uint(idx_0); - uint _S8 = calcValue_0(hitObj_0); - ((outputBuffer_0)._data[(_S7)]) = _S8; + hitObjectTraceRayMotionNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S7.Origin_0, _S7.TMin_0, _S7.Direction_0, _S7.TMax_0, currentTime_0, (0)); + uint _S8 = uint(idx_0); + uint _S9 = calcValue_0(hitObj_0); + ((outputBuffer_0)._data[(_S8)]) = _S9; return; } diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang.1.expected b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang.1.expected index 38ddbf233..16099b5e2 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang.1.expected +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang.1.expected @@ -67,19 +67,20 @@ void main() ivec2 launchID_0 = ivec2(_S3.xy); uvec3 _S4 = ((gl_LaunchSizeEXT)); int idx_0 = launchID_0.x; - SomeValues_0 someValues_0 = { idx_0, float(idx_0) * 2.00000000000000000000 }; + float _S5 = float(idx_0); + SomeValues_0 someValues_0 = { idx_0, _S5 * 2.00000000000000000000 }; RayDesc_0 ray_0; - ray_0.Origin_0 = vec3(float(idx_0), 0.00000000000000000000, 0.00000000000000000000); + ray_0.Origin_0 = vec3(_S5, 0.00000000000000000000, 0.00000000000000000000); ray_0.TMin_0 = 0.00999999977648258209; ray_0.Direction_0 = vec3(0.00000000000000000000, 1.00000000000000000000, 0.00000000000000000000); ray_0.TMax_0 = 10000.00000000000000000000; - RayDesc_0 _S5 = ray_0; + RayDesc_0 _S6 = ray_0; p_0 = someValues_0; hitObjectNV hitObj_0; - hitObjectTraceRayNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S5.Origin_0, _S5.TMin_0, _S5.Direction_0, _S5.TMax_0, (0)); - uint _S6 = uint(idx_0); - uint _S7 = calcValue_0(hitObj_0); - ((outputBuffer_0)._data[(_S6)]) = _S7; + hitObjectTraceRayNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S6.Origin_0, _S6.TMin_0, _S6.Direction_0, _S6.TMax_0, (0)); + uint _S7 = uint(idx_0); + uint _S8 = calcValue_0(hitObj_0); + ((outputBuffer_0)._data[(_S7)]) = _S8; return; } diff --git a/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl b/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl index 1818b7789..84eba46f0 100644 --- a/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl +++ b/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl @@ -19,11 +19,13 @@ void main() { beginInvocationInterlockARB(); - vec4 _S3 = (imageLoad((entryPointParams_texture_0), ivec2((uvec2(_S1.xy))))); - imageStore((entryPointParams_texture_0), ivec2((uvec2(_S1.xy))), _S3 + _S1); + vec2 _S3 = _S1.xy; + + vec4 _S4 = (imageLoad((entryPointParams_texture_0), ivec2((uvec2(_S3))))); + imageStore((entryPointParams_texture_0), ivec2((uvec2(_S3))), _S4 + _S1); endInvocationInterlockARB(); - _S2 = _S3; + _S2 = _S4; return; } diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl index 1da5f4f8a..864f44eb3 100644 --- a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl +++ b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl @@ -14,6 +14,7 @@ out vec4 _S2; void main() { + uvec2 _S3 = uvec2(0U, 0U); _S2 = gl_BaryCoordNV.x * ((_S1)[(0U)]) + gl_BaryCoordNV.y * ((_S1)[(1U)]) + gl_BaryCoordNV.z * ((_S1)[(2U)]); return; } diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl index 257b334bf..ce23492c9 100644 --- a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl +++ b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl @@ -8,7 +8,7 @@ void main( vector<float,3> bary_0 : SV_BARYCENTRICS, out vector<float,4> result_0 : SV_TARGET) { - result_0 = bary_0.x * GetAttributeAtVertex(color_0, (uint) int(0)) - + bary_0.y * GetAttributeAtVertex(color_0, (uint) int(1)) - + bary_0.z * GetAttributeAtVertex(color_0, (uint) int(2)); + result_0 = bary_0.x * GetAttributeAtVertex(color_0, 0U) + + bary_0.y * GetAttributeAtVertex(color_0, 1U) + + bary_0.z * GetAttributeAtVertex(color_0, 2U); } |
