summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2025-04-01 01:31:33 -0700
committerGitHub <noreply@github.com>2025-04-01 01:31:33 -0700
commit0c5eee31c71372f1886c287d386a9618d845f316 (patch)
treee6095f2bfdb45d57cc059bcf58eca23ce641e69e
parent1fa4e486f598c5a7eed0db65f187ab95f890133c (diff)
Make IRWitnessTable HOISTABLE (#6417)
# Make `IRWitnessTable` Hoistable ## Intention of the PR This commit makes `IRWitnessTable` Hoistable so that we can avoid duplicated `IRWitnessTable`. ## Problems This commit tries to address the following issues arise after turning `IRWitnessTable` into Hoistable: 1. A Hoistable instance is immutable. 2. When tries to create a duplicated child, you will get a previously created instance of `IRWitnessTable`, instead of a new one. 3. We don't actually want to hoist `IRWitnessTable`. 4. There can be only one instance of Hoistable and it cannot appear as childs multiple times. 5. Different import/export mangled names were used for the same Witness-table when its type is "enum" interface. ## Implementation ### Solution for "1. A Hoistable instance is immutable." `IRWitnessTable::setConcreteType()` is removed, because when an `IRInst` is Hoistable, it is treated as immutable. Any `IRInst::setXXX()` methods don't work anymore. There were two places calling `setConcreteType()` and their logic had to change little bit. `DeclLoweringVisitor::visitInheritanceDecl()` in `source/slang/slang-lower-to-ir.cpp` was calling `setConcreteType()`. It had a little strange logic around `lowerType()`. The `IRWitnessTable` was added with `context->setGlobalValue()` first and its `concreteType` was changed later. This commit works around in a way that it sets the parent of `IRWitnessTable` temporarily and reset it with the correct `IRWitnessTable`. Without this logic, it went into an infinite recursion. `AutoDiffPass::fillDifferentialTypeImplementation()` in `source/slang/slang-ir-autodiff.cpp` was calling `setConcreteType()`. It was changing the concreteType of `innerResult.diffWitness`. This commit creates a new `IRWitnessTable` and copies its `IRWitnessTableEntry`. ### Solution for "2. When tries to create a duplicated child, you will get a previously created instance of IRWitnessTable, instead of a new one" After a call to `IRBuilder::createWitnessTable()`, this commit checks if the returned `IRWitnessTable` is a brand new or not. If it is not a new one, we have to avoid adding the decorations and children. This commit decides when to add decorations and children based on whether `IRWitnessTable` has any of decorations or children already. It doesn't seem like a proper way to check. But when I tried, it was difficult to find a bottleneck point where the decorations and children are added to `IRWitnessTable` first time. Note that we are not trying to find when `IRWitnessTable` is created for the first time; we need to find if the decorations and children were added once. It might be fine to have duplicated `IRWitnessTableEntry` in most of the cases, but I noticed that it fails an assertion check when `shouldDeepCloneWitnessTable()` returns false in `cloneWitnessTableImpl()`. ### Solution for "3. We don't actually want to hoist IRWitnessTable." The reason why this commit makes `IRWitnessTable` is to prevent the duplicated instances of `IRInst`. But we don't really want to "Hoist" them. When an `IRWitnessTable` gets Hoisted out, it causes unexpected problems and the specialization process fails due to the missing `IRWitnessTable` in the input. This commit prevent from hoisting `IRWitnessTable` in `_replaceInstUsesWith()`. The way this is implemented feel little hack but we discussed on Slack and decided to go with this. One of the proper approaches could be to add a new flag in `IROpFlags` and have a new one like `kIROpFlag_Deduplicate`, which is different from just `kIROpFlag_Hoistable`. ### Solution for "4. There can be only one instance of Hoistable and it cannot appear as childs multiple times." When `IRWitnessTable` is Hoistable, there can be only a unique set of instances. And we cannot have an instance as a duplicated childs. It is because `IRInst` has only one set of `IRInst* next` and `IRInst* prev`. Before this commit, an instance of `IRGeneral` could have duplicated instances of `IRWitnessTable`. As an example, `IInteger` interface inherits two other interfaces, `IArithmetic` and `ILogical`. And they both inherits from `IComparable`. ``` interface IInteger : IArithmetic, ILogical {} interface IArithmetic : IComparable {} interface ILogical : IComparable ``` When we specialize it in `specializeGenericImpl()`, an `IRBlock` gets the following list of children: - IRWitnessTable for IComparable, - IRWitnessTable for IArithmetic, - IRWitnessTable for IComparable, - IRWitnessTable for ILogical, For the cloning during the specialize, "IRWitnessTable for `IComparable`" must be cloned before the cloning of "IRWitnessTable for `IArithmetic`". Because "IRWitnessTable for `IArithmetic`" refers "IRWitnessTable for `IComparable`" as its `IRWitnessTableEntry`. The order they appear in the `IRBlock` as children decides which instances will be cloned first. And "IRWitnessTable for `IComparable`" must appear before "IRWitnessTable for `IArithmetic`". Note that "IRWitnessTable for `IComparable`" appears twice, The first one was added for "IRWitnessTable for `IArithmetic`". And the second one is added for "IRWitnessTable for `ILogical`". With this commit "IRWitnessTable for `IComparable`" can appear as a child only once in `IRBlock`. So it causes an error if it gets the following list: - IRWitnessTable for IArithmetic, - IRWitnessTable for IComparable, - IRWitnessTable for ILogical, In order to resolve the problem, "IRWitnessTable for `IComparable`" must appear before both "IRWitnessTable for `IArithmetic`" and "IRWitnessTable for `ILogical`" as following: - IRWitnessTable for IComparable, - IRWitnessTable for IArithmetic, - IRWitnessTable for ILogical, To address the problem, the instances of `IRWitnessTable` is always added to the end of the children list. If it is already added to the list, we don't move. This works out because the AST tree is built based on the dependencies. ### Solution for "5. Different import/export mangled names were used for the same Witness-table when its type is "enum" interface." This issue was found while testing with Falcor tests where it uses Conformance-type feature of Slang. We are using different import and export mangled names for a same Witness-table when the witness-table is for "Enum" interface. The way we simplify the implementation of "Enum" causes a problem when it comes to generate export/import for the witness-table. And the exact repro step is still unclear. There were two suggested solutions for the problem and this PR adopted the first option for now. Maybe we want to improve it with the second option later. option 1, when we produce mangled names for those witness-table, we can use a mangled name with the underlying "int" type instead of the name of the enum type. In this way, all witness-tables for enum types whose underlying type is same will get the same mangled name. It will allow us to deduplicate the witness-table during the linking. option 2, we can preserve type info for enum type when generating IR. We can still erase all other uses of the type info of enum types for now. But when we generate the witness-table, instead of filling the conforming type operand to IntType, we fill it as EnumType(IntType) where EnumType is a new global IROp code to represent all enum types (like InterfaceType/StructType). This way the operands for the two witness-tables will be different. "option 1" is more quick and dirty and "option 2" is more proper way to address it. I should go with "option 1" and improve it with "option 2" approach later.
-rw-r--r--source/slang/slang-ir-autodiff.cpp252
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir-link.cpp2
-rw-r--r--source/slang/slang-ir.cpp141
-rw-r--r--source/slang/slang-ir.h1
-rw-r--r--source/slang/slang-lower-to-ir.cpp154
-rw-r--r--source/slang/slang-mangle.cpp31
-rw-r--r--source/slang/slang-mangle.h5
-rw-r--r--tests/autodiff/deduplicate-witness-table.slang33
10 files changed, 398 insertions, 225 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index f3f32add2..afd698e8b 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1862,17 +1862,21 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness(
sharedContext->differentiableInterfaceType,
(IRType*)pairType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocTypeStructKey,
- diffDiffPairType);
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocTypeWitnessStructKey,
- table);
- builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ // Add WitnessTableEntry only once
+ if (!table->hasDecorationOrChild())
+ {
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocTypeStructKey,
+ diffDiffPairType);
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocTypeWitnessStructKey,
+ table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ }
bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false;
@@ -1944,15 +1948,19 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness(
sharedContext->differentiablePtrInterfaceType,
(IRType*)pairType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocRefTypeStructKey,
- diffDiffPairType);
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocRefTypeWitnessStructKey,
- table);
+ // Add WitnessTableEntry only once
+ if (!table->hasDecorationOrChild())
+ {
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocRefTypeStructKey,
+ diffDiffPairType);
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocRefTypeWitnessStructKey,
+ table);
+ }
}
return table;
@@ -1987,17 +1995,21 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
sharedContext->differentiableInterfaceType,
(IRType*)arrayType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocTypeStructKey,
- diffArrayType);
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocTypeWitnessStructKey,
- table);
- builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ // Add WitnessTableEntry only once
+ if (!table->hasDecorationOrChild())
+ {
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocTypeStructKey,
+ diffArrayType);
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocTypeWitnessStructKey,
+ table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ }
auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType();
@@ -2066,15 +2078,19 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
sharedContext->differentiablePtrInterfaceType,
(IRType*)arrayType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocRefTypeStructKey,
- diffArrayType);
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocRefTypeWitnessStructKey,
- table);
+ // Add WitnessTableEntry only once
+ if (!table->hasDecorationOrChild())
+ {
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocRefTypeStructKey,
+ diffArrayType);
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocRefTypeWitnessStructKey,
+ table);
+ }
}
else
{
@@ -2105,17 +2121,21 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
sharedContext->differentiableInterfaceType,
(IRType*)inTupleType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocTypeStructKey,
- diffTupleType);
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocTypeWitnessStructKey,
- table);
- builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ // Add WitnessTableEntry only once
+ if (!table->hasDecorationOrChild())
+ {
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocTypeStructKey,
+ diffTupleType);
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocTypeWitnessStructKey,
+ table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ }
// Fill in differential method implementations.
{
@@ -2219,15 +2239,19 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
sharedContext->differentiablePtrInterfaceType,
(IRType*)inTupleType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocRefTypeStructKey,
- diffTupleType);
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocRefTypeWitnessStructKey,
- table);
+ // Add WitnessTableEntry only once
+ if (!table->hasDecorationOrChild())
+ {
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocRefTypeStructKey,
+ diffTupleType);
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocRefTypeWitnessStructKey,
+ table);
+ }
}
return table;
@@ -3078,39 +3102,47 @@ struct AutoDiffPass : public InstPassBase
builder.createWitnessTable(autodiffContext->differentiableInterfaceType, originalType);
result.diffWitness = origTypeIsDiffWitness;
- builder.createWitnessTableEntry(
- origTypeIsDiffWitness,
- autodiffContext->differentialAssocTypeStructKey,
- diffType);
- builder.createWitnessTableEntry(
- origTypeIsDiffWitness,
- autodiffContext->differentialAssocTypeWitnessStructKey,
- diffTypeIsDiffWitness);
- builder.createWitnessTableEntry(
- origTypeIsDiffWitness,
- autodiffContext->zeroMethodStructKey,
- zeroMethod);
- builder.createWitnessTableEntry(
- origTypeIsDiffWitness,
- autodiffContext->addMethodStructKey,
- addMethod);
-
- builder.createWitnessTableEntry(
- diffTypeIsDiffWitness,
- autodiffContext->differentialAssocTypeStructKey,
- diffType);
- builder.createWitnessTableEntry(
- diffTypeIsDiffWitness,
- autodiffContext->differentialAssocTypeWitnessStructKey,
- diffTypeIsDiffWitness);
- builder.createWitnessTableEntry(
- diffTypeIsDiffWitness,
- autodiffContext->zeroMethodStructKey,
- zeroMethod);
- builder.createWitnessTableEntry(
- diffTypeIsDiffWitness,
- autodiffContext->addMethodStructKey,
- addMethod);
+ // Add WitnessTableEntry only once
+ if (!origTypeIsDiffWitness->hasDecorationOrChild())
+ {
+ builder.createWitnessTableEntry(
+ origTypeIsDiffWitness,
+ autodiffContext->differentialAssocTypeStructKey,
+ diffType);
+ builder.createWitnessTableEntry(
+ origTypeIsDiffWitness,
+ autodiffContext->differentialAssocTypeWitnessStructKey,
+ diffTypeIsDiffWitness);
+ builder.createWitnessTableEntry(
+ origTypeIsDiffWitness,
+ autodiffContext->zeroMethodStructKey,
+ zeroMethod);
+ builder.createWitnessTableEntry(
+ origTypeIsDiffWitness,
+ autodiffContext->addMethodStructKey,
+ addMethod);
+ }
+
+ // Add WitnessTableEntry only once
+ if (!diffTypeIsDiffWitness->hasDecorationOrChild())
+ {
+ builder.createWitnessTableEntry(
+ diffTypeIsDiffWitness,
+ autodiffContext->differentialAssocTypeStructKey,
+ diffType);
+ builder.createWitnessTableEntry(
+ diffTypeIsDiffWitness,
+ autodiffContext->differentialAssocTypeWitnessStructKey,
+ diffTypeIsDiffWitness);
+ builder.createWitnessTableEntry(
+ diffTypeIsDiffWitness,
+ autodiffContext->zeroMethodStructKey,
+ zeroMethod);
+ builder.createWitnessTableEntry(
+ diffTypeIsDiffWitness,
+ autodiffContext->addMethodStructKey,
+ addMethod);
+ }
return result;
}
@@ -3177,14 +3209,34 @@ struct AutoDiffPass : public InstPassBase
List<IRInst*> args;
for (auto param : genType->getParams())
args.add(param);
- as<IRWitnessTable>(innerResult.diffWitness)
- ->setConcreteType((IRType*)builder.emitSpecializeInst(
- builder.getTypeKind(),
- originalType,
- (UInt)args.getCount(),
- args.getBuffer()));
+
+ // Create a new WitnessTable with a different concreteType.
+ auto concreteType = as<IRType>(builder.emitSpecializeInst(
+ builder.getTypeKind(),
+ originalType,
+ (UInt)args.getCount(),
+ args.getBuffer()));
+
+ auto witnessTableType =
+ cast<IRWitnessTableType>(innerResult.diffWitness->getFullType());
+ auto conformanceType = cast<IRType>(witnessTableType->getConformanceType());
+ auto newWitnessTable = builder.createWitnessTable(conformanceType, concreteType);
+
+ // Add WitnessTableEntry only once
+ if (!newWitnessTable->hasDecorationOrChild())
+ {
+ builder.setInsertInto(newWitnessTable);
+ for (auto entry : as<IRWitnessTable>(innerResult.diffWitness)->getEntries())
+ {
+ builder.createWitnessTableEntry(
+ newWitnessTable,
+ entry->getRequirementKey(),
+ entry->getSatisfyingVal());
+ }
+ }
+
result.diffWitness =
- hoistValueFromGeneric(builder, innerResult.diffWitness, specInst, true);
+ hoistValueFromGeneric(builder, newWitnessTable, specInst, true);
}
return result;
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 574f45243..8019fdd08 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -294,7 +294,7 @@ INST(GlobalConstant, globalConstant, 0, GLOBAL)
INST(StructKey, key, 0, GLOBAL)
INST(GlobalGenericParam, global_generic_param, 0, GLOBAL)
-INST(WitnessTable, witness_table, 0, 0)
+INST(WitnessTable, witness_table, 0, HOISTABLE)
INST(IndexedFieldKey, indexedFieldKey, 2, HOISTABLE)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index fc8788b4d..86596f316 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2980,8 +2980,6 @@ struct IRWitnessTable : IRInst
IRType* getConcreteType() { return (IRType*)getOperand(0); }
- void setConcreteType(IRType* t) { return setOperand(0, t); }
-
IR_LEAF_ISA(WitnessTable)
};
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 125ab71d0..a84390f14 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -732,6 +732,8 @@ IRWitnessTable* cloneWitnessTableImpl(
clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType()));
auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType()));
clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType);
+ if (clonedTable->hasDecorationOrChild())
+ return clonedTable;
}
else
{
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 6d1d76afc..a74ac58a4 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -1692,7 +1692,9 @@ void addHoistableInst(IRBuilder* builder, IRInst* inst)
// any parameters of the parent.
//
while (insertBeforeInst && insertBeforeInst->getOp() == kIROp_Param)
+ {
insertBeforeInst = insertBeforeInst->getNextInst();
+ }
// For instructions that will be placed at module scope,
// we don't care about relative ordering, but for everything
@@ -2490,6 +2492,72 @@ static void canonicalizeInstOperands(IRBuilder& builder, IROp op, ArrayView<IRIn
}
}
+static void addGlobalValue(IRBuilder* builder, IRInst* value)
+{
+ // If the value is already in the parent, keep it as-is.
+ // Because when the inst is Hoistable, the parent can have
+ // only one instance of the inst. The order among
+ // siblings should remain because the later siblings may
+ // have dependency to the earlier siblings.
+ //
+ if (value->parent)
+ {
+ SLANG_ASSERT(getIROpInfo(value->getOp()).isHoistable());
+ return;
+ }
+
+ // Try to find a suitable parent for the
+ // global value we are emitting.
+ //
+ // We will start out search at the current
+ // parent instruction for the builder, and
+ // possibly work our way up.
+ //
+ auto defaultInsertLoc = builder->getInsertLoc();
+ auto defaultParent = defaultInsertLoc.getParent();
+ auto parent = defaultParent;
+ while (parent)
+ {
+ // Inserting into the top level of a module?
+ // That is fine, and we can stop searching.
+ if (as<IRModuleInst>(parent))
+ break;
+
+ // Inserting into a basic block inside of
+ // a generic? That is okay too.
+ if (auto block = as<IRBlock>(parent))
+ {
+ if (as<IRGeneric>(block->parent))
+ break;
+ }
+
+ // Otherwise, move up the chain.
+ parent = parent->parent;
+ }
+
+ // If we somehow ran out of parents (possibly
+ // because an instruction wasn't linked into
+ // the full hierarchy yet), then we will
+ // fall back to inserting into the overall module.
+ if (!parent)
+ {
+ parent = builder->getModule()->getModuleInst();
+ }
+
+ // If it turns out that we are inserting into the
+ // current "insert into" parent for the builder, then
+ // we need to respect its "insert before" setting
+ // as well.
+ if (parent == defaultParent)
+ {
+ value->insertAt(defaultInsertLoc);
+ }
+ else
+ {
+ value->insertAtEnd(parent);
+ }
+}
+
IRInst* IRBuilder::_findOrEmitHoistableInst(
IRType* type,
IROp op,
@@ -2613,7 +2681,16 @@ IRInst* IRBuilder::_findOrEmitHoistableInst(
}
}
- addHoistableInst(this, inst);
+ // When an hoistable inst is already a child, skip adding it.
+ if (inst->parent == nullptr)
+ {
+ // In order to de-duplicate them, Witness-table is marked as Hoistable.
+ // But it is not exactly a hoistable type and it should be added as a global value.
+ if (inst->getOp() == kIROp_WitnessTable)
+ addGlobalValue(this, inst);
+ else
+ addHoistableInst(this, inst);
+ }
return inst;
}
@@ -4581,60 +4658,6 @@ IRDominatorTree* IRModule::findOrCreateDominatorTree(IRGlobalValueWithCode* func
return analysis->getDominatorTree();
}
-void addGlobalValue(IRBuilder* builder, IRInst* value)
-{
- // Try to find a suitable parent for the
- // global value we are emitting.
- //
- // We will start out search at the current
- // parent instruction for the builder, and
- // possibly work our way up.
- //
- auto defaultInsertLoc = builder->getInsertLoc();
- auto defaultParent = defaultInsertLoc.getParent();
- auto parent = defaultParent;
- while (parent)
- {
- // Inserting into the top level of a module?
- // That is fine, and we can stop searching.
- if (as<IRModuleInst>(parent))
- break;
-
- // Inserting into a basic block inside of
- // a generic? That is okay too.
- if (auto block = as<IRBlock>(parent))
- {
- if (as<IRGeneric>(block->parent))
- break;
- }
-
- // Otherwise, move up the chain.
- parent = parent->parent;
- }
-
- // If we somehow ran out of parents (possibly
- // because an instruction wasn't linked into
- // the full hierarchy yet), then we will
- // fall back to inserting into the overall module.
- if (!parent)
- {
- parent = builder->getModule()->getModuleInst();
- }
-
- // If it turns out that we are inserting into the
- // current "insert into" parent for the builder, then
- // we need to respect its "insert before" setting
- // as well.
- if (parent == defaultParent)
- {
- value->insertAt(defaultInsertLoc);
- }
- else
- {
- value->insertAtEnd(parent);
- }
-}
-
IRInst* IRBuilder::addDifferentiableTypeDictionaryDecoration(IRInst* target)
{
return addDecoration(target, kIROp_DifferentiableTypeDictionaryDecoration);
@@ -7985,7 +8008,11 @@ static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other)
auto user = uu->getUser();
bool userIsHoistable = getIROpInfo(user->getOp()).isHoistable();
- if (userIsHoistable)
+
+ // We want to de-duplicate WitnessTable but we don't really want to hoist them.
+ bool userNeedToBeHoisted = userIsHoistable && (user->getOp() != kIROp_WitnessTable);
+
+ if (userNeedToBeHoisted)
{
if (!dedupContext)
{
@@ -8002,7 +8029,7 @@ static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other)
// to a point before `user`, if it is not already so.
_maybeHoistOperand(uu);
- if (userIsHoistable)
+ if (userNeedToBeHoisted)
{
// Is the updated inst already exists in the global numbering map?
// If so, we need to continue work on replacing the updated inst with the existing
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 64125be9a..dbc66c6a3 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -690,6 +690,7 @@ struct IRInst
m_decorationsAndChildren.last);
}
void removeAndDeallocateAllDecorationsAndChildren();
+ bool hasDecorationOrChild() { return m_decorationsAndChildren.first != nullptr; }
#ifdef SLANG_ENABLE_IR_BREAK_ALLOC
// Unique allocation ID for this instruction since start of current process.
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 3395224ec..e3c4ddf05 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8042,30 +8042,40 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Need to construct a sub-witness-table
auto irWitnessTableBaseType =
lowerType(subContext, astReqWitnessTable->baseType);
- irSatisfyingWitnessTable = subBuilder->createWitnessTable(
- irWitnessTableBaseType,
- irWitnessTable->getConcreteType());
- auto mangledName = getMangledNameForConformanceWitness(
- subContext->astBuilder,
- astReqWitnessTable->witnessedType,
- astReqWitnessTable->baseType);
- subBuilder->addExportDecoration(
- irSatisfyingWitnessTable,
- mangledName.getUnownedSlice());
- if (isExportedType(astReqWitnessTable->witnessedType))
+
+ auto concreteType = irWitnessTable->getConcreteType();
+
+ irSatisfyingWitnessTable =
+ subBuilder->createWitnessTable(irWitnessTableBaseType, concreteType);
+
+ // Avoid adding same decorations and child more than once.
+ if (!irSatisfyingWitnessTable->hasDecorationOrChild())
{
- subBuilder->addHLSLExportDecoration(irSatisfyingWitnessTable);
- subBuilder->addKeepAliveDecoration(irSatisfyingWitnessTable);
- }
+ auto mangledName = getMangledNameForConformanceWitness(
+ subContext->astBuilder,
+ astReqWitnessTable->witnessedType,
+ astReqWitnessTable->baseType,
+ concreteType->getOp());
- // Recursively lower the sub-table.
- lowerWitnessTable(
- subContext,
- astReqWitnessTable,
- irSatisfyingWitnessTable,
- mapASTToIRWitnessTable);
+ subBuilder->addExportDecoration(
+ irSatisfyingWitnessTable,
+ mangledName.getUnownedSlice());
- irSatisfyingWitnessTable->moveToEnd();
+ if (isExportedType(astReqWitnessTable->witnessedType))
+ {
+ subBuilder->addHLSLExportDecoration(irSatisfyingWitnessTable);
+ subBuilder->addKeepAliveDecoration(irSatisfyingWitnessTable);
+ }
+
+ // Recursively lower the sub-table.
+ lowerWitnessTable(
+ subContext,
+ astReqWitnessTable,
+ irSatisfyingWitnessTable,
+ mapASTToIRWitnessTable);
+
+ irSatisfyingWitnessTable->moveToEnd();
+ }
}
irSatisfyingVal = irSatisfyingWitnessTable;
}
@@ -8148,14 +8158,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
- // Construct the mangled name for the witness table, which depends
- // on the type that is conforming, and the type that it conforms to.
- //
- // TODO: This approach doesn't really make sense for generic `extension` conformances.
- auto mangledName =
- getMangledNameForConformanceWitness(context->astBuilder, subType, superType);
-
-
// A witness table may need to be generic, if the outer
// declaration (either a type declaration or an `extension`)
// is generic.
@@ -8174,59 +8176,81 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
auto irWitnessTableBaseType = lowerType(subContext, superType);
- // Create the IR-level witness table
- auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, nullptr);
-
- // Register the value now, rather than later, to avoid any possible infinite recursion.
+ // Register a dummy value to avoid infinite recursions.
+ // Without this, the call to lowerType() can get into an infinite recursion.
+ //
context->setGlobalValue(
inheritanceDecl,
- LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));
+ LoweredValInfo::simple(findOuterMostGeneric(subBuilder->getInsertLoc().getParent())));
auto irSubType = lowerType(subContext, subType);
- irWitnessTable->setConcreteType(irSubType);
- // TODO(JS):
- // Should the mangled name take part in obfuscation if enabled?
+ // Create the IR-level witness table
+ auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, irSubType);
- addLinkageDecoration(
- context,
- irWitnessTable,
+ // Override with the correct witness-table
+ context->setGlobalValue(
inheritanceDecl,
- mangledName.getUnownedSlice());
+ LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));
- // If the witness table is for a COM interface, always keep it alive.
- if (irWitnessTableBaseType->findDecoration<IRComInterfaceDecoration>())
+ // Avoid adding same decorations and child more than once.
+ if (!irWitnessTable->hasDecorationOrChild())
{
- subBuilder->addHLSLExportDecoration(irWitnessTable);
- }
+ // Construct the mangled name for the witness table, which depends
+ // on the type that is conforming, and the type that it conforms to.
+ //
+ // TODO: This approach doesn't really make sense for generic `extension`
+ // conformances.
+ auto mangledName = getMangledNameForConformanceWitness(
+ context->astBuilder,
+ subType,
+ superType,
+ irSubType->getOp());
- for (auto mod : parentDecl->modifiers)
- {
- if (as<HLSLExportModifier>(mod))
+ // TODO(JS):
+ // Should the mangled name take part in obfuscation if enabled?
+
+ addLinkageDecoration(
+ context,
+ irWitnessTable,
+ inheritanceDecl,
+ mangledName.getUnownedSlice());
+
+ // If the witness table is for a COM interface, always keep it alive.
+ if (irWitnessTableBaseType->findDecoration<IRComInterfaceDecoration>())
{
subBuilder->addHLSLExportDecoration(irWitnessTable);
- subBuilder->addKeepAliveDecoration(irWitnessTable);
}
- else if (as<AutoDiffBuiltinAttribute>(mod))
+
+ for (auto mod : parentDecl->modifiers)
{
- subBuilder->addAutoDiffBuiltinDecoration(irWitnessTable);
+ if (as<HLSLExportModifier>(mod))
+ {
+ subBuilder->addHLSLExportDecoration(irWitnessTable);
+ subBuilder->addKeepAliveDecoration(irWitnessTable);
+ }
+ else if (as<AutoDiffBuiltinAttribute>(mod))
+ {
+ subBuilder->addAutoDiffBuiltinDecoration(irWitnessTable);
+ }
}
- }
- // Make sure that all the entries in the witness table have been filled in,
- // including any cases where there are sub-witness-tables for conformances
- bool isExplicitExtern = false;
- bool isImported = isImportedDecl(context, parentDecl, isExplicitExtern);
- if (!isImported || isExplicitExtern)
- {
- Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable;
- lowerWitnessTable(
- subContext,
- inheritanceDecl->witnessTable,
- irWitnessTable,
- mapASTToIRWitnessTable);
+ // Make sure that all the entries in the witness table have been filled in,
+ // including any cases where there are sub-witness-tables for conformances
+ bool isExplicitExtern = false;
+ bool isImported = isImportedDecl(context, parentDecl, isExplicitExtern);
+ if (!isImported || isExplicitExtern)
+ {
+ Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable;
+ lowerWitnessTable(
+ subContext,
+ inheritanceDecl->witnessTable,
+ irWitnessTable,
+ mapASTToIRWitnessTable);
+ }
+
+ irWitnessTable->moveToEnd();
}
- irWitnessTable->moveToEnd();
return LoweredValInfo::simple(
finishOuterGenerics(subBuilder, irWitnessTable, outerGeneric));
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index d51fafb6b..12e185c8b 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -824,6 +824,37 @@ String getMangledNameForConformanceWitness(ASTBuilder* astBuilder, Type* sub, Ty
return context.sb.produceString();
}
+// This function takes an additional parameter to get a simplified
+// mangled name when the witness-table is for enum-type.
+//
+// In order to deduplicate the witness-tables, we need to apply a little different
+// rule for the mangled name when the `superType` is `enum` type.
+// All witness-table for enum types whose underlying type is same should get the same
+// manged name.
+//
+// TODO: We should remove this function and have a new IR for enum-type. The "option 2"
+// described on the issue 6364 is more proper and ideal solution for the issue.
+//
+String getMangledNameForConformanceWitness(ASTBuilder* astBuilder, Type* sub, Type* sup, IROp subOp)
+{
+ SLANG_AST_BUILDER_RAII(astBuilder);
+
+ ManglingContext context(astBuilder);
+ emitRaw(&context, "_SW");
+
+ if (as<EnumTypeType>(sup))
+ {
+ emitRaw(&context, getIROpInfo(subOp).name);
+ }
+ else
+ {
+ emitType(&context, sub);
+ }
+
+ emitType(&context, sup);
+ return context.sb.produceString();
+}
+
String getMangledTypeName(ASTBuilder* astBuilder, Type* type)
{
SLANG_AST_BUILDER_RAII(astBuilder);
diff --git a/source/slang/slang-mangle.h b/source/slang/slang-mangle.h
index cfbe4fb25..cfdbe461b 100644
--- a/source/slang/slang-mangle.h
+++ b/source/slang/slang-mangle.h
@@ -19,6 +19,11 @@ String getHashedName(const UnownedStringSlice& mangledName);
String getMangledNameForConformanceWitness(ASTBuilder* astBuilder, Type* sub, Type* sup);
String getMangledNameForConformanceWitness(
ASTBuilder* astBuilder,
+ Type* sub,
+ Type* sup,
+ IROp subOp);
+String getMangledNameForConformanceWitness(
+ ASTBuilder* astBuilder,
DeclRef<Decl> sub,
DeclRef<Decl> sup);
String getMangledNameForConformanceWitness(ASTBuilder* astBuilder, DeclRef<Decl> sub, Type* sup);
diff --git a/tests/autodiff/deduplicate-witness-table.slang b/tests/autodiff/deduplicate-witness-table.slang
new file mode 100644
index 000000000..ea4c4e730
--- /dev/null
+++ b/tests/autodiff/deduplicate-witness-table.slang
@@ -0,0 +1,33 @@
+//TEST:SIMPLE(filecheck=CHK):-stage compute -entry computeMain -target hlsl
+
+//CHK: struct DiffPair_1
+//CHK-NOT: struct DiffPair_2
+
+RWTexture2D<float> gOutputColor;
+
+struct ShadingFrame : IDifferentiable
+{
+ float3 T;
+}
+
+[Differentiable]
+float computeRay()
+{
+ float3 dir = 1.f;
+ return dot(dir, dir);
+}
+
+[Differentiable]
+float paramRay()
+{
+ DifferentialPair<float> dpDir = fwd_diff(computeRay)();
+ return dpDir.p;
+}
+
+[Shader("compute")]
+[NumThreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DifferentialPair<float> dpColor = fwd_diff(paramRay)();
+ gOutputColor[0] = dpColor.p;
+}