summaryrefslogtreecommitdiff
path: root/source/slang/slang-lower-to-ir.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-20 14:22:00 -0400
committerGitHub <noreply@github.com>2022-10-20 11:22:00 -0700
commit1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch)
treee85158637680f783caaf7f4433a6844398cd8f7b /source/slang/slang-lower-to-ir.cpp
parent576c8407e60143682cd40c68101c6eae8563ca3d (diff)
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
-rw-r--r--source/slang/slang-lower-to-ir.cpp111
1 files changed, 97 insertions, 14 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index b03f3ae62..dc6067868 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1146,10 +1146,6 @@ static void addLinkageDecoration(
{
builder->addExternCppDecoration(inst, mangledName);
}
- if (decl->findModifier<JVPDerivativeModifier>())
- {
- builder->addJVPDerivativeMarkerDecoration(inst);
- }
if (as<InterfaceDecl>(decl->parentDecl) &&
decl->parentDecl->hasModifier<ComInterfaceAttribute>())
{
@@ -3042,6 +3038,38 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
return info;
}
+ LoweredValInfo visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr)
+ {
+ LoweredValInfo info = lowerSubExpr(expr->inner);
+
+ IRInst* irBaseVal = nullptr;
+ switch (info.flavor)
+ {
+ case LoweredValInfo::Flavor::Simple:
+ irBaseVal = getSimpleVal(context, info);
+ break;
+
+ case LoweredValInfo::Flavor::Ptr:
+ irBaseVal = info.val;
+ break;
+
+ default:
+ SLANG_UNEXPECTED("Unhandled lowered value cases");
+ }
+
+ // If the differentiable expr has an associated getter or setter, lower it
+ // and put it in a decoration.
+ //
+ if (expr->getterExpr != nullptr)
+ {
+ auto irGetter = lowerSubExpr(expr->getterExpr);
+ SLANG_ASSERT(irGetter.flavor == LoweredValInfo::Flavor::Simple);
+ getBuilder()->addDifferentialGetterDecoration(irBaseVal, irGetter.val);
+ }
+
+ return info;
+ }
+
// Emit IR to denote the forward-mode derivative
// of the inner func-expr. This will be resolved
// to a concrete function during the derivative
@@ -5844,6 +5872,45 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return LoweredValInfo();
}
+ LoweredValInfo visitDifferentiableTypeDictionary(DifferentiableTypeDictionary* decl)
+ {
+ for (auto & member : decl->members)
+ {
+ if (auto entry = as<DifferentiableTypeDictionaryItem>(member))
+ {
+
+ // Lower type and witness.
+ IRType* irType = lowerType(context, entry->baseType);
+ IRInst* irWitness = lowerVal(context, entry->confWitness).val;
+
+ SLANG_ASSERT(irType);
+
+ // If the witness can be lowered, and the differentiable type entry exists,
+ // add an entry to the context.
+ //
+ if (irWitness && !getBuilder()->findDifferentiableTypeEntry(irType))
+ getBuilder()->addDifferentiableTypeEntry(irType, irWitness);
+ }
+ else if (auto importEntry = as<DifferentiableTypeDictionaryImportItem>(member))
+ {
+ ensureDecl(context, importEntry->dictionaryRef.getDecl());
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unrecognized item in DifferentiableTypeDictionary");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+ }
+
+ if (auto diffTypeDict = getBuilder()->findOrEmitDifferentiableTypeDictionary())
+ {
+ // Place the dictionary at the end of modules and generic blocks.
+ diffTypeDict->moveToEnd();
+ }
+
+ return LoweredValInfo();
+ }
+
#define IGNORED_CASE(NAME) \
LoweredValInfo visit##NAME(NAME*) { return LoweredValInfo(); }
@@ -5853,6 +5920,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
IGNORED_CASE(SyntaxDecl)
IGNORED_CASE(AttributeDecl)
IGNORED_CASE(NamespaceDecl)
+ IGNORED_CASE(DifferentiableTypeDictionaryItem)
#undef IGNORED_CASE
@@ -6130,7 +6198,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, nullptr);
// Register the value now, rather than later, to avoid any possible infinite recursion.
- setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(irWitnessTable));
+ setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));
auto irSubType = lowerType(subContext, subType);
irWitnessTable->setOperand(0, irSubType);
@@ -7219,6 +7287,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
+ // We only need dictionaries to be lowered for decls with executable code (i.e. statements)
+ // Do not lower type dictionaries for inhertiance decls or decls
+ // that are declaring a type, since this can create a cyclic dependancy.
+ //
+ if (as<FunctionDeclBase>(leafDecl))
+ {
+ for (auto diffTypeDict : genericDecl->getMembersOfType<DifferentiableTypeDictionary>())
+ {
+ // We directly use lowerDecl() instead of ensureDecl() to emit to
+ // the current generic block instead of the top-level module.
+ //
+ lowerDecl(subContext, diffTypeDict);
+ }
+ }
+
return irGeneric;
}
@@ -7372,6 +7455,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), genericParam);
}
+
+ // Add a differentiable type dictionary if necessary.
+ if (auto diffTypeDict = subBuilder->findDifferentiableTypeDictionary(parentGeneric->getFirstBlock()))
+ markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), diffTypeDict);
}
if (valuesToClone.Count() == 0)
{
@@ -7723,6 +7810,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
addNameHint(context, irFunc, decl);
addLinkageDecoration(context, irFunc, decl);
+ if (decl->findModifier<JVPDerivativeModifier>())
+ {
+ getBuilder()->addJVPDerivativeMarkerDecoration(irFunc);
+ }
+
FuncDeclBaseTypeInfo info;
_lowerFuncDeclBaseTypeInfo(
subContext,
@@ -8788,15 +8880,6 @@ RefPtr<IRModule> generateIRForTranslationUnit(
// temporaries whenever possible.
constructSSA(module);
- // Process higher-order-function calls before any optimization passes
- // to allow the optimizations to affect the generated funcitons.
- // 1. Process JVP derivative functions.
- processJVPDerivativeMarkers(module, compileRequest->getSink());
- // 2. Process VJP derivative functions.
- // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet.
- // 3. Replace JVP & VJP calls.
- processDerivativeCalls(module);
-
// Do basic constant folding and dead code elimination
// using Sparse Conditional Constant Propagation (SCCP)
//