summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff.cpp252
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir-link.cpp2
-rw-r--r--source/slang/slang-ir.cpp141
-rw-r--r--source/slang/slang-ir.h1
-rw-r--r--source/slang/slang-lower-to-ir.cpp154
-rw-r--r--source/slang/slang-mangle.cpp31
-rw-r--r--source/slang/slang-mangle.h5
-rw-r--r--tests/autodiff/deduplicate-witness-table.slang33
10 files changed, 398 insertions, 225 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index f3f32add2..afd698e8b 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1862,17 +1862,21 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness(
sharedContext->differentiableInterfaceType,
(IRType*)pairType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocTypeStructKey,
- diffDiffPairType);
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocTypeWitnessStructKey,
- table);
- builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ // Add WitnessTableEntry only once
+ if (!table->hasDecorationOrChild())
+ {
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocTypeStructKey,
+ diffDiffPairType);
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocTypeWitnessStructKey,
+ table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ }
bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false;
@@ -1944,15 +1948,19 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness(
sharedContext->differentiablePtrInterfaceType,
(IRType*)pairType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocRefTypeStructKey,
- diffDiffPairType);
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocRefTypeWitnessStructKey,
- table);
+ // Add WitnessTableEntry only once
+ if (!table->hasDecorationOrChild())
+ {
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocRefTypeStructKey,
+ diffDiffPairType);
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocRefTypeWitnessStructKey,
+ table);
+ }
}
return table;
@@ -1987,17 +1995,21 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
sharedContext->differentiableInterfaceType,
(IRType*)arrayType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocTypeStructKey,
- diffArrayType);
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocTypeWitnessStructKey,
- table);
- builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ // Add WitnessTableEntry only once
+ if (!table->hasDecorationOrChild())
+ {
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocTypeStructKey,
+ diffArrayType);
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocTypeWitnessStructKey,
+ table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ }
auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType();
@@ -2066,15 +2078,19 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
sharedContext->differentiablePtrInterfaceType,
(IRType*)arrayType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocRefTypeStructKey,
- diffArrayType);
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocRefTypeWitnessStructKey,
- table);
+ // Add WitnessTableEntry only once
+ if (!table->hasDecorationOrChild())
+ {
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocRefTypeStructKey,
+ diffArrayType);
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocRefTypeWitnessStructKey,
+ table);
+ }
}
else
{
@@ -2105,17 +2121,21 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
sharedContext->differentiableInterfaceType,
(IRType*)inTupleType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocTypeStructKey,
- diffTupleType);
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocTypeWitnessStructKey,
- table);
- builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ // Add WitnessTableEntry only once
+ if (!table->hasDecorationOrChild())
+ {
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocTypeStructKey,
+ diffTupleType);
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocTypeWitnessStructKey,
+ table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ }
// Fill in differential method implementations.
{
@@ -2219,15 +2239,19 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
sharedContext->differentiablePtrInterfaceType,
(IRType*)inTupleType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocRefTypeStructKey,
- diffTupleType);
- builder->createWitnessTableEntry(
- table,
- sharedContext->differentialAssocRefTypeWitnessStructKey,
- table);
+ // Add WitnessTableEntry only once
+ if (!table->hasDecorationOrChild())
+ {
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocRefTypeStructKey,
+ diffTupleType);
+ builder->createWitnessTableEntry(
+ table,
+ sharedContext->differentialAssocRefTypeWitnessStructKey,
+ table);
+ }
}
return table;
@@ -3078,39 +3102,47 @@ struct AutoDiffPass : public InstPassBase
builder.createWitnessTable(autodiffContext->differentiableInterfaceType, originalType);
result.diffWitness = origTypeIsDiffWitness;
- builder.createWitnessTableEntry(
- origTypeIsDiffWitness,
- autodiffContext->differentialAssocTypeStructKey,
- diffType);
- builder.createWitnessTableEntry(
- origTypeIsDiffWitness,
- autodiffContext->differentialAssocTypeWitnessStructKey,
- diffTypeIsDiffWitness);
- builder.createWitnessTableEntry(
- origTypeIsDiffWitness,
- autodiffContext->zeroMethodStructKey,
- zeroMethod);
- builder.createWitnessTableEntry(
- origTypeIsDiffWitness,
- autodiffContext->addMethodStructKey,
- addMethod);
-
- builder.createWitnessTableEntry(
- diffTypeIsDiffWitness,
- autodiffContext->differentialAssocTypeStructKey,
- diffType);
- builder.createWitnessTableEntry(
- diffTypeIsDiffWitness,
- autodiffContext->differentialAssocTypeWitnessStructKey,
- diffTypeIsDiffWitness);
- builder.createWitnessTableEntry(
- diffTypeIsDiffWitness,
- autodiffContext->zeroMethodStructKey,
- zeroMethod);
- builder.createWitnessTableEntry(
- diffTypeIsDiffWitness,
- autodiffContext->addMethodStructKey,
- addMethod);
+ // Add WitnessTableEntry only once
+ if (!origTypeIsDiffWitness->hasDecorationOrChild())
+ {
+ builder.createWitnessTableEntry(
+ origTypeIsDiffWitness,
+ autodiffContext->differentialAssocTypeStructKey,
+ diffType);
+ builder.createWitnessTableEntry(
+ origTypeIsDiffWitness,
+ autodiffContext->differentialAssocTypeWitnessStructKey,
+ diffTypeIsDiffWitness);
+ builder.createWitnessTableEntry(
+ origTypeIsDiffWitness,
+ autodiffContext->zeroMethodStructKey,
+ zeroMethod);
+ builder.createWitnessTableEntry(
+ origTypeIsDiffWitness,
+ autodiffContext->addMethodStructKey,
+ addMethod);
+ }
+
+ // Add WitnessTableEntry only once
+ if (!diffTypeIsDiffWitness->hasDecorationOrChild())
+ {
+ builder.createWitnessTableEntry(
+ diffTypeIsDiffWitness,
+ autodiffContext->differentialAssocTypeStructKey,
+ diffType);
+ builder.createWitnessTableEntry(
+ diffTypeIsDiffWitness,
+ autodiffContext->differentialAssocTypeWitnessStructKey,
+ diffTypeIsDiffWitness);
+ builder.createWitnessTableEntry(
+ diffTypeIsDiffWitness,
+ autodiffContext->zeroMethodStructKey,
+ zeroMethod);
+ builder.createWitnessTableEntry(
+ diffTypeIsDiffWitness,
+ autodiffContext->addMethodStructKey,
+ addMethod);
+ }
return result;
}
@@ -3177,14 +3209,34 @@ struct AutoDiffPass : public InstPassBase
List<IRInst*> args;
for (auto param : genType->getParams())
args.add(param);
- as<IRWitnessTable>(innerResult.diffWitness)
- ->setConcreteType((IRType*)builder.emitSpecializeInst(
- builder.getTypeKind(),
- originalType,
- (UInt)args.getCount(),
- args.getBuffer()));
+
+ // Create a new WitnessTable with a different concreteType.
+ auto concreteType = as<IRType>(builder.emitSpecializeInst(
+ builder.getTypeKind(),
+ originalType,
+ (UInt)args.getCount(),
+ args.getBuffer()));
+
+ auto witnessTableType =
+ cast<IRWitnessTableType>(innerResult.diffWitness->getFullType());
+ auto conformanceType = cast<IRType>(witnessTableType->getConformanceType());
+ auto newWitnessTable = builder.createWitnessTable(conformanceType, concreteType);
+
+ // Add WitnessTableEntry only once
+ if (!newWitnessTable->hasDecorationOrChild())
+ {
+ builder.setInsertInto(newWitnessTable);
+ for (auto entry : as<IRWitnessTable>(innerResult.diffWitness)->getEntries())
+ {
+ builder.createWitnessTableEntry(
+ newWitnessTable,
+ entry->getRequirementKey(),
+ entry->getSatisfyingVal());
+ }
+ }
+
result.diffWitness =
- hoistValueFromGeneric(builder, innerResult.diffWitness, specInst, true);
+ hoistValueFromGeneric(builder, newWitnessTable, specInst, true);
}
return result;
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 574f45243..8019fdd08 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -294,7 +294,7 @@ INST(GlobalConstant, globalConstant, 0, GLOBAL)
INST(StructKey, key, 0, GLOBAL)
INST(GlobalGenericParam, global_generic_param, 0, GLOBAL)
-INST(WitnessTable, witness_table, 0, 0)
+INST(WitnessTable, witness_table, 0, HOISTABLE)
INST(IndexedFieldKey, indexedFieldKey, 2, HOISTABLE)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index fc8788b4d..86596f316 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2980,8 +2980,6 @@ struct IRWitnessTable : IRInst
IRType* getConcreteType() { return (IRType*)getOperand(0); }
- void setConcreteType(IRType* t) { return setOperand(0, t); }
-
IR_LEAF_ISA(WitnessTable)
};
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 125ab71d0..a84390f14 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -732,6 +732,8 @@ IRWitnessTable* cloneWitnessTableImpl(
clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType()));
auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType()));
clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType);
+ if (clonedTable->hasDecorationOrChild())
+ return clonedTable;
}
else
{
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 6d1d76afc..a74ac58a4 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -1692,7 +1692,9 @@ void addHoistableInst(IRBuilder* builder, IRInst* inst)
// any parameters of the parent.
//
while (insertBeforeInst && insertBeforeInst->getOp() == kIROp_Param)
+ {
insertBeforeInst = insertBeforeInst->getNextInst();
+ }
// For instructions that will be placed at module scope,
// we don't care about relative ordering, but for everything
@@ -2490,6 +2492,72 @@ static void canonicalizeInstOperands(IRBuilder& builder, IROp op, ArrayView<IRIn
}
}
+static void addGlobalValue(IRBuilder* builder, IRInst* value)
+{
+ // If the value is already in the parent, keep it as-is.
+ // Because when the inst is Hoistable, the parent can have
+ // only one instance of the inst. The order among
+ // siblings should remain because the later siblings may
+ // have dependency to the earlier siblings.
+ //
+ if (value->parent)
+ {
+ SLANG_ASSERT(getIROpInfo(value->getOp()).isHoistable());
+ return;
+ }
+
+ // Try to find a suitable parent for the
+ // global value we are emitting.
+ //
+ // We will start out search at the current
+ // parent instruction for the builder, and
+ // possibly work our way up.
+ //
+ auto defaultInsertLoc = builder->getInsertLoc();
+ auto defaultParent = defaultInsertLoc.getParent();
+ auto parent = defaultParent;
+ while (parent)
+ {
+ // Inserting into the top level of a module?
+ // That is fine, and we can stop searching.
+ if (as<IRModuleInst>(parent))
+ break;
+
+ // Inserting into a basic block inside of
+ // a generic? That is okay too.
+ if (auto block = as<IRBlock>(parent))
+ {
+ if (as<IRGeneric>(block->parent))
+ break;
+ }
+
+ // Otherwise, move up the chain.
+ parent = parent->parent;
+ }
+
+ // If we somehow ran out of parents (possibly
+ // because an instruction wasn't linked into
+ // the full hierarchy yet), then we will
+ // fall back to inserting into the overall module.
+ if (!parent)
+ {
+ parent = builder->getModule()->getModuleInst();
+ }
+
+ // If it turns out that we are inserting into the
+ // current "insert into" parent for the builder, then
+ // we need to respect its "insert before" setting
+ // as well.
+ if (parent == defaultParent)
+ {
+ value->insertAt(defaultInsertLoc);
+ }
+ else
+ {
+ value->insertAtEnd(parent);
+ }
+}
+
IRInst* IRBuilder::_findOrEmitHoistableInst(
IRType* type,
IROp op,
@@ -2613,7 +2681,16 @@ IRInst* IRBuilder::_findOrEmitHoistableInst(
}
}
- addHoistableInst(this, inst);
+ // When an hoistable inst is already a child, skip adding it.
+ if (inst->parent == nullptr)
+ {
+ // In order to de-duplicate them, Witness-table is marked as Hoistable.
+ // But it is not exactly a hoistable type and it should be added as a global value.
+ if (inst->getOp() == kIROp_WitnessTable)
+ addGlobalValue(this, inst);
+ else
+ addHoistableInst(this, inst);
+ }
return inst;
}
@@ -4581,60 +4658,6 @@ IRDominatorTree* IRModule::findOrCreateDominatorTree(IRGlobalValueWithCode* func
return analysis->getDominatorTree();
}
-void addGlobalValue(IRBuilder* builder, IRInst* value)
-{
- // Try to find a suitable parent for the
- // global value we are emitting.
- //
- // We will start out search at the current
- // parent instruction for the builder, and
- // possibly work our way up.
- //
- auto defaultInsertLoc = builder->getInsertLoc();
- auto defaultParent = defaultInsertLoc.getParent();
- auto parent = defaultParent;
- while (parent)
- {
- // Inserting into the top level of a module?
- // That is fine, and we can stop searching.
- if (as<IRModuleInst>(parent))
- break;
-
- // Inserting into a basic block inside of
- // a generic? That is okay too.
- if (auto block = as<IRBlock>(parent))
- {
- if (as<IRGeneric>(block->parent))
- break;
- }
-
- // Otherwise, move up the chain.
- parent = parent->parent;
- }
-
- // If we somehow ran out of parents (possibly
- // because an instruction wasn't linked into
- // the full hierarchy yet), then we will
- // fall back to inserting into the overall module.
- if (!parent)
- {
- parent = builder->getModule()->getModuleInst();
- }
-
- // If it turns out that we are inserting into the
- // current "insert into" parent for the builder, then
- // we need to respect its "insert before" setting
- // as well.
- if (parent == defaultParent)
- {
- value->insertAt(defaultInsertLoc);
- }
- else
- {
- value->insertAtEnd(parent);
- }
-}
-
IRInst* IRBuilder::addDifferentiableTypeDictionaryDecoration(IRInst* target)
{
return addDecoration(target, kIROp_DifferentiableTypeDictionaryDecoration);
@@ -7985,7 +8008,11 @@ static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other)
auto user = uu->getUser();
bool userIsHoistable = getIROpInfo(user->getOp()).isHoistable();
- if (userIsHoistable)
+
+ // We want to de-duplicate WitnessTable but we don't really want to hoist them.
+ bool userNeedToBeHoisted = userIsHoistable && (user->getOp() != kIROp_WitnessTable);
+
+ if (userNeedToBeHoisted)
{
if (!dedupContext)
{
@@ -8002,7 +8029,7 @@ static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other)
// to a point before `user`, if it is not already so.
_maybeHoistOperand(uu);
- if (userIsHoistable)
+ if (userNeedToBeHoisted)
{
// Is the updated inst already exists in the global numbering map?
// If so, we need to continue work on replacing the updated inst with the existing
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 64125be9a..dbc66c6a3 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -690,6 +690,7 @@ struct IRInst
m_decorationsAndChildren.last);
}
void removeAndDeallocateAllDecorationsAndChildren();
+ bool hasDecorationOrChild() { return m_decorationsAndChildren.first != nullptr; }
#ifdef SLANG_ENABLE_IR_BREAK_ALLOC
// Unique allocation ID for this instruction since start of current process.
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 3395224ec..e3c4ddf05 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8042,30 +8042,40 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Need to construct a sub-witness-table
auto irWitnessTableBaseType =
lowerType(subContext, astReqWitnessTable->baseType);
- irSatisfyingWitnessTable = subBuilder->createWitnessTable(
- irWitnessTableBaseType,
- irWitnessTable->getConcreteType());
- auto mangledName = getMangledNameForConformanceWitness(
- subContext->astBuilder,
- astReqWitnessTable->witnessedType,
- astReqWitnessTable->baseType);
- subBuilder->addExportDecoration(
- irSatisfyingWitnessTable,
- mangledName.getUnownedSlice());
- if (isExportedType(astReqWitnessTable->witnessedType))
+
+ auto concreteType = irWitnessTable->getConcreteType();
+
+ irSatisfyingWitnessTable =
+ subBuilder->createWitnessTable(irWitnessTableBaseType, concreteType);
+
+ // Avoid adding same decorations and child more than once.
+ if (!irSatisfyingWitnessTable->hasDecorationOrChild())
{
- subBuilder->addHLSLExportDecoration(irSatisfyingWitnessTable);
- subBuilder->addKeepAliveDecoration(irSatisfyingWitnessTable);
- }
+ auto mangledName = getMangledNameForConformanceWitness(
+ subContext->astBuilder,
+ astReqWitnessTable->witnessedType,
+ astReqWitnessTable->baseType,
+ concreteType->getOp());
- // Recursively lower the sub-table.
- lowerWitnessTable(
- subContext,
- astReqWitnessTable,
- irSatisfyingWitnessTable,
- mapASTToIRWitnessTable);
+ subBuilder->addExportDecoration(
+ irSatisfyingWitnessTable,
+ mangledName.getUnownedSlice());
- irSatisfyingWitnessTable->moveToEnd();
+ if (isExportedType(astReqWitnessTable->witnessedType))
+ {
+ subBuilder->addHLSLExportDecoration(irSatisfyingWitnessTable);
+ subBuilder->addKeepAliveDecoration(irSatisfyingWitnessTable);
+ }
+
+ // Recursively lower the sub-table.
+ lowerWitnessTable(
+ subContext,
+ astReqWitnessTable,
+ irSatisfyingWitnessTable,
+ mapASTToIRWitnessTable);
+
+ irSatisfyingWitnessTable->moveToEnd();
+ }
}
irSatisfyingVal = irSatisfyingWitnessTable;
}
@@ -8148,14 +8158,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
- // Construct the mangled name for the witness table, which depends
- // on the type that is conforming, and the type that it conforms to.
- //
- // TODO: This approach doesn't really make sense for generic `extension` conformances.
- auto mangledName =
- getMangledNameForConformanceWitness(context->astBuilder, subType, superType);
-
-
// A witness table may need to be generic, if the outer
// declaration (either a type declaration or an `extension`)
// is generic.
@@ -8174,59 +8176,81 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
auto irWitnessTableBaseType = lowerType(subContext, superType);
- // Create the IR-level witness table
- auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, nullptr);
-
- // Register the value now, rather than later, to avoid any possible infinite recursion.
+ // Register a dummy value to avoid infinite recursions.
+ // Without this, the call to lowerType() can get into an infinite recursion.
+ //
context->setGlobalValue(
inheritanceDecl,
- LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));
+ LoweredValInfo::simple(findOuterMostGeneric(subBuilder->getInsertLoc().getParent())));
auto irSubType = lowerType(subContext, subType);
- irWitnessTable->setConcreteType(irSubType);
- // TODO(JS):
- // Should the mangled name take part in obfuscation if enabled?
+ // Create the IR-level witness table
+ auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, irSubType);
- addLinkageDecoration(
- context,
- irWitnessTable,
+ // Override with the correct witness-table
+ context->setGlobalValue(
inheritanceDecl,
- mangledName.getUnownedSlice());
+ LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));
- // If the witness table is for a COM interface, always keep it alive.
- if (irWitnessTableBaseType->findDecoration<IRComInterfaceDecoration>())
+ // Avoid adding same decorations and child more than once.
+ if (!irWitnessTable->hasDecorationOrChild())
{
- subBuilder->addHLSLExportDecoration(irWitnessTable);
- }
+ // Construct the mangled name for the witness table, which depends
+ // on the type that is conforming, and the type that it conforms to.
+ //
+ // TODO: This approach doesn't really make sense for generic `extension`
+ // conformances.
+ auto mangledName = getMangledNameForConformanceWitness(
+ context->astBuilder,
+ subType,
+ superType,
+ irSubType->getOp());
- for (auto mod : parentDecl->modifiers)
- {
- if (as<HLSLExportModifier>(mod))
+ // TODO(JS):
+ // Should the mangled name take part in obfuscation if enabled?
+
+ addLinkageDecoration(
+ context,
+ irWitnessTable,
+ inheritanceDecl,
+ mangledName.getUnownedSlice());
+
+ // If the witness table is for a COM interface, always keep it alive.
+ if (irWitnessTableBaseType->findDecoration<IRComInterfaceDecoration>())
{
subBuilder->addHLSLExportDecoration(irWitnessTable);
- subBuilder->addKeepAliveDecoration(irWitnessTable);
}
- else if (as<AutoDiffBuiltinAttribute>(mod))
+
+ for (auto mod : parentDecl->modifiers)
{
- subBuilder->addAutoDiffBuiltinDecoration(irWitnessTable);
+ if (as<HLSLExportModifier>(mod))
+ {
+ subBuilder->addHLSLExportDecoration(irWitnessTable);
+ subBuilder->addKeepAliveDecoration(irWitnessTable);
+ }
+ else if (as<AutoDiffBuiltinAttribute>(mod))
+ {
+ subBuilder->addAutoDiffBuiltinDecoration(irWitnessTable);
+ }
}
- }
- // Make sure that all the entries in the witness table have been filled in,
- // including any cases where there are sub-witness-tables for conformances
- bool isExplicitExtern = false;
- bool isImported = isImportedDecl(context, parentDecl, isExplicitExtern);
- if (!isImported || isExplicitExtern)
- {
- Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable;
- lowerWitnessTable(
- subContext,
- inheritanceDecl->witnessTable,
- irWitnessTable,
- mapASTToIRWitnessTable);
+ // Make sure that all the entries in the witness table have been filled in,
+ // including any cases where there are sub-witness-tables for conformances
+ bool isExplicitExtern = false;
+ bool isImported = isImportedDecl(context, parentDecl, isExplicitExtern);
+ if (!isImported || isExplicitExtern)
+ {
+ Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable;
+ lowerWitnessTable(
+ subContext,
+ inheritanceDecl->witnessTable,
+ irWitnessTable,
+ mapASTToIRWitnessTable);
+ }
+
+ irWitnessTable->moveToEnd();
}
- irWitnessTable->moveToEnd();
return LoweredValInfo::simple(
finishOuterGenerics(subBuilder, irWitnessTable, outerGeneric));
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index d51fafb6b..12e185c8b 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -824,6 +824,37 @@ String getMangledNameForConformanceWitness(ASTBuilder* astBuilder, Type* sub, Ty
return context.sb.produceString();
}
+// This function takes an additional parameter to get a simplified
+// mangled name when the witness-table is for enum-type.
+//
+// In order to deduplicate the witness-tables, we need to apply a little different
+// rule for the mangled name when the `superType` is `enum` type.
+// All witness-table for enum types whose underlying type is same should get the same
+// manged name.
+//
+// TODO: We should remove this function and have a new IR for enum-type. The "option 2"
+// described on the issue 6364 is more proper and ideal solution for the issue.
+//
+String getMangledNameForConformanceWitness(ASTBuilder* astBuilder, Type* sub, Type* sup, IROp subOp)
+{
+ SLANG_AST_BUILDER_RAII(astBuilder);
+
+ ManglingContext context(astBuilder);
+ emitRaw(&context, "_SW");
+
+ if (as<EnumTypeType>(sup))
+ {
+ emitRaw(&context, getIROpInfo(subOp).name);
+ }
+ else
+ {
+ emitType(&context, sub);
+ }
+
+ emitType(&context, sup);
+ return context.sb.produceString();
+}
+
String getMangledTypeName(ASTBuilder* astBuilder, Type* type)
{
SLANG_AST_BUILDER_RAII(astBuilder);
diff --git a/source/slang/slang-mangle.h b/source/slang/slang-mangle.h
index cfbe4fb25..cfdbe461b 100644
--- a/source/slang/slang-mangle.h
+++ b/source/slang/slang-mangle.h
@@ -19,6 +19,11 @@ String getHashedName(const UnownedStringSlice& mangledName);
String getMangledNameForConformanceWitness(ASTBuilder* astBuilder, Type* sub, Type* sup);
String getMangledNameForConformanceWitness(
ASTBuilder* astBuilder,
+ Type* sub,
+ Type* sup,
+ IROp subOp);
+String getMangledNameForConformanceWitness(
+ ASTBuilder* astBuilder,
DeclRef<Decl> sub,
DeclRef<Decl> sup);
String getMangledNameForConformanceWitness(ASTBuilder* astBuilder, DeclRef<Decl> sub, Type* sup);
diff --git a/tests/autodiff/deduplicate-witness-table.slang b/tests/autodiff/deduplicate-witness-table.slang
new file mode 100644
index 000000000..ea4c4e730
--- /dev/null
+++ b/tests/autodiff/deduplicate-witness-table.slang
@@ -0,0 +1,33 @@
+//TEST:SIMPLE(filecheck=CHK):-stage compute -entry computeMain -target hlsl
+
+//CHK: struct DiffPair_1
+//CHK-NOT: struct DiffPair_2
+
+RWTexture2D<float> gOutputColor;
+
+struct ShadingFrame : IDifferentiable
+{
+ float3 T;
+}
+
+[Differentiable]
+float computeRay()
+{
+ float3 dir = 1.f;
+ return dot(dir, dir);
+}
+
+[Differentiable]
+float paramRay()
+{
+ DifferentialPair<float> dpDir = fwd_diff(computeRay)();
+ return dpDir.p;
+}
+
+[Shader("compute")]
+[NumThreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DifferentialPair<float> dpColor = fwd_diff(paramRay)();
+ gOutputColor[0] = dpColor.p;
+}