summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff.cpp
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2023-04-25 10:43:29 -0400
committerGitHub <noreply@github.com>2023-04-25 10:43:29 -0400
commit7b7c095b37e85ca3a8f55eff1c3d9643d467b8e0 (patch)
tree9c71955dbc956b0058b19818ca127c8132cda512 /source/slang/slang-ir-autodiff.cpp
parent284cee1f246c072f190c87c8fb60c1d2181e458f (diff)
Dictionary using lowerCamel (#2835)
* #include an absolute path didn't work - because paths were taken to always be relative. * WIP lowerCamel Dictionary. * WIP more lowerCamel fixes for Dictionary. * Add/Remove/Clear * GetValue/Contains * Fix tabs in dictionary. Count -> getCount * Fix fields with caps. * Key -> key Value -> value Use m_ for members where appropriate. Use lowerCamel in linked list. * Some small fixes/improvements to Dictionary. * Kick CI.
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
-rw-r--r--source/slang/slang-ir-autodiff.cpp62
1 files changed, 31 insertions, 31 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 656b0e11b..4dac6b347 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -298,7 +298,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
// purposes.
auto primalType = pairType->getValueType();
- if (pairTypeCache.TryGetValue(primalType, result))
+ if (pairTypeCache.tryGetValue(primalType, result))
return result;
if (!pairType)
{
@@ -315,7 +315,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
if (!diffType)
return result;
result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
- pairTypeCache.Add(primalType, result);
+ pairTypeCache.add(primalType, result);
return result;
}
@@ -391,20 +391,20 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
{
if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
{
- auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType());
+ auto existingItem = differentiableWitnessDictionary.tryGetValue(item->getConcreteType());
if (existingItem)
{
*existingItem = item->getWitness();
}
else
{
- differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness());
+ differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness());
// Also register the type's differential type with the same witness.
IRBuilder subBuilder(item->getConcreteType());
if (!as<IRInterfaceType>(item->getConcreteType()))
{
- differentiableWitnessDictionary.AddIfNotExists(
+ differentiableWitnessDictionary.addIfNotExists(
(IRType*)_lookupWitness(&subBuilder, item->getWitness(), sharedContext->differentialAssocTypeStructKey),
item->getWitness());
}
@@ -418,7 +418,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
auto diffWitness = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeWitnessStructKey);
if (diffType && diffWitness)
{
- differentiableWitnessDictionary.AddIfNotExists((IRType*)diffType, diffWitness);
+ differentiableWitnessDictionary.addIfNotExists((IRType*)diffType, diffWitness);
}
}
}
@@ -429,7 +429,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type)
{
IRInst* foundResult = nullptr;
- differentiableWitnessDictionary.TryGetValue(type, foundResult);
+ differentiableWitnessDictionary.tryGetValue(type, foundResult);
return foundResult;
}
@@ -464,7 +464,7 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
{
if (auto pairType = as<IRDifferentialPairTypeBase>(globalInst))
{
- differentiableWitnessDictionary.AddIfNotExists(pairType->getValueType(), pairType->getWitness());
+ differentiableWitnessDictionary.addIfNotExists(pairType->getValueType(), pairType->getWitness());
}
}
}
@@ -873,9 +873,9 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst*
// Look for equivalent types.
for (auto type : context.differentiableWitnessDictionary)
{
- if (isTypeEqual(type.Key, (IRType*)typeInst))
+ if (isTypeEqual(type.key, (IRType*)typeInst))
{
- context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value;
+ context.differentiableWitnessDictionary[(IRType*)typeInst] = type.value;
return true;
}
}
@@ -1010,7 +1010,7 @@ struct AutoDiffPass : public InstPassBase
auto type = processIntermediateContextTypeBase(&subBuilder, baseFunc);
if (type)
{
- loweredIntermediateTypes.Add(type);
+ loweredIntermediateTypes.add(type);
inst->replaceUsesWith(type);
inst->removeAndDeallocate();
changed = true;
@@ -1034,7 +1034,7 @@ struct AutoDiffPass : public InstPassBase
// Utility function for topology sorting the intermediate context types.
bool isIntermediateContextTypeReadyForProcess(OrderedHashSet<IRInst*>& contextTypes, OrderedHashSet<IRInst*>& sortedSet, IRInst* t)
{
- if (!contextTypes.Contains(t))
+ if (!contextTypes.contains(t))
return true;
switch (t->getOp())
@@ -1082,7 +1082,7 @@ struct AutoDiffPass : public InstPassBase
{
if (auto e = as<IRDifferentiableTypeDictionaryItem>(entry))
{
- registeredType.Add(e->getOperand(0));
+ registeredType.add(e->getOperand(0));
}
}
// Use a work list to recursively walk through all sub fields of the struct type.
@@ -1092,9 +1092,9 @@ struct AutoDiffPass : public InstPassBase
{
auto t = wlist[i];
IntermediateContextTypeDifferentialInfo diffInfo;
- if (!diffTypes.TryGetValue(t, diffInfo))
+ if (!diffTypes.tryGetValue(t, diffInfo))
continue;
- if (registeredType.Add(t))
+ if (registeredType.add(t))
builder.addDifferentiableTypeEntry(diffDecor, t, diffInfo.diffWitness);
else
continue;
@@ -1115,16 +1115,16 @@ struct AutoDiffPass : public InstPassBase
OrderedHashSet<IRInst*> sortedContextTypes;
for (;;)
{
- auto lastCount = sortedContextTypes.Count();
+ auto lastCount = sortedContextTypes.getCount();
for (auto t : contextTypes)
{
- if (sortedContextTypes.Contains(t))
+ if (sortedContextTypes.contains(t))
continue;
// Have all dependent types been added yet?
if (isIntermediateContextTypeReadyForProcess(contextTypes, sortedContextTypes, t))
- sortedContextTypes.Add(t);
+ sortedContextTypes.add(t);
}
- if (lastCount == sortedContextTypes.Count())
+ if (lastCount == sortedContextTypes.getCount())
break;
}
@@ -1149,7 +1149,7 @@ struct AutoDiffPass : public InstPassBase
// A specialize of a context type translates to a specialize of its differential type/witness.
IntermediateContextTypeDifferentialInfo baseInfo;
- SLANG_RELEASE_ASSERT(diffTypes.TryGetValue(specialize->getBase(), baseInfo));
+ SLANG_RELEASE_ASSERT(diffTypes.tryGetValue(specialize->getBase(), baseInfo));
builder.setInsertBefore(t);
List<IRInst*> args;
for (UInt i = 0; i < specialize->getArgCount(); i++)
@@ -1170,7 +1170,7 @@ struct AutoDiffPass : public InstPassBase
// We currently don't support the `LookupInterfaceMethod` case, since it can't
// appear in a derivative function because we will only call the backward diff function without a intermediate-type
// via an interface.
- SLANG_RELEASE_ASSERT(diffTypes.ContainsKey(t));
+ SLANG_RELEASE_ASSERT(diffTypes.containsKey(t));
}
}
@@ -1178,16 +1178,16 @@ struct AutoDiffPass : public InstPassBase
for (auto t : diffTypes)
{
HashSet<IRFunc*> registeredFuncs;
- for (auto use = t.Key->firstUse; use; use = use->nextUse)
+ for (auto use = t.key->firstUse; use; use = use->nextUse)
{
auto parentFunc = getParentFunc(use->getUser());
if (!parentFunc)
continue;
- if (!registeredFuncs.Add(parentFunc))
+ if (!registeredFuncs.add(parentFunc))
continue;
if (auto dictDecor = parentFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
{
- registerDiffContextType(builder, dictDecor, diffTypes, t.Key);
+ registerDiffContextType(builder, dictDecor, diffTypes, t.key);
}
}
}
@@ -1222,7 +1222,7 @@ struct AutoDiffPass : public InstPassBase
else
{
IntermediateContextTypeDifferentialInfo diffFieldTypeInfo;
- diffTypes.TryGetValue(field->getFieldType(), diffFieldTypeInfo);
+ diffTypes.tryGetValue(field->getFieldType(), diffFieldTypeInfo);
diffFieldWitness = diffFieldTypeInfo.diffWitness;
}
if (diffFieldWitness)
@@ -1370,7 +1370,7 @@ struct AutoDiffPass : public InstPassBase
// any unmaterialized intermediate context types.
bool isTypeFullyDifferentiated(IRInst* type)
{
- if (fullyDifferentiatedInsts.Contains(type))
+ if (fullyDifferentiatedInsts.contains(type))
return true;
if (type->getOp() == kIROp_BackwardDiffIntermediateContextType)
return false;
@@ -1384,7 +1384,7 @@ struct AutoDiffPass : public InstPassBase
{
bool result = isTypeFullyDifferentiated(findGenericReturnVal(genType));
if (result)
- fullyDifferentiatedInsts.Add(genType);
+ fullyDifferentiatedInsts.add(genType);
return result;
}
switch (type->getOp())
@@ -1401,7 +1401,7 @@ struct AutoDiffPass : public InstPassBase
if (!isTypeFullyDifferentiated(type->getOperand(i)))
return false;
default:
- fullyDifferentiatedInsts.Add(type);
+ fullyDifferentiatedInsts.add(type);
return true;
}
}
@@ -1410,7 +1410,7 @@ struct AutoDiffPass : public InstPassBase
// any differentiate insts.
bool isFullyDifferentiated(IRFunc* func)
{
- if (fullyDifferentiatedInsts.Contains(func))
+ if (fullyDifferentiatedInsts.contains(func))
return true;
for (auto block : func->getBlocks())
@@ -1430,7 +1430,7 @@ struct AutoDiffPass : public InstPassBase
return false;
}
}
- fullyDifferentiatedInsts.Add(func);
+ fullyDifferentiatedInsts.add(func);
return true;
}
@@ -1439,7 +1439,7 @@ struct AutoDiffPass : public InstPassBase
//
bool processReferencedFunctions(IRBuilder* builder)
{
- fullyDifferentiatedInsts.Clear();
+ fullyDifferentiatedInsts.clear();
bool hasChanges = false;
for (;;)
{