summaryrefslogtreecommitdiff
path: root/tools/gfx/renderer-shared.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tools/gfx/renderer-shared.cpp')
-rw-r--r--tools/gfx/renderer-shared.cpp441
1 files changed, 441 insertions, 0 deletions
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index 2072d52ef..a2e1751fd 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -1,11 +1,60 @@
#include "renderer-shared.h"
#include "render-graphics-common.h"
+#include "core/slang-io.h"
+#include "core/slang-token-reader.h"
using namespace Slang;
namespace gfx
{
+const Slang::Guid GfxGUID::IID_ISlangUnknown = SLANG_UUID_ISlangUnknown;
+const Slang::Guid GfxGUID::IID_IDescriptorSetLayout = SLANG_UUID_IDescriptorSetLayout;
+const Slang::Guid GfxGUID::IID_IDescriptorSet = SLANG_UUID_IDescriptorSet;
+const Slang::Guid GfxGUID::IID_IShaderProgram = SLANG_UUID_IShaderProgram;
+const Slang::Guid GfxGUID::IID_IPipelineLayout = SLANG_UUID_IPipelineLayout;
+const Slang::Guid GfxGUID::IID_IInputLayout = SLANG_UUID_IInputLayout;
+const Slang::Guid GfxGUID::IID_IPipelineState = SLANG_UUID_IPipelineState;
+const Slang::Guid GfxGUID::IID_IResourceView = SLANG_UUID_IResourceView;
+const Slang::Guid GfxGUID::IID_ISamplerState = SLANG_UUID_ISamplerState;
+const Slang::Guid GfxGUID::IID_IResource = SLANG_UUID_IResource;
+const Slang::Guid GfxGUID::IID_IBufferResource = SLANG_UUID_IBufferResource;
+const Slang::Guid GfxGUID::IID_ITextureResource = SLANG_UUID_ITextureResource;
+const Slang::Guid GfxGUID::IID_IRenderer = SLANG_UUID_IRenderer;
+const Slang::Guid GfxGUID::IID_IShaderObjectLayout = SLANG_UUID_IShaderObjectLayout;
+const Slang::Guid GfxGUID::IID_IShaderObject = SLANG_UUID_IShaderObject;
+
+gfx::StageType translateStage(SlangStage slangStage)
+{
+ switch (slangStage)
+ {
+ default:
+ SLANG_ASSERT(!"unhandled case");
+ return gfx::StageType::Unknown;
+
+#define CASE(FROM, TO) \
+case SLANG_STAGE_##FROM: \
+return gfx::StageType::TO
+
+ CASE(VERTEX, Vertex);
+ CASE(HULL, Hull);
+ CASE(DOMAIN, Domain);
+ CASE(GEOMETRY, Geometry);
+ CASE(FRAGMENT, Fragment);
+
+ CASE(COMPUTE, Compute);
+
+ CASE(RAY_GENERATION, RayGeneration);
+ CASE(INTERSECTION, Intersection);
+ CASE(ANY_HIT, AnyHit);
+ CASE(CLOSEST_HIT, ClosestHit);
+ CASE(MISS, Miss);
+ CASE(CALLABLE, Callable);
+
+#undef CASE
+ }
+}
+
IResource* BufferResource::getInterface(const Slang::Guid& guid)
{
if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IResource ||
@@ -96,4 +145,396 @@ Result createProgramFromSlang(IRenderer* renderer, IShaderProgram::Desc const& o
return renderer->createProgram(programDesc, outProgram);
}
+IShaderObject* gfx::ShaderObjectBase::getInterface(const Guid& guid)
+{
+ if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObject)
+ return static_cast<IShaderObject*>(this);
+ return nullptr;
+}
+
+IShaderProgram* gfx::ShaderProgramBase::getInterface(const Guid& guid)
+{
+ if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderProgram)
+ return static_cast<IShaderProgram*>(this);
+ return nullptr;
+}
+
+IPipelineState* gfx::PipelineStateBase::getInterface(const Guid& guid)
+{
+ if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState)
+ return static_cast<IPipelineState*>(this);
+ return nullptr;
+}
+
+void PipelineStateBase::initializeBase(const PipelineStateDesc& inDesc)
+{
+ desc = inDesc;
+
+ auto program = desc.getProgram();
+ isSpecializable = (program->slangProgram && program->slangProgram->getSpecializationParamCount() != 0);
+}
+
+IRenderer* gfx::RendererBase::getInterface(const Guid& guid)
+{
+ return (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IRenderer)
+ ? static_cast<IRenderer*>(this)
+ : nullptr;
+}
+
+SLANG_NO_THROW Result SLANG_MCALL RendererBase::initialize(const Desc& desc, void* inWindowHandle)
+{
+ SLANG_UNUSED(inWindowHandle);
+
+ shaderCache.init(desc.shaderCacheFileSystem);
+ return SLANG_OK;
+}
+
+SLANG_NO_THROW Result SLANG_MCALL RendererBase::getFeatures(
+ const char** outFeatures, UInt bufferSize, UInt* outFeatureCount)
+{
+ if (bufferSize >= (UInt)m_features.getCount())
+ {
+ for (Index i = 0; i < m_features.getCount(); i++)
+ {
+ outFeatures[i] = m_features[i].getUnownedSlice().begin();
+ }
+ }
+ if (outFeatureCount)
+ *outFeatureCount = (UInt)m_features.getCount();
+ return SLANG_OK;
+}
+
+SLANG_NO_THROW bool SLANG_MCALL RendererBase::hasFeature(const char* featureName)
+{
+ return m_features.findFirstIndex([&](Slang::String x) { return x == featureName; }) != -1;
+}
+
+SLANG_NO_THROW Result SLANG_MCALL RendererBase::getSlangSession(slang::ISession** outSlangSession)
+{
+ *outSlangSession = slangContext.session.get();
+ slangContext.session->addRef();
+ return SLANG_OK;
+}
+
+ShaderComponentID ShaderCache::getComponentId(slang::TypeReflection* type)
+{
+ ComponentKey key;
+ key.typeName = UnownedStringSlice(type->getName());
+ switch (type->getKind())
+ {
+ case slang::TypeReflection::Kind::Specialized:
+ // TODO: collect specialization arguments and append them to `key`.
+ SLANG_UNIMPLEMENTED_X("specialized type");
+ default:
+ break;
+ }
+ key.updateHash();
+ return getComponentId(key);
+}
+
+ShaderComponentID ShaderCache::getComponentId(UnownedStringSlice name)
+{
+ ComponentKey key;
+ key.typeName = name;
+ key.updateHash();
+ return getComponentId(key);
+}
+
+ShaderComponentID ShaderCache::getComponentId(ComponentKey key)
+{
+ ShaderComponentID componentId = 0;
+ if (componentIds.TryGetValue(key, componentId))
+ return componentId;
+ OwningComponentKey owningTypeKey;
+ owningTypeKey.hash = key.hash;
+ owningTypeKey.typeName = key.typeName;
+ owningTypeKey.specializationArgs.addRange(key.specializationArgs);
+ ShaderComponentID resultId = static_cast<ShaderComponentID>(componentIds.Count());
+ componentIds[owningTypeKey] = resultId;
+ return resultId;
+}
+
+void ShaderCache::init(ISlangFileSystem* cacheFileSystem)
+{
+ fileSystem = cacheFileSystem;
+
+ ComPtr<ISlangBlob> indexFileBlob;
+ if (fileSystem && fileSystem->loadFile("index", indexFileBlob.writeRef()) == SLANG_OK)
+ {
+ UnownedStringSlice indexText = UnownedStringSlice(static_cast<const char*>(indexFileBlob->getBufferPointer()));
+ TokenReader reader = TokenReader(indexText);
+ auto componentCountInFileSystem = reader.ReadUInt();
+ for (uint32_t i = 0; i < componentCountInFileSystem; i++)
+ {
+ OwningComponentKey key;
+ auto componentId = reader.ReadUInt();
+ key.typeName = reader.ReadWord();
+ key.specializationArgs.setCount(reader.ReadUInt());
+ for (auto& arg : key.specializationArgs)
+ arg = reader.ReadUInt();
+ componentIds[key] = componentId;
+ }
+ }
+}
+
+void ShaderCache::writeToFileSystem(ISlangMutableFileSystem* outputFileSystem)
+{
+ StringBuilder indexBuilder;
+ indexBuilder << componentIds.Count() << Slang::EndLine;
+ for (auto id : componentIds)
+ {
+ indexBuilder << id.Value << " ";
+ indexBuilder << id.Key.typeName << " " << id.Key.specializationArgs.getCount();
+ for (auto arg : id.Key.specializationArgs)
+ indexBuilder << " " << arg;
+ indexBuilder << Slang::EndLine;
+ }
+ outputFileSystem->saveFile("index", indexBuilder.getBuffer(), indexBuilder.getLength());
+ for (auto& binary : shaderBinaries)
+ {
+ ComPtr<ISlangBlob> blob;
+ binary.Value->writeToBlob(blob.writeRef());
+ outputFileSystem->saveFile(String(binary.Key).getBuffer(), blob->getBufferPointer(), blob->getBufferSize());
+ }
+}
+
+Slang::RefPtr<ShaderBinary> ShaderCache::tryLoadShaderBinary(ShaderComponentID componentId)
+{
+ Slang::ComPtr<ISlangBlob> entryBlob;
+ Slang::RefPtr<ShaderBinary> binary;
+ if (shaderBinaries.TryGetValue(componentId, binary))
+ return binary;
+
+ if (fileSystem && fileSystem->loadFile(String(componentId).getBuffer(), entryBlob.writeRef()) == SLANG_OK)
+ {
+ binary = new ShaderBinary();
+ binary->loadFromBlob(entryBlob.get());
+ return binary;
+ }
+ return nullptr;
+}
+
+void ShaderCache::addShaderBinary(ShaderComponentID componentId, ShaderBinary* binary)
+{
+ shaderBinaries[componentId] = binary;
+}
+
+void ShaderCache::addSpecializedPipeline(PipelineKey key, Slang::RefPtr<PipelineStateBase> specializedPipeline)
+{
+ specializedPipelines[key] = specializedPipeline;
+}
+
+struct ShaderBinaryEntryHeader
+{
+ StageType stage;
+ uint32_t nameLength;
+ uint32_t codeLength;
+};
+
+Result ShaderBinary::loadFromBlob(ISlangBlob* blob)
+{
+ MemoryStreamBase memStream(Slang::FileAccess::Read, blob->getBufferPointer(), blob->getBufferSize());
+ uint32_t nameLength = 0;
+ ShaderBinaryEntryHeader header;
+ if (memStream.read(&header, sizeof(header)) != sizeof(header))
+ return SLANG_FAIL;
+ const uint8_t* name = memStream.getContents().getBuffer() + memStream.getPosition();
+ const uint8_t* code = name + header.nameLength;
+ entryPointName = reinterpret_cast<const char*>(name);
+ stage = header.stage;
+ source.addRange(code, header.codeLength);
+ return SLANG_OK;
+}
+
+Result ShaderBinary::writeToBlob(ISlangBlob** outBlob)
+{
+ OwnedMemoryStream outStream(FileAccess::Write);
+ ShaderBinaryEntryHeader header;
+ header.stage = stage;
+ header.nameLength = static_cast<uint32_t>(entryPointName.getLength() + 1);
+ header.codeLength = static_cast<uint32_t>(source.getCount());
+ outStream.write(&header, sizeof(header));
+ outStream.write(entryPointName.getBuffer(), header.nameLength - 1);
+ uint8_t zeroTerminator = 0;
+ outStream.write(&zeroTerminator, 1);
+ outStream.write(source.getBuffer(), header.codeLength);
+ RefPtr<RawBlob> blob = new RawBlob(outStream.getContents().getBuffer(), outStream.getContents().getCount());
+ *outBlob = blob.detach();
+ return SLANG_OK;
+}
+
+IShaderObjectLayout* ShaderObjectLayoutBase::getInterface(const Slang::Guid& guid)
+{
+ if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObjectLayout)
+ return static_cast<IShaderObjectLayout*>(this);
+ return nullptr;
+}
+
+void ShaderObjectLayoutBase::initBase(RendererBase* renderer, slang::TypeLayoutReflection* elementTypeLayout)
+{
+ m_renderer = renderer;
+ m_elementTypeLayout = elementTypeLayout;
+ m_componentID = m_renderer->shaderCache.getComponentId(m_elementTypeLayout->getType());
+}
+
+SLANG_NO_THROW Result SLANG_MCALL ShaderObjectBase::finalizeBindings()
+{
+ m_bindingFinalized = true;
+
+ // With all binding fixed, the shader object's type can be determined by specializing the
+ // shader object's type with the type of bound sub objects.
+ // Now obtain a componentID for the specialized shader object type from the shader cache.
+ SLANG_RETURN_ON_FAIL(getSpecializedShaderObjectType(&shaderObjectType));
+ return SLANG_OK;
+}
+
+
+// Get the final type this shader object represents. If the shader object's type has existential fields,
+// this function will return a specialized type using the bound sub-objects' type as specialization argument.
+
+Result ShaderObjectBase::getSpecializedShaderObjectType(ExtendedShaderObjectType* outType)
+{
+ if (shaderObjectType.slangType)
+ *outType = shaderObjectType;
+ ExtendedShaderObjectTypeList specializationArgs;
+ SLANG_RETURN_ON_FAIL(collectSpecializationArgs(specializationArgs));
+ if (specializationArgs.getCount() == 0)
+ {
+ shaderObjectType.componentID = getLayout()->getComponentID();
+ shaderObjectType.slangType = getLayout()->getElementTypeLayout()->getType();
+ }
+ else
+ {
+ shaderObjectType.slangType = getRenderer()->slangContext.session->specializeType(
+ getElementTypeLayout()->getType(),
+ specializationArgs.components.getArrayView().getBuffer(), specializationArgs.getCount());
+ shaderObjectType.componentID = getRenderer()->shaderCache.getComponentId(shaderObjectType.slangType);
+ }
+ *outType = shaderObjectType;
+ return SLANG_OK;
+}
+
+Result RendererBase::maybeSpecializePipeline(ShaderObjectBase* rootObject)
+{
+ auto currentPipeline = getCurrentPipeline();
+ auto pipelineType = currentPipeline->desc.type;
+ if (currentPipeline->unspecializedPipelineState)
+ currentPipeline = currentPipeline->unspecializedPipelineState;
+ // If the currently bound pipeline is specializable, we need to specialize it based on bound shader objects.
+ if (currentPipeline->isSpecializable)
+ {
+ specializationArgs.clear();
+ SLANG_RETURN_ON_FAIL(rootObject->collectSpecializationArgs(specializationArgs));
+
+ // Construct a shader cache key that represents the specialized shader kernels.
+ PipelineKey pipelineKey;
+ pipelineKey.pipeline = currentPipeline;
+ pipelineKey.specializationArgs.addRange(specializationArgs.componentIDs);
+ pipelineKey.updateHash();
+
+ auto specializedPipelineState = shaderCache.getSpecializedPipelineState(pipelineKey);
+ // Try to find specialized pipeline from shader cache.
+ if (!specializedPipelineState)
+ {
+ auto unspecializedProgram = static_cast<ShaderProgramBase*>(pipelineType == PipelineType::Compute
+ ? currentPipeline->desc.compute.program
+ : currentPipeline->desc.graphics.program);
+ List<RefPtr<ShaderBinary>> entryPointBinaries;
+ auto unspecializedProgramLayout = unspecializedProgram->slangProgram->getLayout();
+ for (SlangUInt i = 0; i < unspecializedProgramLayout->getEntryPointCount(); i++)
+ {
+ auto unspecializedEntryPoint = unspecializedProgramLayout->getEntryPointByIndex(i);
+ UnownedStringSlice entryPointName = UnownedStringSlice(unspecializedEntryPoint->getName());
+ ComponentKey specializedKernelKey;
+ specializedKernelKey.typeName = entryPointName;
+ specializedKernelKey.specializationArgs.addRange(specializationArgs.componentIDs);
+ specializedKernelKey.updateHash();
+ // If the pipeline is not created, check if the kernel binaries has been compiled.
+ auto specializedKernelComponentID = shaderCache.getComponentId(specializedKernelKey);
+ RefPtr<ShaderBinary> binary = shaderCache.tryLoadShaderBinary(specializedKernelComponentID);
+ if (!binary)
+ {
+ // If the specialized shader binary does not exist in cache, use slang to generate it.
+ entryPointBinaries.clear();
+ ComPtr<slang::IComponentType> specializedComponentType;
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto result = unspecializedProgram->slangProgram->specialize(specializationArgs.components.getArrayView().getBuffer(),
+ specializationArgs.getCount(), specializedComponentType.writeRef(), diagnosticBlob.writeRef());
+
+ // TODO: print diagnostic message via debug output interface.
+
+ if (result != SLANG_OK)
+ return result;
+
+ // Cache specialized binaries.
+ auto programLayout = specializedComponentType->getLayout();
+ for (SlangUInt j = 0; j < programLayout->getEntryPointCount(); j++)
+ {
+ auto entryPointLayout = programLayout->getEntryPointByIndex(j);
+ ComPtr<slang::IBlob> entryPointCode;
+ SLANG_RETURN_ON_FAIL(specializedComponentType->getEntryPointCode(j, 0, entryPointCode.writeRef(), diagnosticBlob.writeRef()));
+ binary = new ShaderBinary();
+ binary->stage = gfx::translateStage(entryPointLayout->getStage());
+ binary->entryPointName = entryPointLayout->getName();
+ binary->source.addRange((uint8_t*)entryPointCode->getBufferPointer(), entryPointCode->getBufferSize());
+ entryPointBinaries.add(binary);
+ shaderCache.addShaderBinary(specializedKernelComponentID, binary);
+ }
+
+ // We have already obtained all kernel binaries from this program, so break out of the outer loop since we no longer
+ // need to examine the rest of the kernels.
+ break;
+ }
+ entryPointBinaries.add(binary);
+ }
+
+ // Now create specialized shader program using compiled binaries.
+ ComPtr<IShaderProgram> specializedProgram;
+ IShaderProgram::Desc specializedProgramDesc = {};
+ specializedProgramDesc.kernelCount = unspecializedProgramLayout->getEntryPointCount();
+ ShortList<IShaderProgram::KernelDesc> kernelDescs;
+ for (Slang::Index i = 0; i < entryPointBinaries.getCount(); i++)
+ {
+ auto entryPoint = unspecializedProgramLayout->getEntryPointByIndex(i);;
+ auto& kernelDesc = kernelDescs[i];
+ kernelDesc.stage = entryPointBinaries[i]->stage;
+ kernelDesc.entryPointName = entryPointBinaries[i]->entryPointName.getBuffer();
+ kernelDesc.codeBegin = entryPointBinaries[i]->source.begin();
+ kernelDesc.codeEnd = entryPointBinaries[i]->source.end();
+ }
+ specializedProgramDesc.kernels = kernelDescs.getArrayView().getBuffer();
+ specializedProgramDesc.pipelineType = pipelineType;
+ SLANG_RETURN_ON_FAIL(createProgram(specializedProgramDesc, specializedProgram.writeRef()));
+
+ // Create specialized pipeline state.
+ ComPtr<IPipelineState> specializedPipelineState;
+ switch (pipelineType)
+ {
+ case PipelineType::Compute:
+ {
+ auto pipelineDesc = currentPipeline->desc.compute;
+ pipelineDesc.program = specializedProgram;
+ SLANG_RETURN_ON_FAIL(createComputePipelineState(pipelineDesc, specializedPipelineState.writeRef()));
+ break;
+ }
+ case PipelineType::Graphics:
+ {
+ auto pipelineDesc = currentPipeline->desc.graphics;
+ pipelineDesc.program = specializedProgram;
+ SLANG_RETURN_ON_FAIL(createGraphicsPipelineState(pipelineDesc, specializedPipelineState.writeRef()));
+ break;
+ }
+ default:
+ break;
+ }
+ auto specializedPipelineStateBase = static_cast<PipelineStateBase*>(specializedPipelineState.get());
+ specializedPipelineStateBase->unspecializedPipelineState = currentPipeline;
+ shaderCache.addSpecializedPipeline(pipelineKey, specializedPipelineStateBase);
+ }
+ setPipelineState(specializedPipelineState);
+ }
+ return SLANG_OK;
+}
+
+
} // namespace gfx