diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-10-20 14:22:00 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-20 11:22:00 -0700 |
| commit | 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch) | |
| tree | e85158637680f783caaf7f4433a6844398cd8f7b /source/slang/slang-lower-to-ir.cpp | |
| parent | 576c8407e60143682cd40c68101c6eae8563ca3d (diff) | |
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions.
* Changed a few asserts to release asserts to avoid unreferenced variable errors
* Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl
* Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types.
* Moved the auto-diff passes to after the specialization step, added a more complex generics test
* Added a generics stress test and fixed AST-side logic. IR side needs some more work
* Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions
* Changed differential getters to use pointer types, added getter type checking
* Fixed some bugs related to diff type registration and differential getters
* Removed some superfluous code
* Removed some more unused code.
* Fixed an issue with witness substitution
* Minor fix
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 111 |
1 files changed, 97 insertions, 14 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index b03f3ae62..dc6067868 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1146,10 +1146,6 @@ static void addLinkageDecoration( { builder->addExternCppDecoration(inst, mangledName); } - if (decl->findModifier<JVPDerivativeModifier>()) - { - builder->addJVPDerivativeMarkerDecoration(inst); - } if (as<InterfaceDecl>(decl->parentDecl) && decl->parentDecl->hasModifier<ComInterfaceAttribute>()) { @@ -3042,6 +3038,38 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return info; } + LoweredValInfo visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr) + { + LoweredValInfo info = lowerSubExpr(expr->inner); + + IRInst* irBaseVal = nullptr; + switch (info.flavor) + { + case LoweredValInfo::Flavor::Simple: + irBaseVal = getSimpleVal(context, info); + break; + + case LoweredValInfo::Flavor::Ptr: + irBaseVal = info.val; + break; + + default: + SLANG_UNEXPECTED("Unhandled lowered value cases"); + } + + // If the differentiable expr has an associated getter or setter, lower it + // and put it in a decoration. + // + if (expr->getterExpr != nullptr) + { + auto irGetter = lowerSubExpr(expr->getterExpr); + SLANG_ASSERT(irGetter.flavor == LoweredValInfo::Flavor::Simple); + getBuilder()->addDifferentialGetterDecoration(irBaseVal, irGetter.val); + } + + return info; + } + // Emit IR to denote the forward-mode derivative // of the inner func-expr. This will be resolved // to a concrete function during the derivative @@ -5844,6 +5872,45 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return LoweredValInfo(); } + LoweredValInfo visitDifferentiableTypeDictionary(DifferentiableTypeDictionary* decl) + { + for (auto & member : decl->members) + { + if (auto entry = as<DifferentiableTypeDictionaryItem>(member)) + { + + // Lower type and witness. + IRType* irType = lowerType(context, entry->baseType); + IRInst* irWitness = lowerVal(context, entry->confWitness).val; + + SLANG_ASSERT(irType); + + // If the witness can be lowered, and the differentiable type entry exists, + // add an entry to the context. + // + if (irWitness && !getBuilder()->findDifferentiableTypeEntry(irType)) + getBuilder()->addDifferentiableTypeEntry(irType, irWitness); + } + else if (auto importEntry = as<DifferentiableTypeDictionaryImportItem>(member)) + { + ensureDecl(context, importEntry->dictionaryRef.getDecl()); + } + else + { + SLANG_UNEXPECTED("Unrecognized item in DifferentiableTypeDictionary"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + } + + if (auto diffTypeDict = getBuilder()->findOrEmitDifferentiableTypeDictionary()) + { + // Place the dictionary at the end of modules and generic blocks. + diffTypeDict->moveToEnd(); + } + + return LoweredValInfo(); + } + #define IGNORED_CASE(NAME) \ LoweredValInfo visit##NAME(NAME*) { return LoweredValInfo(); } @@ -5853,6 +5920,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> IGNORED_CASE(SyntaxDecl) IGNORED_CASE(AttributeDecl) IGNORED_CASE(NamespaceDecl) + IGNORED_CASE(DifferentiableTypeDictionaryItem) #undef IGNORED_CASE @@ -6130,7 +6198,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, nullptr); // Register the value now, rather than later, to avoid any possible infinite recursion. - setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(irWitnessTable)); + setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable))); auto irSubType = lowerType(subContext, subType); irWitnessTable->setOperand(0, irSubType); @@ -7219,6 +7287,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } + // We only need dictionaries to be lowered for decls with executable code (i.e. statements) + // Do not lower type dictionaries for inhertiance decls or decls + // that are declaring a type, since this can create a cyclic dependancy. + // + if (as<FunctionDeclBase>(leafDecl)) + { + for (auto diffTypeDict : genericDecl->getMembersOfType<DifferentiableTypeDictionary>()) + { + // We directly use lowerDecl() instead of ensureDecl() to emit to + // the current generic block instead of the top-level module. + // + lowerDecl(subContext, diffTypeDict); + } + } + return irGeneric; } @@ -7372,6 +7455,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), genericParam); } + + // Add a differentiable type dictionary if necessary. + if (auto diffTypeDict = subBuilder->findDifferentiableTypeDictionary(parentGeneric->getFirstBlock())) + markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), diffTypeDict); } if (valuesToClone.Count() == 0) { @@ -7723,6 +7810,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> addNameHint(context, irFunc, decl); addLinkageDecoration(context, irFunc, decl); + if (decl->findModifier<JVPDerivativeModifier>()) + { + getBuilder()->addJVPDerivativeMarkerDecoration(irFunc); + } + FuncDeclBaseTypeInfo info; _lowerFuncDeclBaseTypeInfo( subContext, @@ -8788,15 +8880,6 @@ RefPtr<IRModule> generateIRForTranslationUnit( // temporaries whenever possible. constructSSA(module); - // Process higher-order-function calls before any optimization passes - // to allow the optimizations to affect the generated funcitons. - // 1. Process JVP derivative functions. - processJVPDerivativeMarkers(module, compileRequest->getSink()); - // 2. Process VJP derivative functions. - // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet. - // 3. Replace JVP & VJP calls. - processDerivativeCalls(module); - // Do basic constant folding and dead code elimination // using Sparse Conditional Constant Propagation (SCCP) // |
