summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-link.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-link.cpp')
-rw-r--r--source/slang/slang-ir-link.cpp390
1 files changed, 340 insertions, 50 deletions
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 53bbbda9e..125ab71d0 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -52,15 +52,28 @@ struct IRSharedSpecContext
// A map from mangled symbol names to zero or
// more global IR values that have that name,
// in the *original* module.
- typedef Dictionary<String, RefPtr<IRSpecSymbol>> SymbolDictionary;
+ typedef Dictionary<ImmutableHashedString, RefPtr<IRSpecSymbol>> SymbolDictionary;
SymbolDictionary symbols;
+ Dictionary<ImmutableHashedString, bool> isImportedSymbol;
+
+ bool useAutodiff = false;
+
IRBuilder builderStorage;
// The "global" specialization environment.
IRSpecEnv globalEnv;
};
+void insertGlobalValueSymbol(IRSharedSpecContext* sharedContext, IRInst* gv);
+
+struct WitnessTableCloneInfo : RefObject
+{
+ IRWitnessTable* clonedTable;
+ IRWitnessTable* originalTable;
+ Dictionary<UnownedStringSlice, IRWitnessTableEntry*> deferredEntries;
+};
+
struct IRSpecContextBase
{
IRSharedSpecContext* shared;
@@ -69,7 +82,27 @@ struct IRSpecContextBase
IRModule* getModule() { return getShared()->module; }
- IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; }
+ List<IRModule*> irModules;
+
+ HashSet<UnownedStringSlice> deferredWitnessTableEntryKeys;
+ List<RefPtr<WitnessTableCloneInfo>> witnessTables;
+
+ IRSpecSymbol* findSymbols(UnownedStringSlice mangledName)
+ {
+ ImmutableHashedString hashedName(mangledName);
+ RefPtr<IRSpecSymbol> symbol;
+ if (shared->symbols.tryGetValue(hashedName, symbol))
+ return symbol;
+ for (auto m : irModules)
+ {
+ for (auto inst : m->findSymbolByMangledName(hashedName))
+ insertGlobalValueSymbol(shared, inst);
+ }
+ if (shared->symbols.tryGetValue(hashedName, symbol))
+ return symbol;
+ shared->symbols[hashedName] = nullptr;
+ return nullptr;
+ }
// The current specialization environment to use.
IRSpecEnv* env = nullptr;
@@ -105,6 +138,26 @@ void registerClonedValue(IRSpecContextBase* context, IRInst* clonedValue, IRInst
// an `Add()` call.
//
context->getEnv()->clonedValues[originalValue] = clonedValue;
+
+ switch (clonedValue->getOp())
+ {
+ case kIROp_LookupWitness:
+
+ // If `originalVal` represents a witness table entry key, add the key
+ // to witnessTableEntryWorkList.
+ context->deferredWitnessTableEntryKeys.add(
+ getMangledName(as<IRLookupWitnessMethod>(clonedValue)->getRequirementKey()));
+ break;
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ if (context->getShared()->useAutodiff)
+ {
+ if (auto key = as<IRStructKey>(clonedValue->getOperand(0)))
+ context->deferredWitnessTableEntryKeys.add(getMangledName(key));
+ }
+ break;
+ }
}
// Information on values to use when registering a cloned value
@@ -149,6 +202,28 @@ IRInst* cloneInst(IRSpecContextBase* context, IRBuilder* builder, IRInst* origin
return cloneInst(context, builder, originalInst, originalInst);
}
+bool isAutoDiffDecoration(IRInst* decor)
+{
+ switch (decor->getOp())
+ {
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_BackwardDerivativeIntermediateTypeDecoration:
+ case kIROp_BackwardDerivativePrimalDecoration:
+ case kIROp_BackwardDerivativePropagateDecoration:
+ case kIROp_BackwardDerivativePrimalContextDecoration:
+ case kIROp_BackwardDerivativePrimalReturnDecoration:
+ case kIROp_PrimalSubstituteDecoration:
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_DifferentiableTypeDictionaryDecoration:
+ case kIROp_ForwardDifferentiableDecoration:
+ case kIROp_BackwardDifferentiableDecoration:
+ return true;
+ default:
+ return false;
+ }
+}
+
/// Clone any decorations from `originalValue` onto `clonedValue`
void cloneDecorations(IRSpecContextBase* context, IRInst* clonedValue, IRInst* originalValue)
{
@@ -165,6 +240,8 @@ void cloneDecorations(IRSpecContextBase* context, IRInst* clonedValue, IRInst* o
SLANG_UNUSED(context);
for (auto originalDecoration : originalValue->getDecorations())
{
+ if (!context->shared->useAutodiff && isAutoDiffDecoration(originalDecoration))
+ continue;
cloneInst(context, builder, originalDecoration);
}
@@ -185,6 +262,8 @@ void cloneDecorationsAndChildren(
SLANG_UNUSED(context);
for (auto originalItem : originalValue->getDecorationsAndChildren())
{
+ if (!context->shared->useAutodiff && isAutoDiffDecoration(originalItem))
+ continue;
cloneInst(context, builder, originalItem);
}
@@ -294,7 +373,6 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
cloneDecorationsAndChildren(this, clonedValue, originalValue);
addHoistableInst(builder, clonedValue);
-
return clonedValue;
}
break;
@@ -407,15 +485,17 @@ static void cloneExtraDecorationsFromInst(
{
default:
break;
-
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_PrimalSubstituteDecoration:
+ if (!context->getShared()->useAutodiff)
+ break;
+ [[fallthrough]];
case kIROp_HLSLExportDecoration:
case kIROp_BindExistentialSlotsDecoration:
case kIROp_LayoutDecoration:
case kIROp_PublicDecoration:
case kIROp_SequentialIDDecoration:
- case kIROp_ForwardDerivativeDecoration:
- case kIROp_UserDefinedBackwardDerivativeDecoration:
- case kIROp_PrimalSubstituteDecoration:
case kIROp_IntrinsicOpDecoration:
case kIROp_NonCopyableTypeDecoration:
case kIROp_DynamicDispatchWitnessDecoration:
@@ -602,6 +682,40 @@ IRGlobalGenericParam* cloneGlobalGenericParamImpl(
return clonedVal;
}
+bool shouldDeepCloneWitnessTable(IRSpecContextBase* context, IRWitnessTable* table)
+{
+ for (auto decor : table->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_HLSLExportDecoration:
+ case kIROp_KeepAliveDecoration:
+ return true;
+ }
+ }
+
+ auto conformanceType = getResolvedInstForDecorations(table->getConformanceType());
+
+ for (auto decor : conformanceType->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_ComInterfaceDecoration:
+ return true;
+ case kIROp_KnownBuiltinDecoration:
+ {
+ auto name = as<IRKnownBuiltinDecoration>(decor)->getName();
+ if (name == toSlice("IDifferentiable") || name == toSlice("IDifferentiablePtr"))
+ return context->getShared()->useAutodiff;
+ break;
+ }
+ default:
+ break;
+ }
+ }
+
+ return false;
+}
IRWitnessTable* cloneWitnessTableImpl(
IRSpecContextBase* context,
@@ -612,13 +726,56 @@ IRWitnessTable* cloneWitnessTableImpl(
bool registerValue = true)
{
IRWitnessTable* clonedTable = dstTable;
+ IRType* clonedBaseType = nullptr;
if (!clonedTable)
{
- auto clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType()));
+ clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType()));
auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType()));
clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType);
}
- cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue);
+ else
+ {
+ clonedBaseType = (IRType*)clonedTable->getConformanceType();
+ }
+ if (registerValue)
+ registerClonedValue(context, clonedTable, originalValues);
+
+ // Set up an IR builder for inserting into the witness table
+ IRBuilder builderStorage = *context->builder;
+ IRBuilder* entryBuilder = &builderStorage;
+ entryBuilder->setInsertInto(clonedTable);
+
+ // Clone decorations first
+ for (auto decoration : originalTable->getDecorations())
+ {
+ cloneInst(context, entryBuilder, decoration);
+ }
+ cloneExtraDecorations(context, clonedTable, originalValues);
+
+ RefPtr<WitnessTableCloneInfo> witnessInfo = new WitnessTableCloneInfo();
+ witnessInfo->clonedTable = clonedTable;
+ witnessInfo->originalTable = originalTable;
+
+ bool shouldDeepClone = shouldDeepCloneWitnessTable(context, originalTable);
+
+ // Clone only the witness table entries that are actually used
+ for (auto child : originalTable->getDecorationsAndChildren())
+ {
+ if (auto entry = as<IRWitnessTableEntry>(child))
+ {
+ if (!shouldDeepClone)
+ {
+ // Skip witness table entries during the first pass,
+ // and just add them to the deferred work list.
+ witnessInfo->deferredEntries.add(getMangledName(entry->getRequirementKey()), entry);
+ continue;
+ }
+ }
+ // Clone any non-entry children as is
+ cloneInst(context, entryBuilder, child);
+ }
+ context->witnessTables.add(witnessInfo);
+
return clonedTable;
}
@@ -740,6 +897,9 @@ void cloneGlobalValueWithCodeCommon(
}
else
{
+ if (oi->getOp() == kIROp_DifferentiableTypeAnnotation &&
+ !context->getShared()->useAutodiff)
+ continue;
cloneInst(context, builder, oi);
}
}
@@ -886,12 +1046,12 @@ IRFunc* specializeIRForEntryPoint(
// so that the mangled name of the decl-ref is
// not the same as the mangled name of the decl.
//
- RefPtr<IRSpecSymbol> sym;
- if (!context->getSymbols().tryGetValue(mangledName, sym))
+ IRSpecSymbol* sym = context->findSymbols(mangledName.getUnownedSlice());
+ if (!sym)
{
String hashedName = getHashedName(mangledName.getUnownedSlice());
-
- if (!context->getSymbols().tryGetValue(hashedName, sym))
+ sym = context->findSymbols(hashedName.getUnownedSlice());
+ if (!sym)
{
SLANG_UNEXPECTED("no matching IR symbol");
return nullptr;
@@ -967,7 +1127,7 @@ IRFunc* specializeIRForEntryPoint(
if (!clonedFunc)
{
SLANG_UNEXPECTED("expected entry point to be a function");
- return nullptr;
+ UNREACHABLE_RETURN(nullptr);
}
if (!clonedFunc->findDecorationImpl(kIROp_KeepAliveDecoration))
@@ -1022,7 +1182,8 @@ CapabilitySet getTargetCapabilities(IRSpecContext* context)
return context->getShared()->targetReq->getTargetCaps();
}
-/// Get the most appropriate ("best") capability requirements for `inVal` based on the `targetCaps`.
+/// Get the most appropriate ("best") capability requirements for `inVal` based on the
+/// `targetCaps`.
static CapabilitySet _getBestSpecializationCaps(IRInst* inVal, CapabilitySet const& targetCaps)
{
IRInst* val = getResolvedInstForDecorations(inVal);
@@ -1335,9 +1496,9 @@ IRInst* cloneGlobalValueWithLinkage(
// with the same mangled name as `originalVal` and try
// to pick the "best" one for our target.
- auto mangledName = String(originalLinkage->getMangledName());
- RefPtr<IRSpecSymbol> sym;
- if (!context->getSymbols().tryGetValue(mangledName, sym))
+ auto mangledName = originalLinkage->getMangledName();
+ IRSpecSymbol* sym = context->findSymbols(mangledName);
+ if (!sym)
{
if (!originalVal)
return nullptr;
@@ -1419,6 +1580,11 @@ void insertGlobalValueSymbol(IRSharedSpecContext* sharedContext, IRInst* gv)
{
sharedContext->symbols.add(mangledName, sym);
}
+
+ if (as<IRImportDecoration>(linkage))
+ sharedContext->isImportedSymbol.tryGetValueOrAdd(mangledName, true);
+ else
+ sharedContext->isImportedSymbol.set(mangledName, false);
}
void insertGlobalValueSymbols(IRSharedSpecContext* sharedContext, IRModule* originalModule)
@@ -1605,8 +1771,8 @@ void convertAtomicToStorageBuffer(
IRSpecContext* context,
Dictionary<int, List<IRInst*>>& bindingToInstMapUnsorted)
{
- // Atomic_uint definitions needs to become a storage buffer to follow GL_EXT_vulkan_glsl_relaxed
- // and to allow translation of atomic_uint into SPIRV
+ // Atomic_uint definitions needs to become a storage buffer to follow
+ // GL_EXT_vulkan_glsl_relaxed and to allow translation of atomic_uint into SPIRV
IRBuilder builder = *context->builder;
@@ -1653,8 +1819,8 @@ void convertAtomicToStorageBuffer(
instToSwitch->setFullType(storageBuffer);
// All references to a atomic_uint need to be an element ref. to emulate storage buffer
- // usage All function calls must be inlined since storage buffers cannot pass as parameters
- // to atomic methods
+ // usage All function calls must be inlined since storage buffers cannot pass as
+ // parameters to atomic methods
for (auto& i : bindingToInstList.second)
{
int64_t currOffset =
@@ -1722,12 +1888,13 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram,
{
case kIROp_GLSLAtomicUintType:
{
- // atomic_uint are supported by GLSL->VK through converting to a different type
- // (GL_EXT_vulkan_glsl_relaxed). atomic_uint are not supported by SPIR-V->VK;
- // this means that to get SPIR-V to work we must convert the type ourselves to
- // an equivlent representation (storage buffer); the added benifit is that then
- // HLSL is possible to emit as a target as well since atomic_uint is not an HLSL
- // concept, but storageBuffer->RWBuffer is and HLSL concept
+ // atomic_uint are supported by GLSL->VK through converting to a different
+ // type (GL_EXT_vulkan_glsl_relaxed). atomic_uint are not supported by
+ // SPIR-V->VK; this means that to get SPIR-V to work we must convert the
+ // type ourselves to an equivlent representation (storage buffer); the added
+ // benifit is that then HLSL is possible to emit as a target as well since
+ // atomic_uint is not an HLSL concept, but storageBuffer->RWBuffer is and
+ // HLSL concept
auto layout = inst->findDecoration<IRLayoutDecoration>()->getLayout();
auto layoutVal = as<IRVarOffsetAttr>(layout->getOperand(1));
assert(layoutVal != nullptr);
@@ -1743,6 +1910,110 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram,
convertAtomicToStorageBuffer(context, bindingToInstMapUnsorted);
}
+bool isDiffPairType(IRInst* type)
+{
+ for (;;)
+ {
+ auto type1 = (IRType*)unwrapAttributedType(type);
+ auto type2 = unwrapArray(type1);
+ if (type2 == type)
+ break;
+ type = type2;
+ }
+ return as<IRDifferentialPairTypeBase>(type) != nullptr;
+}
+
+bool doesModuleUseAutodiff(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_Call:
+ if (auto callee = getResolvedInstForDecorations(inst->getOperand(0)))
+ {
+ switch (callee->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ case kIROp_BackwardDifferentiatePrimal:
+ case kIROp_BackwardDifferentiatePropagate:
+ return true;
+ }
+ }
+ return false;
+ case kIROp_DifferentialPairGetDifferentialUserCode:
+ case kIROp_DifferentialPairGetPrimalUserCode:
+ case kIROp_DifferentialPtrPairGetPrimal:
+ case kIROp_DifferentialPtrPairGetDifferential:
+ return true;
+ case kIROp_StructField:
+ return isDiffPairType(as<IRStructField>(inst)->getFieldType());
+ case kIROp_Param:
+ return isDiffPairType(inst->getDataType());
+ default:
+ for (auto child : inst->getChildren())
+ {
+ bool isImported = false;
+ for (auto decor : child->getDecorations())
+ {
+ if (as<IRImportDecoration>(decor))
+ {
+ isImported = true;
+ break;
+ }
+ else if (as<IRAutoPyBindCudaDecoration>(decor))
+ {
+ return true;
+ }
+ else if (as<IRAutoPyBindExportInfoDecoration>(decor))
+ {
+ return true;
+ }
+ }
+ if (isImported)
+ continue;
+ for (auto decor : child->getDecorations())
+ {
+ if (isAutoDiffDecoration(decor))
+ return true;
+ }
+ if (doesModuleUseAutodiff(child))
+ return true;
+ }
+ return false;
+ }
+}
+
+void cloneUsedWitnessTableEntries(IRSpecContext* context)
+{
+ bool changed = true;
+ while (changed)
+ {
+ changed = false;
+ for (Index i = 0; i < context->witnessTables.getCount(); i++)
+ {
+ auto table = context->witnessTables[i].get();
+ ShortList<UnownedStringSlice> entriesToRemove;
+ for (auto entry : table->deferredEntries)
+ {
+ if (context->deferredWitnessTableEntryKeys.contains(entry.first))
+ {
+ IRBuilder builder(table->clonedTable);
+ builder.setInsertInto(table->clonedTable);
+ auto deferredKeyCount = context->deferredWitnessTableEntryKeys.getCount();
+ cloneInst(context, &builder, entry.second);
+ entriesToRemove.add(entry.first);
+ if (deferredKeyCount != context->deferredWitnessTableEntryKeys.getCount())
+ changed = true;
+ }
+ }
+ for (auto entry : entriesToRemove)
+ {
+ table->deferredEntries.remove(entry);
+ }
+ }
+ }
+}
+
LinkedIR linkIR(CodeGenContext* codeGenContext)
{
SLANG_PROFILE;
@@ -1763,50 +2034,61 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
state->target = target;
state->targetReq = targetReq;
-
+ auto& irModules = stateStorage.contextStorage.irModules;
auto sharedContext = state->getSharedContext();
initializeSharedSpecContext(sharedContext, session, nullptr, target, targetReq);
state->irModule = sharedContext->module;
-
// We need to be able to look up IR definitions for any symbols in
// modules that the program depends on (transitively). To
// accelerate lookup, we will create a symbol table for looking
// up IR definitions by their mangled name.
//
-
- List<IRModule*> irModules;
-
- // Link the core modules.
- auto& coreModules = static_cast<Session*>(linkage->getGlobalSession())->coreModules;
- for (auto& m : coreModules)
- irModules.add(m->getIRModule());
+ auto globalSession = static_cast<Session*>(linkage->getGlobalSession());
+ List<IRModule*> builtinModules;
+ for (auto& m : globalSession->coreModules)
+ builtinModules.add(m->getIRModule());
// Link modules in the program.
- program->enumerateIRModules([&](IRModule* irModule) { irModules.add(irModule); });
-
- // Add any modules that were loaded as libraries
- for (IRModule* irModule : irModules)
- {
- insertGlobalValueSymbols(sharedContext, irModule);
- }
+ program->enumerateIRModules(
+ [&](IRModule* module)
+ {
+ if (module->getName() == globalSession->glslModuleName)
+ builtinModules.add(module);
+ else
+ irModules.add(module);
+ });
- // We will also insert the IR global symbols from the IR module
+ // We will also consider the IR global symbols from the IR module
// attached to the `TargetProgram`, since this module is
// responsible for associating layout information to those
// global symbols via decorations.
//
auto irModuleForLayout = targetProgram->getExistingIRModuleForLayout();
- insertGlobalValueSymbols(sharedContext, irModuleForLayout);
+ if (irModuleForLayout)
+ irModules.add(irModuleForLayout);
+
+ Index userModuleCount = irModules.getCount();
+ irModules.addRange(builtinModules);
+ ArrayView<IRModule*> userModules = irModules.getArrayView(0, userModuleCount);
+
+ // Check if any user module uses auto-diff, if so we will need to link
+ // additional witnesses and decorations.
+ for (IRModule* irModule : userModules)
+ {
+ if (sharedContext->useAutodiff)
+ break;
+ sharedContext->useAutodiff = doesModuleUseAutodiff(irModule->getModuleInst());
+ }
auto context = state->getContext();
// Combine all of the contents of IRGlobalHashedStringLiterals
{
StringSlicePool pool(StringSlicePool::Style::Empty);
- for (IRModule* irModule : irModules)
+ for (IRModule* irModule : userModules)
{
findGlobalHashedStringLiterals(irModule, pool);
}
@@ -1875,7 +2157,7 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
// instructions in all the input modules.
//
- for (IRModule* irModule : irModules)
+ for (IRModule* irModule : userModules)
{
for (auto inst : irModule->getGlobalInsts())
{
@@ -1896,7 +2178,9 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
// We need to copy over exported symbols,
// and any global parameters if preserve-params option is set.
if (_isHLSLExported(inst) || shouldCopyGlobalParams && as<IRGlobalParam>(inst) ||
- as<IRDifferentiableTypeAnnotation>(inst))
+ sharedContext->useAutodiff &&
+ (as<IRDifferentiableTypeAnnotation>(inst) ||
+ inst->findDecorationImpl(kIROp_AutoDiffBuiltinDecoration) != nullptr))
{
auto cloned = cloneValue(context, inst);
if (!cloned->findDecorationImpl(kIROp_KeepAliveDecoration))
@@ -1907,6 +2191,12 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
}
}
+ // In previous steps, we have skipped cloning the witness table entries, and
+ // registered any used witness table entry keys to context->deferredWitnessTableEntryKeys
+ // for on-demand cloning. Now we will use the deferred keys to clone the witness table
+ // entries that are referenced.
+ cloneUsedWitnessTableEntries(context);
+
// It is possible that metadata has been attached to the input modules
// themselves, which should be copied over to the output module.
//
@@ -1917,7 +2207,7 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
// `[assumedWaveSize(...)]` decoration might require that all specified
// values match exactly).
//
- for (IRModule* irModule : irModules)
+ for (IRModule* irModule : userModules)
{
for (auto decoration : irModule->getModuleInst()->getDecorations())
{