summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-03-28 15:39:52 -0400
committerGitHub <noreply@github.com>2023-03-28 12:39:52 -0700
commit8f03af5e5b580170fab3fd2fe6144f92038c7701 (patch)
tree0ac3464f74d5a5490f42a8c051e0a4d9839aef78
parente22b4dbb1bed1393fc028b87b8ff6ff30e1b73f3 (diff)
AD: Warped-Area-Sampling test works now. (#2742)
* Create render.slang * Added higher-order differentiability decorators for built-ins + preliminary tests * Update diff.meta.slang * Copy over conformance synthesis code to `DifferentiableTypeConformanceContext` * Update render.slang * Fixed 1D warped-area sampling test * Update warped-sampling-1d.slang * Remove commented line. * Change WAS test to use fixed point * Replaced InterlockedCmpExchange with InterlockedAdd * Increase fixed point precision * Reduce floating-point precision by 2 digits to avoid platform-specific problems * Dropped another digit (just to be safe) --------- Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h4
-rw-r--r--source/slang/slang-ir-autodiff.cpp268
-rw-r--r--source/slang/slang-ir-autodiff.h12
-rw-r--r--tests/autodiff/was/warped-sampling-1d.slang287
-rw-r--r--tests/autodiff/was/warped-sampling-1d.slang.expected.txt11
5 files changed, 580 insertions, 2 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index bbdb01290..91a1601fb 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1896,9 +1896,11 @@ struct DiffTransposePass
auto var = builder->emitVar(arg->getDataType());
auto diffType = (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType());
+ auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, pairType->getValueType());
+ SLANG_ASSERT(zeroMethod);
auto diffZero = builder->emitCallInst(
diffType,
- diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()),
+ zeroMethod,
List<IRInst*>());
// Initialize this var to (arg.primal, 0).
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index ddffd0e21..10c751d52 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -411,7 +411,7 @@ IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* t
IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
{
- if (auto conformance = lookUpConformanceForType(origType))
+ if (auto conformance = tryGetDifferentiableWitness(builder, origType))
return _lookupWitness(builder, conformance, key);
return nullptr;
}
@@ -445,6 +445,272 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
}
}
+IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* builder, IRInst* primalType)
+{
+ if (auto ptrType = as<IRPtrTypeBase>(primalType))
+ return builder->getPtrType(
+ primalType->getOp(),
+ differentiateType(builder, ptrType->getValueType()));
+
+ // Special case certain compound types (PtrType, FuncType, etc..)
+ // otherwise try to lookup a differential definition for the given type.
+ // If one does not exist, then we assume it's not differentiable.
+ //
+ switch (primalType->getOp())
+ {
+ case kIROp_Param:
+ if (as<IRTypeType>(primalType->getDataType()))
+ return differentiateType(builder, primalType);
+ else if (as<IRWitnessTableType>(primalType->getDataType()))
+ return (IRType*)primalType;
+
+ case kIROp_ArrayType:
+ {
+ auto primalArrayType = as<IRArrayType>(primalType);
+ if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
+ return builder->getArrayType(
+ diffElementType,
+ primalArrayType->getElementCount());
+ else
+ return nullptr;
+ }
+
+ case kIROp_DifferentialPairType:
+ {
+ auto primalPairType = as<IRDifferentialPairType>(primalType);
+ return getOrCreateDiffPairType(
+ builder,
+ getDiffTypeFromPairType(builder, primalPairType),
+ getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
+
+ case kIROp_DifferentialPairUserCodeType:
+ {
+ auto primalPairType = as<IRDifferentialPairUserCodeType>(primalType);
+ return builder->getDifferentialPairUserCodeType(
+ (IRType*)getDiffTypeFromPairType(builder, primalPairType),
+ getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
+
+ case kIROp_FuncType:
+ {
+ SLANG_UNIMPLEMENTED_X("Impl");
+ }
+
+ case kIROp_OutType:
+ if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType()))
+ return builder->getOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_InOutType:
+ if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType()))
+ return builder->getInOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_ExtractExistentialType:
+ {
+ SLANG_UNIMPLEMENTED_X("Impl");
+ }
+
+ case kIROp_TupleType:
+ {
+ auto tupleType = as<IRTupleType>(primalType);
+ List<IRType*> diffTypeList;
+ // TODO: what if we have type parameters here?
+ for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
+ diffTypeList.add(
+ differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
+
+ return builder->getTupleType(diffTypeList);
+ }
+
+ default:
+ return (IRType*)getDifferentialForType(builder, (IRType*)primalType);
+ }
+}
+
+IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType)
+{
+ IRInst* witness = lookUpConformanceForType((IRType*)primalType);
+ if (witness)
+ {
+ SLANG_RELEASE_ASSERT(witness || as<IRArrayType>(primalType));
+ }
+
+ if (!witness)
+ {
+ SLANG_RELEASE_ASSERT(primalType);
+ if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType))
+ {
+ witness = getOrCreateDifferentiablePairWitness(builder, primalPairType);
+ }
+ else if (auto arrayType = as<IRArrayType>(primalType))
+ {
+ witness = getArrayWitness(builder, arrayType);
+ }
+ else if (auto extractExistential = as<IRExtractExistentialType>(primalType))
+ {
+ witness = getExtractExistensialTypeWitness(builder, extractExistential);
+ }
+ }
+ return witness;
+}
+
+IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness)
+{
+ return builder->getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+}
+
+IRInst* DifferentiableTypeConformanceContext::getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType)
+{
+ // Differentiate the pair type to get it's differential (which is itself a pair)
+ auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType);
+
+ auto addMethod = builder->createFunc();
+ auto zeroMethod = builder->createFunc();
+
+ auto table = builder->createWitnessTable(this->sharedContext->differentiableInterfaceType, (IRType*)pairType);
+
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+
+ bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false;
+
+ // Fill in differential method implementations.
+ auto elementType = as<IRDifferentialPairTypeBase>(pairType)->getValueType();
+ auto innerWitness = as<IRDifferentialPairTypeBase>(pairType)->getWitness();
+
+ {
+ // Add method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType };
+ addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType));
+ b.emitBlock();
+ auto p0 = b.emitParam(diffDiffPairType);
+ auto p1 = b.emitParam(diffDiffPairType);
+
+ // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
+ auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey);
+ IRInst* argsPrimal[2] = {
+ isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0),
+ isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) };
+ auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal);
+ IRInst* argsDiff[2] = {
+ isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0),
+ isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)};
+ auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff);
+ auto retVal =
+ isUserCodeType
+ ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart)
+ : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart);
+ b.emitReturn(retVal);
+ }
+ {
+ // Zero method.
+ IRBuilder b = *builder;
+ b.setInsertInto(zeroMethod);
+ zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType));
+ b.emitBlock();
+ auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey);
+ auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
+ auto retVal =
+ isUserCodeType
+ ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal)
+ : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal);
+ b.emitReturn(retVal);
+ }
+
+ // Record this in the context for future lookups
+ differentiableWitnessDictionary[(IRType*)pairType] = table;
+
+ return table;
+}
+
+IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder, IRArrayType* arrayType)
+{
+ // Differentiate the pair type to get it's differential (which is itself a pair)
+ auto diffArrayType = (IRType*)differentiateType(builder, (IRType*)arrayType);
+
+ if (!diffArrayType)
+ return nullptr;
+
+ auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(arrayType)->getElementType());
+
+ auto addMethod = builder->createFunc();
+ auto zeroMethod = builder->createFunc();
+
+ auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType);
+
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+
+ auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType();
+
+ // Fill in differential method implementations.
+ {
+ // Add method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ IRType* paramTypes[2] = { diffArrayType, diffArrayType };
+ addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType));
+ b.emitBlock();
+ auto p0 = b.emitParam(diffArrayType);
+ auto p1 = b.emitParam(diffArrayType);
+
+ // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
+ auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey);
+ auto resultVar = b.emitVar(diffArrayType);
+ IRBlock* loopBodyBlock = nullptr;
+ IRBlock* loopBreakBlock = nullptr;
+ auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock);
+ b.setInsertBefore(loopBodyBlock->getTerminator());
+
+ IRInst* args[2] = {
+ b.emitElementExtract(p0, loopCounter),
+ b.emitElementExtract(p1, loopCounter) };
+ auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args);
+ auto addr = b.emitElementAddress(resultVar, loopCounter);
+ b.emitStore(addr, elementResult);
+ b.setInsertInto(loopBreakBlock);
+ b.emitReturn(b.emitLoad(resultVar));
+ }
+ {
+ // Zero method.
+ IRBuilder b = *builder;
+ b.setInsertInto(zeroMethod);
+ zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType));
+ b.emitBlock();
+
+ auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey);
+ auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
+ auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal);
+ b.emitReturn(retVal);
+ }
+
+ // Record this in the context for future lookups
+ differentiableWitnessDictionary[(IRType*)arrayType] = table;
+
+ return table;
+}
+
+IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness(IRBuilder*, IRExtractExistentialType*)
+{
+ SLANG_UNIMPLEMENTED_X("TODO: Implement");
+}
+
void stripDerivativeDecorations(IRInst* inst)
{
for (auto decor = inst->getFirstDecoration(); decor; )
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index d49babc52..da0cdc755 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -158,6 +158,18 @@ struct DifferentiableTypeConformanceContext
IRInst* lookUpConformanceForType(IRInst* type);
IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key);
+
+ IRType* differentiateType(IRBuilder* builder, IRInst* primalType);
+
+ IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType);
+
+ IRInst* getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType);
+
+ IRInst* getArrayWitness(IRBuilder* builder, IRArrayType* pairType);
+
+ IRInst* getExtractExistensialTypeWitness(IRBuilder* builder, IRExtractExistentialType* extractExistentialType);
+
+ IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness);
IRInst* getDifferentialTypeFromDiffPairType(IRBuilder* builder, IRDifferentialPairTypeBase* diffPairType);
diff --git a/tests/autodiff/was/warped-sampling-1d.slang b/tests/autodiff/was/warped-sampling-1d.slang
new file mode 100644
index 000000000..c7caca579
--- /dev/null
+++ b/tests/autodiff/was/warped-sampling-1d.slang
@@ -0,0 +1,287 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4):out,name=endpointDifferentialBuffer
+RWStructuredBuffer<float> endpointDifferentialBuffer;
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=4):out,name=endpointDifferentialBufferInt
+RWStructuredBuffer<int> endpointDifferentialBufferInt;
+
+//TEST_INPUT:ubuffer(data=[0.3 0.7 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4):name=endpointBuffer
+RWStructuredBuffer<float> endpointBuffer;
+//TEST_INPUT:ubuffer(data=[1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4):name=colorBuffer
+RWStructuredBuffer<float> colorBuffer;
+
+typedef float Color;
+
+struct PRNG
+{
+ __init(uint seed)
+ {
+ this.state = seed;
+ }
+
+ [mutating] uint next()
+ {
+ state ^= state << 13;
+ state ^= state >> 7;
+ state ^= state << 17;
+ return state;
+ }
+
+ [mutating] float nextFloat1D()
+ {
+ return float(next()) / float(4294967295.0);
+ }
+
+ uint state;
+};
+
+struct LineSegment : IDifferentiable
+{
+ float x0;
+ float x1;
+
+ Color color;
+
+ [BackwardDifferentiable]
+ __init(float _x0, float _x1, Color _color)
+ {
+ x0 = _x0;
+ x1 = _x1;
+ color = _color;
+ }
+};
+
+struct Intersection : IDifferentiable
+{
+ LineSegment ls;
+ float x;
+ bool isIntersected;
+ float wt;
+
+ [BackwardDifferentiable]
+ __init(LineSegment _ls, float _x, bool _isIntersected, float _wt)
+ {
+ this.ls = _ls;
+ this.x = _x;
+ this.isIntersected = _isIntersected;
+ this.wt = _wt;
+ }
+};
+
+[BackwardDerivative(d_loadLineSegment)]
+[ForwardDerivative(fwd_loadLineSegment)]
+LineSegment loadLineSegment(uint id)
+{
+ return {endpointBuffer[id * 2], endpointBuffer[id * 2 + 1], colorBuffer[id]};
+}
+
+[BackwardDerivative(d_fwd_loadLineSegment)]
+DifferentialPair<LineSegment> fwd_loadLineSegment(uint id)
+{
+ return DifferentialPair<LineSegment>(loadLineSegment(id), LineSegment.dzero());
+}
+
+void accumulateDifferentialFixedPoint(
+ RWStructuredBuffer<int> buffer,
+ uint index,
+ float.Differential df,
+ float scale = 1000000.f)
+{
+ InterlockedAdd(buffer[index], (int)round(df * scale));
+}
+
+void d_loadLineSegment(uint id, LineSegment.Differential d_ls)
+{
+ accumulateDifferentialFixedPoint(endpointDifferentialBufferInt, id * 2, d_ls.x0);
+ accumulateDifferentialFixedPoint(endpointDifferentialBufferInt, id * 2 + 1, d_ls.x1);
+}
+
+void d_fwd_loadLineSegment(uint id, DifferentialPair<LineSegment>.Differential dp_ls)
+{
+ accumulateDifferentialFixedPoint(endpointDifferentialBufferInt, id * 2, dp_ls.p.x0);
+ accumulateDifferentialFixedPoint(endpointDifferentialBufferInt, id * 2 + 1, dp_ls.p.x1);
+}
+
+int getIntersectionID(float x)
+{
+ // Line segments are ordered by z-index so return the first intersection.
+ for (int id = 0; id < 2; id++)
+ {
+ LineSegment ls = loadLineSegment(id);
+ if (x > ls.x0 && x < ls.x1)
+ return id;
+ }
+ return -1;
+}
+
+[BackwardDifferentiable]
+Intersection intersect(float x)
+{
+ int id = getIntersectionID(x);
+ if (id >= 0)
+ return Intersection(loadLineSegment((uint)id), x, true, 1.0);
+
+ return Intersection(LineSegment(0, 0, 0), x, false, 0.0);
+}
+
+[BackwardDifferentiable]
+float shadeIntersection(Intersection isect)
+{
+ return isect.ls.color;
+}
+
+float sample1DNormal(inout PRNG prng, float mu, float sigma)
+{
+ float u = prng.nextFloat1D();
+ float v = prng.nextFloat1D();
+ return mu + (sqrt(-2 * log(u))*cos(2*3.1415*v) * sigma);
+}
+
+[BackwardDifferentiable]
+float pdf1DNormal(no_diff float x, float mu, no_diff float sigma)
+{
+ float k = ((x - mu) / sigma);
+ return exp(-0.5 * (k * k)) / (sigma * 2.506628);
+}
+
+float boundaryTerm(Intersection isect)
+{
+ if (!isect.isIntersected)
+ return 100.0; // Large default value for missed rays.
+
+ float leftDist = abs(isect.x - isect.ls.x0);
+ float rightDist = abs(isect.ls.x1 - isect.x);
+
+ if (leftDist > rightDist)
+ return rightDist * 30.f;
+ else
+ return leftDist * 30.f;
+}
+
+[BackwardDifferentiable]
+DifferentialPair<float> infinitesimal(DifferentialPair<float> x)
+{
+ return diffPair(x.p - detach(x.p), x.d - detach(x.d));
+}
+
+[BackwardDifferentiable]
+float harmonicWeight(Intersection isect, no_diff Intersection aux_isect)
+{
+ float x_dist = isect.x - aux_isect.x;
+ float k = 1.0 / (((x_dist * x_dist) + no_diff(boundaryTerm(aux_isect))));
+ return k;
+}
+
+[BackwardDifferentiable]
+float attachToGeometry(Intersection isect)
+{
+ float leftWt = detach(isect.ls.x1 - isect.x);
+ float rightWt = detach(isect.x - isect.ls.x0);
+
+ return (leftWt * isect.ls.x0 + rightWt * isect.ls.x1) / (leftWt + rightWt);
+}
+
+[BackwardDifferentiable]
+float warp(Intersection isect, inout PRNG prng)
+{
+ float totalWeight = 0.f;
+ float totalWarpedPoint = 0.f;
+
+ float aux_sigma = 0.01;
+
+ for (int i = 0; i < 32; i++)
+ {
+ float y = no_diff(sample1DNormal(prng, isect.x, aux_sigma));
+ float y_flipped = 2 * isect.x - y;
+
+ Intersection aux_isect_left = intersect(y);
+
+ if (aux_isect_left.isIntersected)
+ {
+ float pdf = pdf1DNormal(y, isect.x, aux_sigma);
+ float wt = harmonicWeight(isect, aux_isect_left) * (pdf / detach(pdf));
+ totalWarpedPoint += attachToGeometry(aux_isect_left) * wt;
+ totalWeight += wt;
+ }
+
+ Intersection aux_isect_right = intersect(detach(y_flipped));
+
+ if (aux_isect_right.isIntersected)
+ {
+ float pdf = pdf1DNormal(y_flipped, isect.x, aux_sigma);
+ float wt = harmonicWeight(isect, aux_isect_right) * (pdf / detach(pdf));
+ totalWarpedPoint += attachToGeometry(aux_isect_right) * wt;
+ totalWeight += wt;
+ }
+ }
+
+ return totalWarpedPoint / totalWeight;
+}
+
+[BackwardDifferentiable]
+Intersection warpedIntersect(float x, inout PRNG prng)
+{
+ // TODO: For now the jacobian here is 1.0,
+ // but we will need to adjust the warp by the jacobian for
+ // more complex intersection models.
+ //
+ Intersection isect = intersect(x);
+
+ Intersection.Differential d_isect = Intersection.Differential.dzero();
+ d_isect.x = 1.0;
+
+ var dpwarp = infinitesimal(
+ __fwd_diff(warp)(diffPair(isect, d_isect), prng));
+
+ isect.x = detach(isect.x) + dpwarp.p;
+ isect.wt = isect.wt * (1 + dpwarp.d);
+
+ return isect;
+}
+
+[BackwardDifferentiable]
+float renderSample(inout PRNG prng)
+{
+ float u = no_diff(prng.nextFloat1D());
+
+ float leftBound = 0.0;
+ float rightBound = 1.0;
+
+ float sample = leftBound * u + rightBound * (1 - u);
+ float weight = 1.0/(rightBound - leftBound);
+
+ Intersection isect = warpedIntersect(sample, prng);
+
+ return shadeIntersection(isect) * isect.wt;
+}
+
+[numthreads(1000, 1, 1)]
+void computeMain(uint3 threadIdx : SV_DispatchThreadID,)
+{
+ uint seed = (threadIdx.x * threadIdx.x) * 30 + 3;
+ PRNG prng = PRNG(seed);
+
+ float d_color = 1.0 / 1000.0;
+ __bwd_diff(renderSample)(prng, d_color);
+
+ AllMemoryBarrierWithGroupSync();
+
+ // Convert to floating point (but with 2 fewer digits of precision to
+ // avoid platform-specific differences in floating point precision)
+ //
+ if (threadIdx.x < 10)
+ endpointDifferentialBuffer[threadIdx.x] =
+ ((endpointDifferentialBufferInt[threadIdx.x]/1000) / 1000000.f) * 1000.f;
+
+// Note that this specific derivative estimation method is biased, so the
+// expected results are approximate. (We've fixed the RNG seed to generate
+// repeatable results)
+//
+// Expect: Approximately -1.0 in endpointDifferentialBuffer[0]
+// Expect: Approximately 1.0 in endpointDifferentialBuffer[1]
+//
+// Expect: Approximately 0.0 in endpointDifferentialBuffer[2]
+// Expect: Approximately 0.0 in endpointDifferentialBuffer[3]
+//
+} \ No newline at end of file
diff --git a/tests/autodiff/was/warped-sampling-1d.slang.expected.txt b/tests/autodiff/was/warped-sampling-1d.slang.expected.txt
new file mode 100644
index 000000000..84272b50a
--- /dev/null
+++ b/tests/autodiff/was/warped-sampling-1d.slang.expected.txt
@@ -0,0 +1,11 @@
+type: float
+-0.954000
+0.950000
+-0.000000
+0.004000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000 \ No newline at end of file