1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
|
//Tests automatic synthesis of Differential type requirement.
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -dx12
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
struct PathData : IDifferentiable
{
float3 thp;
uint length;
bool terminated;
bool isHit;
[BackwardDifferentiable]
__init()
{
this.thp = float3(1.f);
this.length = 0;
this.terminated = false;
this.isHit = false;
}
}
bool traceRayInline(uint length)
{
if (length < 2) return true;
else return false;
}
float3 getAlbedo(uint length)
{
return float3(0.9f, 1.f, 1.f);
}
float3 getAlbedoDerivative(uint length)
{
return float3(1.f, 0.f, 0.f);
}
[ForwardDerivativeOf(getAlbedo)]
[TreatAsDifferentiable]
DifferentialPair<float3> __fwd_d_getAlbedo(uint length)
{
float3 primalValue = getAlbedo(length);
float3 derivativeValue = no_diff getAlbedoDerivative(length);
return DifferentialPair<float3>(primalValue, derivativeValue);
}
[BackwardDerivativeOf(getAlbedo)]
[TreatAsDifferentiable]
void __bwd_d_getAlbedo(uint length, float3 dOut)
{
outputBuffer[2] += dOut.x;
}
[BackwardDifferentiable]
void handleHit(inout PathData pathData)
{
if (pathData.length >= 2)
{
pathData.terminated = true;
return;
}
float3 albedo = getAlbedo(pathData.length);
pathData.thp *= albedo;
pathData.length++;
}
[BackwardDifferentiable]
[PreferRecompute]
float3 tracePath()
{
PathData pathData = PathData();
if (traceRayInline(pathData.length))
{
pathData.isHit = true;
}
else
{
pathData.terminated = true;
pathData.isHit = false;
}
[MaxIters(4)]
while (!pathData.terminated)
{
if (pathData.isHit)
{
handleHit(pathData);
//pathData.isHit = traceRayInline(pathData.length);
if (!traceRayInline(pathData.length)) pathData.isHit = false;
else pathData.isHit = true;
}
else
{
pathData.terminated = true;
}
}
return pathData.thp;
}
[numthreads(1, 1, 1)]
void computeMain(uint3 dispathThreadID: SV_DispatchThreadID)
{
DifferentialPair<float3> dpThp = __fwd_diff(tracePath)();
outputBuffer[0] = dpThp.p.x;
outputBuffer[1] = dpThp.d.x;
__bwd_diff(tracePath)(float3(1.f, 0.f, 0.f));
}
|