diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-27 12:19:30 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-27 12:19:30 -0700 |
| commit | 0cbef6fd6d7924d37ef3ea5ec7c848c80947d13f (patch) | |
| tree | 173fa18c39638e7d41ae092b9012554cb867a31b | |
| parent | 351e78f3abc54f114237d4af64f8199476ebf176 (diff) | |
Rename `__jvp`-->`__fwd_diff`. (#2471)
Co-authored-by: Yong He <yhe@nvidia.com>
29 files changed, 82 insertions, 82 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index fca628a49..f2a72703e 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -437,12 +437,12 @@ class HigherOrderInvokeExpr : public Expr Expr* baseFunction; }; - /// An expression of the form `__jvp(fn)` to access the + /// An expression of the form `__fwd_diff(fn)` to access the /// forward-mode derivative version of the function `fn` /// -class JVPDifferentiateExpr: public HigherOrderInvokeExpr +class ForwardDifferentiateExpr: public HigherOrderInvokeExpr { - SLANG_AST_CLASS(JVPDifferentiateExpr) + SLANG_AST_CLASS(ForwardDifferentiateExpr) }; /// A type expression of the form `__TaggedUnion(A, ...)`. diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index d1e737720..0975de985 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1971,7 +1971,7 @@ namespace Slang return jvpType; } - Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) + Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) { this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 31075c3e8..af1173051 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1987,7 +1987,7 @@ namespace Slang Expr* visitPointerTypeExpr(PointerTypeExpr* expr); Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr); - Expr* visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr); + Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr); /// Perform semantic checking on a `modifier` that is being applied to the given `type` Val* checkTypeModifier(Modifier* modifier, Type* type); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index eadf2f63d..ef067c06c 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1548,11 +1548,11 @@ namespace Slang // Lookup the higher order function and process types accordingly. In the future, // if there are enough varieties, we can have dispatch logic instead of an // if-else ladder. - if (auto jvpExpr = as<JVPDifferentiateExpr>(funcExpr)) + if (auto jvpExpr = as<ForwardDifferentiateExpr>(funcExpr)) { if (auto origFuncType = as<FuncType>(jvpExpr->baseFunction->type)) { - // Case: __jvp(name-resolved-to-decl-ref) + // Case: __fwd_diff(name-resolved-to-decl-ref) auto baseFuncDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<CallableDecl>(); SLANG_ASSERT(baseFuncDeclRef); @@ -1567,7 +1567,7 @@ namespace Slang } else if (auto origOverloadedType = as<OverloadGroupType>(jvpExpr->baseFunction->type)) { - // Case: __jvp(name-resolved-to-multiple-decl-ref) + // Case: __fwd_diff(name-resolved-to-multiple-decl-ref) if (auto overloadExpr = as<OverloadedExpr>(jvpExpr->baseFunction)) { @@ -1595,7 +1595,7 @@ namespace Slang } else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<GenericDecl>()) { - // Case: __jvp(name-resolved-to-generic-decl) + // Case: __fwd_diff(name-resolved-to-generic-decl) // Get inner function DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>( diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp index 3f2c6c789..34e7e3de0 100644 --- a/source/slang/slang-ir-diff-call.cpp +++ b/source/slang/slang-ir-diff-call.cpp @@ -34,8 +34,8 @@ struct DerivativeCallProcessContext do { auto nextChild = child->getNextInst(); - // Look for IRJVPDifferentiate - if (auto derivOf = as<IRJVPDifferentiate>(child)) + // Look for IRForwardDifferentiate + if (auto derivOf = as<IRForwardDifferentiate>(child)) { processDifferentiate(derivOf); } @@ -50,14 +50,14 @@ struct DerivativeCallProcessContext // Perform forward-mode automatic differentiation on // the intstructions. - void processDifferentiate(IRJVPDifferentiate* derivOfInst) + void processDifferentiate(IRForwardDifferentiate* derivOfInst) { IRInst* jvpCallable = nullptr; // First get base function auto origCallable = derivOfInst->getBaseFn(); - // Resolve the derivative function for IRJVPDifferentiate(IRSpecialize(IRFunc)) + // Resolve the derivative function for IRForwardDifferentiate(IRSpecialize(IRFunc)) // Check the specialize inst for a reference to the derivative fn. // if (auto origSpecialize = as<IRSpecialize>(origCallable)) @@ -68,7 +68,7 @@ struct DerivativeCallProcessContext } } - // Resolve the derivative function for an IRJVPDifferentiate(IRFunc) + // Resolve the derivative function for an IRForwardDifferentiate(IRFunc) // // Check for the 'JVPDerivativeReference' decorator on the // base function. diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 1a86506b3..7e6fd30dd 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -1125,7 +1125,7 @@ struct JVPTranscriber { // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass // to generate the implementation. - diffCallee = builder->emitJVPDifferentiateInst( + diffCallee = builder->emitForwardDifferentiateInst( differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), primalCallee); } @@ -1357,8 +1357,8 @@ struct JVPTranscriber // logic here. We simple clone the original lookup which points to the original function, // or the cloned version in case we're inside a generic scope. // The differentiation logic is inserted later when this is used in an IRCall. - // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Lookup(Table)) - // rather than have Lookup(JVPDifferentiate(Table)) + // This decision is mostly to maintain a uniform convention of ForwardDifferentiate(Lookup(Table)) + // rather than have Lookup(ForwardDifferentiate(Table)) // auto diffLookup = cloneInst(&cloneEnv, builder, origLookup); return InstPair(diffLookup, diffLookup); @@ -1937,7 +1937,7 @@ struct JVPDerivativeContext IRBuilder builderStorage(sharedBuilderStorage); IRBuilder* builder = &builderStorage; - // Process all JVPDifferentiate instructions (kIROp_JVPDifferentiate), by + // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by // generating derivative code for the referenced function. // bool modified = processReferencedFunctions(builder); @@ -1962,7 +1962,7 @@ struct JVPDerivativeContext return nullptr; } - // Recursively process instructions looking for JVP calls (kIROp_JVPDifferentiate), + // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate), // then check that the referenced function is marked correctly for differentiation. // bool processReferencedFunctions(IRBuilder* builder) @@ -1979,13 +1979,13 @@ struct JVPDerivativeContext { // Either the child instruction has more children (func/block etc..) // and we add it to the work list for further processing, or - // it's an ordinary inst in which case we check if it's a JVPDifferentiate + // it's an ordinary inst in which case we check if it's a ForwardDifferentiate // instruction. // if (child->getFirstChild() != nullptr) workQueue->push(child); - if (auto jvpDiffInst = as<IRJVPDifferentiate>(child)) + if (auto jvpDiffInst = as<IRForwardDifferentiate>(child)) { auto baseInst = jvpDiffInst->getBaseFn(); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index c59286116..ccde80476 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -758,7 +758,7 @@ INST(BitCast, bitCast, 1, 0) INST(Reinterpret, reinterpret, 1, 0) INST(CastPtrToBool, CastPtrToBool, 1, 0) INST(IsType, IsType, 3, 0) -INST(JVPDifferentiate, jvpDifferentiate, 1, 0) +INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0) // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 5a9c14038..95202d9d0 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -579,17 +579,17 @@ struct IRJVPDerivativeMemberReferenceDecoration : IRDecoration // An instruction that replaces the function symbol // with it's derivative function. -struct IRJVPDifferentiate : IRInst +struct IRForwardDifferentiate : IRInst { enum { - kOp = kIROp_JVPDifferentiate + kOp = kIROp_ForwardDifferentiate }; // The base function for the call. IRUse base; IRInst* getBaseFn() { return getOperand(0); } - IR_LEAF_ISA(JVPDifferentiate) + IR_LEAF_ISA(ForwardDifferentiate) }; // Dictionary item mapping a type with a corresponding @@ -2486,7 +2486,7 @@ public: IRInst* emitExtractExistentialWitnessTable( IRInst* existentialValue); - IRInst* emitJVPDifferentiateInst(IRType* type, IRInst* baseFn); + IRInst* emitForwardDifferentiateInst(IRType* type, IRInst* baseFn); IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 2aaeb4ac3..3a59eb6c9 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3065,11 +3065,11 @@ namespace Slang return inst; } - IRInst* IRBuilder::emitJVPDifferentiateInst(IRType* type, IRInst* baseFn) + IRInst* IRBuilder::emitForwardDifferentiateInst(IRType* type, IRInst* baseFn) { - auto inst = createInst<IRJVPDifferentiate>( + auto inst = createInst<IRForwardDifferentiate>( this, - kIROp_JVPDifferentiate, + kIROp_ForwardDifferentiate, type, baseFn); addInst(inst); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ff66caa90..3766a1a5e 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3042,13 +3042,13 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // of the inner func-expr. This will be resolved // to a concrete function during the derivative // pass. - LoweredValInfo visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) + LoweredValInfo visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) { auto baseVal = lowerSubExpr(expr->baseFunction); SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); return LoweredValInfo::simple( - getBuilder()->emitJVPDifferentiateInst( + getBuilder()->emitForwardDifferentiateInst( lowerType(context, expr->type), baseVal.val)); } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index f2284a121..93f2fcdcb 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2089,11 +2089,11 @@ namespace Slang { return parseTaggedUnionType(parser); } - /// Parse an expression of the form __jvp(fn) where fn is an + /// Parse an expression of the form __fwd_diff(fn) where fn is an /// identifier pointing to a function. - static Expr* parseJVPDifferentiate(Parser* parser) + static Expr* parseForwardDifferentiate(Parser* parser) { - JVPDifferentiateExpr* jvpExpr = parser->astBuilder->create<JVPDifferentiateExpr>(); + ForwardDifferentiateExpr* jvpExpr = parser->astBuilder->create<ForwardDifferentiateExpr>(); parser->ReadToken(TokenType::LParent); @@ -2104,9 +2104,9 @@ namespace Slang return jvpExpr; } - static NodeBase* parseJVPDifferentiate(Parser* parser, void* /* unused */) + static NodeBase* parseForwardDifferentiate(Parser* parser, void* /* unused */) { - return parseJVPDifferentiate(parser); + return parseForwardDifferentiate(parser); } /// Parse a `This` type expression @@ -6634,7 +6634,7 @@ namespace Slang _makeParseExpr("none", parseNoneExpr), _makeParseExpr("try", parseTryExpr), _makeParseExpr("__TaggedUnion", parseTaggedUnionType), - _makeParseExpr("__jvp", parseJVPDifferentiate) + _makeParseExpr("__fwd_diff", parseForwardDifferentiate) }; ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos() diff --git a/tests/autodiff/arithmetic-jvp.slang b/tests/autodiff/arithmetic-jvp.slang index ddd1a4aa9..0c7dd039d 100644 --- a/tests/autodiff/arithmetic-jvp.slang +++ b/tests/autodiff/arithmetic-jvp.slang @@ -43,10 +43,10 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat dpa = dpfloat(2.0, 1.0); dpfloat dpb = dpfloat(1.5, 1.0); - outputBuffer[0] = __jvp(f)(dpa).d(); // Expect: 1 - outputBuffer[1] = __jvp(f)(dpfloat(dpa.p(), 0.0)).d(); // Expect: 0 - outputBuffer[2] = __jvp(g)(dpa).d(); // Expect: 2 - outputBuffer[3] = __jvp(h)(dpa, dpb).d(); // Expect: 8 - outputBuffer[4] = __jvp(j)(dpa, dpb).d(); // Expect: 1 + outputBuffer[0] = __fwd_diff(f)(dpa).d(); // Expect: 1 + outputBuffer[1] = __fwd_diff(f)(dpfloat(dpa.p(), 0.0)).d(); // Expect: 0 + outputBuffer[2] = __fwd_diff(g)(dpa).d(); // Expect: 2 + outputBuffer[3] = __fwd_diff(h)(dpa, dpb).d(); // Expect: 8 + outputBuffer[4] = __fwd_diff(j)(dpa, dpb).d(); // Expect: 1 } } diff --git a/tests/autodiff/auto-differential-type.slang b/tests/autodiff/auto-differential-type.slang index f15fb6417..b551db4ab 100644 --- a/tests/autodiff/auto-differential-type.slang +++ b/tests/autodiff/auto-differential-type.slang @@ -14,7 +14,7 @@ struct A : IDifferentiable [__unsafeForceInlineEarly] static Differential zero() { - Differential b = {0.0, 0.0}; + Differential b = {0.0, float.zero()}; return b; } @@ -53,6 +53,6 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpA dpa = dpA(a, b); - outputBuffer[0] = __jvp(f)(dpa).d().x; // Expect: 1 + outputBuffer[0] = __fwd_diff(f)(dpa).d().x; // Expect: 1 } } diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang index 8ce354edc..02f6541f5 100644 --- a/tests/autodiff/custom-intrinsic.slang +++ b/tests/autodiff/custom-intrinsic.slang @@ -108,12 +108,12 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) dpfloat dpa = dpfloat(2.0, 1.0); outputBuffer[0] = f(dpa.p()); // Expect: 7.389056 - outputBuffer[1] = __jvp(f)(dpa).d(); // Expect: 7.389056 + outputBuffer[1] = __fwd_diff(f)(dpa).d(); // Expect: 7.389056 // g() needs additional handling of IRMakeDifferentialPair(PtrType). This needs to // generate a new var, load from the individual vars and store into the pair var. //outputBuffer[2] = g(dpa.p()); // Expect: 1.381773 - //outputBuffer[3] = __jvp(g)(dpa).d(); // Expect: -0.301168 + //outputBuffer[3] = __fwd_diff(g)(dpa).d(); // Expect: -0.301168 } }
\ No newline at end of file diff --git a/tests/autodiff/differential-method-synthesis.slang b/tests/autodiff/differential-method-synthesis.slang index 53957fd91..3ecd636e9 100644 --- a/tests/autodiff/differential-method-synthesis.slang +++ b/tests/autodiff/differential-method-synthesis.slang @@ -40,7 +40,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) A a = {1.0, 2.0}; A.Differential b = {0.2}; dpA dpa = dpA(a, b); - outputBuffer[0] = __jvp(f)(dpa).d().b.x; // Expect: 0 + outputBuffer[0] = __fwd_diff(f)(dpa).d().b.x; // Expect: 0 outputBuffer[1] = A.dadd(b, b).b.x; // Expect: 0.4 outputBuffer[2] = A.dmul(a, b).b.x; // Expect: 0.2 } diff --git a/tests/autodiff/dstdlib.slang b/tests/autodiff/dstdlib.slang index 6c7ecffbe..614de54f6 100644 --- a/tests/autodiff/dstdlib.slang +++ b/tests/autodiff/dstdlib.slang @@ -28,10 +28,10 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) dpfloat dpa = dpfloat(2.0, 1.0); outputBuffer[0] = f(dpa.p()); // Expect: 7.389056 - outputBuffer[1] = __jvp(f)(dpa).d(); // Expect: 7.389056 + outputBuffer[1] = __fwd_diff(f)(dpa).d(); // Expect: 7.389056 outputBuffer[2] = g(dpa.p()); // Expect: 0.909297 - outputBuffer[3] = __jvp(g)(dpa).d(); // Expect: -0.416146 + outputBuffer[3] = __fwd_diff(g)(dpa).d(); // Expect: -0.416146 outputBuffer[4] = h(dpa.p()); // Expect: -0.416146 - outputBuffer[5] = __jvp(h)(dpa).d(); // Expect: -0.909297 + outputBuffer[5] = __fwd_diff(h)(dpa).d(); // Expect: -0.909297 } }
\ No newline at end of file diff --git a/tests/autodiff/generic-custom-jvp.slang b/tests/autodiff/generic-custom-jvp.slang index 3f0d85b60..f0b8d3898 100644 --- a/tests/autodiff/generic-custom-jvp.slang +++ b/tests/autodiff/generic-custom-jvp.slang @@ -27,8 +27,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat dpa = dpfloat(5.0, 1.0); dpfloat dpn = dpfloat(2, 0.0); - outputBuffer[0] = __jvp(_pow)(dpa, dpn).d(); // Expect: 10.0 - outputBuffer[1] = __jvp(_pow)( + outputBuffer[0] = __fwd_diff(_pow)(dpa, dpn).d(); // Expect: 10.0 + outputBuffer[1] = __fwd_diff(_pow)( dpfloat(dpa.p(), 0.0), dpfloat(dpn.p(), 1.0)).d(); // Expect: 40.23595 } diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index fe4ffc426..e14f851ac 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -283,8 +283,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat3 dpf3 = dpfloat3(myfloat3(float3(1.0, 3.0, 5.0)), mydfloat3(float3(0.5, 1.5, 2.5))); outputBuffer[0] = f(dpa.p()); // Expect: 22.0 - outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5 - outputBuffer[2] = __jvp(f)(dpf4).d().val.values[3]; // Expect: 27.5 - outputBuffer[3] = __jvp(f)(dpf3).d().val.values[1]; // Expect: 40.5 + outputBuffer[1] = __fwd_diff(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5 + outputBuffer[2] = __fwd_diff(f)(dpf4).d().val.values[3]; // Expect: 27.5 + outputBuffer[3] = __fwd_diff(f)(dpf3).d().val.values[1]; // Expect: 40.5 } } diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index bcd5e764e..365be45aa 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -189,8 +189,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat3 dpf3 = dpfloat3(myfloat3(float3(1.0, 3.0, 5.0)), myfloat3(float3(0.5, 1.5, 2.5))); outputBuffer[0] = f(dpa.p()); // Expect: 22.0 - outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5 - outputBuffer[2] = __jvp(f)(dpf4).d().val.w; // Expect: 27.5 - outputBuffer[3] = __jvp(f)(dpf3).d().val.y; // Expect: 40.5 + outputBuffer[1] = __fwd_diff(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5 + outputBuffer[2] = __fwd_diff(f)(dpf4).d().val.w; // Expect: 27.5 + outputBuffer[3] = __fwd_diff(f)(dpf3).d().val.y; // Expect: 40.5 } } diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang index c19a3f6bb..08816c5bc 100644 --- a/tests/autodiff/getter-setter-multi.slang +++ b/tests/autodiff/getter-setter-multi.slang @@ -66,8 +66,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpA dpa = dpA(a, b); - outputBuffer[0] = __jvp(f)(dpa).d().z.z; // Expect: 0.5 - outputBuffer[1] = __jvp(f)(dpa).d().k[5]; // Expect: 1 - outputBuffer[2] = __jvp(f)(dpa).d().k[2]; // Expect: 1.5 + outputBuffer[0] = __fwd_diff(f)(dpa).d().z.z; // Expect: 0.5 + outputBuffer[1] = __fwd_diff(f)(dpa).d().k[5]; // Expect: 1 + outputBuffer[2] = __fwd_diff(f)(dpa).d().k[2]; // Expect: 1.5 } } diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang index 0e8cac13b..2f385b87f 100644 --- a/tests/autodiff/getter-setter.slang +++ b/tests/autodiff/getter-setter.slang @@ -59,6 +59,6 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpA dpa = dpA(a, b); - outputBuffer[0] = __jvp(f)(dpa).d().z; // Expect: 1 + outputBuffer[0] = __fwd_diff(f)(dpa).d().z; // Expect: 1 } } diff --git a/tests/autodiff/imported-custom-jvp.slang b/tests/autodiff/imported-custom-jvp.slang index ee8bdf51d..8adcdee25 100644 --- a/tests/autodiff/imported-custom-jvp.slang +++ b/tests/autodiff/imported-custom-jvp.slang @@ -20,6 +20,6 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat dpa = dpfloat(2.0, 1.0); dpfloat dpb = dpfloat(1.5, 1.0); - outputBuffer[0] = __jvp(f)(dpa).d(); // Expect: 2 + outputBuffer[0] = __fwd_diff(f)(dpa).d(); // Expect: 2 } } diff --git a/tests/autodiff/inout-parameters-jvp.slang b/tests/autodiff/inout-parameters-jvp.slang index ba04c6b65..e53e5db7c 100644 --- a/tests/autodiff/inout-parameters-jvp.slang +++ b/tests/autodiff/inout-parameters-jvp.slang @@ -33,12 +33,12 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat dpz = dpfloat(z, dz); - __jvp(h)(dpfloat(x, dx), dpfloat(y, dy), dpz); + __fwd_diff(h)(dpfloat(x, dx), dpfloat(y, dy), dpz); outputBuffer[0] = dpz.d(); // Expect: 12.0 outputBuffer[1] = dpz.p(); // Expect: 6.75 - __jvp(g)(dpfloat(x, dx), dpfloat(y, dy), dpz); + __fwd_diff(g)(dpfloat(x, dx), dpfloat(y, dy), dpz); outputBuffer[2] = dpz.d(); // Expect: 21.5 outputBuffer[3] = dpz.p(); // Expect: 12.5 diff --git a/tests/autodiff/local-redecl-custom-jvp.slang b/tests/autodiff/local-redecl-custom-jvp.slang index 6241a8bf5..79b90bd16 100644 --- a/tests/autodiff/local-redecl-custom-jvp.slang +++ b/tests/autodiff/local-redecl-custom-jvp.slang @@ -25,8 +25,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat dpa = dpfloat(5.0, 1.0); dpfloat dpn = dpfloat(2, 0.0); - outputBuffer[0] = __jvp(_pow)(dpa, dpn).d(); // Expect: 10.0 - outputBuffer[1] = __jvp(_pow)( + outputBuffer[0] = __fwd_diff(_pow)(dpa, dpn).d(); // Expect: 10.0 + outputBuffer[1] = __fwd_diff(_pow)( dpfloat(dpa.p(), 0.0), dpfloat(dpn.p(), 1.0)).d(); // Expect: 40.23595 } diff --git a/tests/autodiff/nested-jvp.slang b/tests/autodiff/nested-jvp.slang index baebeee56..96648d861 100644 --- a/tests/autodiff/nested-jvp.slang +++ b/tests/autodiff/nested-jvp.slang @@ -58,7 +58,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) float3 d_f90 = float3(0.9, 0.9, 0.9); float d_cosTheta = 1.0; - outputBuffer[0] = __jvp(fresnel)( + outputBuffer[0] = __fwd_diff(fresnel)( dpfloat3(f0, d_f0), dpfloat3(f90, d_f90), dpfloat(cosTheta, d_cosTheta)).d().y; // Expect: -0.031250 @@ -71,14 +71,14 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) float db = -1.0; float dc = 0.2; - outputBuffer[1] = __jvp(g)( + outputBuffer[1] = __fwd_diff(g)( dpfloat(a, da), dpfloat(b, db), dpfloat(c, dc)).d(); // Expect: -0.24375 outputBuffer[2] = g(a, b, c); // Expect: 0.95625 - outputBuffer[3] = __jvp(g)( + outputBuffer[3] = __fwd_diff(g)( dpfloat(a, da), dpfloat(b, db), dpfloat(3.0, dc)).d(); // Expect: -0.4; diff --git a/tests/autodiff/out-parameters-jvp.slang b/tests/autodiff/out-parameters-jvp.slang index b243d4fb5..9a311ed31 100644 --- a/tests/autodiff/out-parameters-jvp.slang +++ b/tests/autodiff/out-parameters-jvp.slang @@ -23,7 +23,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) float dy = 0.5; dpfloat dresult; - __jvp(h)(dpfloat(x, dx), dpfloat(y, dy), dresult); + __fwd_diff(h)(dpfloat(x, dx), dpfloat(y, dy), dresult); outputBuffer[0] = dresult.d(); // Expect: 9.5 diff --git a/tests/autodiff/overloads-jvp.slang b/tests/autodiff/overloads-jvp.slang index 26b5c0076..95b9cadd3 100644 --- a/tests/autodiff/overloads-jvp.slang +++ b/tests/autodiff/overloads-jvp.slang @@ -33,8 +33,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) outputBuffer[0] = f(dpa.p()); // Expect: 6 outputBuffer[1] = f(dpf3.p()); // Expect: 8 - outputBuffer[2] = __jvp(f)(dpf3).d(); // Expect: 5.5 - outputBuffer[3] = __jvp(f)(dpa).d(); // Expect: 5 - outputBuffer[4] = __jvp(g)(dpa).d(); // Expect: 11.0 + outputBuffer[2] = __fwd_diff(f)(dpf3).d(); // Expect: 5.5 + outputBuffer[3] = __fwd_diff(f)(dpa).d(); // Expect: 5 + outputBuffer[4] = __fwd_diff(g)(dpa).d(); // Expect: 11.0 } } diff --git a/tests/autodiff/vector-arithmetic-jvp.slang b/tests/autodiff/vector-arithmetic-jvp.slang index e05d94733..b79b3e764 100644 --- a/tests/autodiff/vector-arithmetic-jvp.slang +++ b/tests/autodiff/vector-arithmetic-jvp.slang @@ -48,17 +48,17 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) float4 a4 = float4(2.0, 1.0, 0.0, 2.0); float4 b4 = float4(1.5, -2.0, 1.0, 1.5); - outputBuffer[0] = __jvp(f)(dpfloat3(a, da)).d().z; // Expect: 1 + outputBuffer[0] = __fwd_diff(f)(dpfloat3(a, da)).d().z; // Expect: 1 - outputBuffer[1] = __jvp(g)( + outputBuffer[1] = __fwd_diff(g)( dpfloat3(a, da), dpfloat3(b, float3(2.0, 1.0, 0.0))).d().y; // Expect: 8 - outputBuffer[2] = __jvp(h)( + outputBuffer[2] = __fwd_diff(h)( dpfloat2(a2, float2(1.0, 0.0)), dpfloat2(b2, float2(1.0, 1.0))).d().x; // Expect: 8 - outputBuffer[3] = __jvp(j)( + outputBuffer[3] = __fwd_diff(j)( dpfloat4(a4, float4(1.0)), dpfloat4(b4, float4(2.0))).d().w; // Expect: 9 } diff --git a/tests/autodiff/vector-swizzle-jvp.slang b/tests/autodiff/vector-swizzle-jvp.slang index 775c0140e..fc726d067 100644 --- a/tests/autodiff/vector-swizzle-jvp.slang +++ b/tests/autodiff/vector-swizzle-jvp.slang @@ -27,16 +27,16 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) float3 a = float3(2.0, 2.0, 2.0); float3 da = float3(1.0, 0.5, 1.0); - outputBuffer[0] = __jvp(f)(dpfloat3(a, da)).d().x; // Expect: 1 - outputBuffer[1] = __jvp(f)(dpfloat3(a, da)).d().y; // Expect: 0.5 + outputBuffer[0] = __fwd_diff(f)(dpfloat3(a, da)).d().x; // Expect: 1 + outputBuffer[1] = __fwd_diff(f)(dpfloat3(a, da)).d().y; // Expect: 0.5 float3 x = float3(0.5, 2.0, 0.5); float4 y = float4(-1.5, 1.0, 4.0, 2.0); float3 dx = float3(1.0, 0.0, -1.0); float4 dy = float4(0.0, 0.5, -0.25, 1.0); - outputBuffer[2] = __jvp(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d().x; // Expect: -2.25 - outputBuffer[3] = __jvp(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d().y; // Expect: 0.5 + outputBuffer[2] = __fwd_diff(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d().x; // Expect: -2.25 + outputBuffer[3] = __fwd_diff(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d().y; // Expect: 0.5 } } |
