summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-link.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-09 19:19:17 -0800
committerGitHub <noreply@github.com>2022-11-09 19:19:17 -0800
commit004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 (patch)
treecbc942746bab043da0eb5298993d95f9665dfddf /source/slang/slang-ir-link.cpp
parentcedd93690c63188cf98e452c9d104cf51aad6c4e (diff)
Add `[ForwardDerivativeOf]` attribute. (#2501)
* Add [ForwardDerivativeOf] attribute. * Fix handling around phi nodes. * Fixes. * Remove IR opcode for ForwardDerivativeOfDecoration. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-link.cpp')
-rw-r--r--source/slang/slang-ir-link.cpp114
1 files changed, 83 insertions, 31 deletions
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index ad4f691f1..cf0293f0d 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -6,7 +6,7 @@
#include "slang-ir-insts.h"
#include "slang-mangle.h"
#include "slang-ir-string-hash.h"
-
+#include "slang-ir-diff-jvp.h"
#include "slang-module-library.h"
#include "../compiler-core/slang-artifact.h"
@@ -412,7 +412,43 @@ IRGlobalVar* cloneGlobalVarImpl(
/// For a given decoration opcode, only one such decoration will ever be copied, and nothing
/// will be copied if the instruction already has a matching decoration (that was cloned
/// from the "best" definition).
- ///
+ ///
+static void cloneExtraDecorationsFromInst(
+ IRSpecContextBase* context,
+ IRBuilder* builder,
+ IRInst* clonedInst,
+ IRInst* originalInst)
+{
+ for (auto decoration : originalInst->getDecorations())
+ {
+ switch (decoration->getOp())
+ {
+ default:
+ break;
+
+ case kIROp_HLSLExportDecoration:
+ case kIROp_BindExistentialSlotsDecoration:
+ case kIROp_LayoutDecoration:
+ case kIROp_PublicDecoration:
+ case kIROp_SequentialIDDecoration:
+ case kIROp_ForwardDerivativeDecoration:
+ if (!clonedInst->findDecorationImpl(decoration->getOp()))
+ {
+ cloneInst(context, builder, decoration);
+ }
+ break;
+ }
+ }
+
+ // We will also copy over source location information from the alternative
+ // values, in case any of them has it available.
+ //
+ if (originalInst->sourceLoc.isValid() && !clonedInst->sourceLoc.isValid())
+ {
+ clonedInst->sourceLoc = originalInst->sourceLoc;
+ }
+}
+
static void cloneExtraDecorations(
IRSpecContextBase* context,
IRInst* clonedInst,
@@ -435,34 +471,7 @@ static void cloneExtraDecorations(
for(auto sym = originalValues.sym; sym; sym = sym->nextWithSameName)
{
- for(auto decoration : sym->irGlobalValue->getDecorations())
- {
- switch(decoration->getOp())
- {
- default:
- break;
-
- case kIROp_HLSLExportDecoration:
- case kIROp_BindExistentialSlotsDecoration:
- case kIROp_LayoutDecoration:
- case kIROp_PublicDecoration:
- case kIROp_SequentialIDDecoration:
- case kIROp_ForwardDerivativeDecoration:
- if(!clonedInst->findDecorationImpl(decoration->getOp()))
- {
- cloneInst(context, builder, decoration);
- }
- break;
- }
- }
-
- // We will also copy over source location information from the alternative
- // values, in case any of them has it available.
- //
- if(sym->irGlobalValue->sourceLoc.isValid() && !clonedInst->sourceLoc.isValid())
- {
- clonedInst->sourceLoc = sym->irGlobalValue->sourceLoc;
- }
+ cloneExtraDecorationsFromInst(context, builder, clonedInst, sym->irGlobalValue);
}
}
@@ -547,6 +556,43 @@ IRGeneric* cloneGenericImpl(
originalVal,
originalValues);
+ // We want to clone extra decorations on the
+ // return value from other symbols as well.
+ auto clonedInnerVal = findGenericReturnVal(clonedVal);
+ for (auto originalSym = originalValues.sym; originalSym;
+ originalSym = originalSym->nextWithSameName.get())
+ {
+ auto originalGeneric = as<IRGeneric>(originalSym->irGlobalValue);
+ if (!originalGeneric)
+ continue;
+ auto originalInnerVal = findGenericReturnVal(originalGeneric);
+
+ // Register all generic parameters before cloning the decorations.
+ auto clonedParam = clonedVal->getFirstParam();
+ auto originalParam = originalGeneric->getFirstParam();
+
+ ShortList<KeyValuePair<IRInst*, IRInst*>> paramMapping;
+ for (; clonedParam && originalParam; (clonedParam = as<IRParam>(clonedParam->next)), (originalParam = as<IRParam>(originalParam->next)))
+ {
+ paramMapping.add(KeyValuePair<IRInst*, IRInst*>(clonedParam, originalParam));
+ }
+ // Generic parameter list does not match, bail.
+ if (clonedParam || originalParam)
+ continue;
+ for (auto kv : paramMapping)
+ {
+ registerClonedValue(context, kv.Key, kv.Value);
+ }
+
+ IRBuilder builderStorage = *builder;
+ IRBuilder* decorBuilder = &builderStorage;
+ decorBuilder->setInsertInto(clonedInnerVal);
+ if (auto firstChild = clonedInnerVal->getFirstChild())
+ {
+ decorBuilder->setInsertBefore(firstChild);
+ }
+ cloneExtraDecorationsFromInst(context, decorBuilder, clonedInnerVal, originalInnerVal);
+ }
return clonedVal;
}
@@ -694,7 +740,6 @@ void cloneGlobalValueWithCodeCommon(
cb = cb->getNextBlock();
}
}
-
}
void checkIRDuplicate(IRInst* inst, IRInst* moduleInst, UnownedStringSlice const& mangledName)
@@ -1405,6 +1450,13 @@ LinkedIR linkIR(
//
List<IRModule*> irModules;
+
+ // Link stdlib modules.
+ auto builtinLinkage = static_cast<Session*>(linkage->getGlobalSession())->getBuiltinLinkage();
+ for (auto& m : builtinLinkage->mapNameToLoadedModules)
+ irModules.add(m.Value->getIRModule());
+
+ // Link modules in the program.
program->enumerateIRModules([&](IRModule* irModule)
{
irModules.add(irModule);