summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-03-08 18:08:24 -0800
committerGitHub <noreply@github.com>2024-03-08 18:08:24 -0800
commit0629b22bf09ae6b3c3689c5f98492df7577bf0d2 (patch)
tree286eaf6268986b1ecb3cc19e8f3b72495e881d78
parent21502874666c282a3c5fa1f802deff27fab4e93b (diff)
Enhance link-time type test. (#3724)
* Enhance link-time type test. * Fix. * Fix.
-rw-r--r--source/slang/slang-check-decl.cpp24
-rw-r--r--source/slang/slang-emit.cpp5
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h10
-rw-r--r--source/slang/slang-ir-link.cpp1
-rw-r--r--source/slang/slang-ir-strip-witness-tables.cpp17
-rw-r--r--source/slang/slang-ir-strip-witness-tables.h5
-rw-r--r--source/slang/slang-lower-to-ir.cpp2
-rw-r--r--tools/gfx-unit-test/link-time-type.cpp13
9 files changed, 69 insertions, 9 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 7784100a6..8dee7b0c5 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -311,6 +311,8 @@ namespace Slang
void visitAggTypeDecl(AggTypeDecl* aggTypeDecl);
+ SemanticsContext registerDifferentiableTypesForFunc(FunctionDeclBase* funcDecl);
+
};
template<typename VisitorType>
@@ -3660,9 +3662,12 @@ namespace Slang
// the work of constructing our synthesized method.
//
+ bool isInWrapperType = isWrapperTypeDecl(context->parentDecl);
+
// First, we check that the differentiabliity of the method matches the requirement,
// and we don't attempt to synthesize a method if they don't match.
- if (getShared()->getFuncDifferentiableLevel(
+ if (!isInWrapperType &&
+ getShared()->getFuncDifferentiableLevel(
as<FunctionDeclBase>(lookupResult.item.declRef.getDecl()))
< getShared()->getFuncDifferentiableLevel(
as<FunctionDeclBase>(requiredMemberDeclRef.getDecl())))
@@ -3689,7 +3694,7 @@ namespace Slang
auto synBase = m_astBuilder->create<OverloadedExpr>();
synBase->name = requiredMemberDeclRef.getDecl()->getName();
- if (isWrapperTypeDecl(context->parentDecl))
+ if (isInWrapperType)
{
auto aggTypeDecl = as<AggTypeDecl>(context->parentDecl);
synBase->lookupResult2 = lookUpMember(
@@ -3701,6 +3706,10 @@ namespace Slang
LookupMask::Default,
LookupOptions::IgnoreBaseInterfaces);
addModifier(synFuncDecl, m_astBuilder->create<ForceInlineAttribute>());
+
+ synFuncDecl->parentDecl = aggTypeDecl;
+ SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl));
+ bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl);
}
else
{
@@ -3714,7 +3723,7 @@ namespace Slang
//
if (synThis)
{
- if (isWrapperTypeDecl(context->parentDecl))
+ if (isInWrapperType)
{
// If this is a wrapper type, then use the inner
// object as the actual this parameter for the redirected
@@ -3723,6 +3732,8 @@ namespace Slang
innerExpr->scope = synThis->scope;
innerExpr->name = getName("inner");
synBase->base = CheckExpr(innerExpr);
+ SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl));
+ bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, synBase->base->type);
}
else
{
@@ -6066,7 +6077,7 @@ namespace Slang
checkVisibility(decl);
}
- void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
+ SemanticsContext SemanticsDeclBodyVisitor::registerDifferentiableTypesForFunc(FunctionDeclBase* decl)
{
auto newContext = withParentFunc(decl);
if (newContext.getParentDifferentiableAttribute())
@@ -6086,7 +6097,12 @@ namespace Slang
}
m_parentDifferentiableAttr = oldAttr;
}
+ return newContext;
+ }
+ void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
+ {
+ auto newContext = registerDifferentiableTypesForFunc(decl);
if (const auto body = decl->body)
{
checkStmt(decl->body, newContext);
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 1d72fd233..5ec6fa62a 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -491,6 +491,11 @@ Result linkAndOptimizeIR(
validateIRModuleIfEnabled(codeGenContext, irModule);
+ // If we have any witness tables that are marked as `KeepAlive`,
+ // but are not used for dynamic dispatch, unpin them so we don't
+ // do unnecessary work to lower them.
+ unpinWitnessTables(irModule);
+
simplifyIR(targetProgram, irModule, IRSimplificationOptions::getFast(), sink);
if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc))
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 4865aa0b5..cd79f05a6 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -815,6 +815,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(AnyValueSizeDecoration, AnyValueSize, 1, 0)
INST(SpecializeDecoration, SpecializeDecoration, 0, 0)
INST(SequentialIDDecoration, SequentialIDDecoration, 1, 0)
+ INST(DynamicDispatchWitnessDecoration, DynamicDispatchWitnessDecoration, 0, 0)
INST(StaticRequirementDecoration, StaticRequirementDecoration, 0, 0)
INST(DispatchFuncDecoration, DispatchFuncDecoration, 1, 0)
INST(TypeConstraintDecoration, TypeConstraintDecoration, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 6fc94c657..6fbccab5c 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -708,6 +708,11 @@ struct IRSequentialIDDecoration : IRDecoration
IRIntegerValue getSequentialID() { return getSequentialIDOperand()->getValue(); }
};
+struct IRDynamicDispatchWitnessDecoration : IRDecoration
+{
+ IR_LEAF_ISA(DynamicDispatchWitnessDecoration)
+};
+
struct IRAutoDiffOriginalValueDecoration : IRDecoration
{
enum
@@ -4692,6 +4697,11 @@ public:
addDecoration(inst, kIROp_SequentialIDDecoration, getIntValue(getUIntType(), id));
}
+ void addDynamicDispatchWitnessDecoration(IRInst* inst)
+ {
+ addDecoration(inst, kIROp_DynamicDispatchWitnessDecoration);
+ }
+
void addVulkanRayPayloadDecoration(IRInst* inst, int location)
{
addDecoration(inst, kIROp_VulkanRayPayloadDecoration, getIntValue(getIntType(), location));
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index e81eddab7..18cb850c0 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -447,6 +447,7 @@ static void cloneExtraDecorationsFromInst(
case kIROp_PrimalSubstituteDecoration:
case kIROp_IntrinsicOpDecoration:
case kIROp_NonCopyableTypeDecoration:
+ case kIROp_DynamicDispatchWitnessDecoration:
if (!clonedInst->findDecorationImpl(decoration->getOp()))
{
cloneInst(context, builder, decoration);
diff --git a/source/slang/slang-ir-strip-witness-tables.cpp b/source/slang/slang-ir-strip-witness-tables.cpp
index 4c8901c52..b80bd7c23 100644
--- a/source/slang/slang-ir-strip-witness-tables.cpp
+++ b/source/slang/slang-ir-strip-witness-tables.cpp
@@ -33,4 +33,21 @@ void stripWitnessTables(IRModule* module)
}
}
+void unpinWitnessTables(IRModule* module)
+{
+ for (auto inst : module->getGlobalInsts())
+ {
+ auto witnessTable = as<IRWitnessTable>(inst);
+ if (!witnessTable)
+ continue;
+
+ // If a witness table is not used for dynamic dispatch, unpin it.
+ if (!witnessTable->findDecoration<IRDynamicDispatchWitnessDecoration>())
+ {
+ while (auto decor = witnessTable->findDecoration<IRKeepAliveDecoration>())
+ decor->removeAndDeallocate();
+ }
+ }
+}
+
}
diff --git a/source/slang/slang-ir-strip-witness-tables.h b/source/slang/slang-ir-strip-witness-tables.h
index 43bd0127d..4e3106418 100644
--- a/source/slang/slang-ir-strip-witness-tables.h
+++ b/source/slang/slang-ir-strip-witness-tables.h
@@ -7,4 +7,7 @@ struct IRModule;
/// Strip the contents of all witness table instructions from the given IR `module`
void stripWitnessTables(IRModule* module);
-} \ No newline at end of file
+
+ /// Remove [KeepAlive] decorations from witness tables.
+void unpinWitnessTables(IRModule* module);
+}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index c5a4da1f6..566e5a878 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -10872,7 +10872,7 @@ struct TypeConformanceIRGenContext
auto witness = lowerSimpleVal(context, typeConformance->getSubtypeWitness());
builder->addKeepAliveDecoration(witness);
builder->addHLSLExportDecoration(witness);
-
+ builder->addDynamicDispatchWitnessDecoration(witness);
if (conformanceIdOverride != -1)
{
builder->addSequentialIDDecoration(witness, conformanceIdOverride);
diff --git a/tools/gfx-unit-test/link-time-type.cpp b/tools/gfx-unit-test/link-time-type.cpp
index a522b6903..32a6b6775 100644
--- a/tools/gfx-unit-test/link-time-type.cpp
+++ b/tools/gfx-unit-test/link-time-type.cpp
@@ -16,7 +16,12 @@ namespace gfx_test
slang::ProgramLayout*& slangReflection)
{
const char* moduleInterfaceSrc = R"(
- interface IFoo
+ interface IBase : IDifferentiable
+ {
+ [Differentiable]
+ float getBaseValue();
+ }
+ interface IFoo : IBase
{
static const int offset;
[mutating] void setValue(float v);
@@ -29,6 +34,8 @@ namespace gfx_test
static const int offset = -1;
[mutating] void setValue(float v) { val = v; }
float getValue() { return val + 1.0; }
+ [Differentiable]
+ float getBaseValue() { return val; }
property float val2 {
get { return val + 2.0; }
set { val = newValue; }
@@ -44,7 +51,7 @@ namespace gfx_test
{
Foo foo;
foo.setValue(3.0);
- buffer[0] = foo.getValue() + foo.val2 + Foo.offset;
+ buffer[0] = foo.getValue() + foo.val2 + Foo.offset + foo.getBaseValue();
}
)";
const char* module1Src = R"(
@@ -169,7 +176,7 @@ namespace gfx_test
compareComputeResult(
device,
numbersBuffer,
- Slang::makeArray<float>(8.0));
+ Slang::makeArray<float>(11.0));
}
SLANG_UNIT_TEST(linkTimeTypeD3D12)