diff options
| author | Yong He <yonghe@outlook.com> | 2024-03-08 18:08:24 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-03-08 18:08:24 -0800 |
| commit | 0629b22bf09ae6b3c3689c5f98492df7577bf0d2 (patch) | |
| tree | 286eaf6268986b1ecb3cc19e8f3b72495e881d78 | |
| parent | 21502874666c282a3c5fa1f802deff27fab4e93b (diff) | |
Enhance link-time type test. (#3724)
* Enhance link-time type test.
* Fix.
* Fix.
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 24 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-strip-witness-tables.cpp | 17 | ||||
| -rw-r--r-- | source/slang/slang-ir-strip-witness-tables.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 2 | ||||
| -rw-r--r-- | tools/gfx-unit-test/link-time-type.cpp | 13 |
9 files changed, 69 insertions, 9 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7784100a6..8dee7b0c5 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -311,6 +311,8 @@ namespace Slang void visitAggTypeDecl(AggTypeDecl* aggTypeDecl); + SemanticsContext registerDifferentiableTypesForFunc(FunctionDeclBase* funcDecl); + }; template<typename VisitorType> @@ -3660,9 +3662,12 @@ namespace Slang // the work of constructing our synthesized method. // + bool isInWrapperType = isWrapperTypeDecl(context->parentDecl); + // First, we check that the differentiabliity of the method matches the requirement, // and we don't attempt to synthesize a method if they don't match. - if (getShared()->getFuncDifferentiableLevel( + if (!isInWrapperType && + getShared()->getFuncDifferentiableLevel( as<FunctionDeclBase>(lookupResult.item.declRef.getDecl())) < getShared()->getFuncDifferentiableLevel( as<FunctionDeclBase>(requiredMemberDeclRef.getDecl()))) @@ -3689,7 +3694,7 @@ namespace Slang auto synBase = m_astBuilder->create<OverloadedExpr>(); synBase->name = requiredMemberDeclRef.getDecl()->getName(); - if (isWrapperTypeDecl(context->parentDecl)) + if (isInWrapperType) { auto aggTypeDecl = as<AggTypeDecl>(context->parentDecl); synBase->lookupResult2 = lookUpMember( @@ -3701,6 +3706,10 @@ namespace Slang LookupMask::Default, LookupOptions::IgnoreBaseInterfaces); addModifier(synFuncDecl, m_astBuilder->create<ForceInlineAttribute>()); + + synFuncDecl->parentDecl = aggTypeDecl; + SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); + bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl); } else { @@ -3714,7 +3723,7 @@ namespace Slang // if (synThis) { - if (isWrapperTypeDecl(context->parentDecl)) + if (isInWrapperType) { // If this is a wrapper type, then use the inner // object as the actual this parameter for the redirected @@ -3723,6 +3732,8 @@ namespace Slang innerExpr->scope = synThis->scope; innerExpr->name = getName("inner"); synBase->base = CheckExpr(innerExpr); + SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); + bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, synBase->base->type); } else { @@ -6066,7 +6077,7 @@ namespace Slang checkVisibility(decl); } - void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) + SemanticsContext SemanticsDeclBodyVisitor::registerDifferentiableTypesForFunc(FunctionDeclBase* decl) { auto newContext = withParentFunc(decl); if (newContext.getParentDifferentiableAttribute()) @@ -6086,7 +6097,12 @@ namespace Slang } m_parentDifferentiableAttr = oldAttr; } + return newContext; + } + void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) + { + auto newContext = registerDifferentiableTypesForFunc(decl); if (const auto body = decl->body) { checkStmt(decl->body, newContext); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 1d72fd233..5ec6fa62a 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -491,6 +491,11 @@ Result linkAndOptimizeIR( validateIRModuleIfEnabled(codeGenContext, irModule); + // If we have any witness tables that are marked as `KeepAlive`, + // but are not used for dynamic dispatch, unpin them so we don't + // do unnecessary work to lower them. + unpinWitnessTables(irModule); + simplifyIR(targetProgram, irModule, IRSimplificationOptions::getFast(), sink); if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc)) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 4865aa0b5..cd79f05a6 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -815,6 +815,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(AnyValueSizeDecoration, AnyValueSize, 1, 0) INST(SpecializeDecoration, SpecializeDecoration, 0, 0) INST(SequentialIDDecoration, SequentialIDDecoration, 1, 0) + INST(DynamicDispatchWitnessDecoration, DynamicDispatchWitnessDecoration, 0, 0) INST(StaticRequirementDecoration, StaticRequirementDecoration, 0, 0) INST(DispatchFuncDecoration, DispatchFuncDecoration, 1, 0) INST(TypeConstraintDecoration, TypeConstraintDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 6fc94c657..6fbccab5c 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -708,6 +708,11 @@ struct IRSequentialIDDecoration : IRDecoration IRIntegerValue getSequentialID() { return getSequentialIDOperand()->getValue(); } }; +struct IRDynamicDispatchWitnessDecoration : IRDecoration +{ + IR_LEAF_ISA(DynamicDispatchWitnessDecoration) +}; + struct IRAutoDiffOriginalValueDecoration : IRDecoration { enum @@ -4692,6 +4697,11 @@ public: addDecoration(inst, kIROp_SequentialIDDecoration, getIntValue(getUIntType(), id)); } + void addDynamicDispatchWitnessDecoration(IRInst* inst) + { + addDecoration(inst, kIROp_DynamicDispatchWitnessDecoration); + } + void addVulkanRayPayloadDecoration(IRInst* inst, int location) { addDecoration(inst, kIROp_VulkanRayPayloadDecoration, getIntValue(getIntType(), location)); diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index e81eddab7..18cb850c0 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -447,6 +447,7 @@ static void cloneExtraDecorationsFromInst( case kIROp_PrimalSubstituteDecoration: case kIROp_IntrinsicOpDecoration: case kIROp_NonCopyableTypeDecoration: + case kIROp_DynamicDispatchWitnessDecoration: if (!clonedInst->findDecorationImpl(decoration->getOp())) { cloneInst(context, builder, decoration); diff --git a/source/slang/slang-ir-strip-witness-tables.cpp b/source/slang/slang-ir-strip-witness-tables.cpp index 4c8901c52..b80bd7c23 100644 --- a/source/slang/slang-ir-strip-witness-tables.cpp +++ b/source/slang/slang-ir-strip-witness-tables.cpp @@ -33,4 +33,21 @@ void stripWitnessTables(IRModule* module) } } +void unpinWitnessTables(IRModule* module) +{ + for (auto inst : module->getGlobalInsts()) + { + auto witnessTable = as<IRWitnessTable>(inst); + if (!witnessTable) + continue; + + // If a witness table is not used for dynamic dispatch, unpin it. + if (!witnessTable->findDecoration<IRDynamicDispatchWitnessDecoration>()) + { + while (auto decor = witnessTable->findDecoration<IRKeepAliveDecoration>()) + decor->removeAndDeallocate(); + } + } +} + } diff --git a/source/slang/slang-ir-strip-witness-tables.h b/source/slang/slang-ir-strip-witness-tables.h index 43bd0127d..4e3106418 100644 --- a/source/slang/slang-ir-strip-witness-tables.h +++ b/source/slang/slang-ir-strip-witness-tables.h @@ -7,4 +7,7 @@ struct IRModule; /// Strip the contents of all witness table instructions from the given IR `module` void stripWitnessTables(IRModule* module); -}
\ No newline at end of file + + /// Remove [KeepAlive] decorations from witness tables. +void unpinWitnessTables(IRModule* module); +} diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index c5a4da1f6..566e5a878 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -10872,7 +10872,7 @@ struct TypeConformanceIRGenContext auto witness = lowerSimpleVal(context, typeConformance->getSubtypeWitness()); builder->addKeepAliveDecoration(witness); builder->addHLSLExportDecoration(witness); - + builder->addDynamicDispatchWitnessDecoration(witness); if (conformanceIdOverride != -1) { builder->addSequentialIDDecoration(witness, conformanceIdOverride); diff --git a/tools/gfx-unit-test/link-time-type.cpp b/tools/gfx-unit-test/link-time-type.cpp index a522b6903..32a6b6775 100644 --- a/tools/gfx-unit-test/link-time-type.cpp +++ b/tools/gfx-unit-test/link-time-type.cpp @@ -16,7 +16,12 @@ namespace gfx_test slang::ProgramLayout*& slangReflection) { const char* moduleInterfaceSrc = R"( - interface IFoo + interface IBase : IDifferentiable + { + [Differentiable] + float getBaseValue(); + } + interface IFoo : IBase { static const int offset; [mutating] void setValue(float v); @@ -29,6 +34,8 @@ namespace gfx_test static const int offset = -1; [mutating] void setValue(float v) { val = v; } float getValue() { return val + 1.0; } + [Differentiable] + float getBaseValue() { return val; } property float val2 { get { return val + 2.0; } set { val = newValue; } @@ -44,7 +51,7 @@ namespace gfx_test { Foo foo; foo.setValue(3.0); - buffer[0] = foo.getValue() + foo.val2 + Foo.offset; + buffer[0] = foo.getValue() + foo.val2 + Foo.offset + foo.getBaseValue(); } )"; const char* module1Src = R"( @@ -169,7 +176,7 @@ namespace gfx_test compareComputeResult( device, numbersBuffer, - Slang::makeArray<float>(8.0)); + Slang::makeArray<float>(11.0)); } SLANG_UNIT_TEST(linkTimeTypeD3D12) |
