diff options
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 48 |
1 files changed, 24 insertions, 24 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 355381559..e1601c39a 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -93,7 +93,7 @@ public: return false; } - if (auto existingLevel = differentiableFunctions.TryGetValue(func)) + if (auto existingLevel = differentiableFunctions.tryGetValue(func)) return *existingLevel >= level; if (func->findDecoration<IRTreatAsDifferentiableDecoration>()) @@ -124,7 +124,7 @@ public: { if (as<IRGeneric>(func)) { - if (auto existingLevel = differentiableFunctions.TryGetValue(func)) + if (auto existingLevel = differentiableFunctions.tryGetValue(func)) { if (*existingLevel >= level) return true; @@ -242,8 +242,8 @@ public: { if (as<IROutTypeBase>(param->getFullType())) differentiableOutputs++; - produceDiffSet.Add(param); - carryNonTrivialDiffSet.Add(param); + produceDiffSet.add(param); + carryNonTrivialDiffSet.add(param); } } if (auto funcType = as<IRFuncType>(funcInst->getDataType())) @@ -283,7 +283,7 @@ public: return false; for (UInt i = 0; i < inst->getOperandCount(); i++) { - if (produceDiffSet.Contains(inst->getOperand(i))) + if (produceDiffSet.contains(inst->getOperand(i))) { return true; } @@ -315,7 +315,7 @@ public: return false; for (UInt i = 0; i < inst->getOperandCount(); i++) { - if (carryNonTrivialDiffSet.Contains(inst->getOperand(i))) + if (carryNonTrivialDiffSet.contains(inst->getOperand(i))) { return true; } @@ -330,7 +330,7 @@ public: { if (isInstInFunc(inst, funcInst)) { - if (expectDiffInstWorkListSet.Add(inst)) + if (expectDiffInstWorkListSet.add(inst)) { expectDiffInstWorkList.add(inst); } @@ -341,7 +341,7 @@ public: Index lastProduceDiffCount = 0; do { - lastProduceDiffCount = produceDiffSet.Count(); + lastProduceDiffCount = produceDiffSet.getCount(); for (auto block : funcInst->getBlocks()) { if (block != funcInst->getFirstBlock()) @@ -357,10 +357,10 @@ public: if (branch->getArgCount() > paramIndex) { auto arg = branch->getArg(paramIndex); - if (produceDiffSet.Contains(arg)) - produceDiffSet.Add(param); - if (carryNonTrivialDiffSet.Contains(arg)) - carryNonTrivialDiffSet.Add(param); + if (produceDiffSet.contains(arg)) + produceDiffSet.add(param); + if (carryNonTrivialDiffSet.contains(arg)) + carryNonTrivialDiffSet.add(param); } } } @@ -370,9 +370,9 @@ public: for (auto inst : block->getChildren()) { if (isInstProducingDiff(inst)) - produceDiffSet.Add(inst); + produceDiffSet.add(inst); if (isInstCarryingOverDiff(inst)) - carryNonTrivialDiffSet.Add(inst); + carryNonTrivialDiffSet.add(inst); switch (inst->getOp()) { case kIROp_Call: @@ -406,14 +406,14 @@ public: } } } - } while (produceDiffSet.Count() != lastProduceDiffCount); + } while (produceDiffSet.getCount() != lastProduceDiffCount); // Reverse propagate `expectDiffSet`. for (int i = 0; i < expectDiffInstWorkList.getCount(); i++) { auto inst = expectDiffInstWorkList[i]; // Is inst in produceDiffSet? - if (!produceDiffSet.Contains(inst)) + if (!produceDiffSet.contains(inst)) { if (auto call = as<IRCall>(inst)) { @@ -526,7 +526,7 @@ public: { if (auto storeInst = as<IRStore>(inst)) { - if (carryNonTrivialDiffSet.Contains(storeInst->getVal()) && + if (carryNonTrivialDiffSet.contains(storeInst->getVal()) && !canAddressHoldDerivative(diffTypeContext, storeInst->getPtr())) { sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); @@ -569,24 +569,24 @@ public: if (_isDifferentiableFuncImpl(inst, DifferentiableLevel::Backward)) { if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>()) - bwdDifferentiableSymbolNames.Add(linkageDecor->getMangledName()); - differentiableFunctions.Add(inst, DifferentiableLevel::Backward); + bwdDifferentiableSymbolNames.add(linkageDecor->getMangledName()); + differentiableFunctions.add(inst, DifferentiableLevel::Backward); } else if (_isDifferentiableFuncImpl(inst, DifferentiableLevel::Forward)) { if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>()) - fwdDifferentiableSymbolNames.Add(linkageDecor->getMangledName()); - differentiableFunctions.Add(inst, DifferentiableLevel::Forward); + fwdDifferentiableSymbolNames.add(linkageDecor->getMangledName()); + differentiableFunctions.add(inst, DifferentiableLevel::Forward); } } for (auto inst : module->getGlobalInsts()) { if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>()) { - if (bwdDifferentiableSymbolNames.Contains(linkageDecor->getMangledName())) + if (bwdDifferentiableSymbolNames.contains(linkageDecor->getMangledName())) differentiableFunctions[inst] = DifferentiableLevel::Backward; - else if (fwdDifferentiableSymbolNames.Contains(linkageDecor->getMangledName())) - differentiableFunctions.AddIfNotExists(inst, DifferentiableLevel::Forward); + else if (fwdDifferentiableSymbolNames.contains(linkageDecor->getMangledName())) + differentiableFunctions.addIfNotExists(inst, DifferentiableLevel::Forward); } } |
