diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-29 17:05:07 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-29 17:05:07 -0700 |
| commit | 082c48d96c5f8f6b4f560d705fe731da14409cb4 (patch) | |
| tree | fe9860aea3326cd321365bc5530a917fcef94718 | |
| parent | a862f5b7007ef50b5def30506f0cea138b73c710 (diff) | |
Update checkpoint policy to make obvious recompute decisions. (#2753)
* Update checkpoint policy to make obvious recompute decisions.
Also adds an optimization to fold updateElement chains on the same array or struct into a single makeArray or makeStruct.
* Bug fixes around array types with different int typed count.
* change test.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 208 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.h | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 168 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 110 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 53 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 2 | ||||
| -rw-r--r-- | tests/autodiff/array-param.slang | 83 | ||||
| -rw-r--r-- | tests/autodiff/array-param.slang.expected.txt | 2 |
12 files changed, 495 insertions, 197 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index cfcb15269..a14ed38d8 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1546,7 +1546,12 @@ namespace Slang // it is possible that we are referring to a generic value param if (auto declRefExpr = expr.as<DeclRefExpr>()) { - auto declRef = getDeclRef(m_astBuilder, declRefExpr); + auto checkedExpr = as<DeclRefExpr>(CheckTerm(expr.getExpr())); + if (!checkedExpr) + return nullptr; + + SubstExpr<DeclRefExpr> substExpr(checkedExpr, expr.getSubsts()); + auto declRef = getDeclRef(m_astBuilder, substExpr); if (auto genericValParamRef = declRef.as<GenericValueParamDecl>()) { diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 793a8ff07..04d5560d9 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -23,7 +23,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal HashSet<IRUse*> processedUses; HashSet<IRUse*> usesToReplace; - + auto addPrimalOperandsToWorkList = [&](IRInst* inst) { UIndex opIndex = 0; @@ -144,7 +144,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal if (auto var = as<IRVar>(result.instToRecompute)) { IRUse* storeUse = findUniqueStoredVal(var); - if (!storeUse) + if (storeUse) workList.add(storeUse); } else @@ -635,40 +635,216 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( return hoistInfo; } -void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode*) +void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func) { - // Do nothing.. This is an (almost) always-store policy. + domTree = computeDominatorTree(func); return; } -HoistResult DefaultCheckpointPolicy::classify(IRUse* use) +static bool doesInstHaveDiffUse(IRInst* inst) { - // Store all that we can.. by default, classify will only be called on relevant differential - // uses (or on uses in a 'recompute' inst) - // - if (auto var = as<IRVar>(use->get())) + bool hasDiffUser = false; + + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (isDiffInst(user)) + { + // Ignore uses that is a return or MakeDiffPair + switch (user->getOp()) + { + case kIROp_Return: + continue; + case kIROp_MakeDifferentialPair: + if (!user->hasMoreThanOneUse() && user->firstUse && + user->firstUse->getUser()->getOp() == kIROp_Return) + continue; + break; + default: + break; + } + hasDiffUser = true; + break; + } + } + + return hasDiffUser; +} + +static bool doesInstHaveStore(IRInst* inst) +{ + SLANG_RELEASE_ASSERT(as<IRPtrTypeBase>(inst->getDataType())); + + for (auto use = inst->firstUse; use; use = use->nextUse) + { + if (as<IRStore>(use->getUser())) + return true; + + if (as<IRPtrTypeBase>(use->getUser()->getDataType())) + { + if (doesInstHaveStore(use->getUser())) + return true; + } + } + + return false; +} + +static bool isIntermediateContextType(IRType* type) +{ + switch (type->getOp()) + { + case kIROp_BackwardDiffIntermediateContextType: + return true; + case kIROp_PtrType: + return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType()); + case kIROp_ArrayType: + return isIntermediateContextType(as<IRArrayType>(type)->getElementType()); + } + + return false; +} + +static bool shouldStoreVar(IRVar* var) +{ + // Always store intermediate context var. + if (auto typeDecor = var->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) { + // If we are specializing a callee's intermediate context with types that can't be stored, + // we can't store the entire context. if (auto spec = as<IRSpecialize>(as<IRPtrTypeBase>(var->getDataType())->getValueType())) { for (UInt i = 0; i < spec->getArgCount(); i++) { if (!canTypeBeStored(spec->getArg(i)->getDataType())) - return HoistResult::recompute(use->get()); + return false; } - return HoistResult::store(use->get()); } - else // if (canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType())); + return true; + } + + if (isIntermediateContextType(var->getDataType())) + { + return true; + } + + // For now the store policy is simple, we use two conditions: + // 1. Is the var used in a differential block and, + // 2. Does the var have a store + // + + return (doesInstHaveDiffUse(var) && doesInstHaveStore(var) && canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType())); +} + +static bool shouldStoreInst(IRInst* inst) +{ + if (!inst->getDataType()) + { + return false; + } + + if (!canTypeBeStored(inst->getDataType())) + return false; + + // Never store certain opcodes. + switch (inst->getOp()) + { + case kIROp_CastFloatToInt: + case kIROp_CastIntToFloat: + case kIROp_IntCast: + case kIROp_FloatCast: + case kIROp_MakeVectorFromScalar: + case kIROp_MakeMatrixFromScalar: + case kIROp_Reinterpret: + case kIROp_BitCast: + case kIROp_DefaultConstruct: + case kIROp_MakeStruct: + case kIROp_MakeTuple: + case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: + case kIROp_MakeDifferentialPair: + case kIROp_MakeOptionalNone: + case kIROp_MakeOptionalValue: + case kIROp_DifferentialPairGetDifferential: + case kIROp_DifferentialPairGetPrimal: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialWitnessTable: + case kIROp_undefined: + return false; + case kIROp_GetElement: + case kIROp_FieldExtract: + case kIROp_swizzle: + case kIROp_UpdateElement: + case kIROp_OptionalHasValue: + case kIROp_GetOptionalValue: + case kIROp_MatrixReshape: + case kIROp_VectorReshape: + // If the operand is already stored, don't store the result of these insts. + if (inst->getOperand(0)->findDecoration<IRPrimalValueStructKeyDecoration>()) { - return HoistResult::store(use->get()); + return false; } + break; + default: + break; + } + + // Only store if the inst has differential inst user. + bool hasDiffUser = doesInstHaveDiffUse(inst); + if (!hasDiffUser) + return false; + + return true; +} + +bool canRecompute(IRDominatorTree* domTree, IRUse* use) +{ + auto param = as<IRParam>(use->get()); + if (!param) + return true; + auto paramBlock = as<IRBlock>(param->getParent()); + for (auto predecessor : paramBlock->getPredecessors()) + { + // If we hit this, the checkpoint policy is trying to recompute + // values across a loop region boundary (we don't currently support this, + // and in general this is quite inefficient in both compute & memory) + // + if (domTree->dominates(paramBlock, predecessor)) + { + return false; + } + } + return true; +} + +HoistResult DefaultCheckpointPolicy::classify(IRUse* use) +{ + // Store all that we can.. by default, classify will only be called on relevant differential + // uses (or on uses in a 'recompute' inst) + // + if (auto var = as<IRVar>(use->get())) + { + if (shouldStoreVar(var)) + return HoistResult::store(var); + else + return HoistResult::recompute(var); } else { - if (canTypeBeStored(use->get()->getDataType())) + if (shouldStoreInst(use->get())) return HoistResult::store(use->get()); else - return HoistResult::recompute(use->get()); + { + // We may not be able to recompute due to limitations of + // the unzip pass. If so we will store the result. + if (canRecompute(domTree, use)) + return HoistResult::recompute(use->get()); + + // The fallback is to store. + return HoistResult::store(use->get()); + } } } -};
\ No newline at end of file +}; diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h index dc85942f6..bd2575172 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.h +++ b/source/slang/slang-ir-autodiff-primal-hoist.h @@ -218,7 +218,7 @@ namespace Slang class AutodiffCheckpointPolicyBase : public RefObject { - public: + public: AutodiffCheckpointPolicyBase(IRModule* module) : module(module) { } @@ -233,14 +233,14 @@ namespace Slang virtual HoistResult classify(IRUse* diffBlockUse) = 0; - protected: + protected: IRModule* module; }; class DefaultCheckpointPolicy : public AutodiffCheckpointPolicyBase { - public: + public: DefaultCheckpointPolicy(IRModule* module) : AutodiffCheckpointPolicyBase(module) @@ -248,6 +248,8 @@ namespace Slang virtual void preparePolicy(IRGlobalValueWithCode* func); virtual HoistResult classify(IRUse* use); + + RefPtr<IRDominatorTree> domTree; }; RefPtr<HoistedPrimalsInfo> applyCheckpointSet( @@ -261,4 +263,4 @@ namespace Slang IRGlobalValueWithCode* func, Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlockInfo); -};
\ No newline at end of file +}; diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 91a1601fb..8c005a5c6 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1621,10 +1621,18 @@ struct DiffTransposePass if (auto varToHoist = as<IRVar>(inst)) { varToHoist->insertBefore(varBlock->getFirstOrdinaryInst()); - inst = findUniqueStoredVal(varToHoist)->getUser(); - SLANG_ASSERT(inst); + auto uniqueStoreUse = findUniqueStoredVal(varToHoist); + if (uniqueStoreUse) + { + inst = uniqueStoreUse->getUser(); + SLANG_ASSERT(inst); - defBlock = getBlock(inst); + defBlock = getBlock(inst); + } + else + { + defBlock = getBlock(inst); + } } else { diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index af7792748..53f0cbba2 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -91,18 +91,6 @@ struct ExtractPrimalFuncContext return newFuncType; } - bool isDiffInst(IRInst* inst) - { - if (inst->findDecoration<IRDifferentialInstDecoration>() || - inst->findDecoration<IRMixedDifferentialInstDecoration>()) - return true; - - if (auto block = as<IRBlock>(inst->getParent())) - return isDiffInst(block); - - return false; - } - IRInst* insertIntoReturnBlock(IRBuilder& builder, IRInst* inst) { if (!isDiffInst(inst)) @@ -130,162 +118,6 @@ struct ExtractPrimalFuncContext } } - bool doesInstHaveDiffUse(IRInst* inst) - { - bool hasDiffUser = false; - - for (auto use = inst->firstUse; use; use = use->nextUse) - { - auto user = use->getUser(); - if (isDiffInst(user)) - { - // Ignore uses that is a return or MakeDiffPair - switch (user->getOp()) - { - case kIROp_Return: - continue; - case kIROp_MakeDifferentialPair: - if (!user->hasMoreThanOneUse() && user->firstUse && - user->firstUse->getUser()->getOp() == kIROp_Return) - continue; - break; - default: - break; - } - hasDiffUser = true; - break; - } - } - - return hasDiffUser; - } - - bool doesInstHaveStore(IRInst* inst) - { - SLANG_RELEASE_ASSERT(as<IRPtrTypeBase>(inst->getDataType())); - - for (auto use = inst->firstUse; use; use = use->nextUse) - { - if (as<IRStore>(use->getUser())) - return true; - - if (as<IRPtrTypeBase>(use->getUser()->getDataType())) - { - if (doesInstHaveStore(use->getUser())) - return true; - } - } - - return false; - } - - bool isIntermediateContextType(IRType* type) - { - switch (type->getOp()) - { - case kIROp_BackwardDiffIntermediateContextType: - return true; - case kIROp_PtrType: - return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType()); - case kIROp_ArrayType: - return isIntermediateContextType(as<IRArrayType>(type)->getElementType()); - } - - return false; - } - - bool shouldStoreVar(IRVar* var) - { - // Always store intermediate context var. - if (auto typeDecor = var->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) - { - // If we are specializing a callee's intermediate context with types that can't be stored, - // we can't store the entire context. - if (auto spec = as<IRSpecialize>(as<IRPtrTypeBase>(var->getDataType())->getValueType())) - { - for (UInt i = 0; i < spec->getArgCount(); i++) - { - if (!canTypeBeStored(spec->getArg(i)->getDataType())) - return false; - } - } - return true; - } - - if (isIntermediateContextType(var->getDataType())) - { - return true; - } - - // For now the store policy is simple, we use two conditions: - // 1. Is the var used in a differential block and, - // 2. Does the var have a store - // - - return (doesInstHaveDiffUse(var) && doesInstHaveStore(var) && canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType())); - } - - bool shouldStoreInst(IRInst* inst) - { - if (!inst->getDataType()) - { - return false; - } - - if (!canTypeBeStored(inst->getDataType())) - return false; - - // Never store certain opcodes. - switch (inst->getOp()) - { - case kIROp_CastFloatToInt: - case kIROp_CastIntToFloat: - case kIROp_IntCast: - case kIROp_FloatCast: - case kIROp_MakeVectorFromScalar: - case kIROp_MakeMatrixFromScalar: - case kIROp_Reinterpret: - case kIROp_BitCast: - case kIROp_DefaultConstruct: - case kIROp_MakeStruct: - case kIROp_MakeTuple: - case kIROp_MakeArray: - case kIROp_MakeArrayFromElement: - case kIROp_MakeDifferentialPair: - case kIROp_MakeOptionalNone: - case kIROp_MakeOptionalValue: - case kIROp_DifferentialPairGetDifferential: - case kIROp_DifferentialPairGetPrimal: - case kIROp_ExtractExistentialValue: - case kIROp_ExtractExistentialType: - case kIROp_ExtractExistentialWitnessTable: - return false; - case kIROp_GetElement: - case kIROp_FieldExtract: - case kIROp_swizzle: - case kIROp_UpdateElement: - case kIROp_OptionalHasValue: - case kIROp_GetOptionalValue: - case kIROp_MatrixReshape: - case kIROp_VectorReshape: - // If the operand is already stored, don't store the result of these insts. - if (inst->getOperand(0)->findDecoration<IRPrimalValueStructKeyDecoration>()) - { - return false; - } - break; - default: - break; - } - - // Only store if the inst has differential inst user. - bool hasDiffUser = doesInstHaveDiffUse(inst); - if (!hasDiffUser) - return false; - - return true; - } - IRStructField* addIntermediateContextField(IRInst* type, IRInst* intermediateOutput) { IRBuilder genTypeBuilder(module); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 10c751d52..024d31fd8 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -282,15 +282,29 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( IRBuilder* builder, IRType* originalPairType) { IRInst* result = nullptr; - if (pairTypeCache.TryGetValue(originalPairType, result)) - return result; auto pairType = as<IRDifferentialPairTypeBase>(originalPairType); if (!pairType) + return originalPairType; + + // We make our type cache keyed on the primal type, not the pair type. + // This is because there may be duplicate pair types for the same + // primal type but different witness tables, and we don't want to treat + // them as distinct. + // We might want to consider making witness tables part of IR + // deduplication (make them HOISTABLE insts), but that is a bigger + // change. Another alternative is to make the witness operand of + // `IRDifferentialPairTypeBase` be child instead of an operand + // so that it is not considered part of the type for deduplication + // purposes. + + auto primalType = pairType->getValueType(); + if (pairTypeCache.TryGetValue(primalType, result)) + return result; + if (!pairType) { result = originalPairType; return result; } - auto primalType = pairType->getValueType(); if (as<IRParam>(primalType)) { result = nullptr; @@ -301,7 +315,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( if (!diffType) return result; result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); - pairTypeCache.Add(originalPairType, result); + pairTypeCache.Add(primalType, result); return result; } @@ -1820,4 +1834,16 @@ bool isDerivativeContextVar(IRVar* var) return var->findDecoration<IRBackwardDerivativePrimalContextDecoration>(); } +bool isDiffInst(IRInst* inst) +{ + if (inst->findDecoration<IRDifferentialInstDecoration>() || + inst->findDecoration<IRMixedDifferentialInstDecoration>()) + return true; + + if (auto block = as<IRBlock>(inst->getParent())) + return isDiffInst(block); + + return false; +} + } diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index da0cdc755..167aa2357 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -346,5 +346,6 @@ IRUse* findUniqueStoredVal(IRVar* var); bool isDerivativeContextVar(IRVar* var); +bool isDiffInst(IRInst* inst); }; diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index 65b5d2f45..ab3f0ceab 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -43,6 +43,9 @@ struct PeepholeContext : InstPassBase chainKey.reverse(); if (auto updateInst = as<IRUpdateElement>(chainNode)) { + // If we see an extract(updateElement(x, accessChain, val), accessChain), then + // we can replace the inst with val. + if (updateInst->getAccessKeyCount() > (UInt)chainKey.getCount()) return false; @@ -96,6 +99,8 @@ struct PeepholeContext : InstPassBase } else if (isAccessChainNotEqual) { + // If we see an extract(updateElement(x, accessChain, val), accessChain2), where accessChain!=accessChain2, + // then we can replace the inst with extract(x, accessChain2). IRBuilder builder(module); builder.setInsertBefore(inst); auto newInst = builder.emitElementExtract(updateInst->getOldValue(), chainKey.getArrayView()); @@ -445,12 +450,65 @@ struct PeepholeContext : InstPassBase changed = true; } } + else + { + // Check if the updated value is a chain of `updateElement` instructions that + // updates every element in the same array, and if so we can replace the + // whole chain with a single `makeArray` instruction. + auto arrayType = as<IRArrayType>(inst->getDataType()); + if (!arrayType) break; + auto arraySize = as<IRIntLit>(arrayType->getElementCount()); + if (!arraySize) break; + + List<IRInst*> args; + args.setCount((UInt)arraySize->getValue()); + for (Index i = 0; i < args.getCount(); i++) + args[i] = nullptr; + + for (auto updateElement = updateInst; updateElement; + updateElement = as<IRUpdateElement>(updateElement->getOldValue())) + { + auto subKey = updateElement->getAccessKey(0); + auto subConstIndex = as<IRIntLit>(subKey); + if (!subConstIndex) + break; + auto index = (Index)subConstIndex->getValue(); + if (index >= args.getCount()) + break; + // If we have already seen an update for this index, then we can't + // override it with an earlier update. + if (args[index]) + continue; + args[index] = updateElement->getElementValue(); + } + + bool isComplete = true; + for (auto arg : args) + { + if (!arg) + { + isComplete = false; + break; + } + } + if (isComplete) + { + IRBuilder builder(module); + builder.setInsertBefore(inst); + auto makeArray = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer()); + inst->replaceUsesWith(makeArray); + maybeRemoveOldInst(inst); + changed = true; + } + } } else if (auto structKey = as<IRStructKey>(key)) { auto oldVal = inst->getOperand(0); if (oldVal->getOp() == kIROp_MakeStruct) { + // If we see updateElement(makeStruct(...), structKey, ...), we can + // replace it with a makeStruct that has the updated value. auto structType = as<IRStructType>(inst->getDataType()); if (!structType) break; List<IRInst*> args; @@ -484,6 +542,58 @@ struct PeepholeContext : InstPassBase changed = true; } } + else + { + // Check if the updated `oldVal` is a chain of updateElement insts that assigns + // values to every field of the struct, if so, we can just emit a makeStruct instead. + Dictionary<IRStructKey*, IRInst*> mapFieldKeyToVal; + for (auto updateElement = as<IRUpdateElement>(inst); updateElement; + updateElement = as<IRUpdateElement>(updateElement->getOldValue())) + { + if (updateElement->getAccessKeyCount() != 1) + break; + auto subStructKey = as<IRStructKey>(updateElement->getAccessKey(0)); + if (!subStructKey) + break; + + // If the key already exists, it means there is already a later update at this key. + // We need to be careful not to override it with an earlier value. + // AddIfNotExists will ensure this does not happen. + mapFieldKeyToVal.AddIfNotExists( + subStructKey, updateElement->getElementValue()); + } + + // Check if every field of the struct has a value assigned to it, + // while build up arguments for makeStruct inst at the same time. + auto structType = as<IRStructType>(inst->getDataType()); + if (!structType) break; + List<IRInst*> args; + bool isComplete = true; + for (auto field : structType->getFields()) + { + IRInst* arg = nullptr; + if (mapFieldKeyToVal.TryGetValue(field->getKey(), arg)) + { + args.add(arg); + } + else + { + isComplete = false; + break; + } + } + + if (!isComplete) break; + + // Create a makeStruct inst using args. + + IRBuilder builder(module); + builder.setInsertBefore(inst); + auto makeStruct = builder.emitMakeStruct(structType, (UInt)args.getCount(), args.getBuffer()); + inst->replaceUsesWith(makeStruct); + maybeRemoveOldInst(inst); + changed = true; + } } } break; diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index a4d7840d3..16926a742 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2389,6 +2389,25 @@ namespace Slang capabilitySetType, kIROp_CapabilitySet, args.getCount(), args.getBuffer()); } + static void canonicalizeInstOperands(IRBuilder& builder, IRInst* inst) + { + // For Array types, we always want to make sure its element count + // has an int32_t type. We will convert all other int types to int32_t + // to avoid things like float[8] and float[8U] being distinct types. + if (inst->getOp() == kIROp_ArrayType) + { + IRInst* elementCount = inst->getOperand(1); + if (auto intLit = as<IRIntLit>(elementCount)) + { + if (intLit->getDataType()->getOp() != kIROp_IntType) + { + IRInst* newElementCount = builder.getIntValue(builder.getIntType(), intLit->getValue()); + inst->getOperands()[1].usedValue = newElementCount; + } + } + } + } + IRInst* IRBuilder::_findOrEmitHoistableInst( IRType* type, IROp op, @@ -2448,6 +2467,8 @@ namespace Slang } } + canonicalizeInstOperands(*this, inst); + // Find or add the key/inst { IRInstKey key = { inst }; @@ -2515,6 +2536,38 @@ namespace Slang UInt operandCount, IRInst* const* operands) { + switch (op) + { + case kIROp_ArrayType: + { + ShortList<IRInst*, 2> newOperands; + newOperands.addRange(operands, operandCount); + + // If elementCount does not have int type, then we always cast + // it to an int type, to avoid having to deal with the + // possibility that an array<int, 2> and an array<int, 2U> are + // treated as distinct types. + if (operandCount < 2) break; + auto elementCount = operands[1]; + if (elementCount->getFullType() && elementCount->getFullType()->getOp() != kIROp_IntType) + { + auto intLit = as<IRIntLit>(elementCount); + if (intLit) + elementCount = getIntValue(getIntType(), intLit->getValue()); + else + elementCount = emitIntrinsicInst(getIntType(), kIROp_IntCast, 1, &elementCount); + } + newOperands[1] = elementCount; + return (IRType*)createIntrinsicInst( + nullptr, + op, + operandCount, + newOperands.getArrayView().getBuffer()); + } + default: + break; + } + return (IRType*)createIntrinsicInst( nullptr, op, diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 6d4e50463..3affcff44 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -6446,7 +6446,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable))); auto irSubType = lowerType(subContext, subType); - irWitnessTable->setOperand(0, irSubType); + irWitnessTable->setConcreteType(irSubType); // TODO(JS): // Should the mangled name take part in obfuscation if enabled? diff --git a/tests/autodiff/array-param.slang b/tests/autodiff/array-param.slang new file mode 100644 index 000000000..fd78b3246 --- /dev/null +++ b/tests/autodiff/array-param.slang @@ -0,0 +1,83 @@ + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +static const uint32_t N_LATENT_DIMS = 4; +static const uint32_t kDecoderInputCount = 6; +struct LatentTexture +{ + static const uint32_t kLatentDimsCount = N_LATENT_DIMS; + static const uint32_t kLatentTextureCount = N_LATENT_DIMS / 4; + + [BackwardDifferentiable] + void getCodeStochastic(float2 uv, out float code[kLatentDimsCount]) + { + return getCode(uint2(1,2), code); + } + + void getCode(uint2 texel, out float code[kLatentDimsCount]) + { + for (uint i = 0; i < kLatentTextureCount; ++i) + { + for (uint j = 0; j < 4; ++j) + { + code[i * 4 + j] = j; + } + } + } + [BackwardDerivativeOf(getCode)] + void bwd_getCode(uint2 texel, float d_out[kLatentDimsCount]) + { + outputBuffer[0] = d_out[0]; + } +} + +static LatentTexture gLatents; + +[BackwardDifferentiable] +void test(float arr[10], out float result[3]) +{ + float sum = 0; + [ForceUnroll] + for (int i = 0; i < LatentTexture.kLatentDimsCount + kDecoderInputCount; i++) + sum += arr[i]; + result[0] = sum; + result[1] = sum; + result[2] = sum; +} + +[BackwardDifferentiable] +float evalDecoder() +{ + // Latent code. + float latentCode[LatentTexture.kLatentDimsCount]; + gLatents.getCodeStochastic(float2(1,2), latentCode); + + // Model input. + float input[kDecoderInputCount + LatentTexture.kLatentDimsCount]; + input[0] = 0; + input[1] = 1; + input[2] = 2; + input[3] = 3; + input[4] = 4; + input[5] = 5; + [ForceUnroll] + for (int i = 0; i < LatentTexture.kLatentDimsCount; i++) + { + input[kDecoderInputCount + i] = latentCode[i]; + } + + float res[3]; + test(input, res); + return res[0] + res[1] + res[2]; +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + __bwd_diff(evalDecoder)(1.0); +} diff --git a/tests/autodiff/array-param.slang.expected.txt b/tests/autodiff/array-param.slang.expected.txt new file mode 100644 index 000000000..f38cc1080 --- /dev/null +++ b/tests/autodiff/array-param.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +3.0 |
