summaryrefslogtreecommitdiffstats
path: root/examples/autodiff-texture
diff options
context:
space:
mode:
Diffstat (limited to 'examples/autodiff-texture')
-rw-r--r--examples/autodiff-texture/README.md4
-rw-r--r--examples/autodiff-texture/buildmip.slang25
-rw-r--r--examples/autodiff-texture/checkerboard.jpgbin0 -> 51262 bytes
-rw-r--r--examples/autodiff-texture/convert.slang25
-rw-r--r--examples/autodiff-texture/draw-quad.slang52
-rw-r--r--examples/autodiff-texture/learnmip.slang22
-rw-r--r--examples/autodiff-texture/main.cpp662
-rw-r--r--examples/autodiff-texture/reconstruct.slang48
-rw-r--r--examples/autodiff-texture/train.slang216
9 files changed, 1054 insertions, 0 deletions
diff --git a/examples/autodiff-texture/README.md b/examples/autodiff-texture/README.md
new file mode 100644
index 000000000..3c7328fa7
--- /dev/null
+++ b/examples/autodiff-texture/README.md
@@ -0,0 +1,4 @@
+Autodiff Texture Example
+===========================
+
+The goal of this example is to demonstrate how to use custom backward derivative functions to propagate derivatives backwards to a texture. \ No newline at end of file
diff --git a/examples/autodiff-texture/buildmip.slang b/examples/autodiff-texture/buildmip.slang
new file mode 100644
index 000000000..ebec7d754
--- /dev/null
+++ b/examples/autodiff-texture/buildmip.slang
@@ -0,0 +1,25 @@
+// A compute shader to build a mip-map layer using box filtering.
+
+cbuffer Uniforms
+{
+ uint dstWidth;
+ uint dstHeight;
+ RWTexture2D dstTexture;
+ RWTexture2D srcTexture;
+}
+
+[shader("compute")]
+[numthreads(16, 16, 1)]
+void computeMain(uint3 threadIdx : SV_DispatchThreadID)
+{
+ uint x = threadIdx.x;
+ uint y = threadIdx.y;
+ if (x >= dstWidth) return;
+ if (y >= dstHeight) return;
+ var val0 = srcTexture[uint2(x * 2, y * 2)];
+ var val1 = srcTexture[uint2(x * 2 + 1, y * 2)];
+ var val2 = srcTexture[uint2(x * 2, y * 2 + 1)];
+ var val3 = srcTexture[uint2(x * 2 + 1, y * 2 + 1)];
+
+ dstTexture[uint2(x, y)] = (val0 + val1 + val2 + val3) / 4;
+}
diff --git a/examples/autodiff-texture/checkerboard.jpg b/examples/autodiff-texture/checkerboard.jpg
new file mode 100644
index 000000000..5181d78cd
--- /dev/null
+++ b/examples/autodiff-texture/checkerboard.jpg
Binary files differ
diff --git a/examples/autodiff-texture/convert.slang b/examples/autodiff-texture/convert.slang
new file mode 100644
index 000000000..364cc286e
--- /dev/null
+++ b/examples/autodiff-texture/convert.slang
@@ -0,0 +1,25 @@
+// A compute shader to convert from a buffer into a mip-map texture.
+cbuffer Uniforms
+{
+ uint4 mipOffset[16];
+ uint dstLayer;
+ uint width;
+ uint height;
+ RWStructuredBuffer<float> srcBuffer;
+ RWTexture2D dstTexture;
+}
+
+[shader("compute")]
+[numthreads(16, 16, 1)]
+void computeMain(uint3 threadIdx : SV_DispatchThreadID)
+{
+ uint x = threadIdx.x;
+ uint y = threadIdx.y;
+ uint dstW = width >> dstLayer;
+ uint dstH = height >> dstLayer;
+ if (x >= dstW) return;
+ if (y >= dstH) return;
+ uint dstOffset = mipOffset[dstLayer / 4][dstLayer % 4] + (y * dstW + x) * 4;
+ var dstVal = float4(srcBuffer[dstOffset], srcBuffer[dstOffset + 1], srcBuffer[dstOffset + 2], srcBuffer[dstOffset + 3]);
+ dstTexture[uint2(x, y)] = dstVal;
+}
diff --git a/examples/autodiff-texture/draw-quad.slang b/examples/autodiff-texture/draw-quad.slang
new file mode 100644
index 000000000..55a33b46b
--- /dev/null
+++ b/examples/autodiff-texture/draw-quad.slang
@@ -0,0 +1,52 @@
+// Vertex and fragment shader to draw a textured quad on screen.
+
+cbuffer Uniforms
+{
+ int x;
+ int y;
+ int width;
+ int height;
+ int viewWidth;
+ int viewHeight;
+ Texture2D texture;
+ SamplerState sampler;
+}
+
+struct AssembledVertex
+{
+ float3 position : POSITION;
+};
+
+struct Fragment
+{
+ float4 color;
+};
+
+struct VertexStageOutput
+{
+ float2 uv : UV;
+ float4 sv_position : SV_Position;
+};
+
+[shader("vertex")]
+VertexStageOutput vertexMain(
+ AssembledVertex assembledVertex)
+{
+ VertexStageOutput output;
+
+ float3 position = assembledVertex.position;
+
+ output.uv = position.xy;
+ output.sv_position.x = (x + position.x * width) / (float)viewWidth * 2.0f - 1.0f;
+ output.sv_position.y = -((y + position.y * height) / (float)viewHeight * 2.0f - 1.0f);
+ output.sv_position.z = 0.5;
+ output.sv_position.w = 1.0;
+ return output;
+}
+
+[shader("fragment")]
+float4 fragmentMain(
+ float2 uv : UV) : SV_Target
+{
+ return float4(texture.Sample(sampler, uv).xyz, 1.0);
+}
diff --git a/examples/autodiff-texture/learnmip.slang b/examples/autodiff-texture/learnmip.slang
new file mode 100644
index 000000000..1434f1a66
--- /dev/null
+++ b/examples/autodiff-texture/learnmip.slang
@@ -0,0 +1,22 @@
+// A compute shader to add gradients to a mip-map texture.
+
+cbuffer Uniforms
+{
+ uint dstWidth;
+ uint dstHeight;
+ float learningRate;
+ RWTexture2D dstTexture;
+ RWTexture2D srcTexture;
+}
+
+[shader("compute")]
+[numthreads(16, 16, 1)]
+void computeMain(uint3 threadIdx : SV_DispatchThreadID)
+{
+ uint x = threadIdx.x;
+ uint y = threadIdx.y;
+ if (x >= dstWidth) return;
+ if (y >= dstHeight) return;
+ var val = srcTexture[uint2(x, y)];
+ dstTexture[uint2(x, y)] = float4((dstTexture[uint2(x, y)] - val * learningRate).xyz, 1.0);
+}
diff --git a/examples/autodiff-texture/main.cpp b/examples/autodiff-texture/main.cpp
new file mode 100644
index 000000000..7d1f809ee
--- /dev/null
+++ b/examples/autodiff-texture/main.cpp
@@ -0,0 +1,662 @@
+#include "examples/example-base/example-base.h"
+#include "gfx-util/shader-cursor.h"
+#include "slang-com-ptr.h"
+#include "slang-gfx.h"
+#include "source/core/slang-basic.h"
+#include "tools/platform/vector-math.h"
+#include "tools/platform/window.h"
+#include <slang.h>
+
+using namespace gfx;
+using namespace Slang;
+
+struct Vertex
+{
+ float position[3];
+};
+
+static const int kVertexCount = 4;
+static const Vertex kVertexData[kVertexCount] = {
+ {{0, 0, 0}},
+ {{0, 1, 0}},
+ {{1, 0, 0}},
+ {{1, 1, 0}},
+};
+
+struct AutoDiffTexture : public WindowedAppBase
+{
+
+ List<uint32_t> mipMapOffset;
+ int textureWidth;
+ int textureHeight;
+
+ void diagnoseIfNeeded(slang::IBlob* diagnosticsBlob)
+ {
+ if (diagnosticsBlob != nullptr)
+ {
+ printf("%s", (const char*)diagnosticsBlob->getBufferPointer());
+ }
+ }
+
+ gfx::Result loadRenderProgram(
+ gfx::IDevice* device, const char* fileName, const char* fragmentShader, gfx::IShaderProgram** outProgram)
+ {
+ ComPtr<slang::ISession> slangSession;
+ slangSession = device->getSlangSession();
+
+ ComPtr<slang::IBlob> diagnosticsBlob;
+ slang::IModule* module = slangSession->loadModule(fileName, diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ if (!module)
+ return SLANG_FAIL;
+
+ ComPtr<slang::IEntryPoint> vertexEntryPoint;
+ SLANG_RETURN_ON_FAIL(
+ module->findEntryPointByName("vertexMain", vertexEntryPoint.writeRef()));
+ ComPtr<slang::IEntryPoint> fragmentEntryPoint;
+ SLANG_RETURN_ON_FAIL(
+ module->findEntryPointByName(fragmentShader, fragmentEntryPoint.writeRef()));
+
+ Slang::List<slang::IComponentType*> componentTypes;
+ componentTypes.add(module);
+ int entryPointCount = 0;
+ int vertexEntryPointIndex = entryPointCount++;
+ componentTypes.add(vertexEntryPoint);
+
+ int fragmentEntryPointIndex = entryPointCount++;
+ componentTypes.add(fragmentEntryPoint);
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ SlangResult result = slangSession->createCompositeComponentType(
+ componentTypes.getBuffer(),
+ componentTypes.getCount(),
+ linkedProgram.writeRef(),
+ diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ SLANG_RETURN_ON_FAIL(result);
+
+ gfx::IShaderProgram::Desc programDesc = {};
+ programDesc.slangGlobalScope = linkedProgram;
+ SLANG_RETURN_ON_FAIL(device->createProgram(programDesc, outProgram));
+
+ return SLANG_OK;
+ }
+
+ gfx::Result loadComputeProgram(
+ gfx::IDevice* device, const char* fileName, gfx::IShaderProgram** outProgram)
+ {
+ ComPtr<slang::ISession> slangSession;
+ slangSession = device->getSlangSession();
+
+ ComPtr<slang::IBlob> diagnosticsBlob;
+ slang::IModule* module = slangSession->loadModule(fileName, diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ if (!module)
+ return SLANG_FAIL;
+
+ Slang::List<slang::IComponentType*> componentTypes;
+ componentTypes.add(module);
+ ComPtr<slang::IEntryPoint> computeEntryPoint;
+ SLANG_RETURN_ON_FAIL(
+ module->findEntryPointByName("computeMain", computeEntryPoint.writeRef()));
+ componentTypes.add(computeEntryPoint);
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ SlangResult result = slangSession->createCompositeComponentType(
+ componentTypes.getBuffer(),
+ componentTypes.getCount(),
+ linkedProgram.writeRef(),
+ diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ SLANG_RETURN_ON_FAIL(result);
+
+ gfx::IShaderProgram::Desc programDesc = {};
+ programDesc.slangGlobalScope = linkedProgram;
+ SLANG_RETURN_ON_FAIL(device->createProgram(programDesc, outProgram));
+
+ return SLANG_OK;
+ }
+
+ ComPtr<gfx::IPipelineState> gRefPipelineState;
+ ComPtr<gfx::IPipelineState> gIterPipelineState;
+ ComPtr<gfx::IPipelineState> gReconstructPipelineState;
+ ComPtr<gfx::IPipelineState> gConvertPipelineState;
+ ComPtr<gfx::IPipelineState> gBuildMipPipelineState;
+ ComPtr<gfx::IPipelineState> gLearnMipPipelineState;
+ ComPtr<gfx::IPipelineState> gDrawQuadPipelineState;
+
+ ComPtr<gfx::ITextureResource> gLearningTexture;
+ ComPtr<gfx::IResourceView> gLearningTextureSRV;
+ List<ComPtr<gfx::IResourceView>> gLearningTextureUAVs;
+
+ ComPtr<gfx::ITextureResource> gDiffTexture;
+ ComPtr<gfx::IResourceView> gDiffTextureSRV;
+ List<ComPtr<gfx::IResourceView>> gDiffTextureUAVs;
+
+ ComPtr<gfx::IBufferResource> gVertexBuffer;
+ ComPtr<gfx::IResourceView> gTexView;
+ ComPtr<gfx::ISamplerState> gSampler;
+ ComPtr<gfx::IFramebuffer> gRefFrameBuffer;
+ ComPtr<gfx::IFramebuffer> gIterFrameBuffer;
+
+ ComPtr<gfx::ITextureResource> gDepthTexture;
+ ComPtr<gfx::IResourceView> gDepthTextureView;
+
+ ComPtr<gfx::IResourceView> gIterImageDSV;
+
+ ComPtr<gfx::ITextureResource> gIterImage;
+ ComPtr<gfx::IResourceView> gIterImageSRV;
+ ComPtr<gfx::IResourceView> gIterImageRTV;
+
+ ComPtr<gfx::ITextureResource> gRefImage;
+ ComPtr<gfx::IResourceView> gRefImageSRV;
+ ComPtr<gfx::IResourceView> gRefImageRTV;
+
+ ComPtr<gfx::IBufferResource> gAccumulateBuffer;
+ ComPtr<gfx::IBufferResource> gReconstructBuffer;
+ ComPtr<gfx::IResourceView> gAccumulateBufferView;
+ ComPtr<gfx::IResourceView> gReconstructBufferView;
+
+ ClearValue kClearValue;
+ bool resetLearntTexture = false;
+
+ ComPtr<gfx::ITextureResource> createRenderTargetTexture(gfx::Format format, int w, int h, int levels)
+ {
+ gfx::ITextureResource::Desc textureDesc = {};
+ textureDesc.allowedStates.add(ResourceState::ShaderResource);
+ textureDesc.allowedStates.add(ResourceState::UnorderedAccess);
+ textureDesc.allowedStates.add(ResourceState::RenderTarget);
+ textureDesc.defaultState = ResourceState::RenderTarget;
+ textureDesc.format = format;
+ textureDesc.numMipLevels = levels;
+ textureDesc.type = gfx::IResource::Type::Texture2D;
+ textureDesc.size.width = w;
+ textureDesc.size.height = h;
+ textureDesc.size.depth = 1;
+ textureDesc.optimalClearValue = &kClearValue;
+ return gDevice->createTextureResource(textureDesc, nullptr);
+ }
+ ComPtr<gfx::ITextureResource> createDepthTexture()
+ {
+ gfx::ITextureResource::Desc textureDesc = {};
+ textureDesc.allowedStates.add(ResourceState::DepthWrite);
+ textureDesc.defaultState = ResourceState::DepthWrite;
+ textureDesc.format = gfx::Format::D32_FLOAT;
+ textureDesc.numMipLevels = 1;
+ textureDesc.type = gfx::IResource::Type::Texture2D;
+ textureDesc.size.width = windowWidth;
+ textureDesc.size.height = windowHeight;
+ textureDesc.size.depth = 1;
+ ClearValue clearValue = {};
+ textureDesc.optimalClearValue = &clearValue;
+ return gDevice->createTextureResource(textureDesc, nullptr);
+ }
+ ComPtr<gfx::IFramebuffer> createRenderTargetFramebuffer(IResourceView* tex)
+ {
+ IFramebuffer::Desc desc = {};
+ desc.layout = gFramebufferLayout.get();
+ desc.renderTargetCount = 1;
+ desc.renderTargetViews = &tex;
+ desc.depthStencilView = gDepthTextureView;
+ return gDevice->createFramebuffer(desc);
+ }
+ ComPtr<gfx::IResourceView> createRTV(ITextureResource* tex, Format f)
+ {
+ IResourceView::Desc rtvDesc = {};
+ rtvDesc.type = IResourceView::Type::RenderTarget;
+ rtvDesc.subresourceRange.mipLevelCount = 1;
+ rtvDesc.format = f;
+ rtvDesc.renderTarget.shape = gfx::IResource::Type::Texture2D;
+ return gDevice->createTextureView(tex, rtvDesc);
+ }
+ ComPtr<gfx::IResourceView> createDSV(ITextureResource* tex)
+ {
+ IResourceView::Desc dsvDesc = {};
+ dsvDesc.type = IResourceView::Type::DepthStencil;
+ dsvDesc.subresourceRange.mipLevelCount = 1;
+ dsvDesc.format = Format::D32_FLOAT;
+ dsvDesc.renderTarget.shape = gfx::IResource::Type::Texture2D;
+ return gDevice->createTextureView(tex, dsvDesc);
+ }
+ ComPtr<gfx::IResourceView> createSRV(ITextureResource* tex)
+ {
+ IResourceView::Desc rtvDesc = {};
+ rtvDesc.type = IResourceView::Type::ShaderResource;
+ return gDevice->createTextureView(tex, rtvDesc);
+ }
+ ComPtr<gfx::IPipelineState> createRenderPipelineState(
+ IInputLayout* inputLayout, IShaderProgram* program)
+ {
+ GraphicsPipelineStateDesc desc;
+ desc.inputLayout = inputLayout;
+ desc.program = program;
+ desc.rasterizer.cullMode = gfx::CullMode::None;
+ desc.framebufferLayout = gFramebufferLayout;
+ auto pipelineState = gDevice->createGraphicsPipelineState(desc);
+ return pipelineState;
+ }
+ ComPtr<gfx::IPipelineState> createComputePipelineState(IShaderProgram* program)
+ {
+ ComputePipelineStateDesc desc = {};
+ desc.program = program;
+ auto pipelineState = gDevice->createComputePipelineState(desc);
+ return pipelineState;
+ }
+ ComPtr<gfx::IResourceView> createUAV(IBufferResource* buffer)
+ {
+ IResourceView::Desc desc = {};
+ desc.type = IResourceView::Type::UnorderedAccess;
+ desc.bufferElementSize = 0;
+ return gDevice->createBufferView(buffer, nullptr, desc);
+ }
+ ComPtr<gfx::IResourceView> createUAV(ITextureResource* texture, int level)
+ {
+ IResourceView::Desc desc = {};
+ desc.type = IResourceView::Type::UnorderedAccess;
+ desc.subresourceRange.layerCount = 1;
+ desc.subresourceRange.mipLevel = level;
+ desc.subresourceRange.baseArrayLayer = 0;
+ return gDevice->createTextureView(texture,desc);
+ }
+ Slang::Result initialize()
+ {
+ initializeBase("autodiff-texture", 1024, 768);
+ srand(20421);
+
+ gWindow->events.keyPress = [this](platform::KeyEventArgs& e)
+ {
+ if (e.keyChar == 'R' || e.keyChar == 'r')
+ resetLearntTexture = true;
+ };
+
+ kClearValue.color.floatValues[0] = 0.3f;
+ kClearValue.color.floatValues[1] = 0.5f;
+ kClearValue.color.floatValues[2] = 0.7f;
+ kClearValue.color.floatValues[3] = 1.0f;
+
+ auto clientRect = gWindow->getClientRect();
+ windowWidth = clientRect.width;
+ windowHeight = clientRect.height;
+
+ InputElementDesc inputElements[] = {
+ {"POSITION", 0, Format::R32G32B32_FLOAT, offsetof(Vertex, position)}};
+ auto inputLayout = gDevice->createInputLayout(sizeof(Vertex), &inputElements[0], 1);
+ if (!inputLayout)
+ return SLANG_FAIL;
+
+ IBufferResource::Desc vertexBufferDesc;
+ vertexBufferDesc.type = IResource::Type::Buffer;
+ vertexBufferDesc.sizeInBytes = kVertexCount * sizeof(Vertex);
+ vertexBufferDesc.defaultState = ResourceState::VertexBuffer;
+ gVertexBuffer = gDevice->createBufferResource(vertexBufferDesc, &kVertexData[0]);
+ if (!gVertexBuffer)
+ return SLANG_FAIL;
+
+ {
+ ComPtr<IShaderProgram> shaderProgram;
+ SLANG_RETURN_ON_FAIL(
+ loadRenderProgram(gDevice, "train", "fragmentMain", shaderProgram.writeRef()));
+ gRefPipelineState = createRenderPipelineState(inputLayout, shaderProgram);
+ }
+ {
+ ComPtr<IShaderProgram> shaderProgram;
+ SLANG_RETURN_ON_FAIL(
+ loadRenderProgram(gDevice, "train", "diffFragmentMain", shaderProgram.writeRef()));
+ gIterPipelineState = createRenderPipelineState(inputLayout, shaderProgram);
+ }
+ {
+ ComPtr<IShaderProgram> shaderProgram;
+ SLANG_RETURN_ON_FAIL(
+ loadRenderProgram(gDevice, "draw-quad", "fragmentMain", shaderProgram.writeRef()));
+ gDrawQuadPipelineState = createRenderPipelineState(inputLayout, shaderProgram);
+ }
+ {
+ ComPtr<IShaderProgram> shaderProgram;
+ SLANG_RETURN_ON_FAIL(
+ loadComputeProgram(gDevice, "reconstruct", shaderProgram.writeRef()));
+ gReconstructPipelineState = createComputePipelineState(shaderProgram);
+ }
+ {
+ ComPtr<IShaderProgram> shaderProgram;
+ SLANG_RETURN_ON_FAIL(loadComputeProgram(gDevice, "convert", shaderProgram.writeRef()));
+ gConvertPipelineState = createComputePipelineState(shaderProgram);
+ }
+ {
+ ComPtr<IShaderProgram> shaderProgram;
+ SLANG_RETURN_ON_FAIL(loadComputeProgram(gDevice, "buildmip", shaderProgram.writeRef()));
+ gBuildMipPipelineState = createComputePipelineState(shaderProgram);
+ }
+ {
+ ComPtr<IShaderProgram> shaderProgram;
+ SLANG_RETURN_ON_FAIL(loadComputeProgram(gDevice, "learnmip", shaderProgram.writeRef()));
+ gLearnMipPipelineState = createComputePipelineState(shaderProgram);
+ }
+
+ gTexView = createTextureFromFile("checkerboard.jpg", textureWidth, textureHeight);
+ initMipOffsets(textureWidth, textureHeight);
+
+ gfx::IBufferResource::Desc bufferDesc = {};
+ bufferDesc.allowedStates.add(ResourceState::ShaderResource);
+ bufferDesc.allowedStates.add(ResourceState::UnorderedAccess);
+ bufferDesc.allowedStates.add(ResourceState::General);
+ bufferDesc.sizeInBytes = mipMapOffset.getLast() * sizeof(uint32_t);
+ bufferDesc.type = IResource::Type::Buffer;
+ gAccumulateBuffer = gDevice->createBufferResource(bufferDesc);
+ gReconstructBuffer = gDevice->createBufferResource(bufferDesc);
+
+ gAccumulateBufferView = createUAV(gAccumulateBuffer);
+ gReconstructBufferView = createUAV(gReconstructBuffer);
+
+ int mipCount = 1 + Math::Log2Ceil(Math::Max(textureWidth, textureHeight));
+ gLearningTexture = createRenderTargetTexture(
+ Format::R32G32B32A32_FLOAT,
+ textureWidth,
+ textureHeight,
+ mipCount);
+ gLearningTextureSRV = createSRV(gLearningTexture);
+ for (int i = 0; i < mipCount; i++)
+ gLearningTextureUAVs.add(createUAV(gLearningTexture, i));
+
+ gDiffTexture = createRenderTargetTexture(
+ Format::R32G32B32A32_FLOAT,
+ textureWidth,
+ textureHeight,
+ mipCount);
+ gDiffTextureSRV = createSRV(gDiffTexture);
+ for (int i = 0; i < mipCount; i++)
+ gDiffTextureUAVs.add(createUAV(gDiffTexture, i));
+
+ gfx::ISamplerState::Desc samplerDesc = {};
+ //samplerDesc.maxLOD = 0.0f;
+ gSampler = gDevice->createSamplerState(samplerDesc);
+
+ gDepthTexture = createDepthTexture();
+ gDepthTextureView = createDSV(gDepthTexture);
+
+ gRefImage = createRenderTargetTexture(Format::R8G8B8A8_UNORM, windowWidth, windowHeight, 1);
+ gRefImageRTV = createRTV(gRefImage, Format::R8G8B8A8_UNORM);
+ gRefImageSRV = createSRV(gRefImage);
+
+ gIterImage = createRenderTargetTexture(Format::R8G8B8A8_UNORM, windowWidth, windowHeight, 1);
+ gIterImageRTV = createRTV(gIterImage, Format::R8G8B8A8_UNORM);
+ gIterImageSRV = createSRV(gIterImage);
+
+ gRefFrameBuffer = createRenderTargetFramebuffer(gRefImageRTV);
+ gIterFrameBuffer = createRenderTargetFramebuffer(gIterImageRTV);
+
+ {
+ ComPtr<ICommandBuffer> commandBuffer = gTransientHeaps[0]->createCommandBuffer();
+ auto encoder = commandBuffer->encodeResourceCommands();
+ encoder->textureBarrier(gLearningTexture, ResourceState::RenderTarget, ResourceState::UnorderedAccess);
+ encoder->textureBarrier(gDiffTexture, ResourceState::RenderTarget, ResourceState::UnorderedAccess);
+ encoder->textureBarrier(gRefImage, ResourceState::RenderTarget, ResourceState::ShaderResource);
+ encoder->textureBarrier(gIterImage, ResourceState::RenderTarget, ResourceState::ShaderResource);
+ for (int i = 0; i < gLearningTextureUAVs.getCount(); i++)
+ {
+ ClearValue clearValue = {};
+ encoder->clearResourceView(gLearningTextureUAVs[i], &clearValue, ClearResourceViewFlags::None);
+ encoder->clearResourceView(gDiffTextureUAVs[i], &clearValue, ClearResourceViewFlags::None);
+ }
+ encoder->textureBarrier(gLearningTexture, ResourceState::UnorderedAccess, ResourceState::ShaderResource);
+
+ encoder->endEncoding();
+ commandBuffer->close();
+ gQueue->executeCommandBuffer(commandBuffer);
+ }
+
+ return SLANG_OK;
+ }
+
+ void initMipOffsets(int w, int h)
+ {
+ int layers = 1 + Math::Log2Ceil(Math::Max(w, h));
+ uint32_t offset = 0;
+ for (int i = 0; i < layers; i++)
+ {
+ auto lw = Math::Max(1, w >> i);
+ auto lh = Math::Max(1, h >> i);
+ mipMapOffset.add(offset);
+ offset += lw * lh * 4;
+ }
+ mipMapOffset.add(offset);
+ }
+
+ glm::mat4x4 getTransformMatrix()
+ {
+ float rotX = (rand() / (float)RAND_MAX) * 0.3f;
+ float rotY = (rand() / (float)RAND_MAX) * 0.2f;
+ glm::mat4x4 matProj = glm::perspectiveRH_ZO(
+ glm::radians(60.0f), (float)windowWidth / (float)windowHeight, 0.1f, 1000.0f);
+ auto identity = glm::mat4(1.0f);
+ auto translate = glm::translate(
+ identity,
+ glm::vec3(
+ -0.6f + 0.2f * (rand() / (float)RAND_MAX),
+ -0.6f + 0.2f * (rand() / (float)RAND_MAX),
+ -1.0f));
+ auto rot = glm::rotate(translate, -glm::pi<float>() * rotX, glm::vec3(1.0f, 0.0f, 0.0f));
+ rot = glm::rotate(rot, -glm::pi<float>() * rotY, glm::vec3(0.0f, 1.0f, 0.0f));
+ auto transformMatrix = matProj * rot;
+ transformMatrix = glm::transpose(transformMatrix);
+ return transformMatrix;
+ }
+
+ template <typename SetupPipelineFunc>
+ void renderImage(
+ int transientHeapIndex, IFramebuffer* fb, const SetupPipelineFunc& setupPipeline)
+ {
+ ComPtr<ICommandBuffer> commandBuffer =
+ gTransientHeaps[transientHeapIndex]->createCommandBuffer();
+ auto renderEncoder = commandBuffer->encodeRenderCommands(gRenderPass, fb);
+
+ gfx::Viewport viewport = {};
+ viewport.maxZ = 1.0f;
+ viewport.extentX = (float)windowWidth;
+ viewport.extentY = (float)windowHeight;
+ renderEncoder->setViewportAndScissor(viewport);
+
+ setupPipeline(renderEncoder);
+
+ renderEncoder->setVertexBuffer(0, gVertexBuffer);
+ renderEncoder->setPrimitiveTopology(PrimitiveTopology::TriangleStrip);
+
+ renderEncoder->draw(4);
+ renderEncoder->endEncoding();
+ commandBuffer->close();
+ gQueue->executeCommandBuffer(commandBuffer);
+ }
+
+ void renderReferenceImage(int transientHeapIndex, glm::mat4x4 transformMatrix)
+ {
+ {
+ ComPtr<ICommandBuffer> commandBuffer = gTransientHeaps[transientHeapIndex]->createCommandBuffer();
+ auto encoder = commandBuffer->encodeResourceCommands();
+ encoder->textureBarrier(gRefImage, ResourceState::ShaderResource, ResourceState::RenderTarget);
+ encoder->endEncoding();
+ commandBuffer->close();
+ gQueue->executeCommandBuffer(commandBuffer);
+ }
+
+ renderImage(
+ transientHeapIndex,
+ gRefFrameBuffer,
+ [&](IRenderCommandEncoder* encoder)
+ {
+ auto rootObject = encoder->bindPipeline(gRefPipelineState);
+ ShaderCursor rootCursor(rootObject);
+ rootCursor["Uniforms"]["modelViewProjection"].setData(
+ &transformMatrix, sizeof(float) * 16);
+ rootCursor["Uniforms"]["bwdTexture"]["texture"].setResource(gTexView);
+ rootCursor["Uniforms"]["sampler"].setSampler(gSampler);
+ rootCursor["Uniforms"]["mipOffset"].setData(mipMapOffset.getBuffer(), sizeof(uint32_t) * mipMapOffset.getCount());
+ rootCursor["Uniforms"]["texRef"].setResource(gTexView);
+ rootCursor["Uniforms"]["bwdTexture"]["accumulateBuffer"].setResource(gAccumulateBufferView);
+ });
+ }
+
+ virtual void renderFrame(int frameBufferIndex) override
+ {
+ static uint32_t frameCount = 0;
+ frameCount++;
+ auto transformMatrix = getTransformMatrix();
+ renderReferenceImage(frameBufferIndex, transformMatrix);
+
+ // Barriers.
+ {
+ ComPtr<ICommandBuffer> commandBuffer =
+ gTransientHeaps[frameBufferIndex]->createCommandBuffer();
+ auto resEncoder = commandBuffer->encodeResourceCommands();
+ ClearValue clearValue = {};
+ resEncoder->bufferBarrier(gAccumulateBuffer, ResourceState::Undefined, ResourceState::UnorderedAccess);
+ resEncoder->bufferBarrier(gReconstructBuffer, ResourceState::Undefined, ResourceState::UnorderedAccess);
+ resEncoder->textureBarrier(gRefImage, ResourceState::Present, ResourceState::ShaderResource);
+ resEncoder->textureBarrier(gIterImage, ResourceState::ShaderResource, ResourceState::RenderTarget);
+ resEncoder->clearResourceView(gAccumulateBufferView, &clearValue, ClearResourceViewFlags::None);
+ resEncoder->clearResourceView(gReconstructBufferView, &clearValue, ClearResourceViewFlags::None);
+ if (resetLearntTexture)
+ {
+ resEncoder->textureBarrier(gLearningTexture, ResourceState::ShaderResource, ResourceState::UnorderedAccess);
+ for (Index i =0; i <gLearningTextureUAVs.getCount(); i++)
+ resEncoder->clearResourceView(gLearningTextureUAVs[i], &clearValue, ClearResourceViewFlags::None);
+ resEncoder->textureBarrier(gLearningTexture, ResourceState::UnorderedAccess, ResourceState::ShaderResource);
+ resetLearntTexture = false;
+ }
+ resEncoder->endEncoding();
+ commandBuffer->close();
+ gQueue->executeCommandBuffer(commandBuffer);
+ }
+
+ // Render image using backward propagate shader to obtain texture-space gradients.
+ renderImage(
+ frameBufferIndex,
+ gIterFrameBuffer,
+ [&](IRenderCommandEncoder* encoder)
+ {
+ auto rootObject = encoder->bindPipeline(gIterPipelineState);
+ ShaderCursor rootCursor(rootObject);
+
+ rootCursor["Uniforms"]["modelViewProjection"].setData(
+ &transformMatrix, sizeof(float) * 16);
+ rootCursor["Uniforms"]["bwdTexture"]["texture"].setResource(gLearningTextureSRV);
+ rootCursor["Uniforms"]["sampler"].setSampler(gSampler);
+ rootCursor["Uniforms"]["mipOffset"].setData(mipMapOffset.getBuffer(), sizeof(uint32_t) * mipMapOffset.getCount());
+ rootCursor["Uniforms"]["texRef"].setResource(gRefImageSRV);
+ rootCursor["Uniforms"]["bwdTexture"]["accumulateBuffer"].setResource(gAccumulateBufferView);
+ rootCursor["Uniforms"]["bwdTexture"]["minLOD"].setData(5.0);
+
+ });
+
+ // Propagete gradients through mip map layers from top (lowest res) to bottom (highest res).
+ {
+ ComPtr<ICommandBuffer> commandBuffer =
+ gTransientHeaps[frameBufferIndex]->createCommandBuffer();
+ auto encoder = commandBuffer->encodeComputeCommands();
+ encoder->textureBarrier(gLearningTexture, ResourceState::ShaderResource, ResourceState::UnorderedAccess);
+ auto rootObject = encoder->bindPipeline(gReconstructPipelineState);
+ for (int i = (int)mipMapOffset.getCount() - 2; i >= 0; i--)
+ {
+ ShaderCursor rootCursor(rootObject);
+ rootCursor["Uniforms"]["mipOffset"].setData(mipMapOffset.getBuffer(), sizeof(uint32_t) * mipMapOffset.getCount());
+ rootCursor["Uniforms"]["dstLayer"].setData(i);
+ rootCursor["Uniforms"]["layerCount"].setData(mipMapOffset.getCount() - 1);
+ rootCursor["Uniforms"]["width"].setData(textureWidth);
+ rootCursor["Uniforms"]["height"].setData(textureHeight);
+ rootCursor["Uniforms"]["accumulateBuffer"].setResource(gAccumulateBufferView);
+ rootCursor["Uniforms"]["dstBuffer"].setResource(gReconstructBufferView);
+ encoder->dispatchCompute(
+ ((textureWidth >> i) + 15) / 16, ((textureHeight >> i) + 15) / 16, 1);
+ encoder->bufferBarrier(gReconstructBuffer, ResourceState::UnorderedAccess, ResourceState::UnorderedAccess);
+ }
+
+ // Convert bottom layer mip from buffer to texture.
+ rootObject = encoder->bindPipeline(gConvertPipelineState);
+ ShaderCursor rootCursor(rootObject);
+ rootCursor["Uniforms"]["mipOffset"].setData(mipMapOffset.getBuffer(), sizeof(uint32_t) * mipMapOffset.getCount());
+ rootCursor["Uniforms"]["dstLayer"].setData(0);
+ rootCursor["Uniforms"]["width"].setData(textureWidth);
+ rootCursor["Uniforms"]["height"].setData(textureHeight);
+ rootCursor["Uniforms"]["srcBuffer"].setResource(gReconstructBufferView);
+ rootCursor["Uniforms"]["dstTexture"].setResource(gDiffTextureUAVs[0]);
+ encoder->dispatchCompute(
+ (textureWidth + 15) / 16, (textureHeight + 15) / 16, 1);
+ encoder->textureBarrier(gDiffTexture, ResourceState::UnorderedAccess, ResourceState::UnorderedAccess);
+
+ // Build higher level mip map layers.
+ rootObject = encoder->bindPipeline(gBuildMipPipelineState);
+ for (int i = 1; i < (int)mipMapOffset.getCount() - 1; i++)
+ {
+ ShaderCursor rootCursor(rootObject);
+ rootCursor["Uniforms"]["dstWidth"].setData(textureWidth >> i);
+ rootCursor["Uniforms"]["dstHeight"].setData(textureHeight >> i);
+ rootCursor["Uniforms"]["srcTexture"].setResource(gDiffTextureUAVs[i-1]);
+ rootCursor["Uniforms"]["dstTexture"].setResource(gDiffTextureUAVs[i]);
+ encoder->dispatchCompute(
+ ((textureWidth >> i) + 15) / 16, ((textureHeight >> i) + 15) / 16, 1);
+ encoder->textureBarrier(gDiffTexture, ResourceState::UnorderedAccess, ResourceState::UnorderedAccess);
+ }
+
+ // Accumulate gradients to learnt texture.
+ rootObject = encoder->bindPipeline(gLearnMipPipelineState);
+ for (int i = 0; i < (int)mipMapOffset.getCount() - 1; i++)
+ {
+ ShaderCursor rootCursor(rootObject);
+ rootCursor["Uniforms"]["dstWidth"].setData(textureWidth >> i);
+ rootCursor["Uniforms"]["dstHeight"].setData(textureHeight >> i);
+ rootCursor["Uniforms"]["learningRate"].setData(0.1f);
+ rootCursor["Uniforms"]["srcTexture"].setResource(gDiffTextureUAVs[i]);
+ rootCursor["Uniforms"]["dstTexture"].setResource(gLearningTextureUAVs[i]);
+ encoder->dispatchCompute(
+ ((textureWidth >> i) + 15) / 16, ((textureHeight >> i) + 15) / 16, 1);
+ }
+ encoder->textureBarrier(gLearningTexture, ResourceState::UnorderedAccess, ResourceState::ShaderResource);
+ encoder->textureBarrier(gIterImage, ResourceState::Present, ResourceState::ShaderResource);
+
+ encoder->endEncoding();
+ commandBuffer->close();
+ gQueue->executeCommandBuffer(commandBuffer);
+ }
+
+ // Draw currently learnt texture.
+ {
+ ComPtr<ICommandBuffer> commandBuffer =
+ gTransientHeaps[frameBufferIndex]->createCommandBuffer();
+ auto renderEncoder = commandBuffer->encodeRenderCommands(gRenderPass, gFramebuffers[frameBufferIndex]);
+ drawTexturedQuad(renderEncoder, 0, 0, textureWidth, textureHeight, gLearningTextureSRV);
+ int refImageWidth = windowWidth - textureWidth - 10;
+ int refImageHeight = refImageWidth * windowHeight / windowWidth;
+ drawTexturedQuad(renderEncoder, textureWidth + 10, 0, refImageWidth, refImageHeight, gRefImageSRV);
+ drawTexturedQuad(renderEncoder, textureWidth + 10, refImageHeight + 10, refImageWidth, refImageHeight, gIterImageSRV);
+ renderEncoder->endEncoding();
+ commandBuffer->close();
+ gQueue->executeCommandBuffer(commandBuffer);
+ }
+
+ gSwapchain->present();
+ }
+
+ void drawTexturedQuad(IRenderCommandEncoder* renderEncoder, int x, int y, int w, int h, IResourceView* srv)
+ {
+ gfx::Viewport viewport = {};
+ viewport.maxZ = 1.0f;
+ viewport.extentX = (float)windowWidth;
+ viewport.extentY = (float)windowHeight;
+ renderEncoder->setViewportAndScissor(viewport);
+
+ auto root = renderEncoder->bindPipeline(gDrawQuadPipelineState);
+ ShaderCursor rootCursor(root);
+ rootCursor["Uniforms"]["x"].setData(x);
+ rootCursor["Uniforms"]["y"].setData(y);
+ rootCursor["Uniforms"]["width"].setData(w);
+ rootCursor["Uniforms"]["height"].setData(h);
+ rootCursor["Uniforms"]["viewWidth"].setData(windowWidth);
+ rootCursor["Uniforms"]["viewHeight"].setData(windowHeight);
+ rootCursor["Uniforms"]["texture"].setResource(srv);
+ rootCursor["Uniforms"]["sampler"].setSampler(gSampler);
+ renderEncoder->setVertexBuffer(0, gVertexBuffer);
+ renderEncoder->setPrimitiveTopology(PrimitiveTopology::TriangleStrip);
+ renderEncoder->draw(4);
+ }
+
+};
+
+PLATFORM_UI_MAIN(innerMain<AutoDiffTexture>)
diff --git a/examples/autodiff-texture/reconstruct.slang b/examples/autodiff-texture/reconstruct.slang
new file mode 100644
index 000000000..c123010e5
--- /dev/null
+++ b/examples/autodiff-texture/reconstruct.slang
@@ -0,0 +1,48 @@
+// A compute shader to propagate gradients from high level mip(low-res) to lower level mip (high-res).
+
+cbuffer Uniforms
+{
+ uint4 mipOffset[16];
+ uint dstLayer;
+ uint layerCount;
+ uint width;
+ uint height;
+ RWStructuredBuffer<int> accumulateBuffer;
+ RWStructuredBuffer<float> dstBuffer;
+}
+
+[shader("compute")]
+[numthreads(16, 16, 1)]
+void computeMain(uint3 threadIdx : SV_DispatchThreadID)
+{
+ uint x = threadIdx.x;
+ uint y = threadIdx.y;
+ uint dstW = width >> dstLayer;
+ uint dstH = height >> dstLayer;
+ if (x >= dstW) return;
+ if (y >= dstH) return;
+ uint dstOffset = mipOffset[dstLayer / 4][dstLayer % 4] + (y * dstW + x) * 4;
+ var dstVal = int4(accumulateBuffer[dstOffset], accumulateBuffer[dstOffset + 1], accumulateBuffer[dstOffset + 2], accumulateBuffer[dstOffset + 3]);
+ var newDstValToAdd = float3(0.0);
+ if (dstVal.w > 0)
+ newDstValToAdd = (float3)dstVal.xyz * float3(1.0 / (dstVal.w * 65536.0));
+
+ float4 existingVal = 0.0;
+
+ if (dstLayer < layerCount - 1 )
+ {
+ uint parentOffset = mipOffset[(dstLayer + 1) / 4][(dstLayer + 1) % 4];
+ uint parentW = dstW / 2;
+ uint parentPixelLoc = parentOffset + ((y / 2) * parentW + (x / 2)) * 4;
+ existingVal.x = dstBuffer[parentPixelLoc] * 0.25;
+ existingVal.y = dstBuffer[parentPixelLoc + 1] * 0.25;
+ existingVal.z = dstBuffer[parentPixelLoc + 2] * 0.25;
+ existingVal.w = 0.0;
+ }
+
+ var newDstVal = existingVal + float4(newDstValToAdd, 0.0);
+ dstBuffer[dstOffset] = newDstVal.x;
+ dstBuffer[dstOffset + 1] = newDstVal.y;
+ dstBuffer[dstOffset + 2] = newDstVal.z;
+ dstBuffer[dstOffset + 3] = 1.0;
+}
diff --git a/examples/autodiff-texture/train.slang b/examples/autodiff-texture/train.slang
new file mode 100644
index 000000000..16f3be35c
--- /dev/null
+++ b/examples/autodiff-texture/train.slang
@@ -0,0 +1,216 @@
+// shaders.slang
+
+struct BwdTexture
+{
+ RWStructuredBuffer<int> accumulateBuffer;
+ Texture2D texture;
+ float minLOD;
+ void writeDiffToTexel(uint offset, uint layerW, uint layerH, float x, float y, float3 val)
+ {
+ int4 uval = int4(int3(val * 65536), 1);
+ InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 0], uval.x);
+ InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 1], uval.y);
+ InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 2], uval.z);
+ InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 3], uval.w);
+ }
+
+ void broadcastDiffToLayer(uint lod, float3 diff, float2 uv, uint w, uint h)
+ {
+ var offset = mipOffset[lod / 4][lod % 4];
+ w >>= lod;
+ h >>= lod;
+ uv = uv - floor(uv);
+ float2 loc = uv * float2(w, h) - float2(0.5);
+ float x0 = floor(loc.x);
+ float y0 = floor(loc.y);
+ float fracX = loc.x - x0;
+ float fracY = loc.y - y0;
+ float x1 = x0 + 1;
+ float y1 = y0 + 1;
+ if (x0 < 0) x0 += w;
+ if (y0 < 0) y0 += h;
+ if (x1 >= w) x1 -= w;
+ if (y1 >= h) y1 -= h;
+ float weight0 = 1.0f - fracY;
+ float weight1 = fracY;
+ float weight00 = weight0 * (1.0f - fracX);
+ float weight01 = weight0 * fracX;
+ float weight10 = weight1 * (1.0f - fracX);
+ float weight11 = weight1 * fracX;
+
+ writeDiffToTexel(offset, w, h, x0, y0, weight00 * diff);
+ writeDiffToTexel(offset, w, h, x1, y0, weight01 * diff);
+ writeDiffToTexel(offset, w, h, x0, y1, weight10 * diff);
+ writeDiffToTexel(offset, w, h, x1, y1, weight11 * diff);
+ }
+
+ void sampleTexture_trilinear_bwd(uint w, uint h, uint levels, float2 uv, float2 dX, float2 dY, float4 dOut)
+ {
+ dX = dX * float2(w, h);
+ dY = dY * float2(w, h);
+
+ // Isotropic filter.
+ float lengthX = length(dX);
+ float lengthY = length(dY);
+ float LOD = log2(max(lengthX, lengthY));
+ float maxLOD = levels - 1;
+ float clampedLOD = max(minLOD, (min(maxLOD, LOD)));
+
+ float lodFrac = clampedLOD - floor(clampedLOD);
+ uint lod0 = (uint)floor(clampedLOD);
+ uint lod1 = min(levels - 1, lod0 + 1);
+ float weightLod0 = 1.0 - lodFrac;
+ float weightLod1 = lodFrac;
+ weightLod0 = 1.0;
+ broadcastDiffToLayer(lod0, dOut.xyz * weightLod0, uv, w, h);
+ broadcastDiffToLayer(lod1, dOut.xyz * weightLod1, uv, w, h);
+ }
+
+ float4 sampleTexture_linear(uint lod, float2 uv, uint w, uint h)
+ {
+ w >>= lod;
+ h >>= lod;
+ uv = uv - floor(uv);
+ float2 loc = uv * float2(w, h) - float2(0.5);
+ float x0 = floor(loc.x);
+ float y0 = floor(loc.y);
+ float fracX = loc.x - x0;
+ float fracY = loc.y - y0;
+ float x1 = x0 + 1;
+ float y1 = y0 + 1;
+ if (x0 < 0) x0 += w;
+ if (y0 < 0) y0 += h;
+ if (x1 >= w) x1 -= w;
+ if (y1 >= h) y1 -= h;
+ float weight0 = 1.0f - fracY;
+ float weight1 = fracY;
+ float weight00 = weight0 * (1.0f - fracX);
+ float weight01 = weight0 * fracX;
+ float weight10 = weight1 * (1.0f - fracX);
+ float weight11 = weight1 * fracX;
+ return texture.Load(int3(int(x0), int(y0), int(lod)), int2(0)) * weight00 +
+ texture.Load(int3(int(x1), int(y0), int(lod)), int2(0)) * weight01 +
+ texture.Load(int3(int(x0), int(y1), int(lod)), int2(0)) * weight10 +
+ texture.Load(int3(int(x1), int(y1), int(lod)), int2(0)) * weight11;
+ }
+
+ float4 sampleTexture_trilinear(uint w, uint h, uint levels, float2 uv, float2 dX, float2 dY)
+ {
+ dX = dX * float2(w, h);
+ dY = dY * float2(w, h);
+
+ // Isotropic filter.
+ float lengthX = length(dX);
+ float lengthY = length(dY);
+ float LOD = log2(max(lengthX, lengthY));
+ float maxLOD = levels - 1;
+ float clampedLOD = max(minLOD, (min(maxLOD, LOD)));
+
+ float lodFrac = clampedLOD - floor(clampedLOD);
+ uint lod0 = (uint)floor(clampedLOD);
+ uint lod1 = min(levels - 1, lod0 + 1);
+ float weightLod0 = 1.0 - lodFrac;
+ float weightLod1 = lodFrac;
+
+ let v0 = sampleTexture_linear(lod0, uv, w, h) * weightLod0;
+ let v1 = sampleTexture_linear(lod1, uv, w, h) * weightLod1;
+ return v0 + v1;
+ }
+
+ static float4 sample(BwdTexture t, SamplerState s, no_diff float2 uv)
+ {
+ return t.texture.Sample(s, uv);
+ }
+
+ [ForwardDerivativeOf(BwdTexture.sample)]
+ static float4 fwd_sample(BwdTexture t, SamplerState s, no_diff float2 uv)
+ {
+ return float4(0.0);
+ }
+
+ [BackwardDerivativeOf(BwdTexture.sample)]
+ static void bwd_sample(BwdTexture t, SamplerState s, no_diff float2 uv, float4 dOut)
+ {
+ float2 dX = ddx_coarse(uv);
+ float2 dY = ddy_coarse(uv);
+ uint w;
+ uint h;
+ uint levels;
+ t.texture.GetDimensions(0, w, h, levels);
+ t.sampleTexture_trilinear_bwd(w, h, levels, uv, dX, dY, dOut);
+ }
+}
+
+cbuffer Uniforms
+{
+ float4x4 modelViewProjection;
+ uint4 mipOffset[16];
+
+ Texture2D texRef;
+ SamplerState sampler;
+ BwdTexture bwdTexture;
+}
+
+struct AssembledVertex
+{
+ float3 position : POSITION;
+};
+
+struct Fragment
+{
+ float4 color;
+};
+
+struct VertexStageOutput
+{
+ float2 uv : UV;
+ float4 sv_position : SV_Position;
+};
+
+[BackwardDifferentiable]
+float4 shadeFragment(no_diff float2 uv)
+{
+ float3 color = BwdTexture.sample(bwdTexture, sampler, uv * 2).xyz;
+ return float4(color, 1.0);
+}
+
+[BackwardDifferentiable]
+float3 loss(no_diff float2 uv, no_diff float4 screenPos)
+{
+ float3 refColor = (no_diff texRef.Load(int3(int2(screenPos.xy), 0))).xyz;
+ float3 rs = shadeFragment(uv).xyz - refColor;
+ rs *= rs;
+ return rs;
+}
+
+[shader("vertex")]
+VertexStageOutput vertexMain(
+ AssembledVertex assembledVertex)
+{
+ VertexStageOutput output;
+
+ float3 position = assembledVertex.position;
+
+ output.uv = position.xy;
+ output.sv_position = mul(modelViewProjection, float4(position, 1.0));
+
+ return output;
+}
+
+float3 sqr(float3 v) { return v * v; }
+
+[shader("fragment")]
+float4 fragmentMain(
+ float2 uv : UV) : SV_Target
+{
+ return shadeFragment(uv);
+}
+
+[shader("fragment")]
+float4 diffFragmentMain(
+ float2 uv : UV,
+ float4 screenPos : SV_POSITION) : SV_Target
+{
+ __bwd_diff(loss)(uv, screenPos, float3(1.0));
+ return float4(loss(uv, screenPos), 1.0);
+}