summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2021-08-26 10:30:35 -0700
committerGitHub <noreply@github.com>2021-08-26 10:30:35 -0700
commitb2ad8e99a82884bb157e1be76b1ad7eb0e481457 (patch)
tree3f5357083b5972761d516b70cb51a4fa7ab72cd5 /source
parent33f7e1599cbecb32c23787b37b2bf3b34bdd5c84 (diff)
Add API to control interface specialization. (#1925)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-compiler.cpp86
-rwxr-xr-xsource/slang/slang-compiler.h126
-rw-r--r--source/slang/slang-ir-link.cpp4
-rw-r--r--source/slang/slang-ir-specialize.cpp5
-rw-r--r--source/slang/slang-lower-to-ir.cpp64
-rw-r--r--source/slang/slang-lower-to-ir.h11
-rw-r--r--source/slang/slang-parameter-binding.cpp18
-rw-r--r--source/slang/slang.cpp47
8 files changed, 358 insertions, 3 deletions
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 997f3fd51..5d0ae2f8d 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -296,6 +296,92 @@ namespace Slang
return empty;
}
+ TypeConformance::TypeConformance(
+ Linkage* linkage,
+ SubtypeWitness* witness,
+ Int confomrmanceIdOverride,
+ DiagnosticSink* sink)
+ : ComponentType(linkage)
+ , m_subtypeWitness(witness)
+ , m_conformanceIdOverride(confomrmanceIdOverride)
+ {
+ addDepedencyFromWitness(witness);
+ m_irModule = generateIRForTypeConformance(this, m_conformanceIdOverride, sink);
+ }
+
+ void TypeConformance::addDepedencyFromWitness(SubtypeWitness* witness)
+ {
+ if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness))
+ {
+ auto declModule = getModule(declaredWitness->declRef.getDecl());
+ m_moduleDependency.addDependency(declModule);
+ m_pathDependency.addDependency(declModule);
+ if (m_requirementSet.Add(declModule))
+ {
+ m_requirements.add(declModule);
+ }
+ // TODO: handle the specialization arguments in declaredWitness->declRef.substitutions.
+ }
+ else if (auto transitiveWitness = as<TransitiveSubtypeWitness>(witness))
+ {
+ addDepedencyFromWitness(transitiveWitness->midToSup);
+ addDepedencyFromWitness(transitiveWitness->subToMid);
+ }
+ else if (auto conjunctionWitness = as<ConjunctionSubtypeWitness>(witness))
+ {
+ auto left = as<SubtypeWitness>(conjunctionWitness->leftWitness);
+ if (left)
+ addDepedencyFromWitness(left);
+ auto right = as<SubtypeWitness>(conjunctionWitness->rightWitness);
+ if (right)
+ addDepedencyFromWitness(right);
+ }
+ }
+
+ ISlangUnknown* TypeConformance::getInterface(const Guid& guid)
+ {
+ if (guid == slang::ITypeConformance::getTypeGuid())
+ return static_cast<slang::ITypeConformance*>(this);
+
+ return Super::getInterface(guid);
+ }
+
+ List<Module*> const& TypeConformance::getModuleDependencies()
+ {
+ return m_moduleDependency.getModuleList();
+ }
+
+ List<String> const& TypeConformance::getFilePathDependencies()
+ {
+ return m_pathDependency.getFilePathList();
+ }
+
+ Index TypeConformance::getRequirementCount() { return m_requirements.getCount(); }
+
+ RefPtr<ComponentType> TypeConformance::getRequirement(Index index)
+ {
+ return m_requirements[index];
+ }
+
+ void TypeConformance::acceptVisitor(
+ ComponentTypeVisitor* visitor,
+ ComponentType::SpecializationInfo* specializationInfo)
+ {
+ SLANG_UNUSED(specializationInfo);
+ visitor->visitTypeConformance(this);
+ }
+
+ RefPtr<ComponentType::SpecializationInfo> TypeConformance::_validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink)
+ {
+ SLANG_UNUSED(args);
+ SLANG_UNUSED(argCount);
+ SLANG_UNUSED(sink);
+ return nullptr;
+ }
+
//
Profile Profile::lookUp(UnownedStringSlice const& name)
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 1724794cf..727d21c3a 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -850,6 +850,125 @@ namespace Slang
Profile m_profile;
};
+ class TypeConformance
+ : public ComponentType
+ , public slang::ITypeConformance
+ {
+ typedef ComponentType Super;
+
+ public:
+ SLANG_REF_OBJECT_IUNKNOWN_ALL
+
+ ISlangUnknown* getInterface(const Guid& guid);
+
+ TypeConformance(
+ Linkage* linkage,
+ SubtypeWitness* witness,
+ Int confomrmanceIdOverride,
+ DiagnosticSink* sink);
+
+ // Forward `IComponentType` methods
+
+ SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE
+ {
+ return Super::getSession();
+ }
+
+ SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL
+ getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::getLayout(targetIndex, outDiagnostics);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ slang::IComponentType** outSpecializedComponentType,
+ ISlangBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::specialize(
+ specializationArgs,
+ specializationArgCount,
+ outSpecializedComponentType,
+ outDiagnostics);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL link(
+ slang::IComponentType** outLinkedComponentType,
+ ISlangBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::link(outLinkedComponentType, outDiagnostics);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable(
+ int entryPointIndex,
+ int targetIndex,
+ ISlangSharedLibrary** outSharedLibrary,
+ slang::IBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::getEntryPointHostCallable(
+ entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics);
+ }
+
+ List<Module*> const& getModuleDependencies() SLANG_OVERRIDE;
+ List<String> const& getFilePathDependencies() SLANG_OVERRIDE;
+
+ SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; }
+
+ /// Get the existential type parameter at `index`.
+ SpecializationParam const& getSpecializationParam(Index /*index*/) SLANG_OVERRIDE
+ {
+ static SpecializationParam emptyParam;
+ return emptyParam;
+ }
+
+ Index getRequirementCount() SLANG_OVERRIDE;
+ RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE;
+ Index getEntryPointCount() SLANG_OVERRIDE { return 0; };
+ RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(index);
+ return nullptr;
+ }
+ String getEntryPointMangledName(Index /*index*/) SLANG_OVERRIDE { return ""; }
+
+ Index getShaderParamCount() SLANG_OVERRIDE { return 0; }
+ ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(index);
+ return ShaderParamInfo();
+ }
+
+ SubtypeWitness* getSubtypeWitness() { return m_subtypeWitness; }
+ IRModule* getIRModule() { return m_irModule.Ptr(); }
+ protected:
+ void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
+ SLANG_OVERRIDE;
+
+ RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink) SLANG_OVERRIDE;
+ private:
+ SubtypeWitness* m_subtypeWitness;
+ ModuleDependencyList m_moduleDependency;
+ FilePathDependencyList m_pathDependency;
+ List<RefPtr<Module>> m_requirements;
+ HashSet<Module*> m_requirementSet;
+ RefPtr<IRModule> m_irModule;
+ Int m_conformanceIdOverride;
+ void addDepedencyFromWitness(SubtypeWitness* witness);
+ };
+
enum class PassThroughMode : SlangPassThroughIntegral
{
None = SLANG_PASS_THROUGH_NONE, ///< don't pass through: use Slang compiler
@@ -1319,6 +1438,12 @@ namespace Slang
slang::TypeReflection* type,
slang::TypeReflection* interfaceType,
uint32_t* outId) override;
+ SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType(
+ slang::TypeReflection* type,
+ slang::TypeReflection* interfaceType,
+ slang::ITypeConformance** outConformance,
+ SlangInt conformanceIdOverride,
+ ISlangBlob** outDiagnostics) override;
SLANG_NO_THROW SlangResult SLANG_MCALL createCompileRequest(
SlangCompileRequest** outCompileRequest) override;
@@ -1756,6 +1881,7 @@ namespace Slang
virtual void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) = 0;
virtual void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0;
virtual void visitSpecialized(SpecializedComponentType* specialized) = 0;
+ virtual void visitTypeConformance(TypeConformance* conformance) = 0;
protected:
// These helpers can be used to recurse into the logical children of a
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 480fe504c..36985efd6 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -274,7 +274,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
default:
{
- // In the deafult case, assume that we have some sort of "hoistable"
+ // In the default case, assume that we have some sort of "hoistable"
// instruction that requires us to create a clone of it.
UInt argCount = originalValue->getOperandCount();
@@ -439,6 +439,8 @@ static void cloneExtraDecorations(
case kIROp_BindExistentialSlotsDecoration:
case kIROp_LayoutDecoration:
+ case kIROp_PublicDecoration:
+ case kIROp_SequentialIDDecoration:
if(!clonedInst->findDecorationImpl(decoration->getOp()))
{
cloneInst(context, builder, decoration);
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index e9300ae06..9522dbeca 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -2158,6 +2158,11 @@ IRInst* specializeGenericImpl(
if( auto returnValInst = as<IRReturnVal>(ii) )
{
auto specializedVal = findCloneForOperand(&env, returnValInst->getVal());
+
+ // Clone decorations on the orignal `specialize` inst over to the newly specialized
+ // value.
+ cloneInstDecorationsAndChildren(
+ &env, &sharedBuilderStorage, specializeInst, specializedVal);
return specializedVal;
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index c7e32072e..8a40dbde9 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8255,6 +8255,11 @@ struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor
{
visitChildren(specialized);
}
+
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
};
RefPtr<IRModule> generateIRForSpecializedComponentType(
@@ -8265,6 +8270,65 @@ RefPtr<IRModule> generateIRForSpecializedComponentType(
return context.process(componentType, sink);
}
+ /// Context for generating IR code to represent a `TypeConformance`
+struct TypeConformanceIRGenContext
+{
+ DiagnosticSink* sink;
+ Linkage* linkage;
+ Session* session;
+ IRGenContext* context;
+ IRBuilder* builder;
+
+ RefPtr<IRModule> process(
+ TypeConformance* typeConformance,
+ Int conformanceIdOverride,
+ DiagnosticSink* inSink)
+ {
+ sink = inSink;
+
+ linkage = typeConformance->getLinkage();
+ session = linkage->getSessionImpl();
+
+ SharedIRGenContext sharedContextStorage(session, sink, linkage->m_obfuscateCode);
+ SharedIRGenContext* sharedContext = &sharedContextStorage;
+
+ IRGenContext contextStorage(sharedContext, linkage->getASTBuilder());
+ context = &contextStorage;
+
+ SharedIRBuilder sharedBuilderStorage;
+ SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
+ sharedBuilder->module = nullptr;
+ sharedBuilder->session = session;
+
+ IRBuilder builderStorage;
+ builder = &builderStorage;
+ builder->sharedBuilder = sharedBuilder;
+
+ RefPtr<IRModule> module = builder->createModule();
+ sharedBuilder->module = module;
+
+ builder->setInsertInto(module->getModuleInst());
+
+ context->irBuilder = builder;
+
+ auto witness = lowerSimpleVal(context, typeConformance->getSubtypeWitness());
+ builder->addPublicDecoration(witness);
+ if (conformanceIdOverride != -1)
+ {
+ builder->addSequentialIDDecoration(witness, conformanceIdOverride);
+ }
+ return module;
+ }
+};
+
+RefPtr<IRModule> generateIRForTypeConformance(
+ TypeConformance* typeConformance,
+ Int conformanceIdOverride,
+ DiagnosticSink* sink)
+{
+ TypeConformanceIRGenContext context;
+ return context.process(typeConformance, conformanceIdOverride, sink);
+}
RefPtr<IRModule> TargetProgram::getOrCreateIRModuleForLayout(DiagnosticSink* sink)
{
diff --git a/source/slang/slang-lower-to-ir.h b/source/slang/slang-lower-to-ir.h
index dbf2e550a..ce7f9eaf0 100644
--- a/source/slang/slang-lower-to-ir.h
+++ b/source/slang/slang-lower-to-ir.h
@@ -43,5 +43,14 @@ namespace Slang
RefPtr<IRModule> generateIRForSpecializedComponentType(
SpecializedComponentType* componentType,
DiagnosticSink* sink);
-}
+
+ /// Generate an IR module to represent a user specified `TypeConformance` component type.
+ /// The generated IR will include an extern symbol representing the type conformance
+ /// (typically a `IRWitnessTable` or a `specialize(IRWitnessTable)` inst), with a `public`
+ /// decoration to keep the referenced witness table alive during linking.
+ RefPtr<IRModule> generateIRForTypeConformance(
+ TypeConformance* typeConformance,
+ Int conformanceIdOverride,
+ DiagnosticSink* sink);
+ }
#endif
diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp
index 9b080a11c..6f9f127f8 100644
--- a/source/slang/slang-parameter-binding.cpp
+++ b/source/slang/slang-parameter-binding.cpp
@@ -2735,6 +2735,11 @@ struct CollectGlobalGenericArgumentsVisitor : ComponentTypeVisitor
SLANG_UNUSED(specializationInfo);
}
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
+
void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE
{
SLANG_UNUSED(module);
@@ -2920,6 +2925,10 @@ struct CollectParametersVisitor : ComponentTypeVisitor
}
}
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
};
/// Recursively collect the global shader parameters and entry points in `program`.
@@ -3146,6 +3155,11 @@ struct CompleteBindingsVisitor : ComponentTypeVisitor
auto base = specialized->getBaseComponentType();
_completeBindings(m_context, base, m_counters);
}
+
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
};
/// A visitor used by `_completeBindings`.
@@ -3272,6 +3286,10 @@ struct FlushPendingDataVisitor : ComponentTypeVisitor
m_counters->entryPointCounter += specialized->getEntryPointCount();
}
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
};
static void _completeBindings(
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index ef6872e76..e62d2c9ae 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -1043,6 +1043,33 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessSequent
return SLANG_OK;
}
+SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentType(
+ slang::TypeReflection* type,
+ slang::TypeReflection* interfaceType,
+ slang::ITypeConformance** outConformanceComponentType,
+ SlangInt conformanceIdOverride,
+ ISlangBlob** outDiagnostics)
+{
+ RefPtr<TypeConformance> result;
+ DiagnosticSink sink;
+ try
+ {
+ SharedSemanticsContext sharedSemanticsContext(this, nullptr, &sink);
+ SemanticsVisitor visitor(&sharedSemanticsContext);
+ auto witness =
+ visitor.tryGetSubtypeWitness((Slang::Type*)type, (Slang::Type*)interfaceType);
+ if (auto subtypeWitness = as<SubtypeWitness>(witness))
+ {
+ result = new TypeConformance(this, subtypeWitness, conformanceIdOverride, &sink);
+ }
+ }
+ catch (...)
+ {}
+ sink.getBlobIfNeeded(outDiagnostics);
+ *outConformanceComponentType = result.detach();
+ return result ? SLANG_OK : SLANG_FAIL;
+}
+
SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompileRequest(
SlangCompileRequest** outCompileRequest)
{
@@ -3041,6 +3068,11 @@ struct EnumerateModulesVisitor : ComponentTypeVisitor
{
visitChildren(specialized);
}
+
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
};
@@ -3079,6 +3111,11 @@ struct EnumerateIRModulesVisitor : ComponentTypeVisitor
m_callback(specialized->getIRModule(), m_userData);
}
+
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ m_callback(conformance->getIRModule(), m_userData);
+ }
};
void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void* userData)
@@ -3401,6 +3438,11 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor
{
visitChildren(specialized);
}
+
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
};
SpecializedComponentType::SpecializedComponentType(
@@ -3616,7 +3658,10 @@ SpecializedComponentType::SpecializedComponentType(
{ visitChildren(composite, specializationInfo); }
void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE
{ visitChildren(specialized); }
-
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
EntryPointMangledNameCollector(ASTBuilder* astBuilder):
m_astBuilder(astBuilder)
{