diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-03-28 15:39:52 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-28 12:39:52 -0700 |
| commit | 8f03af5e5b580170fab3fd2fe6144f92038c7701 (patch) | |
| tree | 0ac3464f74d5a5490f42a8c051e0a4d9839aef78 | |
| parent | e22b4dbb1bed1393fc028b87b8ff6ff30e1b73f3 (diff) | |
AD: Warped-Area-Sampling test works now. (#2742)
* Create render.slang
* Added higher-order differentiability decorators for built-ins + preliminary tests
* Update diff.meta.slang
* Copy over conformance synthesis code to `DifferentiableTypeConformanceContext`
* Update render.slang
* Fixed 1D warped-area sampling test
* Update warped-sampling-1d.slang
* Remove commented line.
* Change WAS test to use fixed point
* Replaced InterlockedCmpExchange with InterlockedAdd
* Increase fixed point precision
* Reduce floating-point precision by 2 digits to avoid platform-specific problems
* Dropped another digit (just to be safe)
---------
Co-authored-by: Yong He <yonghe@outlook.com>
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 268 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 12 | ||||
| -rw-r--r-- | tests/autodiff/was/warped-sampling-1d.slang | 287 | ||||
| -rw-r--r-- | tests/autodiff/was/warped-sampling-1d.slang.expected.txt | 11 |
5 files changed, 580 insertions, 2 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index bbdb01290..91a1601fb 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1896,9 +1896,11 @@ struct DiffTransposePass auto var = builder->emitVar(arg->getDataType()); auto diffType = (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()); + auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()); + SLANG_ASSERT(zeroMethod); auto diffZero = builder->emitCallInst( diffType, - diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()), + zeroMethod, List<IRInst*>()); // Initialize this var to (arg.primal, 0). diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index ddffd0e21..10c751d52 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -411,7 +411,7 @@ IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* t IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) { - if (auto conformance = lookUpConformanceForType(origType)) + if (auto conformance = tryGetDifferentiableWitness(builder, origType)) return _lookupWitness(builder, conformance, key); return nullptr; } @@ -445,6 +445,272 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() } } +IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* builder, IRInst* primalType) +{ + if (auto ptrType = as<IRPtrTypeBase>(primalType)) + return builder->getPtrType( + primalType->getOp(), + differentiateType(builder, ptrType->getValueType())); + + // Special case certain compound types (PtrType, FuncType, etc..) + // otherwise try to lookup a differential definition for the given type. + // If one does not exist, then we assume it's not differentiable. + // + switch (primalType->getOp()) + { + case kIROp_Param: + if (as<IRTypeType>(primalType->getDataType())) + return differentiateType(builder, primalType); + else if (as<IRWitnessTableType>(primalType->getDataType())) + return (IRType*)primalType; + + case kIROp_ArrayType: + { + auto primalArrayType = as<IRArrayType>(primalType); + if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) + return builder->getArrayType( + diffElementType, + primalArrayType->getElementCount()); + else + return nullptr; + } + + case kIROp_DifferentialPairType: + { + auto primalPairType = as<IRDifferentialPairType>(primalType); + return getOrCreateDiffPairType( + builder, + getDiffTypeFromPairType(builder, primalPairType), + getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + + case kIROp_DifferentialPairUserCodeType: + { + auto primalPairType = as<IRDifferentialPairUserCodeType>(primalType); + return builder->getDifferentialPairUserCodeType( + (IRType*)getDiffTypeFromPairType(builder, primalPairType), + getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + + case kIROp_FuncType: + { + SLANG_UNIMPLEMENTED_X("Impl"); + } + + case kIROp_OutType: + if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType())) + return builder->getOutType(diffValueType); + else + return nullptr; + + case kIROp_InOutType: + if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType())) + return builder->getInOutType(diffValueType); + else + return nullptr; + + case kIROp_ExtractExistentialType: + { + SLANG_UNIMPLEMENTED_X("Impl"); + } + + case kIROp_TupleType: + { + auto tupleType = as<IRTupleType>(primalType); + List<IRType*> diffTypeList; + // TODO: what if we have type parameters here? + for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) + diffTypeList.add( + differentiateType(builder, (IRType*)tupleType->getOperand(ii))); + + return builder->getTupleType(diffTypeList); + } + + default: + return (IRType*)getDifferentialForType(builder, (IRType*)primalType); + } +} + +IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType) +{ + IRInst* witness = lookUpConformanceForType((IRType*)primalType); + if (witness) + { + SLANG_RELEASE_ASSERT(witness || as<IRArrayType>(primalType)); + } + + if (!witness) + { + SLANG_RELEASE_ASSERT(primalType); + if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType)) + { + witness = getOrCreateDifferentiablePairWitness(builder, primalPairType); + } + else if (auto arrayType = as<IRArrayType>(primalType)) + { + witness = getArrayWitness(builder, arrayType); + } + else if (auto extractExistential = as<IRExtractExistentialType>(primalType)) + { + witness = getExtractExistensialTypeWitness(builder, extractExistential); + } + } + return witness; +} + +IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness) +{ + return builder->getDifferentialPairType( + (IRType*)primalType, + witness); +} + +IRInst* DifferentiableTypeConformanceContext::getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType) +{ + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + + auto table = builder->createWitnessTable(this->sharedContext->differentiableInterfaceType, (IRType*)pairType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + + bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false; + + // Fill in differential method implementations. + auto elementType = as<IRDifferentialPairTypeBase>(pairType)->getValueType(); + auto innerWitness = as<IRDifferentialPairTypeBase>(pairType)->getWitness(); + + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType)); + b.emitBlock(); + auto p0 = b.emitParam(diffDiffPairType); + auto p1 = b.emitParam(diffDiffPairType); + + // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); + IRInst* argsPrimal[2] = { + isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0), + isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) }; + auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal); + IRInst* argsDiff[2] = { + isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0), + isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)}; + auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff); + auto retVal = + isUserCodeType + ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart) + : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart); + b.emitReturn(retVal); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType)); + b.emitBlock(); + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); + auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); + auto retVal = + isUserCodeType + ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal) + : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal); + b.emitReturn(retVal); + } + + // Record this in the context for future lookups + differentiableWitnessDictionary[(IRType*)pairType] = table; + + return table; +} + +IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder, IRArrayType* arrayType) +{ + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffArrayType = (IRType*)differentiateType(builder, (IRType*)arrayType); + + if (!diffArrayType) + return nullptr; + + auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(arrayType)->getElementType()); + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + + auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + + auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType(); + + // Fill in differential method implementations. + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffArrayType, diffArrayType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType)); + b.emitBlock(); + auto p0 = b.emitParam(diffArrayType); + auto p1 = b.emitParam(diffArrayType); + + // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); + auto resultVar = b.emitVar(diffArrayType); + IRBlock* loopBodyBlock = nullptr; + IRBlock* loopBreakBlock = nullptr; + auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock); + b.setInsertBefore(loopBodyBlock->getTerminator()); + + IRInst* args[2] = { + b.emitElementExtract(p0, loopCounter), + b.emitElementExtract(p1, loopCounter) }; + auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args); + auto addr = b.emitElementAddress(resultVar, loopCounter); + b.emitStore(addr, elementResult); + b.setInsertInto(loopBreakBlock); + b.emitReturn(b.emitLoad(resultVar)); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType)); + b.emitBlock(); + + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); + auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); + auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal); + b.emitReturn(retVal); + } + + // Record this in the context for future lookups + differentiableWitnessDictionary[(IRType*)arrayType] = table; + + return table; +} + +IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness(IRBuilder*, IRExtractExistentialType*) +{ + SLANG_UNIMPLEMENTED_X("TODO: Implement"); +} + void stripDerivativeDecorations(IRInst* inst) { for (auto decor = inst->getFirstDecoration(); decor; ) diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index d49babc52..da0cdc755 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -158,6 +158,18 @@ struct DifferentiableTypeConformanceContext IRInst* lookUpConformanceForType(IRInst* type); IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key); + + IRType* differentiateType(IRBuilder* builder, IRInst* primalType); + + IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType); + + IRInst* getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType); + + IRInst* getArrayWitness(IRBuilder* builder, IRArrayType* pairType); + + IRInst* getExtractExistensialTypeWitness(IRBuilder* builder, IRExtractExistentialType* extractExistentialType); + + IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness); IRInst* getDifferentialTypeFromDiffPairType(IRBuilder* builder, IRDifferentialPairTypeBase* diffPairType); diff --git a/tests/autodiff/was/warped-sampling-1d.slang b/tests/autodiff/was/warped-sampling-1d.slang new file mode 100644 index 000000000..c7caca579 --- /dev/null +++ b/tests/autodiff/was/warped-sampling-1d.slang @@ -0,0 +1,287 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4):out,name=endpointDifferentialBuffer +RWStructuredBuffer<float> endpointDifferentialBuffer; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=4):out,name=endpointDifferentialBufferInt +RWStructuredBuffer<int> endpointDifferentialBufferInt; + +//TEST_INPUT:ubuffer(data=[0.3 0.7 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4):name=endpointBuffer +RWStructuredBuffer<float> endpointBuffer; +//TEST_INPUT:ubuffer(data=[1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4):name=colorBuffer +RWStructuredBuffer<float> colorBuffer; + +typedef float Color; + +struct PRNG +{ + __init(uint seed) + { + this.state = seed; + } + + [mutating] uint next() + { + state ^= state << 13; + state ^= state >> 7; + state ^= state << 17; + return state; + } + + [mutating] float nextFloat1D() + { + return float(next()) / float(4294967295.0); + } + + uint state; +}; + +struct LineSegment : IDifferentiable +{ + float x0; + float x1; + + Color color; + + [BackwardDifferentiable] + __init(float _x0, float _x1, Color _color) + { + x0 = _x0; + x1 = _x1; + color = _color; + } +}; + +struct Intersection : IDifferentiable +{ + LineSegment ls; + float x; + bool isIntersected; + float wt; + + [BackwardDifferentiable] + __init(LineSegment _ls, float _x, bool _isIntersected, float _wt) + { + this.ls = _ls; + this.x = _x; + this.isIntersected = _isIntersected; + this.wt = _wt; + } +}; + +[BackwardDerivative(d_loadLineSegment)] +[ForwardDerivative(fwd_loadLineSegment)] +LineSegment loadLineSegment(uint id) +{ + return {endpointBuffer[id * 2], endpointBuffer[id * 2 + 1], colorBuffer[id]}; +} + +[BackwardDerivative(d_fwd_loadLineSegment)] +DifferentialPair<LineSegment> fwd_loadLineSegment(uint id) +{ + return DifferentialPair<LineSegment>(loadLineSegment(id), LineSegment.dzero()); +} + +void accumulateDifferentialFixedPoint( + RWStructuredBuffer<int> buffer, + uint index, + float.Differential df, + float scale = 1000000.f) +{ + InterlockedAdd(buffer[index], (int)round(df * scale)); +} + +void d_loadLineSegment(uint id, LineSegment.Differential d_ls) +{ + accumulateDifferentialFixedPoint(endpointDifferentialBufferInt, id * 2, d_ls.x0); + accumulateDifferentialFixedPoint(endpointDifferentialBufferInt, id * 2 + 1, d_ls.x1); +} + +void d_fwd_loadLineSegment(uint id, DifferentialPair<LineSegment>.Differential dp_ls) +{ + accumulateDifferentialFixedPoint(endpointDifferentialBufferInt, id * 2, dp_ls.p.x0); + accumulateDifferentialFixedPoint(endpointDifferentialBufferInt, id * 2 + 1, dp_ls.p.x1); +} + +int getIntersectionID(float x) +{ + // Line segments are ordered by z-index so return the first intersection. + for (int id = 0; id < 2; id++) + { + LineSegment ls = loadLineSegment(id); + if (x > ls.x0 && x < ls.x1) + return id; + } + return -1; +} + +[BackwardDifferentiable] +Intersection intersect(float x) +{ + int id = getIntersectionID(x); + if (id >= 0) + return Intersection(loadLineSegment((uint)id), x, true, 1.0); + + return Intersection(LineSegment(0, 0, 0), x, false, 0.0); +} + +[BackwardDifferentiable] +float shadeIntersection(Intersection isect) +{ + return isect.ls.color; +} + +float sample1DNormal(inout PRNG prng, float mu, float sigma) +{ + float u = prng.nextFloat1D(); + float v = prng.nextFloat1D(); + return mu + (sqrt(-2 * log(u))*cos(2*3.1415*v) * sigma); +} + +[BackwardDifferentiable] +float pdf1DNormal(no_diff float x, float mu, no_diff float sigma) +{ + float k = ((x - mu) / sigma); + return exp(-0.5 * (k * k)) / (sigma * 2.506628); +} + +float boundaryTerm(Intersection isect) +{ + if (!isect.isIntersected) + return 100.0; // Large default value for missed rays. + + float leftDist = abs(isect.x - isect.ls.x0); + float rightDist = abs(isect.ls.x1 - isect.x); + + if (leftDist > rightDist) + return rightDist * 30.f; + else + return leftDist * 30.f; +} + +[BackwardDifferentiable] +DifferentialPair<float> infinitesimal(DifferentialPair<float> x) +{ + return diffPair(x.p - detach(x.p), x.d - detach(x.d)); +} + +[BackwardDifferentiable] +float harmonicWeight(Intersection isect, no_diff Intersection aux_isect) +{ + float x_dist = isect.x - aux_isect.x; + float k = 1.0 / (((x_dist * x_dist) + no_diff(boundaryTerm(aux_isect)))); + return k; +} + +[BackwardDifferentiable] +float attachToGeometry(Intersection isect) +{ + float leftWt = detach(isect.ls.x1 - isect.x); + float rightWt = detach(isect.x - isect.ls.x0); + + return (leftWt * isect.ls.x0 + rightWt * isect.ls.x1) / (leftWt + rightWt); +} + +[BackwardDifferentiable] +float warp(Intersection isect, inout PRNG prng) +{ + float totalWeight = 0.f; + float totalWarpedPoint = 0.f; + + float aux_sigma = 0.01; + + for (int i = 0; i < 32; i++) + { + float y = no_diff(sample1DNormal(prng, isect.x, aux_sigma)); + float y_flipped = 2 * isect.x - y; + + Intersection aux_isect_left = intersect(y); + + if (aux_isect_left.isIntersected) + { + float pdf = pdf1DNormal(y, isect.x, aux_sigma); + float wt = harmonicWeight(isect, aux_isect_left) * (pdf / detach(pdf)); + totalWarpedPoint += attachToGeometry(aux_isect_left) * wt; + totalWeight += wt; + } + + Intersection aux_isect_right = intersect(detach(y_flipped)); + + if (aux_isect_right.isIntersected) + { + float pdf = pdf1DNormal(y_flipped, isect.x, aux_sigma); + float wt = harmonicWeight(isect, aux_isect_right) * (pdf / detach(pdf)); + totalWarpedPoint += attachToGeometry(aux_isect_right) * wt; + totalWeight += wt; + } + } + + return totalWarpedPoint / totalWeight; +} + +[BackwardDifferentiable] +Intersection warpedIntersect(float x, inout PRNG prng) +{ + // TODO: For now the jacobian here is 1.0, + // but we will need to adjust the warp by the jacobian for + // more complex intersection models. + // + Intersection isect = intersect(x); + + Intersection.Differential d_isect = Intersection.Differential.dzero(); + d_isect.x = 1.0; + + var dpwarp = infinitesimal( + __fwd_diff(warp)(diffPair(isect, d_isect), prng)); + + isect.x = detach(isect.x) + dpwarp.p; + isect.wt = isect.wt * (1 + dpwarp.d); + + return isect; +} + +[BackwardDifferentiable] +float renderSample(inout PRNG prng) +{ + float u = no_diff(prng.nextFloat1D()); + + float leftBound = 0.0; + float rightBound = 1.0; + + float sample = leftBound * u + rightBound * (1 - u); + float weight = 1.0/(rightBound - leftBound); + + Intersection isect = warpedIntersect(sample, prng); + + return shadeIntersection(isect) * isect.wt; +} + +[numthreads(1000, 1, 1)] +void computeMain(uint3 threadIdx : SV_DispatchThreadID,) +{ + uint seed = (threadIdx.x * threadIdx.x) * 30 + 3; + PRNG prng = PRNG(seed); + + float d_color = 1.0 / 1000.0; + __bwd_diff(renderSample)(prng, d_color); + + AllMemoryBarrierWithGroupSync(); + + // Convert to floating point (but with 2 fewer digits of precision to + // avoid platform-specific differences in floating point precision) + // + if (threadIdx.x < 10) + endpointDifferentialBuffer[threadIdx.x] = + ((endpointDifferentialBufferInt[threadIdx.x]/1000) / 1000000.f) * 1000.f; + +// Note that this specific derivative estimation method is biased, so the +// expected results are approximate. (We've fixed the RNG seed to generate +// repeatable results) +// +// Expect: Approximately -1.0 in endpointDifferentialBuffer[0] +// Expect: Approximately 1.0 in endpointDifferentialBuffer[1] +// +// Expect: Approximately 0.0 in endpointDifferentialBuffer[2] +// Expect: Approximately 0.0 in endpointDifferentialBuffer[3] +// +}
\ No newline at end of file diff --git a/tests/autodiff/was/warped-sampling-1d.slang.expected.txt b/tests/autodiff/was/warped-sampling-1d.slang.expected.txt new file mode 100644 index 000000000..84272b50a --- /dev/null +++ b/tests/autodiff/was/warped-sampling-1d.slang.expected.txt @@ -0,0 +1,11 @@ +type: float +-0.954000 +0.950000 +-0.000000 +0.004000 +0.000000 +0.000000 +0.000000 +0.000000 +0.000000 +0.000000
\ No newline at end of file |
