diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-03-14 17:15:36 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-03-15 00:15:36 +0000 |
| commit | 78517dc392f0d2ebba25f0ac3f4d4e004b0f0ab0 (patch) | |
| tree | 104b48da3fc54e43cd7c5ce51cc66b4e2dc26d55 /tests/autodiff | |
| parent | c8c9e424e91e72e718529ed76df14f7586624cd6 (diff) | |
Fix lowering of associated types in generic interfaces (#6600)
* Fix lowering of associated types in generic interfaces.
* Update diff-assoctype-generic-interface.slang
* Fix-up lowering of differentiable witnesses for implicit ops
* Update slang-ir-autodiff-transcriber-base.cpp
* Fix issue with differentiating type-packs
Diffstat (limited to 'tests/autodiff')
| -rw-r--r-- | tests/autodiff/autopybind-printf.slang | 47 | ||||
| -rw-r--r-- | tests/autodiff/diff-assoctype-generic-interface.slang | 110 |
2 files changed, 157 insertions, 0 deletions
diff --git a/tests/autodiff/autopybind-printf.slang b/tests/autodiff/autopybind-printf.slang new file mode 100644 index 000000000..add1923ef --- /dev/null +++ b/tests/autodiff/autopybind-printf.slang @@ -0,0 +1,47 @@ +//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none +//TEST:SIMPLE(filecheck=TORCH): -target torch -line-directive-mode none + +// CUDA: __device__ void s_primal_ctx_myKernel_0( +// CUDA: printf("%f\n", +// CUDA: __global__ void __kernel__myKernel_bwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}}) +// CUDA: __global__ void __kernel__myKernel_fwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}}) +// CUDA: __global__ void __kernel__myKernel(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}}) + +[AutoPyBindCUDA] +[Differentiable] +[CudaKernel] +void myKernel(DiffTensorView inValues, DiffTensorView outValues) +{ + if (cudaThreadIdx().x > 0) + return; + printf("%f\n", inValues[cudaThreadIdx().x]); + outValues[cudaThreadIdx().x] = sin(inValues[cudaThreadIdx().x]); +} + +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void __kernel__myKernel_bwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}}) +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void __kernel__myKernel_fwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}}) +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void __kernel__myKernel(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}}) +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void myKernel(std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}}) +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: std::tuple<std::tuple<const char*, const char*, const char*, const char*>, std::tuple<const char*, const char*>, const char*, const char*> __funcinfo__myKernel() +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void myKernel_fwd_diff(std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}}) +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void myKernel_bwd_diff(std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}}) +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: std::tuple<std::tuple<const char*, const char*>, std::tuple<const char*, const char*>> __typeinfo__DiffTensorView() +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: std::tuple<std::tuple<const char*>, std::tuple<const char*>> __typeinfo__AtomicAdd() +//
\ No newline at end of file diff --git a/tests/autodiff/diff-assoctype-generic-interface.slang b/tests/autodiff/diff-assoctype-generic-interface.slang new file mode 100644 index 000000000..79e0eff08 --- /dev/null +++ b/tests/autodiff/diff-assoctype-generic-interface.slang @@ -0,0 +1,110 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type -compile-arg -skip-spirv-validation -emit-spirv-directly + +//TEST_INPUT:ubuffer(data=[2 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +interface IGetter : IDifferentiable +{ + [Differentiable] + float get(uint id); +} + +struct GetterImpl : IGetter +{ + float[8] data; + + __init(float[8] data) + { this.data = data; } + + [Differentiable] + float get(uint id) + { + return data[id]; + } +} +interface IFoo<int N> +{ + associatedtype Params : IGetter; + + [Differentiable] + Params bar(); +} + +[BackwardDerivative(load_bwd)] +[ForwardDerivative(load_fwd)] +float load(uint id) +{ + return outputBuffer[id] + 2; +} + +DifferentialPair<float> load_fwd(uint id) +{ + return DifferentialPair<float>(load(id), 3.f); +} + +void load_bwd(uint id, float.Differential dOut) +{ + outputBuffer[id + 8] = dOut; +} + +struct FooImpl1: IFoo<8> +{ + typealias Params = GetterImpl; + + __init() + { } + + [Differentiable] + Params bar() + { + float x = load(0); + return GetterImpl({x, x+1, x+2, x+3, x+4, x+5, x+6, x+7}); + } +} + +/* +// There's a slight issue with dynamic dispatch over generic interfaces. Uncomment after that is fixed. + +struct FooImpl2: IFoo<8> +{ + typealias Params = GetterImpl; + + __init() + { } + + [Differentiable] + Params bar() + { + float x = 2 * load(0); + return GetterImpl({x, x+5, x+7, x+9, x+11, x+13, x+15, x+17}); + } +} +*/ + +IFoo<8> getFoo(uint id) +{ + /*if (id == 0) + return FooImpl1(); + else + return FooImpl2();*/ + return FooImpl1(); +} + +[Differentiable] +float doThing(uint id) +{ + IFoo<8> foo = getFoo(id); + return foo.bar().get(0); +} + +[shader("compute")] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = doThing(0); // CHECK: 2.0 + outputBuffer[1] = doThing(1); // CHECK: 4.0 + + outputBuffer[2] = fwd_diff(doThing)(0).d; // CHECK: 3.0 + outputBuffer[3] = fwd_diff(doThing)(1).d; // CHECK: 3.0 +}
\ No newline at end of file |
