From 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 20 Oct 2022 14:22:00 -0400 Subject: Modified the new type system to support generic differentiable types … (#2413) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He --- source/slang/slang-ir.cpp | 150 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) (limited to 'source/slang/slang-ir.cpp') diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 46d6d445d..2aaeb4ac3 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3547,6 +3547,125 @@ namespace Slang } } + + IRInst* IRBuilder::emitDifferentiableTypeDictionary() + { + auto inst = createInst( + this, + kIROp_DifferentiableTypeDictionary, + nullptr); + + addGlobalValue(this, inst); + return inst; + } + + IRInst* IRBuilder::findOrEmitDifferentiableTypeDictionary() + { + auto currentLoc = this->getInsertLoc(); + auto currentInst = currentLoc.getInst(); + + if (auto diffTypeDictionary = findDifferentiableTypeDictionary(currentInst)) + return diffTypeDictionary; + + return emitDifferentiableTypeDictionary(); + } + + IRInst* IRBuilder::findDifferentiableTypeDictionary(IRInst* parent) + { + //auto parent = inst->getParent(); + while (parent) + { + // Inserting into the top level of a module? + // That is fine, and we can stop searching. + if (as(parent)) + break; + + // Inserting into a basic block inside of + // a generic? That is okay too. + if (auto block = as(parent)) + { + if (as(block->parent)) + break; + } + + // Otherwise, move up the chain. + parent = parent->parent; + } + + for (auto child = parent->getFirstChild(); child; child = child->getNextInst()) + { + if (child->getOp() == kIROp_DifferentiableTypeDictionary) + return child; + } + + return nullptr; + } + + IRInst* IRBuilder::addDifferentiableTypeEntry(IRInst* irType, IRInst* conformanceWitness) + { + auto oldLoc = this->getInsertLoc(); + + IRDifferentiableTypeDictionaryItem* item = nullptr; + + if (auto diffTypeDictionary = findOrEmitDifferentiableTypeDictionary()) + { + this->setInsertInto(diffTypeDictionary); + + IRInst* args[2] = {irType, conformanceWitness}; + item = createInstWithTrailingArgs( + this, + kIROp_DifferentiableTypeDictionaryItem, + nullptr, + 2, + args); + + addInst(item); + } + + this->setInsertLoc(oldLoc); + + return item; + } + + IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType, IRInst* scope) + { + for (auto child = scope->getFirstChild(); child; child = child->getNextInst()) + { + if (child->getOp() == kIROp_DifferentiableTypeDictionary) + { + for (auto entry = child->getFirstChild(); entry; entry = entry->getNextInst()) + { + IRInst* entryType = entry->getOperand(0); + IRInst* entryConformanceWitness = entry->getOperand(1); + + if (irType == entryType) + { + return entryConformanceWitness; + } + } + } + } + + return nullptr; + } + + IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType) + { + auto instScope = this->getInsertLoc().getInst(); + + while (instScope) + { + if (auto witness = findDifferentiableTypeEntry(irType, instScope)) + { + return witness; + } + instScope = instScope->getParent(); + } + + return nullptr; + } + + IRFunc* IRBuilder::createFunc() { IRFunc* rsFunc = createInst( @@ -6322,6 +6441,37 @@ namespace Slang return inst; } + IRInst* findOuterGeneric(IRInst* inst) + { + if (inst) + { + inst = inst->getParent(); + } + else + { + return nullptr; + } + + while(inst) + { + if (as(inst)) + return inst; + + inst = inst->getParent(); + } + return nullptr; + } + + IRInst* findOuterMostGeneric(IRInst* inst) + { + IRInst* currInst = inst; + while(auto outerGeneric = findOuterGeneric(currInst)) + { + currInst = outerGeneric; + } + return currInst; + } + IRGeneric* findSpecializedGeneric(IRSpecialize* specialize) { return as(specialize->getBase()); -- cgit v1.2.3