summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-20 14:22:00 -0400
committerGitHub <noreply@github.com>2022-10-20 11:22:00 -0700
commit1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch)
treee85158637680f783caaf7f4433a6844398cd8f7b /source/slang/slang-ir.cpp
parent576c8407e60143682cd40c68101c6eae8563ca3d (diff)
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir.cpp')
-rw-r--r--source/slang/slang-ir.cpp150
1 files changed, 150 insertions, 0 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 46d6d445d..2aaeb4ac3 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3547,6 +3547,125 @@ namespace Slang
}
}
+
+ IRInst* IRBuilder::emitDifferentiableTypeDictionary()
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_DifferentiableTypeDictionary,
+ nullptr);
+
+ addGlobalValue(this, inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::findOrEmitDifferentiableTypeDictionary()
+ {
+ auto currentLoc = this->getInsertLoc();
+ auto currentInst = currentLoc.getInst();
+
+ if (auto diffTypeDictionary = findDifferentiableTypeDictionary(currentInst))
+ return diffTypeDictionary;
+
+ return emitDifferentiableTypeDictionary();
+ }
+
+ IRInst* IRBuilder::findDifferentiableTypeDictionary(IRInst* parent)
+ {
+ //auto parent = inst->getParent();
+ while (parent)
+ {
+ // Inserting into the top level of a module?
+ // That is fine, and we can stop searching.
+ if (as<IRModuleInst>(parent))
+ break;
+
+ // Inserting into a basic block inside of
+ // a generic? That is okay too.
+ if (auto block = as<IRBlock>(parent))
+ {
+ if (as<IRGeneric>(block->parent))
+ break;
+ }
+
+ // Otherwise, move up the chain.
+ parent = parent->parent;
+ }
+
+ for (auto child = parent->getFirstChild(); child; child = child->getNextInst())
+ {
+ if (child->getOp() == kIROp_DifferentiableTypeDictionary)
+ return child;
+ }
+
+ return nullptr;
+ }
+
+ IRInst* IRBuilder::addDifferentiableTypeEntry(IRInst* irType, IRInst* conformanceWitness)
+ {
+ auto oldLoc = this->getInsertLoc();
+
+ IRDifferentiableTypeDictionaryItem* item = nullptr;
+
+ if (auto diffTypeDictionary = findOrEmitDifferentiableTypeDictionary())
+ {
+ this->setInsertInto(diffTypeDictionary);
+
+ IRInst* args[2] = {irType, conformanceWitness};
+ item = createInstWithTrailingArgs<IRDifferentiableTypeDictionaryItem>(
+ this,
+ kIROp_DifferentiableTypeDictionaryItem,
+ nullptr,
+ 2,
+ args);
+
+ addInst(item);
+ }
+
+ this->setInsertLoc(oldLoc);
+
+ return item;
+ }
+
+ IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType, IRInst* scope)
+ {
+ for (auto child = scope->getFirstChild(); child; child = child->getNextInst())
+ {
+ if (child->getOp() == kIROp_DifferentiableTypeDictionary)
+ {
+ for (auto entry = child->getFirstChild(); entry; entry = entry->getNextInst())
+ {
+ IRInst* entryType = entry->getOperand(0);
+ IRInst* entryConformanceWitness = entry->getOperand(1);
+
+ if (irType == entryType)
+ {
+ return entryConformanceWitness;
+ }
+ }
+ }
+ }
+
+ return nullptr;
+ }
+
+ IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType)
+ {
+ auto instScope = this->getInsertLoc().getInst();
+
+ while (instScope)
+ {
+ if (auto witness = findDifferentiableTypeEntry(irType, instScope))
+ {
+ return witness;
+ }
+ instScope = instScope->getParent();
+ }
+
+ return nullptr;
+ }
+
+
IRFunc* IRBuilder::createFunc()
{
IRFunc* rsFunc = createInst<IRFunc>(
@@ -6322,6 +6441,37 @@ namespace Slang
return inst;
}
+ IRInst* findOuterGeneric(IRInst* inst)
+ {
+ if (inst)
+ {
+ inst = inst->getParent();
+ }
+ else
+ {
+ return nullptr;
+ }
+
+ while(inst)
+ {
+ if (as<IRGeneric>(inst))
+ return inst;
+
+ inst = inst->getParent();
+ }
+ return nullptr;
+ }
+
+ IRInst* findOuterMostGeneric(IRInst* inst)
+ {
+ IRInst* currInst = inst;
+ while(auto outerGeneric = findOuterGeneric(currInst))
+ {
+ currInst = outerGeneric;
+ }
+ return currInst;
+ }
+
IRGeneric* findSpecializedGeneric(IRSpecialize* specialize)
{
return as<IRGeneric>(specialize->getBase());