summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp48
1 files changed, 24 insertions, 24 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 355381559..e1601c39a 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -93,7 +93,7 @@ public:
return false;
}
- if (auto existingLevel = differentiableFunctions.TryGetValue(func))
+ if (auto existingLevel = differentiableFunctions.tryGetValue(func))
return *existingLevel >= level;
if (func->findDecoration<IRTreatAsDifferentiableDecoration>())
@@ -124,7 +124,7 @@ public:
{
if (as<IRGeneric>(func))
{
- if (auto existingLevel = differentiableFunctions.TryGetValue(func))
+ if (auto existingLevel = differentiableFunctions.tryGetValue(func))
{
if (*existingLevel >= level)
return true;
@@ -242,8 +242,8 @@ public:
{
if (as<IROutTypeBase>(param->getFullType()))
differentiableOutputs++;
- produceDiffSet.Add(param);
- carryNonTrivialDiffSet.Add(param);
+ produceDiffSet.add(param);
+ carryNonTrivialDiffSet.add(param);
}
}
if (auto funcType = as<IRFuncType>(funcInst->getDataType()))
@@ -283,7 +283,7 @@ public:
return false;
for (UInt i = 0; i < inst->getOperandCount(); i++)
{
- if (produceDiffSet.Contains(inst->getOperand(i)))
+ if (produceDiffSet.contains(inst->getOperand(i)))
{
return true;
}
@@ -315,7 +315,7 @@ public:
return false;
for (UInt i = 0; i < inst->getOperandCount(); i++)
{
- if (carryNonTrivialDiffSet.Contains(inst->getOperand(i)))
+ if (carryNonTrivialDiffSet.contains(inst->getOperand(i)))
{
return true;
}
@@ -330,7 +330,7 @@ public:
{
if (isInstInFunc(inst, funcInst))
{
- if (expectDiffInstWorkListSet.Add(inst))
+ if (expectDiffInstWorkListSet.add(inst))
{
expectDiffInstWorkList.add(inst);
}
@@ -341,7 +341,7 @@ public:
Index lastProduceDiffCount = 0;
do
{
- lastProduceDiffCount = produceDiffSet.Count();
+ lastProduceDiffCount = produceDiffSet.getCount();
for (auto block : funcInst->getBlocks())
{
if (block != funcInst->getFirstBlock())
@@ -357,10 +357,10 @@ public:
if (branch->getArgCount() > paramIndex)
{
auto arg = branch->getArg(paramIndex);
- if (produceDiffSet.Contains(arg))
- produceDiffSet.Add(param);
- if (carryNonTrivialDiffSet.Contains(arg))
- carryNonTrivialDiffSet.Add(param);
+ if (produceDiffSet.contains(arg))
+ produceDiffSet.add(param);
+ if (carryNonTrivialDiffSet.contains(arg))
+ carryNonTrivialDiffSet.add(param);
}
}
}
@@ -370,9 +370,9 @@ public:
for (auto inst : block->getChildren())
{
if (isInstProducingDiff(inst))
- produceDiffSet.Add(inst);
+ produceDiffSet.add(inst);
if (isInstCarryingOverDiff(inst))
- carryNonTrivialDiffSet.Add(inst);
+ carryNonTrivialDiffSet.add(inst);
switch (inst->getOp())
{
case kIROp_Call:
@@ -406,14 +406,14 @@ public:
}
}
}
- } while (produceDiffSet.Count() != lastProduceDiffCount);
+ } while (produceDiffSet.getCount() != lastProduceDiffCount);
// Reverse propagate `expectDiffSet`.
for (int i = 0; i < expectDiffInstWorkList.getCount(); i++)
{
auto inst = expectDiffInstWorkList[i];
// Is inst in produceDiffSet?
- if (!produceDiffSet.Contains(inst))
+ if (!produceDiffSet.contains(inst))
{
if (auto call = as<IRCall>(inst))
{
@@ -526,7 +526,7 @@ public:
{
if (auto storeInst = as<IRStore>(inst))
{
- if (carryNonTrivialDiffSet.Contains(storeInst->getVal()) &&
+ if (carryNonTrivialDiffSet.contains(storeInst->getVal()) &&
!canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()))
{
sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation);
@@ -569,24 +569,24 @@ public:
if (_isDifferentiableFuncImpl(inst, DifferentiableLevel::Backward))
{
if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
- bwdDifferentiableSymbolNames.Add(linkageDecor->getMangledName());
- differentiableFunctions.Add(inst, DifferentiableLevel::Backward);
+ bwdDifferentiableSymbolNames.add(linkageDecor->getMangledName());
+ differentiableFunctions.add(inst, DifferentiableLevel::Backward);
}
else if (_isDifferentiableFuncImpl(inst, DifferentiableLevel::Forward))
{
if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
- fwdDifferentiableSymbolNames.Add(linkageDecor->getMangledName());
- differentiableFunctions.Add(inst, DifferentiableLevel::Forward);
+ fwdDifferentiableSymbolNames.add(linkageDecor->getMangledName());
+ differentiableFunctions.add(inst, DifferentiableLevel::Forward);
}
}
for (auto inst : module->getGlobalInsts())
{
if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
{
- if (bwdDifferentiableSymbolNames.Contains(linkageDecor->getMangledName()))
+ if (bwdDifferentiableSymbolNames.contains(linkageDecor->getMangledName()))
differentiableFunctions[inst] = DifferentiableLevel::Backward;
- else if (fwdDifferentiableSymbolNames.Contains(linkageDecor->getMangledName()))
- differentiableFunctions.AddIfNotExists(inst, DifferentiableLevel::Forward);
+ else if (fwdDifferentiableSymbolNames.contains(linkageDecor->getMangledName()))
+ differentiableFunctions.addIfNotExists(inst, DifferentiableLevel::Forward);
}
}