diff options
| author | Yong He <yonghe@outlook.com> | 2024-09-30 12:50:30 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-30 12:50:30 -0700 |
| commit | 15d1c6c51c5f24663d2567d7e56da62a2bca1c22 (patch) | |
| tree | 78733e327dca1421f39ecff4073463d74500c14c /source | |
| parent | bc11579dd998224bcb429d88aeb07d49e2217a35 (diff) | |
Add COM API for querying metadata. (#5168)
* Add COM API for querying metadata.
* Fix tests.
* fix test.
Diffstat (limited to 'source')
| -rw-r--r-- | source/compiler-core/slang-artifact-associated-impl.cpp | 20 | ||||
| -rw-r--r-- | source/compiler-core/slang-artifact-associated-impl.h | 8 | ||||
| -rw-r--r-- | source/compiler-core/slang-artifact-associated.h | 2 | ||||
| -rw-r--r-- | source/slang-record-replay/record/slang-component-type.cpp | 19 | ||||
| -rw-r--r-- | source/slang-record-replay/record/slang-component-type.h | 9 | ||||
| -rw-r--r-- | source/slang-record-replay/record/slang-entrypoint.h | 17 | ||||
| -rw-r--r-- | source/slang-record-replay/record/slang-module.h | 17 | ||||
| -rw-r--r-- | source/slang-record-replay/record/slang-type-conformance.h | 17 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 65 | ||||
| -rw-r--r-- | source/slang/slang-parameter-binding.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 97 |
11 files changed, 255 insertions, 27 deletions
diff --git a/source/compiler-core/slang-artifact-associated-impl.cpp b/source/compiler-core/slang-artifact-associated-impl.cpp index f29c0f596..07d69bf8e 100644 --- a/source/compiler-core/slang-artifact-associated-impl.cpp +++ b/source/compiler-core/slang-artifact-associated-impl.cpp @@ -313,4 +313,24 @@ Slice<String> ArtifactPostEmitMetadata::getExportedFunctionMangledNames() return Slice<String>(m_exportedFunctionMangledNames.getBuffer(), m_exportedFunctionMangledNames.getCount()); } +SlangResult ArtifactPostEmitMetadata::isParameterLocationUsed( + SlangParameterCategory category, + SlangUInt spaceIndex, + SlangUInt registerIndex, + bool& outUsed) +{ + for (const auto& range : getUsedBindingRanges()) + { + if (range.containsBinding((slang::ParameterCategory)category, spaceIndex, registerIndex)) + { + outUsed = true; + return SLANG_OK; + } + } + + outUsed = false; + return SLANG_OK; +} + + } // namespace Slang diff --git a/source/compiler-core/slang-artifact-associated-impl.h b/source/compiler-core/slang-artifact-associated-impl.h index a6e323b0a..d11498604 100644 --- a/source/compiler-core/slang-artifact-associated-impl.h +++ b/source/compiler-core/slang-artifact-associated-impl.h @@ -134,6 +134,7 @@ struct ShaderBindingRange case slang::ShaderResource: case slang::UnorderedAccess: case slang::SamplerState: + case slang::DescriptorTableSlot: return true; default: return false; @@ -157,6 +158,13 @@ public: SLANG_NO_THROW virtual Slice<ShaderBindingRange> SLANG_MCALL getUsedBindingRanges() SLANG_OVERRIDE; SLANG_NO_THROW virtual Slice<String> SLANG_MCALL getExportedFunctionMangledNames() SLANG_OVERRIDE; + // IMetadata + SLANG_NO_THROW virtual SlangResult SLANG_MCALL isParameterLocationUsed( + SlangParameterCategory category, // is this a `t` register? `s` register? + SlangUInt spaceIndex, // `space` for D3D12, `set` for Vulkan + SlangUInt registerIndex, // `register` for D3D12, `binding` for Vulkan + bool& outUsed) SLANG_OVERRIDE; + void* getInterface(const Guid& uuid); void* getObject(const Guid& uuid); diff --git a/source/compiler-core/slang-artifact-associated.h b/source/compiler-core/slang-artifact-associated.h index 766494271..91ae09aab 100644 --- a/source/compiler-core/slang-artifact-associated.h +++ b/source/compiler-core/slang-artifact-associated.h @@ -117,7 +117,7 @@ public: struct ShaderBindingRange; -class IArtifactPostEmitMetadata : public ICastable +class IArtifactPostEmitMetadata : public slang::IMetadata { public: SLANG_COM_INTERFACE(0x5d03bce9, 0xafb1, 0x4fc8, { 0xa4, 0x6f, 0x3c, 0xe0, 0x7b, 0x6, 0x1b, 0x1b }); diff --git a/source/slang-record-replay/record/slang-component-type.cpp b/source/slang-record-replay/record/slang-component-type.cpp index 5ebf73574..bc38acfae 100644 --- a/source/slang-record-replay/record/slang-component-type.cpp +++ b/source/slang-record-replay/record/slang-component-type.cpp @@ -127,6 +127,25 @@ namespace SlangRecord return res; } + SLANG_NO_THROW SlangResult SLANG_MCALL IComponentTypeRecorder::getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) + { + // No need to record this call. + return m_actualComponentType->getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL IComponentTypeRecorder::getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) + { + // No need to record this call. + return m_actualComponentType->getTargetMetadata(targetIndex, outMetadata, outDiagnostics); + } + SLANG_NO_THROW SlangResult IComponentTypeRecorder::getResultAsFileSystem( SlangInt entryPointIndex, SlangInt targetIndex, diff --git a/source/slang-record-replay/record/slang-component-type.h b/source/slang-record-replay/record/slang-component-type.h index 110733044..52e07d9c0 100644 --- a/source/slang-record-replay/record/slang-component-type.h +++ b/source/slang-record-replay/record/slang-component-type.h @@ -61,6 +61,15 @@ namespace SlangRecord SlangInt targetIndex, slang::IBlob** outCode, slang::IBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics = nullptr) override; protected: virtual ApiClassId getClassId() = 0; virtual SessionRecorder* getSessionRecorder() = 0; diff --git a/source/slang-record-replay/record/slang-entrypoint.h b/source/slang-record-replay/record/slang-entrypoint.h index abe2dc9c0..29a3329fe 100644 --- a/source/slang-record-replay/record/slang-entrypoint.h +++ b/source/slang-record-replay/record/slang-entrypoint.h @@ -69,6 +69,23 @@ namespace SlangRecord return Super::getTargetCode(targetIndex, outCode, outDiagnostics); } + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); + } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( SlangInt entryPointIndex, SlangInt targetIndex, diff --git a/source/slang-record-replay/record/slang-module.h b/source/slang-record-replay/record/slang-module.h index e793bea98..ce009eac8 100644 --- a/source/slang-record-replay/record/slang-module.h +++ b/source/slang-record-replay/record/slang-module.h @@ -85,6 +85,23 @@ namespace SlangRecord return Super::getTargetCode(targetIndex, outCode, outDiagnostics); } + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); + } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( SlangInt entryPointIndex, SlangInt targetIndex, diff --git a/source/slang-record-replay/record/slang-type-conformance.h b/source/slang-record-replay/record/slang-type-conformance.h index d1dfb3238..6c6889d02 100644 --- a/source/slang-record-replay/record/slang-type-conformance.h +++ b/source/slang-record-replay/record/slang-type-conformance.h @@ -65,6 +65,23 @@ namespace SlangRecord return Super::getTargetCode(targetIndex, outCode, outDiagnostics); } + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); + } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( SlangInt entryPointIndex, SlangInt targetIndex, diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 820cab03f..17501163f 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -317,10 +317,22 @@ namespace Slang SlangInt targetIndex, slang::IBlob** outCode, slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + + IArtifact* getTargetArtifact(SlangInt targetIndex, slang::IBlob** outDiagnostics); + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( SlangInt targetIndex, slang::IBlob** outCode, slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( SlangInt entryPointIndex, @@ -580,6 +592,8 @@ namespace Slang Scope* m_lookupScope = nullptr; std::unique_ptr<Dictionary<String, IntVal*>> m_mapMangledNameToIntVal; + + Dictionary<Int, ComPtr<IArtifact>> m_targetArtifacts; }; /// A component type built up from other component types. @@ -914,6 +928,23 @@ namespace Slang return Super::getTargetCode(targetIndex, outCode, outDiagnostics); } + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); + } + SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( SlangInt entryPointIndex, SlangInt targetIndex, @@ -1159,6 +1190,23 @@ namespace Slang return Super::getTargetCode(targetIndex, outCode, outDiagnostics); } + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); + } + SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( SlangInt entryPointIndex, SlangInt targetIndex, @@ -1460,6 +1508,23 @@ namespace Slang return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); } + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); + } + /// Get a serialized representation of the checked module. virtual SLANG_NO_THROW SlangResult SLANG_MCALL serialize(ISlangBlob** outSerializedBlob) override; diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index ea22cbcba..5d3331c6f 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -1225,10 +1225,6 @@ static void addExplicitParameterBindings_GLSL( } } - // We use the HLSL binding directly (even though this notionally for GLSL/Vulkan) - // We'll do the shifting at later later point in _maybeApplyHLSLToVulkanShifts - info[kResInfo].resInfo = typeLayout->findOrAddResourceInfo(hlslInfo.kind); - if (warnedMissingVulkanLayoutModifier) { // If we warn due to invalid bindings and user did not set how to interpret 'hlsl style bindings', we should map @@ -1236,7 +1232,7 @@ static void addExplicitParameterBindings_GLSL( if(!hlslToVulkanLayoutOptions || hlslToVulkanLayoutOptions->getKindShiftEnabledFlags() == HLSLToVulkanLayoutOptions::KindFlag::None) { - info[kResInfo].resInfo->kind = LayoutResourceKind::DescriptorTableSlot; + info[kResInfo].resInfo = typeLayout->findOrAddResourceInfo(LayoutResourceKind::DescriptorTableSlot); info[kResInfo].resInfo->count = 1; } else @@ -1245,6 +1241,11 @@ static void addExplicitParameterBindings_GLSL( } } + // We use the HLSL binding directly (even though this notionally for GLSL/Vulkan) + // We'll do the shifting at later later point in _maybeApplyHLSLToVulkanShifts + if (!info[kResInfo].resInfo) + info[kResInfo].resInfo = typeLayout->findOrAddResourceInfo(hlslInfo.kind); + info[kResInfo].semanticInfo.kind = info[kResInfo].resInfo->kind; info[kResInfo].semanticInfo.index = UInt(hlslInfo.index); info[kResInfo].semanticInfo.space = UInt(hlslInfo.space); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index dc5f9a755..c9717272b 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -4704,6 +4704,38 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointHostCallable( return artifact->loadSharedLibrary(ArtifactKeep::Yes, outSharedLibrary); } +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointMetadata( + SlangInt entryPointIndex, + Int targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) +{ + auto linkage = getLinkage(); + if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) + return SLANG_E_INVALID_ARG; + auto target = linkage->targets[targetIndex]; + + auto targetProgram = getTargetProgram(target); + + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + + IArtifact* artifact = targetProgram->getOrCreateEntryPointResult(entryPointIndex, &sink); + sink.getBlobIfNeeded(outDiagnostics); + + if (artifact == nullptr) + return SLANG_E_NOT_AVAILABLE; + + auto metadata = findAssociatedRepresentation<IArtifactPostEmitMetadata>(artifact); + if (!metadata) + return SLANG_E_NOT_AVAILABLE; + + *outMetadata = static_cast<slang::IMetadata*>(metadata); + (*outMetadata)->addRef(); + return SLANG_OK; +} + RefPtr<ComponentType> ComponentType::specialize( SpecializationArg const* inSpecializationArgs, SlangInt specializationArgCount, @@ -4933,14 +4965,16 @@ void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void acceptVisitor(&visitor, nullptr); } -SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetCode( - Int targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) +IArtifact* ComponentType::getTargetArtifact(Int targetIndex, slang::IBlob** outDiagnostics) { auto linkage = getLinkage(); if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) - return SLANG_E_INVALID_ARG; + return nullptr; + ComPtr<IArtifact> artifact; + if (m_targetArtifacts.tryGetValue(targetIndex, artifact)) + { + return artifact.get(); + } // If the user hasn't specified any entry points, then we should // discover all entrypoints that are defined in linked modules, and @@ -4964,8 +4998,13 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetCode( } RefPtr<CompositeComponentType> composite = new CompositeComponentType(linkage, components); ComPtr<IComponentType> linkedComponentType; - SLANG_RETURN_ON_FAIL(composite->link(linkedComponentType.writeRef(), outDiagnostics)); - return linkedComponentType->getTargetCode(targetIndex, outCode, outDiagnostics); + SLANG_RETURN_NULL_ON_FAIL(composite->link(linkedComponentType.writeRef(), outDiagnostics)); + auto targetArtifact = static_cast<ComponentType*>(linkedComponentType.get())->getTargetArtifact(targetIndex, outDiagnostics); + if (targetArtifact) + { + m_targetArtifacts[targetIndex] = targetArtifact; + } + return targetArtifact; } auto target = linkage->targets[targetIndex]; @@ -4975,8 +5014,18 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetCode( applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); - IArtifact* artifact = targetProgram->getOrCreateWholeProgramResult(&sink); + IArtifact* targetArtifact = targetProgram->getOrCreateWholeProgramResult(&sink); sink.getBlobIfNeeded(outDiagnostics); + m_targetArtifacts[targetIndex] = ComPtr<IArtifact>(targetArtifact); + return targetArtifact; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetCode( + Int targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) +{ + IArtifact* artifact = getTargetArtifact(targetIndex, outDiagnostics); if (artifact == nullptr) return SLANG_FAIL; @@ -4984,6 +5033,24 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetCode( return artifact->loadBlob(ArtifactKeep::Yes, outCode); } +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetMetadata( + Int targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) +{ + IArtifact* artifact = getTargetArtifact(targetIndex, outDiagnostics); + + if (artifact == nullptr) + return SLANG_FAIL; + + auto metadata = findAssociatedRepresentation<IArtifactPostEmitMetadata>(artifact); + if (!metadata) + return SLANG_E_NOT_AVAILABLE; + *outMetadata = static_cast<slang::IMetadata*>(metadata); + (*outMetadata)->addRef(); + return SLANG_OK; +} + // // CompositeComponentType // @@ -6888,19 +6955,7 @@ SlangResult EndToEndCompileRequest::isParameterLocationUsed(Int entryPointIndex, if (!metadata) return SLANG_E_NOT_AVAILABLE; - - // TODO: optimize this with a binary search through a sorted list - for (const auto& range : metadata->getUsedBindingRanges()) - { - if (range.containsBinding((slang::ParameterCategory)category, spaceIndex, registerIndex)) - { - outUsed = true; - return SLANG_OK; - } - } - - outUsed = false; - return SLANG_OK; + return metadata->isParameterLocationUsed(category, spaceIndex, registerIndex, outUsed); } } // namespace Slang |
