summaryrefslogtreecommitdiffstats
path: root/tests/autodiff
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-03-14 17:15:36 -0700
committerGitHub <noreply@github.com>2025-03-15 00:15:36 +0000
commit78517dc392f0d2ebba25f0ac3f4d4e004b0f0ab0 (patch)
tree104b48da3fc54e43cd7c5ce51cc66b4e2dc26d55 /tests/autodiff
parentc8c9e424e91e72e718529ed76df14f7586624cd6 (diff)
Fix lowering of associated types in generic interfaces (#6600)
* Fix lowering of associated types in generic interfaces. * Update diff-assoctype-generic-interface.slang * Fix-up lowering of differentiable witnesses for implicit ops * Update slang-ir-autodiff-transcriber-base.cpp * Fix issue with differentiating type-packs
Diffstat (limited to 'tests/autodiff')
-rw-r--r--tests/autodiff/autopybind-printf.slang47
-rw-r--r--tests/autodiff/diff-assoctype-generic-interface.slang110
2 files changed, 157 insertions, 0 deletions
diff --git a/tests/autodiff/autopybind-printf.slang b/tests/autodiff/autopybind-printf.slang
new file mode 100644
index 000000000..add1923ef
--- /dev/null
+++ b/tests/autodiff/autopybind-printf.slang
@@ -0,0 +1,47 @@
+//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none
+//TEST:SIMPLE(filecheck=TORCH): -target torch -line-directive-mode none
+
+// CUDA: __device__ void s_primal_ctx_myKernel_0(
+// CUDA: printf("%f\n",
+// CUDA: __global__ void __kernel__myKernel_bwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}})
+// CUDA: __global__ void __kernel__myKernel_fwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}})
+// CUDA: __global__ void __kernel__myKernel(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}})
+
+[AutoPyBindCUDA]
+[Differentiable]
+[CudaKernel]
+void myKernel(DiffTensorView inValues, DiffTensorView outValues)
+{
+ if (cudaThreadIdx().x > 0)
+ return;
+ printf("%f\n", inValues[cudaThreadIdx().x]);
+ outValues[cudaThreadIdx().x] = sin(inValues[cudaThreadIdx().x]);
+}
+
+// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
+// TORCH-NEXT: void __kernel__myKernel_bwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}})
+//
+// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
+// TORCH-NEXT: void __kernel__myKernel_fwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}})
+//
+// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
+// TORCH-NEXT: void __kernel__myKernel(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}})
+//
+// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
+// TORCH-NEXT: void myKernel(std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}})
+//
+// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
+// TORCH-NEXT: std::tuple<std::tuple<const char*, const char*, const char*, const char*>, std::tuple<const char*, const char*>, const char*, const char*> __funcinfo__myKernel()
+//
+// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
+// TORCH-NEXT: void myKernel_fwd_diff(std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}})
+//
+// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
+// TORCH-NEXT: void myKernel_bwd_diff(std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}})
+//
+// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
+// TORCH-NEXT: std::tuple<std::tuple<const char*, const char*>, std::tuple<const char*, const char*>> __typeinfo__DiffTensorView()
+//
+// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
+// TORCH-NEXT: std::tuple<std::tuple<const char*>, std::tuple<const char*>> __typeinfo__AtomicAdd()
+// \ No newline at end of file
diff --git a/tests/autodiff/diff-assoctype-generic-interface.slang b/tests/autodiff/diff-assoctype-generic-interface.slang
new file mode 100644
index 000000000..79e0eff08
--- /dev/null
+++ b/tests/autodiff/diff-assoctype-generic-interface.slang
@@ -0,0 +1,110 @@
+// Test calling differentiable function through dynamic dispatch.
+
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type -compile-arg -skip-spirv-validation -emit-spirv-directly
+
+//TEST_INPUT:ubuffer(data=[2 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+interface IGetter : IDifferentiable
+{
+ [Differentiable]
+ float get(uint id);
+}
+
+struct GetterImpl : IGetter
+{
+ float[8] data;
+
+ __init(float[8] data)
+ { this.data = data; }
+
+ [Differentiable]
+ float get(uint id)
+ {
+ return data[id];
+ }
+}
+interface IFoo<int N>
+{
+ associatedtype Params : IGetter;
+
+ [Differentiable]
+ Params bar();
+}
+
+[BackwardDerivative(load_bwd)]
+[ForwardDerivative(load_fwd)]
+float load(uint id)
+{
+ return outputBuffer[id] + 2;
+}
+
+DifferentialPair<float> load_fwd(uint id)
+{
+ return DifferentialPair<float>(load(id), 3.f);
+}
+
+void load_bwd(uint id, float.Differential dOut)
+{
+ outputBuffer[id + 8] = dOut;
+}
+
+struct FooImpl1: IFoo<8>
+{
+ typealias Params = GetterImpl;
+
+ __init()
+ { }
+
+ [Differentiable]
+ Params bar()
+ {
+ float x = load(0);
+ return GetterImpl({x, x+1, x+2, x+3, x+4, x+5, x+6, x+7});
+ }
+}
+
+/*
+// There's a slight issue with dynamic dispatch over generic interfaces. Uncomment after that is fixed.
+
+struct FooImpl2: IFoo<8>
+{
+ typealias Params = GetterImpl;
+
+ __init()
+ { }
+
+ [Differentiable]
+ Params bar()
+ {
+ float x = 2 * load(0);
+ return GetterImpl({x, x+5, x+7, x+9, x+11, x+13, x+15, x+17});
+ }
+}
+*/
+
+IFoo<8> getFoo(uint id)
+{
+ /*if (id == 0)
+ return FooImpl1();
+ else
+ return FooImpl2();*/
+ return FooImpl1();
+}
+
+[Differentiable]
+float doThing(uint id)
+{
+ IFoo<8> foo = getFoo(id);
+ return foo.bar().get(0);
+}
+
+[shader("compute")]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ outputBuffer[0] = doThing(0); // CHECK: 2.0
+ outputBuffer[1] = doThing(1); // CHECK: 4.0
+
+ outputBuffer[2] = fwd_diff(doThing)(0).d; // CHECK: 3.0
+ outputBuffer[3] = fwd_diff(doThing)(1).d; // CHECK: 3.0
+} \ No newline at end of file