diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-compiler.cpp | 86 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 126 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 64 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-parameter-binding.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 47 |
8 files changed, 358 insertions, 3 deletions
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 997f3fd51..5d0ae2f8d 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -296,6 +296,92 @@ namespace Slang return empty; } + TypeConformance::TypeConformance( + Linkage* linkage, + SubtypeWitness* witness, + Int confomrmanceIdOverride, + DiagnosticSink* sink) + : ComponentType(linkage) + , m_subtypeWitness(witness) + , m_conformanceIdOverride(confomrmanceIdOverride) + { + addDepedencyFromWitness(witness); + m_irModule = generateIRForTypeConformance(this, m_conformanceIdOverride, sink); + } + + void TypeConformance::addDepedencyFromWitness(SubtypeWitness* witness) + { + if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness)) + { + auto declModule = getModule(declaredWitness->declRef.getDecl()); + m_moduleDependency.addDependency(declModule); + m_pathDependency.addDependency(declModule); + if (m_requirementSet.Add(declModule)) + { + m_requirements.add(declModule); + } + // TODO: handle the specialization arguments in declaredWitness->declRef.substitutions. + } + else if (auto transitiveWitness = as<TransitiveSubtypeWitness>(witness)) + { + addDepedencyFromWitness(transitiveWitness->midToSup); + addDepedencyFromWitness(transitiveWitness->subToMid); + } + else if (auto conjunctionWitness = as<ConjunctionSubtypeWitness>(witness)) + { + auto left = as<SubtypeWitness>(conjunctionWitness->leftWitness); + if (left) + addDepedencyFromWitness(left); + auto right = as<SubtypeWitness>(conjunctionWitness->rightWitness); + if (right) + addDepedencyFromWitness(right); + } + } + + ISlangUnknown* TypeConformance::getInterface(const Guid& guid) + { + if (guid == slang::ITypeConformance::getTypeGuid()) + return static_cast<slang::ITypeConformance*>(this); + + return Super::getInterface(guid); + } + + List<Module*> const& TypeConformance::getModuleDependencies() + { + return m_moduleDependency.getModuleList(); + } + + List<String> const& TypeConformance::getFilePathDependencies() + { + return m_pathDependency.getFilePathList(); + } + + Index TypeConformance::getRequirementCount() { return m_requirements.getCount(); } + + RefPtr<ComponentType> TypeConformance::getRequirement(Index index) + { + return m_requirements[index]; + } + + void TypeConformance::acceptVisitor( + ComponentTypeVisitor* visitor, + ComponentType::SpecializationInfo* specializationInfo) + { + SLANG_UNUSED(specializationInfo); + visitor->visitTypeConformance(this); + } + + RefPtr<ComponentType::SpecializationInfo> TypeConformance::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) + { + SLANG_UNUSED(args); + SLANG_UNUSED(argCount); + SLANG_UNUSED(sink); + return nullptr; + } + // Profile Profile::lookUp(UnownedStringSlice const& name) diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 1724794cf..727d21c3a 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -850,6 +850,125 @@ namespace Slang Profile m_profile; }; + class TypeConformance + : public ComponentType + , public slang::ITypeConformance + { + typedef ComponentType Super; + + public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + + ISlangUnknown* getInterface(const Guid& guid); + + TypeConformance( + Linkage* linkage, + SubtypeWitness* witness, + Int confomrmanceIdOverride, + DiagnosticSink* sink); + + // Forward `IComponentType` methods + + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE + { + return Super::getSession(); + } + + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL + getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getLayout(targetIndex, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::specialize( + specializationArgs, + specializationArgCount, + outSpecializedComponentType, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL link( + slang::IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::link(outLinkedComponentType, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointHostCallable( + entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); + } + + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE; + List<String> const& getFilePathDependencies() SLANG_OVERRIDE; + + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; } + + /// Get the existential type parameter at `index`. + SpecializationParam const& getSpecializationParam(Index /*index*/) SLANG_OVERRIDE + { + static SpecializationParam emptyParam; + return emptyParam; + } + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; + Index getEntryPointCount() SLANG_OVERRIDE { return 0; }; + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return nullptr; + } + String getEntryPointMangledName(Index /*index*/) SLANG_OVERRIDE { return ""; } + + Index getShaderParamCount() SLANG_OVERRIDE { return 0; } + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return ShaderParamInfo(); + } + + SubtypeWitness* getSubtypeWitness() { return m_subtypeWitness; } + IRModule* getIRModule() { return m_irModule.Ptr(); } + protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; + + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; + private: + SubtypeWitness* m_subtypeWitness; + ModuleDependencyList m_moduleDependency; + FilePathDependencyList m_pathDependency; + List<RefPtr<Module>> m_requirements; + HashSet<Module*> m_requirementSet; + RefPtr<IRModule> m_irModule; + Int m_conformanceIdOverride; + void addDepedencyFromWitness(SubtypeWitness* witness); + }; + enum class PassThroughMode : SlangPassThroughIntegral { None = SLANG_PASS_THROUGH_NONE, ///< don't pass through: use Slang compiler @@ -1319,6 +1438,12 @@ namespace Slang slang::TypeReflection* type, slang::TypeReflection* interfaceType, uint32_t* outId) override; + SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + slang::ITypeConformance** outConformance, + SlangInt conformanceIdOverride, + ISlangBlob** outDiagnostics) override; SLANG_NO_THROW SlangResult SLANG_MCALL createCompileRequest( SlangCompileRequest** outCompileRequest) override; @@ -1756,6 +1881,7 @@ namespace Slang virtual void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) = 0; virtual void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0; virtual void visitSpecialized(SpecializedComponentType* specialized) = 0; + virtual void visitTypeConformance(TypeConformance* conformance) = 0; protected: // These helpers can be used to recurse into the logical children of a diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 480fe504c..36985efd6 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -274,7 +274,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) default: { - // In the deafult case, assume that we have some sort of "hoistable" + // In the default case, assume that we have some sort of "hoistable" // instruction that requires us to create a clone of it. UInt argCount = originalValue->getOperandCount(); @@ -439,6 +439,8 @@ static void cloneExtraDecorations( case kIROp_BindExistentialSlotsDecoration: case kIROp_LayoutDecoration: + case kIROp_PublicDecoration: + case kIROp_SequentialIDDecoration: if(!clonedInst->findDecorationImpl(decoration->getOp())) { cloneInst(context, builder, decoration); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index e9300ae06..9522dbeca 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -2158,6 +2158,11 @@ IRInst* specializeGenericImpl( if( auto returnValInst = as<IRReturnVal>(ii) ) { auto specializedVal = findCloneForOperand(&env, returnValInst->getVal()); + + // Clone decorations on the orignal `specialize` inst over to the newly specialized + // value. + cloneInstDecorationsAndChildren( + &env, &sharedBuilderStorage, specializeInst, specializedVal); return specializedVal; } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index c7e32072e..8a40dbde9 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8255,6 +8255,11 @@ struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor { visitChildren(specialized); } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } }; RefPtr<IRModule> generateIRForSpecializedComponentType( @@ -8265,6 +8270,65 @@ RefPtr<IRModule> generateIRForSpecializedComponentType( return context.process(componentType, sink); } + /// Context for generating IR code to represent a `TypeConformance` +struct TypeConformanceIRGenContext +{ + DiagnosticSink* sink; + Linkage* linkage; + Session* session; + IRGenContext* context; + IRBuilder* builder; + + RefPtr<IRModule> process( + TypeConformance* typeConformance, + Int conformanceIdOverride, + DiagnosticSink* inSink) + { + sink = inSink; + + linkage = typeConformance->getLinkage(); + session = linkage->getSessionImpl(); + + SharedIRGenContext sharedContextStorage(session, sink, linkage->m_obfuscateCode); + SharedIRGenContext* sharedContext = &sharedContextStorage; + + IRGenContext contextStorage(sharedContext, linkage->getASTBuilder()); + context = &contextStorage; + + SharedIRBuilder sharedBuilderStorage; + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->module = nullptr; + sharedBuilder->session = session; + + IRBuilder builderStorage; + builder = &builderStorage; + builder->sharedBuilder = sharedBuilder; + + RefPtr<IRModule> module = builder->createModule(); + sharedBuilder->module = module; + + builder->setInsertInto(module->getModuleInst()); + + context->irBuilder = builder; + + auto witness = lowerSimpleVal(context, typeConformance->getSubtypeWitness()); + builder->addPublicDecoration(witness); + if (conformanceIdOverride != -1) + { + builder->addSequentialIDDecoration(witness, conformanceIdOverride); + } + return module; + } +}; + +RefPtr<IRModule> generateIRForTypeConformance( + TypeConformance* typeConformance, + Int conformanceIdOverride, + DiagnosticSink* sink) +{ + TypeConformanceIRGenContext context; + return context.process(typeConformance, conformanceIdOverride, sink); +} RefPtr<IRModule> TargetProgram::getOrCreateIRModuleForLayout(DiagnosticSink* sink) { diff --git a/source/slang/slang-lower-to-ir.h b/source/slang/slang-lower-to-ir.h index dbf2e550a..ce7f9eaf0 100644 --- a/source/slang/slang-lower-to-ir.h +++ b/source/slang/slang-lower-to-ir.h @@ -43,5 +43,14 @@ namespace Slang RefPtr<IRModule> generateIRForSpecializedComponentType( SpecializedComponentType* componentType, DiagnosticSink* sink); -} + + /// Generate an IR module to represent a user specified `TypeConformance` component type. + /// The generated IR will include an extern symbol representing the type conformance + /// (typically a `IRWitnessTable` or a `specialize(IRWitnessTable)` inst), with a `public` + /// decoration to keep the referenced witness table alive during linking. + RefPtr<IRModule> generateIRForTypeConformance( + TypeConformance* typeConformance, + Int conformanceIdOverride, + DiagnosticSink* sink); + } #endif diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index 9b080a11c..6f9f127f8 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -2735,6 +2735,11 @@ struct CollectGlobalGenericArgumentsVisitor : ComponentTypeVisitor SLANG_UNUSED(specializationInfo); } + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE { SLANG_UNUSED(module); @@ -2920,6 +2925,10 @@ struct CollectParametersVisitor : ComponentTypeVisitor } } + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } }; /// Recursively collect the global shader parameters and entry points in `program`. @@ -3146,6 +3155,11 @@ struct CompleteBindingsVisitor : ComponentTypeVisitor auto base = specialized->getBaseComponentType(); _completeBindings(m_context, base, m_counters); } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } }; /// A visitor used by `_completeBindings`. @@ -3272,6 +3286,10 @@ struct FlushPendingDataVisitor : ComponentTypeVisitor m_counters->entryPointCounter += specialized->getEntryPointCount(); } + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } }; static void _completeBindings( diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index ef6872e76..e62d2c9ae 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1043,6 +1043,33 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessSequent return SLANG_OK; } +SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentType( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + slang::ITypeConformance** outConformanceComponentType, + SlangInt conformanceIdOverride, + ISlangBlob** outDiagnostics) +{ + RefPtr<TypeConformance> result; + DiagnosticSink sink; + try + { + SharedSemanticsContext sharedSemanticsContext(this, nullptr, &sink); + SemanticsVisitor visitor(&sharedSemanticsContext); + auto witness = + visitor.tryGetSubtypeWitness((Slang::Type*)type, (Slang::Type*)interfaceType); + if (auto subtypeWitness = as<SubtypeWitness>(witness)) + { + result = new TypeConformance(this, subtypeWitness, conformanceIdOverride, &sink); + } + } + catch (...) + {} + sink.getBlobIfNeeded(outDiagnostics); + *outConformanceComponentType = result.detach(); + return result ? SLANG_OK : SLANG_FAIL; +} + SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompileRequest( SlangCompileRequest** outCompileRequest) { @@ -3041,6 +3068,11 @@ struct EnumerateModulesVisitor : ComponentTypeVisitor { visitChildren(specialized); } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } }; @@ -3079,6 +3111,11 @@ struct EnumerateIRModulesVisitor : ComponentTypeVisitor m_callback(specialized->getIRModule(), m_userData); } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + m_callback(conformance->getIRModule(), m_userData); + } }; void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void* userData) @@ -3401,6 +3438,11 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor { visitChildren(specialized); } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } }; SpecializedComponentType::SpecializedComponentType( @@ -3616,7 +3658,10 @@ SpecializedComponentType::SpecializedComponentType( { visitChildren(composite, specializationInfo); } void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE { visitChildren(specialized); } - + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } EntryPointMangledNameCollector(ASTBuilder* astBuilder): m_astBuilder(astBuilder) { |
