diff options
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 252 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 141 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 154 | ||||
| -rw-r--r-- | source/slang/slang-mangle.cpp | 31 | ||||
| -rw-r--r-- | source/slang/slang-mangle.h | 5 | ||||
| -rw-r--r-- | tests/autodiff/deduplicate-witness-table.slang | 33 |
10 files changed, 398 insertions, 225 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index f3f32add2..afd698e8b 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1862,17 +1862,21 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness( sharedContext->differentiableInterfaceType, (IRType*)pairType); - // And place it in the synthesized witness table. - builder->createWitnessTableEntry( - table, - sharedContext->differentialAssocTypeStructKey, - diffDiffPairType); - builder->createWitnessTableEntry( - table, - sharedContext->differentialAssocTypeWitnessStructKey, - table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + // Add WitnessTableEntry only once + if (!table->hasDecorationOrChild()) + { + // And place it in the synthesized witness table. + builder->createWitnessTableEntry( + table, + sharedContext->differentialAssocTypeStructKey, + diffDiffPairType); + builder->createWitnessTableEntry( + table, + sharedContext->differentialAssocTypeWitnessStructKey, + table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + } bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false; @@ -1944,15 +1948,19 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness( sharedContext->differentiablePtrInterfaceType, (IRType*)pairType); - // And place it in the synthesized witness table. - builder->createWitnessTableEntry( - table, - sharedContext->differentialAssocRefTypeStructKey, - diffDiffPairType); - builder->createWitnessTableEntry( - table, - sharedContext->differentialAssocRefTypeWitnessStructKey, - table); + // Add WitnessTableEntry only once + if (!table->hasDecorationOrChild()) + { + // And place it in the synthesized witness table. + builder->createWitnessTableEntry( + table, + sharedContext->differentialAssocRefTypeStructKey, + diffDiffPairType); + builder->createWitnessTableEntry( + table, + sharedContext->differentialAssocRefTypeWitnessStructKey, + table); + } } return table; @@ -1987,17 +1995,21 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness( sharedContext->differentiableInterfaceType, (IRType*)arrayType); - // And place it in the synthesized witness table. - builder->createWitnessTableEntry( - table, - sharedContext->differentialAssocTypeStructKey, - diffArrayType); - builder->createWitnessTableEntry( - table, - sharedContext->differentialAssocTypeWitnessStructKey, - table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + // Add WitnessTableEntry only once + if (!table->hasDecorationOrChild()) + { + // And place it in the synthesized witness table. + builder->createWitnessTableEntry( + table, + sharedContext->differentialAssocTypeStructKey, + diffArrayType); + builder->createWitnessTableEntry( + table, + sharedContext->differentialAssocTypeWitnessStructKey, + table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + } auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType(); @@ -2066,15 +2078,19 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness( sharedContext->differentiablePtrInterfaceType, (IRType*)arrayType); - // And place it in the synthesized witness table. - builder->createWitnessTableEntry( - table, - sharedContext->differentialAssocRefTypeStructKey, - diffArrayType); - builder->createWitnessTableEntry( - table, - sharedContext->differentialAssocRefTypeWitnessStructKey, - table); + // Add WitnessTableEntry only once + if (!table->hasDecorationOrChild()) + { + // And place it in the synthesized witness table. + builder->createWitnessTableEntry( + table, + sharedContext->differentialAssocRefTypeStructKey, + diffArrayType); + builder->createWitnessTableEntry( + table, + sharedContext->differentialAssocRefTypeWitnessStructKey, + table); + } } else { @@ -2105,17 +2121,21 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( sharedContext->differentiableInterfaceType, (IRType*)inTupleType); - // And place it in the synthesized witness table. - builder->createWitnessTableEntry( - table, - sharedContext->differentialAssocTypeStructKey, - diffTupleType); - builder->createWitnessTableEntry( - table, - sharedContext->differentialAssocTypeWitnessStructKey, - table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + // Add WitnessTableEntry only once + if (!table->hasDecorationOrChild()) + { + // And place it in the synthesized witness table. + builder->createWitnessTableEntry( + table, + sharedContext->differentialAssocTypeStructKey, + diffTupleType); + builder->createWitnessTableEntry( + table, + sharedContext->differentialAssocTypeWitnessStructKey, + table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + } // Fill in differential method implementations. { @@ -2219,15 +2239,19 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( sharedContext->differentiablePtrInterfaceType, (IRType*)inTupleType); - // And place it in the synthesized witness table. - builder->createWitnessTableEntry( - table, - sharedContext->differentialAssocRefTypeStructKey, - diffTupleType); - builder->createWitnessTableEntry( - table, - sharedContext->differentialAssocRefTypeWitnessStructKey, - table); + // Add WitnessTableEntry only once + if (!table->hasDecorationOrChild()) + { + // And place it in the synthesized witness table. + builder->createWitnessTableEntry( + table, + sharedContext->differentialAssocRefTypeStructKey, + diffTupleType); + builder->createWitnessTableEntry( + table, + sharedContext->differentialAssocRefTypeWitnessStructKey, + table); + } } return table; @@ -3078,39 +3102,47 @@ struct AutoDiffPass : public InstPassBase builder.createWitnessTable(autodiffContext->differentiableInterfaceType, originalType); result.diffWitness = origTypeIsDiffWitness; - builder.createWitnessTableEntry( - origTypeIsDiffWitness, - autodiffContext->differentialAssocTypeStructKey, - diffType); - builder.createWitnessTableEntry( - origTypeIsDiffWitness, - autodiffContext->differentialAssocTypeWitnessStructKey, - diffTypeIsDiffWitness); - builder.createWitnessTableEntry( - origTypeIsDiffWitness, - autodiffContext->zeroMethodStructKey, - zeroMethod); - builder.createWitnessTableEntry( - origTypeIsDiffWitness, - autodiffContext->addMethodStructKey, - addMethod); - - builder.createWitnessTableEntry( - diffTypeIsDiffWitness, - autodiffContext->differentialAssocTypeStructKey, - diffType); - builder.createWitnessTableEntry( - diffTypeIsDiffWitness, - autodiffContext->differentialAssocTypeWitnessStructKey, - diffTypeIsDiffWitness); - builder.createWitnessTableEntry( - diffTypeIsDiffWitness, - autodiffContext->zeroMethodStructKey, - zeroMethod); - builder.createWitnessTableEntry( - diffTypeIsDiffWitness, - autodiffContext->addMethodStructKey, - addMethod); + // Add WitnessTableEntry only once + if (!origTypeIsDiffWitness->hasDecorationOrChild()) + { + builder.createWitnessTableEntry( + origTypeIsDiffWitness, + autodiffContext->differentialAssocTypeStructKey, + diffType); + builder.createWitnessTableEntry( + origTypeIsDiffWitness, + autodiffContext->differentialAssocTypeWitnessStructKey, + diffTypeIsDiffWitness); + builder.createWitnessTableEntry( + origTypeIsDiffWitness, + autodiffContext->zeroMethodStructKey, + zeroMethod); + builder.createWitnessTableEntry( + origTypeIsDiffWitness, + autodiffContext->addMethodStructKey, + addMethod); + } + + // Add WitnessTableEntry only once + if (!diffTypeIsDiffWitness->hasDecorationOrChild()) + { + builder.createWitnessTableEntry( + diffTypeIsDiffWitness, + autodiffContext->differentialAssocTypeStructKey, + diffType); + builder.createWitnessTableEntry( + diffTypeIsDiffWitness, + autodiffContext->differentialAssocTypeWitnessStructKey, + diffTypeIsDiffWitness); + builder.createWitnessTableEntry( + diffTypeIsDiffWitness, + autodiffContext->zeroMethodStructKey, + zeroMethod); + builder.createWitnessTableEntry( + diffTypeIsDiffWitness, + autodiffContext->addMethodStructKey, + addMethod); + } return result; } @@ -3177,14 +3209,34 @@ struct AutoDiffPass : public InstPassBase List<IRInst*> args; for (auto param : genType->getParams()) args.add(param); - as<IRWitnessTable>(innerResult.diffWitness) - ->setConcreteType((IRType*)builder.emitSpecializeInst( - builder.getTypeKind(), - originalType, - (UInt)args.getCount(), - args.getBuffer())); + + // Create a new WitnessTable with a different concreteType. + auto concreteType = as<IRType>(builder.emitSpecializeInst( + builder.getTypeKind(), + originalType, + (UInt)args.getCount(), + args.getBuffer())); + + auto witnessTableType = + cast<IRWitnessTableType>(innerResult.diffWitness->getFullType()); + auto conformanceType = cast<IRType>(witnessTableType->getConformanceType()); + auto newWitnessTable = builder.createWitnessTable(conformanceType, concreteType); + + // Add WitnessTableEntry only once + if (!newWitnessTable->hasDecorationOrChild()) + { + builder.setInsertInto(newWitnessTable); + for (auto entry : as<IRWitnessTable>(innerResult.diffWitness)->getEntries()) + { + builder.createWitnessTableEntry( + newWitnessTable, + entry->getRequirementKey(), + entry->getSatisfyingVal()); + } + } + result.diffWitness = - hoistValueFromGeneric(builder, innerResult.diffWitness, specInst, true); + hoistValueFromGeneric(builder, newWitnessTable, specInst, true); } return result; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 574f45243..8019fdd08 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -294,7 +294,7 @@ INST(GlobalConstant, globalConstant, 0, GLOBAL) INST(StructKey, key, 0, GLOBAL) INST(GlobalGenericParam, global_generic_param, 0, GLOBAL) -INST(WitnessTable, witness_table, 0, 0) +INST(WitnessTable, witness_table, 0, HOISTABLE) INST(IndexedFieldKey, indexedFieldKey, 2, HOISTABLE) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index fc8788b4d..86596f316 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2980,8 +2980,6 @@ struct IRWitnessTable : IRInst IRType* getConcreteType() { return (IRType*)getOperand(0); } - void setConcreteType(IRType* t) { return setOperand(0, t); } - IR_LEAF_ISA(WitnessTable) }; diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 125ab71d0..a84390f14 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -732,6 +732,8 @@ IRWitnessTable* cloneWitnessTableImpl( clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType())); auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType())); clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType); + if (clonedTable->hasDecorationOrChild()) + return clonedTable; } else { diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6d1d76afc..a74ac58a4 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -1692,7 +1692,9 @@ void addHoistableInst(IRBuilder* builder, IRInst* inst) // any parameters of the parent. // while (insertBeforeInst && insertBeforeInst->getOp() == kIROp_Param) + { insertBeforeInst = insertBeforeInst->getNextInst(); + } // For instructions that will be placed at module scope, // we don't care about relative ordering, but for everything @@ -2490,6 +2492,72 @@ static void canonicalizeInstOperands(IRBuilder& builder, IROp op, ArrayView<IRIn } } +static void addGlobalValue(IRBuilder* builder, IRInst* value) +{ + // If the value is already in the parent, keep it as-is. + // Because when the inst is Hoistable, the parent can have + // only one instance of the inst. The order among + // siblings should remain because the later siblings may + // have dependency to the earlier siblings. + // + if (value->parent) + { + SLANG_ASSERT(getIROpInfo(value->getOp()).isHoistable()); + return; + } + + // Try to find a suitable parent for the + // global value we are emitting. + // + // We will start out search at the current + // parent instruction for the builder, and + // possibly work our way up. + // + auto defaultInsertLoc = builder->getInsertLoc(); + auto defaultParent = defaultInsertLoc.getParent(); + auto parent = defaultParent; + while (parent) + { + // Inserting into the top level of a module? + // That is fine, and we can stop searching. + if (as<IRModuleInst>(parent)) + break; + + // Inserting into a basic block inside of + // a generic? That is okay too. + if (auto block = as<IRBlock>(parent)) + { + if (as<IRGeneric>(block->parent)) + break; + } + + // Otherwise, move up the chain. + parent = parent->parent; + } + + // If we somehow ran out of parents (possibly + // because an instruction wasn't linked into + // the full hierarchy yet), then we will + // fall back to inserting into the overall module. + if (!parent) + { + parent = builder->getModule()->getModuleInst(); + } + + // If it turns out that we are inserting into the + // current "insert into" parent for the builder, then + // we need to respect its "insert before" setting + // as well. + if (parent == defaultParent) + { + value->insertAt(defaultInsertLoc); + } + else + { + value->insertAtEnd(parent); + } +} + IRInst* IRBuilder::_findOrEmitHoistableInst( IRType* type, IROp op, @@ -2613,7 +2681,16 @@ IRInst* IRBuilder::_findOrEmitHoistableInst( } } - addHoistableInst(this, inst); + // When an hoistable inst is already a child, skip adding it. + if (inst->parent == nullptr) + { + // In order to de-duplicate them, Witness-table is marked as Hoistable. + // But it is not exactly a hoistable type and it should be added as a global value. + if (inst->getOp() == kIROp_WitnessTable) + addGlobalValue(this, inst); + else + addHoistableInst(this, inst); + } return inst; } @@ -4581,60 +4658,6 @@ IRDominatorTree* IRModule::findOrCreateDominatorTree(IRGlobalValueWithCode* func return analysis->getDominatorTree(); } -void addGlobalValue(IRBuilder* builder, IRInst* value) -{ - // Try to find a suitable parent for the - // global value we are emitting. - // - // We will start out search at the current - // parent instruction for the builder, and - // possibly work our way up. - // - auto defaultInsertLoc = builder->getInsertLoc(); - auto defaultParent = defaultInsertLoc.getParent(); - auto parent = defaultParent; - while (parent) - { - // Inserting into the top level of a module? - // That is fine, and we can stop searching. - if (as<IRModuleInst>(parent)) - break; - - // Inserting into a basic block inside of - // a generic? That is okay too. - if (auto block = as<IRBlock>(parent)) - { - if (as<IRGeneric>(block->parent)) - break; - } - - // Otherwise, move up the chain. - parent = parent->parent; - } - - // If we somehow ran out of parents (possibly - // because an instruction wasn't linked into - // the full hierarchy yet), then we will - // fall back to inserting into the overall module. - if (!parent) - { - parent = builder->getModule()->getModuleInst(); - } - - // If it turns out that we are inserting into the - // current "insert into" parent for the builder, then - // we need to respect its "insert before" setting - // as well. - if (parent == defaultParent) - { - value->insertAt(defaultInsertLoc); - } - else - { - value->insertAtEnd(parent); - } -} - IRInst* IRBuilder::addDifferentiableTypeDictionaryDecoration(IRInst* target) { return addDecoration(target, kIROp_DifferentiableTypeDictionaryDecoration); @@ -7985,7 +8008,11 @@ static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other) auto user = uu->getUser(); bool userIsHoistable = getIROpInfo(user->getOp()).isHoistable(); - if (userIsHoistable) + + // We want to de-duplicate WitnessTable but we don't really want to hoist them. + bool userNeedToBeHoisted = userIsHoistable && (user->getOp() != kIROp_WitnessTable); + + if (userNeedToBeHoisted) { if (!dedupContext) { @@ -8002,7 +8029,7 @@ static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other) // to a point before `user`, if it is not already so. _maybeHoistOperand(uu); - if (userIsHoistable) + if (userNeedToBeHoisted) { // Is the updated inst already exists in the global numbering map? // If so, we need to continue work on replacing the updated inst with the existing diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 64125be9a..dbc66c6a3 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -690,6 +690,7 @@ struct IRInst m_decorationsAndChildren.last); } void removeAndDeallocateAllDecorationsAndChildren(); + bool hasDecorationOrChild() { return m_decorationsAndChildren.first != nullptr; } #ifdef SLANG_ENABLE_IR_BREAK_ALLOC // Unique allocation ID for this instruction since start of current process. diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 3395224ec..e3c4ddf05 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8042,30 +8042,40 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Need to construct a sub-witness-table auto irWitnessTableBaseType = lowerType(subContext, astReqWitnessTable->baseType); - irSatisfyingWitnessTable = subBuilder->createWitnessTable( - irWitnessTableBaseType, - irWitnessTable->getConcreteType()); - auto mangledName = getMangledNameForConformanceWitness( - subContext->astBuilder, - astReqWitnessTable->witnessedType, - astReqWitnessTable->baseType); - subBuilder->addExportDecoration( - irSatisfyingWitnessTable, - mangledName.getUnownedSlice()); - if (isExportedType(astReqWitnessTable->witnessedType)) + + auto concreteType = irWitnessTable->getConcreteType(); + + irSatisfyingWitnessTable = + subBuilder->createWitnessTable(irWitnessTableBaseType, concreteType); + + // Avoid adding same decorations and child more than once. + if (!irSatisfyingWitnessTable->hasDecorationOrChild()) { - subBuilder->addHLSLExportDecoration(irSatisfyingWitnessTable); - subBuilder->addKeepAliveDecoration(irSatisfyingWitnessTable); - } + auto mangledName = getMangledNameForConformanceWitness( + subContext->astBuilder, + astReqWitnessTable->witnessedType, + astReqWitnessTable->baseType, + concreteType->getOp()); - // Recursively lower the sub-table. - lowerWitnessTable( - subContext, - astReqWitnessTable, - irSatisfyingWitnessTable, - mapASTToIRWitnessTable); + subBuilder->addExportDecoration( + irSatisfyingWitnessTable, + mangledName.getUnownedSlice()); - irSatisfyingWitnessTable->moveToEnd(); + if (isExportedType(astReqWitnessTable->witnessedType)) + { + subBuilder->addHLSLExportDecoration(irSatisfyingWitnessTable); + subBuilder->addKeepAliveDecoration(irSatisfyingWitnessTable); + } + + // Recursively lower the sub-table. + lowerWitnessTable( + subContext, + astReqWitnessTable, + irSatisfyingWitnessTable, + mapASTToIRWitnessTable); + + irSatisfyingWitnessTable->moveToEnd(); + } } irSatisfyingVal = irSatisfyingWitnessTable; } @@ -8148,14 +8158,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } - // Construct the mangled name for the witness table, which depends - // on the type that is conforming, and the type that it conforms to. - // - // TODO: This approach doesn't really make sense for generic `extension` conformances. - auto mangledName = - getMangledNameForConformanceWitness(context->astBuilder, subType, superType); - - // A witness table may need to be generic, if the outer // declaration (either a type declaration or an `extension`) // is generic. @@ -8174,59 +8176,81 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // auto irWitnessTableBaseType = lowerType(subContext, superType); - // Create the IR-level witness table - auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, nullptr); - - // Register the value now, rather than later, to avoid any possible infinite recursion. + // Register a dummy value to avoid infinite recursions. + // Without this, the call to lowerType() can get into an infinite recursion. + // context->setGlobalValue( inheritanceDecl, - LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable))); + LoweredValInfo::simple(findOuterMostGeneric(subBuilder->getInsertLoc().getParent()))); auto irSubType = lowerType(subContext, subType); - irWitnessTable->setConcreteType(irSubType); - // TODO(JS): - // Should the mangled name take part in obfuscation if enabled? + // Create the IR-level witness table + auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, irSubType); - addLinkageDecoration( - context, - irWitnessTable, + // Override with the correct witness-table + context->setGlobalValue( inheritanceDecl, - mangledName.getUnownedSlice()); + LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable))); - // If the witness table is for a COM interface, always keep it alive. - if (irWitnessTableBaseType->findDecoration<IRComInterfaceDecoration>()) + // Avoid adding same decorations and child more than once. + if (!irWitnessTable->hasDecorationOrChild()) { - subBuilder->addHLSLExportDecoration(irWitnessTable); - } + // Construct the mangled name for the witness table, which depends + // on the type that is conforming, and the type that it conforms to. + // + // TODO: This approach doesn't really make sense for generic `extension` + // conformances. + auto mangledName = getMangledNameForConformanceWitness( + context->astBuilder, + subType, + superType, + irSubType->getOp()); - for (auto mod : parentDecl->modifiers) - { - if (as<HLSLExportModifier>(mod)) + // TODO(JS): + // Should the mangled name take part in obfuscation if enabled? + + addLinkageDecoration( + context, + irWitnessTable, + inheritanceDecl, + mangledName.getUnownedSlice()); + + // If the witness table is for a COM interface, always keep it alive. + if (irWitnessTableBaseType->findDecoration<IRComInterfaceDecoration>()) { subBuilder->addHLSLExportDecoration(irWitnessTable); - subBuilder->addKeepAliveDecoration(irWitnessTable); } - else if (as<AutoDiffBuiltinAttribute>(mod)) + + for (auto mod : parentDecl->modifiers) { - subBuilder->addAutoDiffBuiltinDecoration(irWitnessTable); + if (as<HLSLExportModifier>(mod)) + { + subBuilder->addHLSLExportDecoration(irWitnessTable); + subBuilder->addKeepAliveDecoration(irWitnessTable); + } + else if (as<AutoDiffBuiltinAttribute>(mod)) + { + subBuilder->addAutoDiffBuiltinDecoration(irWitnessTable); + } } - } - // Make sure that all the entries in the witness table have been filled in, - // including any cases where there are sub-witness-tables for conformances - bool isExplicitExtern = false; - bool isImported = isImportedDecl(context, parentDecl, isExplicitExtern); - if (!isImported || isExplicitExtern) - { - Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable; - lowerWitnessTable( - subContext, - inheritanceDecl->witnessTable, - irWitnessTable, - mapASTToIRWitnessTable); + // Make sure that all the entries in the witness table have been filled in, + // including any cases where there are sub-witness-tables for conformances + bool isExplicitExtern = false; + bool isImported = isImportedDecl(context, parentDecl, isExplicitExtern); + if (!isImported || isExplicitExtern) + { + Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable; + lowerWitnessTable( + subContext, + inheritanceDecl->witnessTable, + irWitnessTable, + mapASTToIRWitnessTable); + } + + irWitnessTable->moveToEnd(); } - irWitnessTable->moveToEnd(); return LoweredValInfo::simple( finishOuterGenerics(subBuilder, irWitnessTable, outerGeneric)); diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index d51fafb6b..12e185c8b 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -824,6 +824,37 @@ String getMangledNameForConformanceWitness(ASTBuilder* astBuilder, Type* sub, Ty return context.sb.produceString(); } +// This function takes an additional parameter to get a simplified +// mangled name when the witness-table is for enum-type. +// +// In order to deduplicate the witness-tables, we need to apply a little different +// rule for the mangled name when the `superType` is `enum` type. +// All witness-table for enum types whose underlying type is same should get the same +// manged name. +// +// TODO: We should remove this function and have a new IR for enum-type. The "option 2" +// described on the issue 6364 is more proper and ideal solution for the issue. +// +String getMangledNameForConformanceWitness(ASTBuilder* astBuilder, Type* sub, Type* sup, IROp subOp) +{ + SLANG_AST_BUILDER_RAII(astBuilder); + + ManglingContext context(astBuilder); + emitRaw(&context, "_SW"); + + if (as<EnumTypeType>(sup)) + { + emitRaw(&context, getIROpInfo(subOp).name); + } + else + { + emitType(&context, sub); + } + + emitType(&context, sup); + return context.sb.produceString(); +} + String getMangledTypeName(ASTBuilder* astBuilder, Type* type) { SLANG_AST_BUILDER_RAII(astBuilder); diff --git a/source/slang/slang-mangle.h b/source/slang/slang-mangle.h index cfbe4fb25..cfdbe461b 100644 --- a/source/slang/slang-mangle.h +++ b/source/slang/slang-mangle.h @@ -19,6 +19,11 @@ String getHashedName(const UnownedStringSlice& mangledName); String getMangledNameForConformanceWitness(ASTBuilder* astBuilder, Type* sub, Type* sup); String getMangledNameForConformanceWitness( ASTBuilder* astBuilder, + Type* sub, + Type* sup, + IROp subOp); +String getMangledNameForConformanceWitness( + ASTBuilder* astBuilder, DeclRef<Decl> sub, DeclRef<Decl> sup); String getMangledNameForConformanceWitness(ASTBuilder* astBuilder, DeclRef<Decl> sub, Type* sup); diff --git a/tests/autodiff/deduplicate-witness-table.slang b/tests/autodiff/deduplicate-witness-table.slang new file mode 100644 index 000000000..ea4c4e730 --- /dev/null +++ b/tests/autodiff/deduplicate-witness-table.slang @@ -0,0 +1,33 @@ +//TEST:SIMPLE(filecheck=CHK):-stage compute -entry computeMain -target hlsl + +//CHK: struct DiffPair_1 +//CHK-NOT: struct DiffPair_2 + +RWTexture2D<float> gOutputColor; + +struct ShadingFrame : IDifferentiable +{ + float3 T; +} + +[Differentiable] +float computeRay() +{ + float3 dir = 1.f; + return dot(dir, dir); +} + +[Differentiable] +float paramRay() +{ + DifferentialPair<float> dpDir = fwd_diff(computeRay)(); + return dpDir.p; +} + +[Shader("compute")] +[NumThreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) +{ + DifferentialPair<float> dpColor = fwd_diff(paramRay)(); + gOutputColor[0] = dpColor.p; +} |
