diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-07-11 13:46:21 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-11 13:46:21 -0400 |
| commit | 977e4b21d69406b0b68c5963f50489d7433db830 (patch) | |
| tree | af5e68f910c1887ef3c638d8620449ab66d9faf3 | |
| parent | c3061afffdc078914cc522538749c6e0c5e36f65 (diff) | |
Fix issue with synthesizing `Differential` type for self-differential generic types (#4602)
* Fix issue with synthesizing `Differential` type for self-differential generic types
The problem was that we were using the type that was performing the lookup for `.Differential` which can have substitutions based on the local context where the decl is being referenced.
We need to synthesize the type local to the decl itself
* Update auto-differential-type-generic.slang
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 7 | ||||
| -rw-r--r-- | tests/autodiff/auto-differential-type-generic.slang | 64 | ||||
| -rw-r--r-- | tests/autodiff/auto-differential-type-generic.slang.expected.txt | 6 |
3 files changed, 76 insertions, 1 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index ee36a21fb..f60ead1e9 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -638,7 +638,12 @@ namespace Slang auto typeDef = m_astBuilder->create<TypeAliasDecl>(); typeDef->nameAndLoc.name = item.declRef.getName(); typeDef->parentDecl = parent; - typeDef->type.type = subType; + + // Compute the decl's type as if it is referred to from itself. This is important because + // subType may have substitutions from the context it is used in, while this synthesis step + // is local to the decl. + // + typeDef->type.type = calcThisType(subType->getDeclRef().getDecl()->getDefaultDeclRef()); synthesizedDecl = parent; diff --git a/tests/autodiff/auto-differential-type-generic.slang b/tests/autodiff/auto-differential-type-generic.slang new file mode 100644 index 000000000..060495e56 --- /dev/null +++ b/tests/autodiff/auto-differential-type-generic.slang @@ -0,0 +1,64 @@ +// Tests automatic synthesis of Differential type requirement for generic types. +// +// This specifically tests a synthesis path that occurs when the lookup of the Differential type happens before the conformance-check. +// If this path doesn't construct the generic differential type correctly, it will throw an error when constructing the array +// in this line: Feature<3>.Differential b = {0.2, 0.3, 0.4}; +// + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +__generic<let C : int> +struct Feature: IDifferentiable +{ + float vals[C]; +} + + +struct Linear<let C : int> +{ + typedef Feature<C> Input; + typedef Feature<C> Output; + + [BackwardDerivative(eval_bwd)] + Output eval(Input in_feature) + { + Output out_feature; + for (int i = 0; i < C; i++) + { + out_feature.vals[i] = in_feature.vals[i] * 2.0; + } + return out_feature; + } + + void eval_bwd(inout DifferentialPair<Input> in_feature_pair, Feature<C>.Differential d_output) + { + /* empty.. doesn't really matter */ + } +} + +[Differentiable] +Feature<3> f(Feature<3> a, Linear<3> layer) +{ + return layer.eval(a); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + Feature<3> a = {1.0, 2.0, 3.0}; + Feature<3>.Differential b = {0.2, 0.3, 0.4}; + + Linear<3> layer; + + var dpA = diffPair(a, b); + + var result = fwd_diff(f)(dpA, layer).d; + + outputBuffer[0] = result.vals[0]; + outputBuffer[1] = result.vals[1]; + outputBuffer[2] = result.vals[2]; +}
\ No newline at end of file diff --git a/tests/autodiff/auto-differential-type-generic.slang.expected.txt b/tests/autodiff/auto-differential-type-generic.slang.expected.txt new file mode 100644 index 000000000..c0d274b2b --- /dev/null +++ b/tests/autodiff/auto-differential-type-generic.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +0.400000 +0.600000 +0.800000 +0.000000 +0.000000 |
