diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-21 15:44:21 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-21 15:44:21 -0700 |
| commit | 96caba75e8dfbb879eff12cbe1a4c148a259f684 (patch) | |
| tree | 1c7b2f25484ac22c738e006334d4df559bb733a5 /examples/autodiff-texture | |
| parent | 7f11f883d0781952f002b3aa3222a3aa0040f18a (diff) | |
Add texture tri-linear autodiff example. (#2715)
* Add quad texture example.
* delete output image
* remove irrelavent files
* update project files
* fix
* Update example.
* Fix.
* remove out-texture
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'examples/autodiff-texture')
| -rw-r--r-- | examples/autodiff-texture/README.md | 4 | ||||
| -rw-r--r-- | examples/autodiff-texture/buildmip.slang | 25 | ||||
| -rw-r--r-- | examples/autodiff-texture/checkerboard.jpg | bin | 0 -> 51262 bytes | |||
| -rw-r--r-- | examples/autodiff-texture/convert.slang | 25 | ||||
| -rw-r--r-- | examples/autodiff-texture/draw-quad.slang | 52 | ||||
| -rw-r--r-- | examples/autodiff-texture/learnmip.slang | 22 | ||||
| -rw-r--r-- | examples/autodiff-texture/main.cpp | 662 | ||||
| -rw-r--r-- | examples/autodiff-texture/reconstruct.slang | 48 | ||||
| -rw-r--r-- | examples/autodiff-texture/train.slang | 216 |
9 files changed, 1054 insertions, 0 deletions
diff --git a/examples/autodiff-texture/README.md b/examples/autodiff-texture/README.md new file mode 100644 index 000000000..3c7328fa7 --- /dev/null +++ b/examples/autodiff-texture/README.md @@ -0,0 +1,4 @@ +Autodiff Texture Example +=========================== + +The goal of this example is to demonstrate how to use custom backward derivative functions to propagate derivatives backwards to a texture.
\ No newline at end of file diff --git a/examples/autodiff-texture/buildmip.slang b/examples/autodiff-texture/buildmip.slang new file mode 100644 index 000000000..ebec7d754 --- /dev/null +++ b/examples/autodiff-texture/buildmip.slang @@ -0,0 +1,25 @@ +// A compute shader to build a mip-map layer using box filtering. + +cbuffer Uniforms +{ + uint dstWidth; + uint dstHeight; + RWTexture2D dstTexture; + RWTexture2D srcTexture; +} + +[shader("compute")] +[numthreads(16, 16, 1)] +void computeMain(uint3 threadIdx : SV_DispatchThreadID) +{ + uint x = threadIdx.x; + uint y = threadIdx.y; + if (x >= dstWidth) return; + if (y >= dstHeight) return; + var val0 = srcTexture[uint2(x * 2, y * 2)]; + var val1 = srcTexture[uint2(x * 2 + 1, y * 2)]; + var val2 = srcTexture[uint2(x * 2, y * 2 + 1)]; + var val3 = srcTexture[uint2(x * 2 + 1, y * 2 + 1)]; + + dstTexture[uint2(x, y)] = (val0 + val1 + val2 + val3) / 4; +} diff --git a/examples/autodiff-texture/checkerboard.jpg b/examples/autodiff-texture/checkerboard.jpg Binary files differnew file mode 100644 index 000000000..5181d78cd --- /dev/null +++ b/examples/autodiff-texture/checkerboard.jpg diff --git a/examples/autodiff-texture/convert.slang b/examples/autodiff-texture/convert.slang new file mode 100644 index 000000000..364cc286e --- /dev/null +++ b/examples/autodiff-texture/convert.slang @@ -0,0 +1,25 @@ +// A compute shader to convert from a buffer into a mip-map texture. +cbuffer Uniforms +{ + uint4 mipOffset[16]; + uint dstLayer; + uint width; + uint height; + RWStructuredBuffer<float> srcBuffer; + RWTexture2D dstTexture; +} + +[shader("compute")] +[numthreads(16, 16, 1)] +void computeMain(uint3 threadIdx : SV_DispatchThreadID) +{ + uint x = threadIdx.x; + uint y = threadIdx.y; + uint dstW = width >> dstLayer; + uint dstH = height >> dstLayer; + if (x >= dstW) return; + if (y >= dstH) return; + uint dstOffset = mipOffset[dstLayer / 4][dstLayer % 4] + (y * dstW + x) * 4; + var dstVal = float4(srcBuffer[dstOffset], srcBuffer[dstOffset + 1], srcBuffer[dstOffset + 2], srcBuffer[dstOffset + 3]); + dstTexture[uint2(x, y)] = dstVal; +} diff --git a/examples/autodiff-texture/draw-quad.slang b/examples/autodiff-texture/draw-quad.slang new file mode 100644 index 000000000..55a33b46b --- /dev/null +++ b/examples/autodiff-texture/draw-quad.slang @@ -0,0 +1,52 @@ +// Vertex and fragment shader to draw a textured quad on screen. + +cbuffer Uniforms +{ + int x; + int y; + int width; + int height; + int viewWidth; + int viewHeight; + Texture2D texture; + SamplerState sampler; +} + +struct AssembledVertex +{ + float3 position : POSITION; +}; + +struct Fragment +{ + float4 color; +}; + +struct VertexStageOutput +{ + float2 uv : UV; + float4 sv_position : SV_Position; +}; + +[shader("vertex")] +VertexStageOutput vertexMain( + AssembledVertex assembledVertex) +{ + VertexStageOutput output; + + float3 position = assembledVertex.position; + + output.uv = position.xy; + output.sv_position.x = (x + position.x * width) / (float)viewWidth * 2.0f - 1.0f; + output.sv_position.y = -((y + position.y * height) / (float)viewHeight * 2.0f - 1.0f); + output.sv_position.z = 0.5; + output.sv_position.w = 1.0; + return output; +} + +[shader("fragment")] +float4 fragmentMain( + float2 uv : UV) : SV_Target +{ + return float4(texture.Sample(sampler, uv).xyz, 1.0); +} diff --git a/examples/autodiff-texture/learnmip.slang b/examples/autodiff-texture/learnmip.slang new file mode 100644 index 000000000..1434f1a66 --- /dev/null +++ b/examples/autodiff-texture/learnmip.slang @@ -0,0 +1,22 @@ +// A compute shader to add gradients to a mip-map texture. + +cbuffer Uniforms +{ + uint dstWidth; + uint dstHeight; + float learningRate; + RWTexture2D dstTexture; + RWTexture2D srcTexture; +} + +[shader("compute")] +[numthreads(16, 16, 1)] +void computeMain(uint3 threadIdx : SV_DispatchThreadID) +{ + uint x = threadIdx.x; + uint y = threadIdx.y; + if (x >= dstWidth) return; + if (y >= dstHeight) return; + var val = srcTexture[uint2(x, y)]; + dstTexture[uint2(x, y)] = float4((dstTexture[uint2(x, y)] - val * learningRate).xyz, 1.0); +} diff --git a/examples/autodiff-texture/main.cpp b/examples/autodiff-texture/main.cpp new file mode 100644 index 000000000..7d1f809ee --- /dev/null +++ b/examples/autodiff-texture/main.cpp @@ -0,0 +1,662 @@ +#include "examples/example-base/example-base.h" +#include "gfx-util/shader-cursor.h" +#include "slang-com-ptr.h" +#include "slang-gfx.h" +#include "source/core/slang-basic.h" +#include "tools/platform/vector-math.h" +#include "tools/platform/window.h" +#include <slang.h> + +using namespace gfx; +using namespace Slang; + +struct Vertex +{ + float position[3]; +}; + +static const int kVertexCount = 4; +static const Vertex kVertexData[kVertexCount] = { + {{0, 0, 0}}, + {{0, 1, 0}}, + {{1, 0, 0}}, + {{1, 1, 0}}, +}; + +struct AutoDiffTexture : public WindowedAppBase +{ + + List<uint32_t> mipMapOffset; + int textureWidth; + int textureHeight; + + void diagnoseIfNeeded(slang::IBlob* diagnosticsBlob) + { + if (diagnosticsBlob != nullptr) + { + printf("%s", (const char*)diagnosticsBlob->getBufferPointer()); + } + } + + gfx::Result loadRenderProgram( + gfx::IDevice* device, const char* fileName, const char* fragmentShader, gfx::IShaderProgram** outProgram) + { + ComPtr<slang::ISession> slangSession; + slangSession = device->getSlangSession(); + + ComPtr<slang::IBlob> diagnosticsBlob; + slang::IModule* module = slangSession->loadModule(fileName, diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + if (!module) + return SLANG_FAIL; + + ComPtr<slang::IEntryPoint> vertexEntryPoint; + SLANG_RETURN_ON_FAIL( + module->findEntryPointByName("vertexMain", vertexEntryPoint.writeRef())); + ComPtr<slang::IEntryPoint> fragmentEntryPoint; + SLANG_RETURN_ON_FAIL( + module->findEntryPointByName(fragmentShader, fragmentEntryPoint.writeRef())); + + Slang::List<slang::IComponentType*> componentTypes; + componentTypes.add(module); + int entryPointCount = 0; + int vertexEntryPointIndex = entryPointCount++; + componentTypes.add(vertexEntryPoint); + + int fragmentEntryPointIndex = entryPointCount++; + componentTypes.add(fragmentEntryPoint); + + ComPtr<slang::IComponentType> linkedProgram; + SlangResult result = slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + linkedProgram.writeRef(), + diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + gfx::IShaderProgram::Desc programDesc = {}; + programDesc.slangGlobalScope = linkedProgram; + SLANG_RETURN_ON_FAIL(device->createProgram(programDesc, outProgram)); + + return SLANG_OK; + } + + gfx::Result loadComputeProgram( + gfx::IDevice* device, const char* fileName, gfx::IShaderProgram** outProgram) + { + ComPtr<slang::ISession> slangSession; + slangSession = device->getSlangSession(); + + ComPtr<slang::IBlob> diagnosticsBlob; + slang::IModule* module = slangSession->loadModule(fileName, diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + if (!module) + return SLANG_FAIL; + + Slang::List<slang::IComponentType*> componentTypes; + componentTypes.add(module); + ComPtr<slang::IEntryPoint> computeEntryPoint; + SLANG_RETURN_ON_FAIL( + module->findEntryPointByName("computeMain", computeEntryPoint.writeRef())); + componentTypes.add(computeEntryPoint); + + ComPtr<slang::IComponentType> linkedProgram; + SlangResult result = slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + linkedProgram.writeRef(), + diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + gfx::IShaderProgram::Desc programDesc = {}; + programDesc.slangGlobalScope = linkedProgram; + SLANG_RETURN_ON_FAIL(device->createProgram(programDesc, outProgram)); + + return SLANG_OK; + } + + ComPtr<gfx::IPipelineState> gRefPipelineState; + ComPtr<gfx::IPipelineState> gIterPipelineState; + ComPtr<gfx::IPipelineState> gReconstructPipelineState; + ComPtr<gfx::IPipelineState> gConvertPipelineState; + ComPtr<gfx::IPipelineState> gBuildMipPipelineState; + ComPtr<gfx::IPipelineState> gLearnMipPipelineState; + ComPtr<gfx::IPipelineState> gDrawQuadPipelineState; + + ComPtr<gfx::ITextureResource> gLearningTexture; + ComPtr<gfx::IResourceView> gLearningTextureSRV; + List<ComPtr<gfx::IResourceView>> gLearningTextureUAVs; + + ComPtr<gfx::ITextureResource> gDiffTexture; + ComPtr<gfx::IResourceView> gDiffTextureSRV; + List<ComPtr<gfx::IResourceView>> gDiffTextureUAVs; + + ComPtr<gfx::IBufferResource> gVertexBuffer; + ComPtr<gfx::IResourceView> gTexView; + ComPtr<gfx::ISamplerState> gSampler; + ComPtr<gfx::IFramebuffer> gRefFrameBuffer; + ComPtr<gfx::IFramebuffer> gIterFrameBuffer; + + ComPtr<gfx::ITextureResource> gDepthTexture; + ComPtr<gfx::IResourceView> gDepthTextureView; + + ComPtr<gfx::IResourceView> gIterImageDSV; + + ComPtr<gfx::ITextureResource> gIterImage; + ComPtr<gfx::IResourceView> gIterImageSRV; + ComPtr<gfx::IResourceView> gIterImageRTV; + + ComPtr<gfx::ITextureResource> gRefImage; + ComPtr<gfx::IResourceView> gRefImageSRV; + ComPtr<gfx::IResourceView> gRefImageRTV; + + ComPtr<gfx::IBufferResource> gAccumulateBuffer; + ComPtr<gfx::IBufferResource> gReconstructBuffer; + ComPtr<gfx::IResourceView> gAccumulateBufferView; + ComPtr<gfx::IResourceView> gReconstructBufferView; + + ClearValue kClearValue; + bool resetLearntTexture = false; + + ComPtr<gfx::ITextureResource> createRenderTargetTexture(gfx::Format format, int w, int h, int levels) + { + gfx::ITextureResource::Desc textureDesc = {}; + textureDesc.allowedStates.add(ResourceState::ShaderResource); + textureDesc.allowedStates.add(ResourceState::UnorderedAccess); + textureDesc.allowedStates.add(ResourceState::RenderTarget); + textureDesc.defaultState = ResourceState::RenderTarget; + textureDesc.format = format; + textureDesc.numMipLevels = levels; + textureDesc.type = gfx::IResource::Type::Texture2D; + textureDesc.size.width = w; + textureDesc.size.height = h; + textureDesc.size.depth = 1; + textureDesc.optimalClearValue = &kClearValue; + return gDevice->createTextureResource(textureDesc, nullptr); + } + ComPtr<gfx::ITextureResource> createDepthTexture() + { + gfx::ITextureResource::Desc textureDesc = {}; + textureDesc.allowedStates.add(ResourceState::DepthWrite); + textureDesc.defaultState = ResourceState::DepthWrite; + textureDesc.format = gfx::Format::D32_FLOAT; + textureDesc.numMipLevels = 1; + textureDesc.type = gfx::IResource::Type::Texture2D; + textureDesc.size.width = windowWidth; + textureDesc.size.height = windowHeight; + textureDesc.size.depth = 1; + ClearValue clearValue = {}; + textureDesc.optimalClearValue = &clearValue; + return gDevice->createTextureResource(textureDesc, nullptr); + } + ComPtr<gfx::IFramebuffer> createRenderTargetFramebuffer(IResourceView* tex) + { + IFramebuffer::Desc desc = {}; + desc.layout = gFramebufferLayout.get(); + desc.renderTargetCount = 1; + desc.renderTargetViews = &tex; + desc.depthStencilView = gDepthTextureView; + return gDevice->createFramebuffer(desc); + } + ComPtr<gfx::IResourceView> createRTV(ITextureResource* tex, Format f) + { + IResourceView::Desc rtvDesc = {}; + rtvDesc.type = IResourceView::Type::RenderTarget; + rtvDesc.subresourceRange.mipLevelCount = 1; + rtvDesc.format = f; + rtvDesc.renderTarget.shape = gfx::IResource::Type::Texture2D; + return gDevice->createTextureView(tex, rtvDesc); + } + ComPtr<gfx::IResourceView> createDSV(ITextureResource* tex) + { + IResourceView::Desc dsvDesc = {}; + dsvDesc.type = IResourceView::Type::DepthStencil; + dsvDesc.subresourceRange.mipLevelCount = 1; + dsvDesc.format = Format::D32_FLOAT; + dsvDesc.renderTarget.shape = gfx::IResource::Type::Texture2D; + return gDevice->createTextureView(tex, dsvDesc); + } + ComPtr<gfx::IResourceView> createSRV(ITextureResource* tex) + { + IResourceView::Desc rtvDesc = {}; + rtvDesc.type = IResourceView::Type::ShaderResource; + return gDevice->createTextureView(tex, rtvDesc); + } + ComPtr<gfx::IPipelineState> createRenderPipelineState( + IInputLayout* inputLayout, IShaderProgram* program) + { + GraphicsPipelineStateDesc desc; + desc.inputLayout = inputLayout; + desc.program = program; + desc.rasterizer.cullMode = gfx::CullMode::None; + desc.framebufferLayout = gFramebufferLayout; + auto pipelineState = gDevice->createGraphicsPipelineState(desc); + return pipelineState; + } + ComPtr<gfx::IPipelineState> createComputePipelineState(IShaderProgram* program) + { + ComputePipelineStateDesc desc = {}; + desc.program = program; + auto pipelineState = gDevice->createComputePipelineState(desc); + return pipelineState; + } + ComPtr<gfx::IResourceView> createUAV(IBufferResource* buffer) + { + IResourceView::Desc desc = {}; + desc.type = IResourceView::Type::UnorderedAccess; + desc.bufferElementSize = 0; + return gDevice->createBufferView(buffer, nullptr, desc); + } + ComPtr<gfx::IResourceView> createUAV(ITextureResource* texture, int level) + { + IResourceView::Desc desc = {}; + desc.type = IResourceView::Type::UnorderedAccess; + desc.subresourceRange.layerCount = 1; + desc.subresourceRange.mipLevel = level; + desc.subresourceRange.baseArrayLayer = 0; + return gDevice->createTextureView(texture,desc); + } + Slang::Result initialize() + { + initializeBase("autodiff-texture", 1024, 768); + srand(20421); + + gWindow->events.keyPress = [this](platform::KeyEventArgs& e) + { + if (e.keyChar == 'R' || e.keyChar == 'r') + resetLearntTexture = true; + }; + + kClearValue.color.floatValues[0] = 0.3f; + kClearValue.color.floatValues[1] = 0.5f; + kClearValue.color.floatValues[2] = 0.7f; + kClearValue.color.floatValues[3] = 1.0f; + + auto clientRect = gWindow->getClientRect(); + windowWidth = clientRect.width; + windowHeight = clientRect.height; + + InputElementDesc inputElements[] = { + {"POSITION", 0, Format::R32G32B32_FLOAT, offsetof(Vertex, position)}}; + auto inputLayout = gDevice->createInputLayout(sizeof(Vertex), &inputElements[0], 1); + if (!inputLayout) + return SLANG_FAIL; + + IBufferResource::Desc vertexBufferDesc; + vertexBufferDesc.type = IResource::Type::Buffer; + vertexBufferDesc.sizeInBytes = kVertexCount * sizeof(Vertex); + vertexBufferDesc.defaultState = ResourceState::VertexBuffer; + gVertexBuffer = gDevice->createBufferResource(vertexBufferDesc, &kVertexData[0]); + if (!gVertexBuffer) + return SLANG_FAIL; + + { + ComPtr<IShaderProgram> shaderProgram; + SLANG_RETURN_ON_FAIL( + loadRenderProgram(gDevice, "train", "fragmentMain", shaderProgram.writeRef())); + gRefPipelineState = createRenderPipelineState(inputLayout, shaderProgram); + } + { + ComPtr<IShaderProgram> shaderProgram; + SLANG_RETURN_ON_FAIL( + loadRenderProgram(gDevice, "train", "diffFragmentMain", shaderProgram.writeRef())); + gIterPipelineState = createRenderPipelineState(inputLayout, shaderProgram); + } + { + ComPtr<IShaderProgram> shaderProgram; + SLANG_RETURN_ON_FAIL( + loadRenderProgram(gDevice, "draw-quad", "fragmentMain", shaderProgram.writeRef())); + gDrawQuadPipelineState = createRenderPipelineState(inputLayout, shaderProgram); + } + { + ComPtr<IShaderProgram> shaderProgram; + SLANG_RETURN_ON_FAIL( + loadComputeProgram(gDevice, "reconstruct", shaderProgram.writeRef())); + gReconstructPipelineState = createComputePipelineState(shaderProgram); + } + { + ComPtr<IShaderProgram> shaderProgram; + SLANG_RETURN_ON_FAIL(loadComputeProgram(gDevice, "convert", shaderProgram.writeRef())); + gConvertPipelineState = createComputePipelineState(shaderProgram); + } + { + ComPtr<IShaderProgram> shaderProgram; + SLANG_RETURN_ON_FAIL(loadComputeProgram(gDevice, "buildmip", shaderProgram.writeRef())); + gBuildMipPipelineState = createComputePipelineState(shaderProgram); + } + { + ComPtr<IShaderProgram> shaderProgram; + SLANG_RETURN_ON_FAIL(loadComputeProgram(gDevice, "learnmip", shaderProgram.writeRef())); + gLearnMipPipelineState = createComputePipelineState(shaderProgram); + } + + gTexView = createTextureFromFile("checkerboard.jpg", textureWidth, textureHeight); + initMipOffsets(textureWidth, textureHeight); + + gfx::IBufferResource::Desc bufferDesc = {}; + bufferDesc.allowedStates.add(ResourceState::ShaderResource); + bufferDesc.allowedStates.add(ResourceState::UnorderedAccess); + bufferDesc.allowedStates.add(ResourceState::General); + bufferDesc.sizeInBytes = mipMapOffset.getLast() * sizeof(uint32_t); + bufferDesc.type = IResource::Type::Buffer; + gAccumulateBuffer = gDevice->createBufferResource(bufferDesc); + gReconstructBuffer = gDevice->createBufferResource(bufferDesc); + + gAccumulateBufferView = createUAV(gAccumulateBuffer); + gReconstructBufferView = createUAV(gReconstructBuffer); + + int mipCount = 1 + Math::Log2Ceil(Math::Max(textureWidth, textureHeight)); + gLearningTexture = createRenderTargetTexture( + Format::R32G32B32A32_FLOAT, + textureWidth, + textureHeight, + mipCount); + gLearningTextureSRV = createSRV(gLearningTexture); + for (int i = 0; i < mipCount; i++) + gLearningTextureUAVs.add(createUAV(gLearningTexture, i)); + + gDiffTexture = createRenderTargetTexture( + Format::R32G32B32A32_FLOAT, + textureWidth, + textureHeight, + mipCount); + gDiffTextureSRV = createSRV(gDiffTexture); + for (int i = 0; i < mipCount; i++) + gDiffTextureUAVs.add(createUAV(gDiffTexture, i)); + + gfx::ISamplerState::Desc samplerDesc = {}; + //samplerDesc.maxLOD = 0.0f; + gSampler = gDevice->createSamplerState(samplerDesc); + + gDepthTexture = createDepthTexture(); + gDepthTextureView = createDSV(gDepthTexture); + + gRefImage = createRenderTargetTexture(Format::R8G8B8A8_UNORM, windowWidth, windowHeight, 1); + gRefImageRTV = createRTV(gRefImage, Format::R8G8B8A8_UNORM); + gRefImageSRV = createSRV(gRefImage); + + gIterImage = createRenderTargetTexture(Format::R8G8B8A8_UNORM, windowWidth, windowHeight, 1); + gIterImageRTV = createRTV(gIterImage, Format::R8G8B8A8_UNORM); + gIterImageSRV = createSRV(gIterImage); + + gRefFrameBuffer = createRenderTargetFramebuffer(gRefImageRTV); + gIterFrameBuffer = createRenderTargetFramebuffer(gIterImageRTV); + + { + ComPtr<ICommandBuffer> commandBuffer = gTransientHeaps[0]->createCommandBuffer(); + auto encoder = commandBuffer->encodeResourceCommands(); + encoder->textureBarrier(gLearningTexture, ResourceState::RenderTarget, ResourceState::UnorderedAccess); + encoder->textureBarrier(gDiffTexture, ResourceState::RenderTarget, ResourceState::UnorderedAccess); + encoder->textureBarrier(gRefImage, ResourceState::RenderTarget, ResourceState::ShaderResource); + encoder->textureBarrier(gIterImage, ResourceState::RenderTarget, ResourceState::ShaderResource); + for (int i = 0; i < gLearningTextureUAVs.getCount(); i++) + { + ClearValue clearValue = {}; + encoder->clearResourceView(gLearningTextureUAVs[i], &clearValue, ClearResourceViewFlags::None); + encoder->clearResourceView(gDiffTextureUAVs[i], &clearValue, ClearResourceViewFlags::None); + } + encoder->textureBarrier(gLearningTexture, ResourceState::UnorderedAccess, ResourceState::ShaderResource); + + encoder->endEncoding(); + commandBuffer->close(); + gQueue->executeCommandBuffer(commandBuffer); + } + + return SLANG_OK; + } + + void initMipOffsets(int w, int h) + { + int layers = 1 + Math::Log2Ceil(Math::Max(w, h)); + uint32_t offset = 0; + for (int i = 0; i < layers; i++) + { + auto lw = Math::Max(1, w >> i); + auto lh = Math::Max(1, h >> i); + mipMapOffset.add(offset); + offset += lw * lh * 4; + } + mipMapOffset.add(offset); + } + + glm::mat4x4 getTransformMatrix() + { + float rotX = (rand() / (float)RAND_MAX) * 0.3f; + float rotY = (rand() / (float)RAND_MAX) * 0.2f; + glm::mat4x4 matProj = glm::perspectiveRH_ZO( + glm::radians(60.0f), (float)windowWidth / (float)windowHeight, 0.1f, 1000.0f); + auto identity = glm::mat4(1.0f); + auto translate = glm::translate( + identity, + glm::vec3( + -0.6f + 0.2f * (rand() / (float)RAND_MAX), + -0.6f + 0.2f * (rand() / (float)RAND_MAX), + -1.0f)); + auto rot = glm::rotate(translate, -glm::pi<float>() * rotX, glm::vec3(1.0f, 0.0f, 0.0f)); + rot = glm::rotate(rot, -glm::pi<float>() * rotY, glm::vec3(0.0f, 1.0f, 0.0f)); + auto transformMatrix = matProj * rot; + transformMatrix = glm::transpose(transformMatrix); + return transformMatrix; + } + + template <typename SetupPipelineFunc> + void renderImage( + int transientHeapIndex, IFramebuffer* fb, const SetupPipelineFunc& setupPipeline) + { + ComPtr<ICommandBuffer> commandBuffer = + gTransientHeaps[transientHeapIndex]->createCommandBuffer(); + auto renderEncoder = commandBuffer->encodeRenderCommands(gRenderPass, fb); + + gfx::Viewport viewport = {}; + viewport.maxZ = 1.0f; + viewport.extentX = (float)windowWidth; + viewport.extentY = (float)windowHeight; + renderEncoder->setViewportAndScissor(viewport); + + setupPipeline(renderEncoder); + + renderEncoder->setVertexBuffer(0, gVertexBuffer); + renderEncoder->setPrimitiveTopology(PrimitiveTopology::TriangleStrip); + + renderEncoder->draw(4); + renderEncoder->endEncoding(); + commandBuffer->close(); + gQueue->executeCommandBuffer(commandBuffer); + } + + void renderReferenceImage(int transientHeapIndex, glm::mat4x4 transformMatrix) + { + { + ComPtr<ICommandBuffer> commandBuffer = gTransientHeaps[transientHeapIndex]->createCommandBuffer(); + auto encoder = commandBuffer->encodeResourceCommands(); + encoder->textureBarrier(gRefImage, ResourceState::ShaderResource, ResourceState::RenderTarget); + encoder->endEncoding(); + commandBuffer->close(); + gQueue->executeCommandBuffer(commandBuffer); + } + + renderImage( + transientHeapIndex, + gRefFrameBuffer, + [&](IRenderCommandEncoder* encoder) + { + auto rootObject = encoder->bindPipeline(gRefPipelineState); + ShaderCursor rootCursor(rootObject); + rootCursor["Uniforms"]["modelViewProjection"].setData( + &transformMatrix, sizeof(float) * 16); + rootCursor["Uniforms"]["bwdTexture"]["texture"].setResource(gTexView); + rootCursor["Uniforms"]["sampler"].setSampler(gSampler); + rootCursor["Uniforms"]["mipOffset"].setData(mipMapOffset.getBuffer(), sizeof(uint32_t) * mipMapOffset.getCount()); + rootCursor["Uniforms"]["texRef"].setResource(gTexView); + rootCursor["Uniforms"]["bwdTexture"]["accumulateBuffer"].setResource(gAccumulateBufferView); + }); + } + + virtual void renderFrame(int frameBufferIndex) override + { + static uint32_t frameCount = 0; + frameCount++; + auto transformMatrix = getTransformMatrix(); + renderReferenceImage(frameBufferIndex, transformMatrix); + + // Barriers. + { + ComPtr<ICommandBuffer> commandBuffer = + gTransientHeaps[frameBufferIndex]->createCommandBuffer(); + auto resEncoder = commandBuffer->encodeResourceCommands(); + ClearValue clearValue = {}; + resEncoder->bufferBarrier(gAccumulateBuffer, ResourceState::Undefined, ResourceState::UnorderedAccess); + resEncoder->bufferBarrier(gReconstructBuffer, ResourceState::Undefined, ResourceState::UnorderedAccess); + resEncoder->textureBarrier(gRefImage, ResourceState::Present, ResourceState::ShaderResource); + resEncoder->textureBarrier(gIterImage, ResourceState::ShaderResource, ResourceState::RenderTarget); + resEncoder->clearResourceView(gAccumulateBufferView, &clearValue, ClearResourceViewFlags::None); + resEncoder->clearResourceView(gReconstructBufferView, &clearValue, ClearResourceViewFlags::None); + if (resetLearntTexture) + { + resEncoder->textureBarrier(gLearningTexture, ResourceState::ShaderResource, ResourceState::UnorderedAccess); + for (Index i =0; i <gLearningTextureUAVs.getCount(); i++) + resEncoder->clearResourceView(gLearningTextureUAVs[i], &clearValue, ClearResourceViewFlags::None); + resEncoder->textureBarrier(gLearningTexture, ResourceState::UnorderedAccess, ResourceState::ShaderResource); + resetLearntTexture = false; + } + resEncoder->endEncoding(); + commandBuffer->close(); + gQueue->executeCommandBuffer(commandBuffer); + } + + // Render image using backward propagate shader to obtain texture-space gradients. + renderImage( + frameBufferIndex, + gIterFrameBuffer, + [&](IRenderCommandEncoder* encoder) + { + auto rootObject = encoder->bindPipeline(gIterPipelineState); + ShaderCursor rootCursor(rootObject); + + rootCursor["Uniforms"]["modelViewProjection"].setData( + &transformMatrix, sizeof(float) * 16); + rootCursor["Uniforms"]["bwdTexture"]["texture"].setResource(gLearningTextureSRV); + rootCursor["Uniforms"]["sampler"].setSampler(gSampler); + rootCursor["Uniforms"]["mipOffset"].setData(mipMapOffset.getBuffer(), sizeof(uint32_t) * mipMapOffset.getCount()); + rootCursor["Uniforms"]["texRef"].setResource(gRefImageSRV); + rootCursor["Uniforms"]["bwdTexture"]["accumulateBuffer"].setResource(gAccumulateBufferView); + rootCursor["Uniforms"]["bwdTexture"]["minLOD"].setData(5.0); + + }); + + // Propagete gradients through mip map layers from top (lowest res) to bottom (highest res). + { + ComPtr<ICommandBuffer> commandBuffer = + gTransientHeaps[frameBufferIndex]->createCommandBuffer(); + auto encoder = commandBuffer->encodeComputeCommands(); + encoder->textureBarrier(gLearningTexture, ResourceState::ShaderResource, ResourceState::UnorderedAccess); + auto rootObject = encoder->bindPipeline(gReconstructPipelineState); + for (int i = (int)mipMapOffset.getCount() - 2; i >= 0; i--) + { + ShaderCursor rootCursor(rootObject); + rootCursor["Uniforms"]["mipOffset"].setData(mipMapOffset.getBuffer(), sizeof(uint32_t) * mipMapOffset.getCount()); + rootCursor["Uniforms"]["dstLayer"].setData(i); + rootCursor["Uniforms"]["layerCount"].setData(mipMapOffset.getCount() - 1); + rootCursor["Uniforms"]["width"].setData(textureWidth); + rootCursor["Uniforms"]["height"].setData(textureHeight); + rootCursor["Uniforms"]["accumulateBuffer"].setResource(gAccumulateBufferView); + rootCursor["Uniforms"]["dstBuffer"].setResource(gReconstructBufferView); + encoder->dispatchCompute( + ((textureWidth >> i) + 15) / 16, ((textureHeight >> i) + 15) / 16, 1); + encoder->bufferBarrier(gReconstructBuffer, ResourceState::UnorderedAccess, ResourceState::UnorderedAccess); + } + + // Convert bottom layer mip from buffer to texture. + rootObject = encoder->bindPipeline(gConvertPipelineState); + ShaderCursor rootCursor(rootObject); + rootCursor["Uniforms"]["mipOffset"].setData(mipMapOffset.getBuffer(), sizeof(uint32_t) * mipMapOffset.getCount()); + rootCursor["Uniforms"]["dstLayer"].setData(0); + rootCursor["Uniforms"]["width"].setData(textureWidth); + rootCursor["Uniforms"]["height"].setData(textureHeight); + rootCursor["Uniforms"]["srcBuffer"].setResource(gReconstructBufferView); + rootCursor["Uniforms"]["dstTexture"].setResource(gDiffTextureUAVs[0]); + encoder->dispatchCompute( + (textureWidth + 15) / 16, (textureHeight + 15) / 16, 1); + encoder->textureBarrier(gDiffTexture, ResourceState::UnorderedAccess, ResourceState::UnorderedAccess); + + // Build higher level mip map layers. + rootObject = encoder->bindPipeline(gBuildMipPipelineState); + for (int i = 1; i < (int)mipMapOffset.getCount() - 1; i++) + { + ShaderCursor rootCursor(rootObject); + rootCursor["Uniforms"]["dstWidth"].setData(textureWidth >> i); + rootCursor["Uniforms"]["dstHeight"].setData(textureHeight >> i); + rootCursor["Uniforms"]["srcTexture"].setResource(gDiffTextureUAVs[i-1]); + rootCursor["Uniforms"]["dstTexture"].setResource(gDiffTextureUAVs[i]); + encoder->dispatchCompute( + ((textureWidth >> i) + 15) / 16, ((textureHeight >> i) + 15) / 16, 1); + encoder->textureBarrier(gDiffTexture, ResourceState::UnorderedAccess, ResourceState::UnorderedAccess); + } + + // Accumulate gradients to learnt texture. + rootObject = encoder->bindPipeline(gLearnMipPipelineState); + for (int i = 0; i < (int)mipMapOffset.getCount() - 1; i++) + { + ShaderCursor rootCursor(rootObject); + rootCursor["Uniforms"]["dstWidth"].setData(textureWidth >> i); + rootCursor["Uniforms"]["dstHeight"].setData(textureHeight >> i); + rootCursor["Uniforms"]["learningRate"].setData(0.1f); + rootCursor["Uniforms"]["srcTexture"].setResource(gDiffTextureUAVs[i]); + rootCursor["Uniforms"]["dstTexture"].setResource(gLearningTextureUAVs[i]); + encoder->dispatchCompute( + ((textureWidth >> i) + 15) / 16, ((textureHeight >> i) + 15) / 16, 1); + } + encoder->textureBarrier(gLearningTexture, ResourceState::UnorderedAccess, ResourceState::ShaderResource); + encoder->textureBarrier(gIterImage, ResourceState::Present, ResourceState::ShaderResource); + + encoder->endEncoding(); + commandBuffer->close(); + gQueue->executeCommandBuffer(commandBuffer); + } + + // Draw currently learnt texture. + { + ComPtr<ICommandBuffer> commandBuffer = + gTransientHeaps[frameBufferIndex]->createCommandBuffer(); + auto renderEncoder = commandBuffer->encodeRenderCommands(gRenderPass, gFramebuffers[frameBufferIndex]); + drawTexturedQuad(renderEncoder, 0, 0, textureWidth, textureHeight, gLearningTextureSRV); + int refImageWidth = windowWidth - textureWidth - 10; + int refImageHeight = refImageWidth * windowHeight / windowWidth; + drawTexturedQuad(renderEncoder, textureWidth + 10, 0, refImageWidth, refImageHeight, gRefImageSRV); + drawTexturedQuad(renderEncoder, textureWidth + 10, refImageHeight + 10, refImageWidth, refImageHeight, gIterImageSRV); + renderEncoder->endEncoding(); + commandBuffer->close(); + gQueue->executeCommandBuffer(commandBuffer); + } + + gSwapchain->present(); + } + + void drawTexturedQuad(IRenderCommandEncoder* renderEncoder, int x, int y, int w, int h, IResourceView* srv) + { + gfx::Viewport viewport = {}; + viewport.maxZ = 1.0f; + viewport.extentX = (float)windowWidth; + viewport.extentY = (float)windowHeight; + renderEncoder->setViewportAndScissor(viewport); + + auto root = renderEncoder->bindPipeline(gDrawQuadPipelineState); + ShaderCursor rootCursor(root); + rootCursor["Uniforms"]["x"].setData(x); + rootCursor["Uniforms"]["y"].setData(y); + rootCursor["Uniforms"]["width"].setData(w); + rootCursor["Uniforms"]["height"].setData(h); + rootCursor["Uniforms"]["viewWidth"].setData(windowWidth); + rootCursor["Uniforms"]["viewHeight"].setData(windowHeight); + rootCursor["Uniforms"]["texture"].setResource(srv); + rootCursor["Uniforms"]["sampler"].setSampler(gSampler); + renderEncoder->setVertexBuffer(0, gVertexBuffer); + renderEncoder->setPrimitiveTopology(PrimitiveTopology::TriangleStrip); + renderEncoder->draw(4); + } + +}; + +PLATFORM_UI_MAIN(innerMain<AutoDiffTexture>) diff --git a/examples/autodiff-texture/reconstruct.slang b/examples/autodiff-texture/reconstruct.slang new file mode 100644 index 000000000..c123010e5 --- /dev/null +++ b/examples/autodiff-texture/reconstruct.slang @@ -0,0 +1,48 @@ +// A compute shader to propagate gradients from high level mip(low-res) to lower level mip (high-res). + +cbuffer Uniforms +{ + uint4 mipOffset[16]; + uint dstLayer; + uint layerCount; + uint width; + uint height; + RWStructuredBuffer<int> accumulateBuffer; + RWStructuredBuffer<float> dstBuffer; +} + +[shader("compute")] +[numthreads(16, 16, 1)] +void computeMain(uint3 threadIdx : SV_DispatchThreadID) +{ + uint x = threadIdx.x; + uint y = threadIdx.y; + uint dstW = width >> dstLayer; + uint dstH = height >> dstLayer; + if (x >= dstW) return; + if (y >= dstH) return; + uint dstOffset = mipOffset[dstLayer / 4][dstLayer % 4] + (y * dstW + x) * 4; + var dstVal = int4(accumulateBuffer[dstOffset], accumulateBuffer[dstOffset + 1], accumulateBuffer[dstOffset + 2], accumulateBuffer[dstOffset + 3]); + var newDstValToAdd = float3(0.0); + if (dstVal.w > 0) + newDstValToAdd = (float3)dstVal.xyz * float3(1.0 / (dstVal.w * 65536.0)); + + float4 existingVal = 0.0; + + if (dstLayer < layerCount - 1 ) + { + uint parentOffset = mipOffset[(dstLayer + 1) / 4][(dstLayer + 1) % 4]; + uint parentW = dstW / 2; + uint parentPixelLoc = parentOffset + ((y / 2) * parentW + (x / 2)) * 4; + existingVal.x = dstBuffer[parentPixelLoc] * 0.25; + existingVal.y = dstBuffer[parentPixelLoc + 1] * 0.25; + existingVal.z = dstBuffer[parentPixelLoc + 2] * 0.25; + existingVal.w = 0.0; + } + + var newDstVal = existingVal + float4(newDstValToAdd, 0.0); + dstBuffer[dstOffset] = newDstVal.x; + dstBuffer[dstOffset + 1] = newDstVal.y; + dstBuffer[dstOffset + 2] = newDstVal.z; + dstBuffer[dstOffset + 3] = 1.0; +} diff --git a/examples/autodiff-texture/train.slang b/examples/autodiff-texture/train.slang new file mode 100644 index 000000000..16f3be35c --- /dev/null +++ b/examples/autodiff-texture/train.slang @@ -0,0 +1,216 @@ +// shaders.slang + +struct BwdTexture +{ + RWStructuredBuffer<int> accumulateBuffer; + Texture2D texture; + float minLOD; + void writeDiffToTexel(uint offset, uint layerW, uint layerH, float x, float y, float3 val) + { + int4 uval = int4(int3(val * 65536), 1); + InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 0], uval.x); + InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 1], uval.y); + InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 2], uval.z); + InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 3], uval.w); + } + + void broadcastDiffToLayer(uint lod, float3 diff, float2 uv, uint w, uint h) + { + var offset = mipOffset[lod / 4][lod % 4]; + w >>= lod; + h >>= lod; + uv = uv - floor(uv); + float2 loc = uv * float2(w, h) - float2(0.5); + float x0 = floor(loc.x); + float y0 = floor(loc.y); + float fracX = loc.x - x0; + float fracY = loc.y - y0; + float x1 = x0 + 1; + float y1 = y0 + 1; + if (x0 < 0) x0 += w; + if (y0 < 0) y0 += h; + if (x1 >= w) x1 -= w; + if (y1 >= h) y1 -= h; + float weight0 = 1.0f - fracY; + float weight1 = fracY; + float weight00 = weight0 * (1.0f - fracX); + float weight01 = weight0 * fracX; + float weight10 = weight1 * (1.0f - fracX); + float weight11 = weight1 * fracX; + + writeDiffToTexel(offset, w, h, x0, y0, weight00 * diff); + writeDiffToTexel(offset, w, h, x1, y0, weight01 * diff); + writeDiffToTexel(offset, w, h, x0, y1, weight10 * diff); + writeDiffToTexel(offset, w, h, x1, y1, weight11 * diff); + } + + void sampleTexture_trilinear_bwd(uint w, uint h, uint levels, float2 uv, float2 dX, float2 dY, float4 dOut) + { + dX = dX * float2(w, h); + dY = dY * float2(w, h); + + // Isotropic filter. + float lengthX = length(dX); + float lengthY = length(dY); + float LOD = log2(max(lengthX, lengthY)); + float maxLOD = levels - 1; + float clampedLOD = max(minLOD, (min(maxLOD, LOD))); + + float lodFrac = clampedLOD - floor(clampedLOD); + uint lod0 = (uint)floor(clampedLOD); + uint lod1 = min(levels - 1, lod0 + 1); + float weightLod0 = 1.0 - lodFrac; + float weightLod1 = lodFrac; + weightLod0 = 1.0; + broadcastDiffToLayer(lod0, dOut.xyz * weightLod0, uv, w, h); + broadcastDiffToLayer(lod1, dOut.xyz * weightLod1, uv, w, h); + } + + float4 sampleTexture_linear(uint lod, float2 uv, uint w, uint h) + { + w >>= lod; + h >>= lod; + uv = uv - floor(uv); + float2 loc = uv * float2(w, h) - float2(0.5); + float x0 = floor(loc.x); + float y0 = floor(loc.y); + float fracX = loc.x - x0; + float fracY = loc.y - y0; + float x1 = x0 + 1; + float y1 = y0 + 1; + if (x0 < 0) x0 += w; + if (y0 < 0) y0 += h; + if (x1 >= w) x1 -= w; + if (y1 >= h) y1 -= h; + float weight0 = 1.0f - fracY; + float weight1 = fracY; + float weight00 = weight0 * (1.0f - fracX); + float weight01 = weight0 * fracX; + float weight10 = weight1 * (1.0f - fracX); + float weight11 = weight1 * fracX; + return texture.Load(int3(int(x0), int(y0), int(lod)), int2(0)) * weight00 + + texture.Load(int3(int(x1), int(y0), int(lod)), int2(0)) * weight01 + + texture.Load(int3(int(x0), int(y1), int(lod)), int2(0)) * weight10 + + texture.Load(int3(int(x1), int(y1), int(lod)), int2(0)) * weight11; + } + + float4 sampleTexture_trilinear(uint w, uint h, uint levels, float2 uv, float2 dX, float2 dY) + { + dX = dX * float2(w, h); + dY = dY * float2(w, h); + + // Isotropic filter. + float lengthX = length(dX); + float lengthY = length(dY); + float LOD = log2(max(lengthX, lengthY)); + float maxLOD = levels - 1; + float clampedLOD = max(minLOD, (min(maxLOD, LOD))); + + float lodFrac = clampedLOD - floor(clampedLOD); + uint lod0 = (uint)floor(clampedLOD); + uint lod1 = min(levels - 1, lod0 + 1); + float weightLod0 = 1.0 - lodFrac; + float weightLod1 = lodFrac; + + let v0 = sampleTexture_linear(lod0, uv, w, h) * weightLod0; + let v1 = sampleTexture_linear(lod1, uv, w, h) * weightLod1; + return v0 + v1; + } + + static float4 sample(BwdTexture t, SamplerState s, no_diff float2 uv) + { + return t.texture.Sample(s, uv); + } + + [ForwardDerivativeOf(BwdTexture.sample)] + static float4 fwd_sample(BwdTexture t, SamplerState s, no_diff float2 uv) + { + return float4(0.0); + } + + [BackwardDerivativeOf(BwdTexture.sample)] + static void bwd_sample(BwdTexture t, SamplerState s, no_diff float2 uv, float4 dOut) + { + float2 dX = ddx_coarse(uv); + float2 dY = ddy_coarse(uv); + uint w; + uint h; + uint levels; + t.texture.GetDimensions(0, w, h, levels); + t.sampleTexture_trilinear_bwd(w, h, levels, uv, dX, dY, dOut); + } +} + +cbuffer Uniforms +{ + float4x4 modelViewProjection; + uint4 mipOffset[16]; + + Texture2D texRef; + SamplerState sampler; + BwdTexture bwdTexture; +} + +struct AssembledVertex +{ + float3 position : POSITION; +}; + +struct Fragment +{ + float4 color; +}; + +struct VertexStageOutput +{ + float2 uv : UV; + float4 sv_position : SV_Position; +}; + +[BackwardDifferentiable] +float4 shadeFragment(no_diff float2 uv) +{ + float3 color = BwdTexture.sample(bwdTexture, sampler, uv * 2).xyz; + return float4(color, 1.0); +} + +[BackwardDifferentiable] +float3 loss(no_diff float2 uv, no_diff float4 screenPos) +{ + float3 refColor = (no_diff texRef.Load(int3(int2(screenPos.xy), 0))).xyz; + float3 rs = shadeFragment(uv).xyz - refColor; + rs *= rs; + return rs; +} + +[shader("vertex")] +VertexStageOutput vertexMain( + AssembledVertex assembledVertex) +{ + VertexStageOutput output; + + float3 position = assembledVertex.position; + + output.uv = position.xy; + output.sv_position = mul(modelViewProjection, float4(position, 1.0)); + + return output; +} + +float3 sqr(float3 v) { return v * v; } + +[shader("fragment")] +float4 fragmentMain( + float2 uv : UV) : SV_Target +{ + return shadeFragment(uv); +} + +[shader("fragment")] +float4 diffFragmentMain( + float2 uv : UV, + float4 screenPos : SV_POSITION) : SV_Target +{ + __bwd_diff(loss)(uv, screenPos, float3(1.0)); + return float4(loss(uv, screenPos), 1.0); +} |
