summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-07-17 16:04:20 -0700
committerGitHub <noreply@github.com>2025-07-17 23:04:20 +0000
commit094d1ba7cd1eb5f09be05b2e57b5fbd3041cca38 (patch)
treef9768d9608ae27ac56aef641fbf9c1cac651711a /source
parented1a0b8b53c7556fbf0ccab4f3496078eea4c8a2 (diff)
Prelink ForceInlined functions during lowering. (#7812)
* Prelink ForceInlined functions during lowering. * Fixes and cleanups. * Fix warning. * Fix crash.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-decl.h2
-rw-r--r--source/slang/slang-ir-link.cpp185
-rw-r--r--source/slang/slang-ir-link.h9
-rw-r--r--source/slang/slang-lower-to-ir.cpp61
4 files changed, 223 insertions, 34 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 7b2cc4007..3fa929d6f 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -616,7 +616,7 @@ FIDDLE(abstract)
class FunctionDeclBase : public CallableDecl
{
FIDDLE(...)
- FIDDLE() Stmt* body = nullptr;
+ Stmt* body = nullptr;
};
// A constructor/initializer to create instances of a type
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index d8bc041fb..b874b9f28 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -40,9 +40,6 @@ struct IRSpecEnv
struct IRSharedSpecContext
{
- // The code-generation target in use
- CodeGenTarget target;
-
// The API-level target request
TargetRequest* targetReq = nullptr;
@@ -1224,6 +1221,9 @@ bool isBetterForTarget(IRSpecContext* context, IRInst* newVal, IRInst* oldVal)
return true;
}
+ if (!context->getShared()->targetReq)
+ return false;
+
// For right now every declaration might have zero or more
// decorations, representing the capabilities for which it is specialized.
// Each decorations has a `CapabilitySet` to represent what it requires of a target.
@@ -1605,7 +1605,6 @@ void initializeSharedSpecContext(
IRSharedSpecContext* sharedContext,
Session* session,
IRModule* inModule,
- CodeGenTarget target,
TargetRequest* targetReq)
{
RefPtr<IRModule> module = inModule;
@@ -1617,7 +1616,6 @@ void initializeSharedSpecContext(
sharedContext->builderStorage = IRBuilder(module);
sharedContext->module = module;
- sharedContext->target = target;
sharedContext->targetReq = targetReq;
}
@@ -2040,7 +2038,7 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
auto& irModules = stateStorage.contextStorage.irModules;
auto sharedContext = state->getSharedContext();
- initializeSharedSpecContext(sharedContext, session, nullptr, target, targetReq);
+ initializeSharedSpecContext(sharedContext, session, nullptr, targetReq);
state->irModule = sharedContext->module;
@@ -2290,6 +2288,181 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
return linkedIR;
}
+
+struct IRPrelinkContext : IRSpecContext
+{
+ // The overriding logic for cloning an external symbol during prelinking stage.
+ // We only want to clone the body of a function if it is marked as unsafeForceInlineEarly.
+ // For anything else, we just clone a declaration without body, and mark it as [Import].
+ //
+ virtual IRInst* maybeCloneValue(IRInst* originalVal) override
+ {
+ // If `originalVal` has a linkage, and the current module already contains
+ // a symbol with the same mangled name, then we will skip and return that
+ // prexisting val.
+ if (auto linkage = originalVal->findDecoration<IRLinkageDecoration>())
+ {
+ RefPtr<IRSpecSymbol> symbol;
+ if (shared->symbols.tryGetValue(linkage->getMangledName()), symbol)
+ {
+ return symbol->irGlobalValue;
+ }
+ }
+
+ // If this is referencing a global value with linkage but that global value does not
+ // exist in the current module, then we will clone a declaration of it and mark it
+ // [Import].
+ //
+ auto completeClonedInst = [&](IRInst* inst)
+ {
+ String mangledName;
+ ShortList<IRInst*> decorsToRemove;
+ bool hasImportDecor = false;
+ for (auto decor : inst->getDecorations())
+ {
+ if (auto exportDecor = as<IRExportDecoration>(decor))
+ {
+ mangledName = exportDecor->getMangledName();
+ decorsToRemove.add(exportDecor);
+ }
+ else if (as<IRImportDecoration>(decor))
+ {
+ hasImportDecor = true;
+ }
+ }
+ if (mangledName.getLength() && !hasImportDecor)
+ {
+ builder->addImportDecoration(inst, mangledName.getUnownedSlice());
+ }
+ for (auto decor : decorsToRemove)
+ {
+ decor->removeFromParent();
+ }
+ if (mangledName.getLength())
+ {
+ // Register the symbol in the shared context, so we don't
+ // clone any symbols with the same mangled name again.
+ RefPtr<IRSpecSymbol> symbol = new IRSpecSymbol();
+ symbol->nextWithSameName = nullptr;
+ symbol->irGlobalValue = inst;
+ shared->symbols[mangledName] = symbol;
+ }
+ return inst;
+ };
+
+ auto builderForClone = builder;
+ if (as<IRModuleInst>(originalVal->getParent()))
+ {
+ // If we are cloning a global value, we will use the module builder.
+ builderForClone = &shared->builderStorage;
+ }
+ IRInst* clonedInst = nullptr;
+ switch (originalVal->getOp())
+ {
+ case kIROp_Generic:
+ case kIROp_GlobalVar:
+ case kIROp_GlobalParam:
+ case kIROp_GlobalConstant:
+ case kIROp_StructKey:
+ case kIROp_InterfaceRequirementEntry:
+ case kIROp_GlobalGenericParam:
+ case kIROp_InterfaceType:
+ return completeClonedInst(
+ cloneGlobalValueImpl(this, originalVal, IROriginalValuesForClone(originalVal)));
+ case kIROp_WitnessTable:
+ {
+ auto witnessTable = as<IRWitnessTable>(originalVal);
+ clonedInst = builder->createWitnessTable(
+ cloneType(this, (IRType*)witnessTable->getConformanceType()),
+ cloneType(this, witnessTable->getConcreteType()));
+ break;
+ }
+ case kIROp_Func:
+ // For functions, we will clone the full body only if it is [unsafeForceInlineEarly].
+ if (originalVal->findDecoration<IRUnsafeForceInlineEarlyDecoration>())
+ {
+ return completeClonedInst(
+ cloneGlobalValueImpl(this, originalVal, IROriginalValuesForClone(originalVal)));
+ }
+ else
+ {
+ clonedInst = builderForClone->createFunc();
+ }
+ break;
+ case kIROp_StructType:
+ clonedInst = builderForClone->createStructType();
+ break;
+ case kIROp_ClassType:
+ clonedInst = builderForClone->createClassType();
+ break;
+ default:
+ return completeClonedInst(IRSpecContext::maybeCloneValue(originalVal));
+ }
+
+ // Clone without body.
+ registerClonedValue(this, clonedInst, IROriginalValuesForClone(originalVal));
+ clonedInst->setFullType(cloneType(this, originalVal->getFullType()));
+
+ // Clone decorations
+ cloneDecorations(this, clonedInst, originalVal);
+ completeClonedInst(clonedInst);
+ return clonedInst;
+ }
+};
+
+void prelinkIR(Module* module, IRModule* irModule, const List<IRInst*>& externalSymbolsToLink)
+{
+ // Setup environment.
+ IRSharedSpecContext sharedContext;
+ sharedContext.builderStorage = IRBuilder(irModule->getModuleInst());
+ sharedContext.module = irModule;
+
+ IRPrelinkContext specContext;
+ specContext.builder = &sharedContext.builderStorage;
+ specContext.env = &sharedContext.globalEnv;
+ specContext.shared = &sharedContext;
+ specContext.irModules.add(module->getIRModule());
+ for (auto importedModule : module->getModuleDependencies())
+ {
+ if (importedModule->getIRModule())
+ specContext.irModules.add(importedModule->getIRModule());
+ }
+ auto linkage = module->getLinkage();
+ auto globalSession = static_cast<Session*>(linkage->getGlobalSession());
+ List<IRModule*> builtinModules;
+ for (auto& m : globalSession->coreModules)
+ builtinModules.add(m->getIRModule());
+
+ // First, register all external symbols in the current module.
+ insertGlobalValueSymbols(&sharedContext, irModule);
+
+ List<KeyValuePair<IRInst*, IRInst*>> pendingReplacements;
+ for (auto originalInst : externalSymbolsToLink)
+ {
+ // originalInst is the function in the imported module to clone.
+ // We should lookup the inst in the current module with the same mangled name,
+ // that's the inst we want to remove and replace with the cloned inst.
+ auto mangledName = getMangledName(originalInst);
+ auto existingInst = specContext.findSymbols(mangledName)->irGlobalValue;
+ specContext.shared->symbols.remove(mangledName);
+ specContext.builder->setInsertBefore(existingInst);
+
+ // Remove existing inst from the module before cloning so our duplication-check
+ // (`checkIRDuplicate`) doesn't complain.
+ existingInst->removeFromParent();
+
+ auto cloned = cloneValue(&specContext, originalInst);
+ pendingReplacements.add(KeyValuePair<IRInst*, IRInst*>(existingInst, cloned));
+ }
+
+ // Now we can replace all the inlined extern symbols with the cloned values.
+ for (auto kv : pendingReplacements)
+ {
+ kv.key->replaceUsesWith(kv.value);
+ kv.key->removeAndDeallocate();
+ }
+}
+
struct ReplaceGlobalConstantsPass
{
void process(IRModule* module)
diff --git a/source/slang/slang-ir-link.h b/source/slang/slang-ir-link.h
index ee0312ca3..c8b96b961 100644
--- a/source/slang/slang-ir-link.h
+++ b/source/slang/slang-ir-link.h
@@ -25,6 +25,15 @@ struct LinkedIR
//
LinkedIR linkIR(CodeGenContext* codeGenContext);
+// Prelinking is a step that happens immediately after visiting all AST nodes during IR lowering,
+// and before any IR validation steps. Prelinking copys all extern symbols that are
+// [unsafeForceInlineEarly] into the current module being lowered, so that we can perform necessary
+// inlining passes before any data-flow analysis.
+// `externalSymbolsToLink` is a list of IRInsts defined in the imported modules that we need to pull
+// into the current module.
+//
+void prelinkIR(Module* module, IRModule* irModule, const List<IRInst*>& externalSymbolsToLink);
+
// Replace any global constants in the IR module with their
// definitions, if possible.
//
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 1f5909e94..ab9f85b21 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -18,6 +18,7 @@
#include "slang-ir-inline.h"
#include "slang-ir-insert-debug-value-store.h"
#include "slang-ir-insts.h"
+#include "slang-ir-link.h"
#include "slang-ir-loop-inversion.h"
#include "slang-ir-lower-defer.h"
#include "slang-ir-lower-error-handling.h"
@@ -494,6 +495,10 @@ struct SharedIRGenContext
Dictionary<IntVal*, IRInst*> mapSpecConstValToIRInst;
+ // External (imported) unsafeForceInline functions that need to
+ // prelink into the current module after lowering.
+ List<IRInst*> externalSymbolsToPrelink;
+
void setGlobalValue(Decl* decl, LoweredValInfo value)
{
globalEnv.mapDeclToValue[decl] = value;
@@ -10636,6 +10641,30 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
FunctionDeclBase* decl,
bool emitBody = true)
{
+ bool isFromDifferentModule = isDeclInDifferentModule(context, decl);
+ if (isFromDifferentModule && isForceInlineEarly(decl))
+ {
+ // If a function is imported from another module then
+ // we usually don't want to emit it as a definition, and
+ // will instead only emit a declaration for it with an
+ // appropriate `[import(...)]` linkage decoration.
+ //
+ // However, if the function is marked with `[__unsafeForceInlineEarly]`
+ // then we need to make sure the IR for its definition is available
+ // to the mandatory optimization passes.
+ //
+ // We do so by finding the IR function from the imported module, and clone
+ // the body of the IRFunc from the imported module to the current module.
+ //
+ auto importedModule = getModule(decl);
+ auto irModule = importedModule->getIRModule();
+ SLANG_ASSERT(irModule && "Module containing imported decl does not have an IRModule.");
+ String mangledName = getMangledName(context->astBuilder, decl);
+ auto importedFunc = irModule->findSymbolByMangledName(mangledName);
+ SLANG_ASSERT(importedFunc.getCount() > 0);
+ subContext->shared->externalSymbolsToPrelink.add(importedFunc[0]);
+ }
+
IRGeneric* outerGeneric = nullptr;
subContext->funcDecl = decl;
@@ -10726,32 +10755,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
subBuilder->setInsertInto(irFunc);
- // If a function is imported from another module then
- // we usually don't want to emit it as a definition, and
- // will instead only emit a declaration for it with an
- // appropriate `[import(...)]` linkage decoration.
- //
- // However, if the function is marked with `[__unsafeForceInlineEarly]`
- // then we need to make sure the IR for its definition is available
- // to the mandatory optimization passes.
- //
- // TODO: The design here means that we will re-emit the inline
- // function from its AST in every module that uses it. We should
- // instead have logic to clone the target function in from the
- // pre-generated IR for the module that defines it (or do some kind
- // of minimal linking to bring in the inline functions).
- //
- if (!decl->body)
- {
- // This is a function declaration without a body.
- // In Slang we currently try not to support forward declarations
- // (although we might have to give in eventually), so
- // this case should really only occur for builtin declarations.
- }
- else if (isDeclInDifferentModule(context, decl) && !isForceInlineEarly(decl))
- {
- }
- else if (emitBody)
+ if (emitBody && decl->body && !isFromDifferentModule)
{
// This is a function definition, so we need to actually
// construct IR for the body...
@@ -12103,7 +12107,6 @@ RefPtr<IRModule> generateIRForTranslationUnit(
validateIRModuleIfEnabled(compileRequest, module);
-
// We will perform certain "mandatory" optimization passes now.
// These passes serve two purposes:
//
@@ -12122,6 +12125,10 @@ RefPtr<IRModule> generateIRForTranslationUnit(
// dumpIR(module);
+ // Before we can do any validation, we need to prelink [unsafeForceInlineEarly]
+ // functions.
+ prelinkIR(translationUnit->module, module, context->shared->externalSymbolsToPrelink);
+
// First, lower error handling logic into normal control flow.
// This includes lowering throwing functions into functions that
// returns a `Result<T,E>` value, translating `tryCall` into