summaryrefslogtreecommitdiffstats
path: root/tests/autodiff/dynamic-dispatch-material.slang
diff options
context:
space:
mode:
Diffstat (limited to 'tests/autodiff/dynamic-dispatch-material.slang')
-rw-r--r--tests/autodiff/dynamic-dispatch-material.slang142
1 files changed, 142 insertions, 0 deletions
diff --git a/tests/autodiff/dynamic-dispatch-material.slang b/tests/autodiff/dynamic-dispatch-material.slang
new file mode 100644
index 000000000..1185a92e7
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-material.slang
@@ -0,0 +1,142 @@
+// Test calling differentiable function through dynamic dispatch.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+//TEST_INPUT: set g_materials = new StructuredBuffer<MaterialDataBlob>[new MaterialDataBlob{new MaterialHeader{[0, 0, 0, 0]}, new MaterialPayload{[1.0, 1.2, 0.3, 0.5]}}];
+RWStructuredBuffer<MaterialDataBlob> g_materials;
+
+public struct ShadingInput
+{
+ public float scale;
+}
+
+struct MaterialHeader
+{
+ uint4 header;
+};
+struct MaterialPayload
+{
+ float4 data;
+};
+struct MaterialDataBlob
+{
+ MaterialHeader header; // 16B
+ MaterialPayload payload; // 16B
+};
+
+interface IMaterial : IDifferentiable
+{
+ associatedtype MaterialInstance : IMaterialInstance;
+
+ [Differentiable]
+ MaterialInstance setupMaterialInstance( ShadingInput input );
+}
+
+interface IMaterialInstance : IDifferentiable
+{
+ [Differentiable]
+ float eval( float x );
+}
+
+
+[BackwardDerivative(getMaterial_bwd)]
+IMaterial getMaterial(int id)
+{
+ return createDynamicObject<IMaterial, MaterialDataBlob>(id, g_materials[id]);
+}
+
+void getMaterial_bwd(int id, IDifferentiable d)
+{
+ // Something random
+ outputBuffer[id] = 2.f;
+}
+
+struct Material1: IMaterial
+{
+ typedef MaterialInstance1 MaterialInstance;
+
+ MaterialHeader header;
+ float a;
+ float b;
+ float c;
+
+ [Differentiable]
+ MaterialInstance1 setupMaterialInstance( ShadingInput input )
+ {
+ MaterialInstance1 instance;
+ instance.a = a * input.scale;
+ instance.b = b * input.scale;
+ instance.c = c * input.scale;
+ return instance;
+ }
+
+}
+struct MaterialInstance1: IMaterialInstance
+{
+ float a;
+ float b;
+ float c;
+
+ [Differentiable]
+ float eval( float x )
+ {
+ return a * x * x + b * x + c;
+ }
+}
+
+struct Material2: IMaterial
+{
+ typedef MaterialInstance2 MaterialInstance;
+
+ MaterialHeader header;
+ float a;
+ float b;
+
+ [Differentiable]
+ MaterialInstance2 setupMaterialInstance( ShadingInput input )
+ {
+ MaterialInstance2 instance;
+ instance.a = a * input.scale * input.scale;
+ instance.b = b * input.scale * input.scale;
+ return instance;
+ }
+
+}
+public struct MaterialInstance2: IMaterialInstance
+{
+ float a;
+ float b;
+
+ [Differentiable]
+ public float eval( float x )
+ {
+ return a * x + b;
+ }
+}
+
+[Differentiable]
+public float shade(int material, ShadingInput input, float x)
+{
+ IMaterial m = getMaterial(material);
+ IMaterialInstance mi = m.setupMaterialInstance(input);
+ return mi.eval(x);
+}
+
+//TEST_INPUT: type_conformance Material1:IMaterial = 0
+//TEST_INPUT: type_conformance Material2:IMaterial = 1
+
+[shader("compute")]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ outputBuffer[0] = shade(0, {0.5}, 0.6);
+
+ // TODO: VERIFY
+ DifferentialPair<float> dpx = diffPair(3.0);
+ bwd_diff(shade)(0, {0.5}, dpx, 1.0);
+
+ outputBuffer[3] = dpx.d;
+} \ No newline at end of file