summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-10-27 12:19:30 -0700
committerGitHub <noreply@github.com>2022-10-27 12:19:30 -0700
commit0cbef6fd6d7924d37ef3ea5ec7c848c80947d13f (patch)
tree173fa18c39638e7d41ae092b9012554cb867a31b
parent351e78f3abc54f114237d4af64f8199476ebf176 (diff)
Rename `__jvp`-->`__fwd_diff`. (#2471)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-ast-expr.h6
-rw-r--r--source/slang/slang-check-expr.cpp2
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-check-overload.cpp8
-rw-r--r--source/slang/slang-ir-diff-call.cpp10
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp14
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h8
-rw-r--r--source/slang/slang-ir.cpp6
-rw-r--r--source/slang/slang-lower-to-ir.cpp4
-rw-r--r--source/slang/slang-parser.cpp12
-rw-r--r--tests/autodiff/arithmetic-jvp.slang10
-rw-r--r--tests/autodiff/auto-differential-type.slang4
-rw-r--r--tests/autodiff/custom-intrinsic.slang4
-rw-r--r--tests/autodiff/differential-method-synthesis.slang2
-rw-r--r--tests/autodiff/dstdlib.slang6
-rw-r--r--tests/autodiff/generic-custom-jvp.slang4
-rw-r--r--tests/autodiff/generic-impl-jvp.slang6
-rw-r--r--tests/autodiff/generic-jvp.slang6
-rw-r--r--tests/autodiff/getter-setter-multi.slang6
-rw-r--r--tests/autodiff/getter-setter.slang2
-rw-r--r--tests/autodiff/imported-custom-jvp.slang2
-rw-r--r--tests/autodiff/inout-parameters-jvp.slang4
-rw-r--r--tests/autodiff/local-redecl-custom-jvp.slang4
-rw-r--r--tests/autodiff/nested-jvp.slang6
-rw-r--r--tests/autodiff/out-parameters-jvp.slang2
-rw-r--r--tests/autodiff/overloads-jvp.slang6
-rw-r--r--tests/autodiff/vector-arithmetic-jvp.slang8
-rw-r--r--tests/autodiff/vector-swizzle-jvp.slang8
29 files changed, 82 insertions, 82 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index fca628a49..f2a72703e 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -437,12 +437,12 @@ class HigherOrderInvokeExpr : public Expr
Expr* baseFunction;
};
- /// An expression of the form `__jvp(fn)` to access the
+ /// An expression of the form `__fwd_diff(fn)` to access the
/// forward-mode derivative version of the function `fn`
///
-class JVPDifferentiateExpr: public HigherOrderInvokeExpr
+class ForwardDifferentiateExpr: public HigherOrderInvokeExpr
{
- SLANG_AST_CLASS(JVPDifferentiateExpr)
+ SLANG_AST_CLASS(ForwardDifferentiateExpr)
};
/// A type expression of the form `__TaggedUnion(A, ...)`.
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index d1e737720..0975de985 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1971,7 +1971,7 @@ namespace Slang
return jvpType;
}
- Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
+ Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr)
{
this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary();
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 31075c3e8..af1173051 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1987,7 +1987,7 @@ namespace Slang
Expr* visitPointerTypeExpr(PointerTypeExpr* expr);
Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr);
- Expr* visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr);
+ Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr);
/// Perform semantic checking on a `modifier` that is being applied to the given `type`
Val* checkTypeModifier(Modifier* modifier, Type* type);
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index eadf2f63d..ef067c06c 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1548,11 +1548,11 @@ namespace Slang
// Lookup the higher order function and process types accordingly. In the future,
// if there are enough varieties, we can have dispatch logic instead of an
// if-else ladder.
- if (auto jvpExpr = as<JVPDifferentiateExpr>(funcExpr))
+ if (auto jvpExpr = as<ForwardDifferentiateExpr>(funcExpr))
{
if (auto origFuncType = as<FuncType>(jvpExpr->baseFunction->type))
{
- // Case: __jvp(name-resolved-to-decl-ref)
+ // Case: __fwd_diff(name-resolved-to-decl-ref)
auto baseFuncDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<CallableDecl>();
SLANG_ASSERT(baseFuncDeclRef);
@@ -1567,7 +1567,7 @@ namespace Slang
}
else if (auto origOverloadedType = as<OverloadGroupType>(jvpExpr->baseFunction->type))
{
- // Case: __jvp(name-resolved-to-multiple-decl-ref)
+ // Case: __fwd_diff(name-resolved-to-multiple-decl-ref)
if (auto overloadExpr = as<OverloadedExpr>(jvpExpr->baseFunction))
{
@@ -1595,7 +1595,7 @@ namespace Slang
}
else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<GenericDecl>())
{
- // Case: __jvp(name-resolved-to-generic-decl)
+ // Case: __fwd_diff(name-resolved-to-generic-decl)
// Get inner function
DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>(
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp
index 3f2c6c789..34e7e3de0 100644
--- a/source/slang/slang-ir-diff-call.cpp
+++ b/source/slang/slang-ir-diff-call.cpp
@@ -34,8 +34,8 @@ struct DerivativeCallProcessContext
do
{
auto nextChild = child->getNextInst();
- // Look for IRJVPDifferentiate
- if (auto derivOf = as<IRJVPDifferentiate>(child))
+ // Look for IRForwardDifferentiate
+ if (auto derivOf = as<IRForwardDifferentiate>(child))
{
processDifferentiate(derivOf);
}
@@ -50,14 +50,14 @@ struct DerivativeCallProcessContext
// Perform forward-mode automatic differentiation on
// the intstructions.
- void processDifferentiate(IRJVPDifferentiate* derivOfInst)
+ void processDifferentiate(IRForwardDifferentiate* derivOfInst)
{
IRInst* jvpCallable = nullptr;
// First get base function
auto origCallable = derivOfInst->getBaseFn();
- // Resolve the derivative function for IRJVPDifferentiate(IRSpecialize(IRFunc))
+ // Resolve the derivative function for IRForwardDifferentiate(IRSpecialize(IRFunc))
// Check the specialize inst for a reference to the derivative fn.
//
if (auto origSpecialize = as<IRSpecialize>(origCallable))
@@ -68,7 +68,7 @@ struct DerivativeCallProcessContext
}
}
- // Resolve the derivative function for an IRJVPDifferentiate(IRFunc)
+ // Resolve the derivative function for an IRForwardDifferentiate(IRFunc)
//
// Check for the 'JVPDerivativeReference' decorator on the
// base function.
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 1a86506b3..7e6fd30dd 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -1125,7 +1125,7 @@ struct JVPTranscriber
{
// If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
// to generate the implementation.
- diffCallee = builder->emitJVPDifferentiateInst(
+ diffCallee = builder->emitForwardDifferentiateInst(
differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
primalCallee);
}
@@ -1357,8 +1357,8 @@ struct JVPTranscriber
// logic here. We simple clone the original lookup which points to the original function,
// or the cloned version in case we're inside a generic scope.
// The differentiation logic is inserted later when this is used in an IRCall.
- // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Lookup(Table))
- // rather than have Lookup(JVPDifferentiate(Table))
+ // This decision is mostly to maintain a uniform convention of ForwardDifferentiate(Lookup(Table))
+ // rather than have Lookup(ForwardDifferentiate(Table))
//
auto diffLookup = cloneInst(&cloneEnv, builder, origLookup);
return InstPair(diffLookup, diffLookup);
@@ -1937,7 +1937,7 @@ struct JVPDerivativeContext
IRBuilder builderStorage(sharedBuilderStorage);
IRBuilder* builder = &builderStorage;
- // Process all JVPDifferentiate instructions (kIROp_JVPDifferentiate), by
+ // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by
// generating derivative code for the referenced function.
//
bool modified = processReferencedFunctions(builder);
@@ -1962,7 +1962,7 @@ struct JVPDerivativeContext
return nullptr;
}
- // Recursively process instructions looking for JVP calls (kIROp_JVPDifferentiate),
+ // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate),
// then check that the referenced function is marked correctly for differentiation.
//
bool processReferencedFunctions(IRBuilder* builder)
@@ -1979,13 +1979,13 @@ struct JVPDerivativeContext
{
// Either the child instruction has more children (func/block etc..)
// and we add it to the work list for further processing, or
- // it's an ordinary inst in which case we check if it's a JVPDifferentiate
+ // it's an ordinary inst in which case we check if it's a ForwardDifferentiate
// instruction.
//
if (child->getFirstChild() != nullptr)
workQueue->push(child);
- if (auto jvpDiffInst = as<IRJVPDifferentiate>(child))
+ if (auto jvpDiffInst = as<IRForwardDifferentiate>(child))
{
auto baseInst = jvpDiffInst->getBaseFn();
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index c59286116..ccde80476 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -758,7 +758,7 @@ INST(BitCast, bitCast, 1, 0)
INST(Reinterpret, reinterpret, 1, 0)
INST(CastPtrToBool, CastPtrToBool, 1, 0)
INST(IsType, IsType, 3, 0)
-INST(JVPDifferentiate, jvpDifferentiate, 1, 0)
+INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0)
// Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer
INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 5a9c14038..95202d9d0 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -579,17 +579,17 @@ struct IRJVPDerivativeMemberReferenceDecoration : IRDecoration
// An instruction that replaces the function symbol
// with it's derivative function.
-struct IRJVPDifferentiate : IRInst
+struct IRForwardDifferentiate : IRInst
{
enum
{
- kOp = kIROp_JVPDifferentiate
+ kOp = kIROp_ForwardDifferentiate
};
// The base function for the call.
IRUse base;
IRInst* getBaseFn() { return getOperand(0); }
- IR_LEAF_ISA(JVPDifferentiate)
+ IR_LEAF_ISA(ForwardDifferentiate)
};
// Dictionary item mapping a type with a corresponding
@@ -2486,7 +2486,7 @@ public:
IRInst* emitExtractExistentialWitnessTable(
IRInst* existentialValue);
- IRInst* emitJVPDifferentiateInst(IRType* type, IRInst* baseFn);
+ IRInst* emitForwardDifferentiateInst(IRType* type, IRInst* baseFn);
IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 2aaeb4ac3..3a59eb6c9 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3065,11 +3065,11 @@ namespace Slang
return inst;
}
- IRInst* IRBuilder::emitJVPDifferentiateInst(IRType* type, IRInst* baseFn)
+ IRInst* IRBuilder::emitForwardDifferentiateInst(IRType* type, IRInst* baseFn)
{
- auto inst = createInst<IRJVPDifferentiate>(
+ auto inst = createInst<IRForwardDifferentiate>(
this,
- kIROp_JVPDifferentiate,
+ kIROp_ForwardDifferentiate,
type,
baseFn);
addInst(inst);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index ff66caa90..3766a1a5e 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3042,13 +3042,13 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// of the inner func-expr. This will be resolved
// to a concrete function during the derivative
// pass.
- LoweredValInfo visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
+ LoweredValInfo visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr)
{
auto baseVal = lowerSubExpr(expr->baseFunction);
SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
return LoweredValInfo::simple(
- getBuilder()->emitJVPDifferentiateInst(
+ getBuilder()->emitForwardDifferentiateInst(
lowerType(context, expr->type),
baseVal.val));
}
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index f2284a121..93f2fcdcb 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -2089,11 +2089,11 @@ namespace Slang
{
return parseTaggedUnionType(parser);
}
- /// Parse an expression of the form __jvp(fn) where fn is an
+ /// Parse an expression of the form __fwd_diff(fn) where fn is an
/// identifier pointing to a function.
- static Expr* parseJVPDifferentiate(Parser* parser)
+ static Expr* parseForwardDifferentiate(Parser* parser)
{
- JVPDifferentiateExpr* jvpExpr = parser->astBuilder->create<JVPDifferentiateExpr>();
+ ForwardDifferentiateExpr* jvpExpr = parser->astBuilder->create<ForwardDifferentiateExpr>();
parser->ReadToken(TokenType::LParent);
@@ -2104,9 +2104,9 @@ namespace Slang
return jvpExpr;
}
- static NodeBase* parseJVPDifferentiate(Parser* parser, void* /* unused */)
+ static NodeBase* parseForwardDifferentiate(Parser* parser, void* /* unused */)
{
- return parseJVPDifferentiate(parser);
+ return parseForwardDifferentiate(parser);
}
/// Parse a `This` type expression
@@ -6634,7 +6634,7 @@ namespace Slang
_makeParseExpr("none", parseNoneExpr),
_makeParseExpr("try", parseTryExpr),
_makeParseExpr("__TaggedUnion", parseTaggedUnionType),
- _makeParseExpr("__jvp", parseJVPDifferentiate)
+ _makeParseExpr("__fwd_diff", parseForwardDifferentiate)
};
ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos()
diff --git a/tests/autodiff/arithmetic-jvp.slang b/tests/autodiff/arithmetic-jvp.slang
index ddd1a4aa9..0c7dd039d 100644
--- a/tests/autodiff/arithmetic-jvp.slang
+++ b/tests/autodiff/arithmetic-jvp.slang
@@ -43,10 +43,10 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
dpfloat dpa = dpfloat(2.0, 1.0);
dpfloat dpb = dpfloat(1.5, 1.0);
- outputBuffer[0] = __jvp(f)(dpa).d(); // Expect: 1
- outputBuffer[1] = __jvp(f)(dpfloat(dpa.p(), 0.0)).d(); // Expect: 0
- outputBuffer[2] = __jvp(g)(dpa).d(); // Expect: 2
- outputBuffer[3] = __jvp(h)(dpa, dpb).d(); // Expect: 8
- outputBuffer[4] = __jvp(j)(dpa, dpb).d(); // Expect: 1
+ outputBuffer[0] = __fwd_diff(f)(dpa).d(); // Expect: 1
+ outputBuffer[1] = __fwd_diff(f)(dpfloat(dpa.p(), 0.0)).d(); // Expect: 0
+ outputBuffer[2] = __fwd_diff(g)(dpa).d(); // Expect: 2
+ outputBuffer[3] = __fwd_diff(h)(dpa, dpb).d(); // Expect: 8
+ outputBuffer[4] = __fwd_diff(j)(dpa, dpb).d(); // Expect: 1
}
}
diff --git a/tests/autodiff/auto-differential-type.slang b/tests/autodiff/auto-differential-type.slang
index f15fb6417..b551db4ab 100644
--- a/tests/autodiff/auto-differential-type.slang
+++ b/tests/autodiff/auto-differential-type.slang
@@ -14,7 +14,7 @@ struct A : IDifferentiable
[__unsafeForceInlineEarly]
static Differential zero()
{
- Differential b = {0.0, 0.0};
+ Differential b = {0.0, float.zero()};
return b;
}
@@ -53,6 +53,6 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
dpA dpa = dpA(a, b);
- outputBuffer[0] = __jvp(f)(dpa).d().x; // Expect: 1
+ outputBuffer[0] = __fwd_diff(f)(dpa).d().x; // Expect: 1
}
}
diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang
index 8ce354edc..02f6541f5 100644
--- a/tests/autodiff/custom-intrinsic.slang
+++ b/tests/autodiff/custom-intrinsic.slang
@@ -108,12 +108,12 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
dpfloat dpa = dpfloat(2.0, 1.0);
outputBuffer[0] = f(dpa.p()); // Expect: 7.389056
- outputBuffer[1] = __jvp(f)(dpa).d(); // Expect: 7.389056
+ outputBuffer[1] = __fwd_diff(f)(dpa).d(); // Expect: 7.389056
// g() needs additional handling of IRMakeDifferentialPair(PtrType). This needs to
// generate a new var, load from the individual vars and store into the pair var.
//outputBuffer[2] = g(dpa.p()); // Expect: 1.381773
- //outputBuffer[3] = __jvp(g)(dpa).d(); // Expect: -0.301168
+ //outputBuffer[3] = __fwd_diff(g)(dpa).d(); // Expect: -0.301168
}
} \ No newline at end of file
diff --git a/tests/autodiff/differential-method-synthesis.slang b/tests/autodiff/differential-method-synthesis.slang
index 53957fd91..3ecd636e9 100644
--- a/tests/autodiff/differential-method-synthesis.slang
+++ b/tests/autodiff/differential-method-synthesis.slang
@@ -40,7 +40,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
A a = {1.0, 2.0};
A.Differential b = {0.2};
dpA dpa = dpA(a, b);
- outputBuffer[0] = __jvp(f)(dpa).d().b.x; // Expect: 0
+ outputBuffer[0] = __fwd_diff(f)(dpa).d().b.x; // Expect: 0
outputBuffer[1] = A.dadd(b, b).b.x; // Expect: 0.4
outputBuffer[2] = A.dmul(a, b).b.x; // Expect: 0.2
}
diff --git a/tests/autodiff/dstdlib.slang b/tests/autodiff/dstdlib.slang
index 6c7ecffbe..614de54f6 100644
--- a/tests/autodiff/dstdlib.slang
+++ b/tests/autodiff/dstdlib.slang
@@ -28,10 +28,10 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
dpfloat dpa = dpfloat(2.0, 1.0);
outputBuffer[0] = f(dpa.p()); // Expect: 7.389056
- outputBuffer[1] = __jvp(f)(dpa).d(); // Expect: 7.389056
+ outputBuffer[1] = __fwd_diff(f)(dpa).d(); // Expect: 7.389056
outputBuffer[2] = g(dpa.p()); // Expect: 0.909297
- outputBuffer[3] = __jvp(g)(dpa).d(); // Expect: -0.416146
+ outputBuffer[3] = __fwd_diff(g)(dpa).d(); // Expect: -0.416146
outputBuffer[4] = h(dpa.p()); // Expect: -0.416146
- outputBuffer[5] = __jvp(h)(dpa).d(); // Expect: -0.909297
+ outputBuffer[5] = __fwd_diff(h)(dpa).d(); // Expect: -0.909297
}
} \ No newline at end of file
diff --git a/tests/autodiff/generic-custom-jvp.slang b/tests/autodiff/generic-custom-jvp.slang
index 3f0d85b60..f0b8d3898 100644
--- a/tests/autodiff/generic-custom-jvp.slang
+++ b/tests/autodiff/generic-custom-jvp.slang
@@ -27,8 +27,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
dpfloat dpa = dpfloat(5.0, 1.0);
dpfloat dpn = dpfloat(2, 0.0);
- outputBuffer[0] = __jvp(_pow)(dpa, dpn).d(); // Expect: 10.0
- outputBuffer[1] = __jvp(_pow)(
+ outputBuffer[0] = __fwd_diff(_pow)(dpa, dpn).d(); // Expect: 10.0
+ outputBuffer[1] = __fwd_diff(_pow)(
dpfloat(dpa.p(), 0.0),
dpfloat(dpn.p(), 1.0)).d(); // Expect: 40.23595
}
diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang
index fe4ffc426..e14f851ac 100644
--- a/tests/autodiff/generic-impl-jvp.slang
+++ b/tests/autodiff/generic-impl-jvp.slang
@@ -283,8 +283,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
dpfloat3 dpf3 = dpfloat3(myfloat3(float3(1.0, 3.0, 5.0)), mydfloat3(float3(0.5, 1.5, 2.5)));
outputBuffer[0] = f(dpa.p()); // Expect: 22.0
- outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5
- outputBuffer[2] = __jvp(f)(dpf4).d().val.values[3]; // Expect: 27.5
- outputBuffer[3] = __jvp(f)(dpf3).d().val.values[1]; // Expect: 40.5
+ outputBuffer[1] = __fwd_diff(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5
+ outputBuffer[2] = __fwd_diff(f)(dpf4).d().val.values[3]; // Expect: 27.5
+ outputBuffer[3] = __fwd_diff(f)(dpf3).d().val.values[1]; // Expect: 40.5
}
}
diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang
index bcd5e764e..365be45aa 100644
--- a/tests/autodiff/generic-jvp.slang
+++ b/tests/autodiff/generic-jvp.slang
@@ -189,8 +189,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
dpfloat3 dpf3 = dpfloat3(myfloat3(float3(1.0, 3.0, 5.0)), myfloat3(float3(0.5, 1.5, 2.5)));
outputBuffer[0] = f(dpa.p()); // Expect: 22.0
- outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5
- outputBuffer[2] = __jvp(f)(dpf4).d().val.w; // Expect: 27.5
- outputBuffer[3] = __jvp(f)(dpf3).d().val.y; // Expect: 40.5
+ outputBuffer[1] = __fwd_diff(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5
+ outputBuffer[2] = __fwd_diff(f)(dpf4).d().val.w; // Expect: 27.5
+ outputBuffer[3] = __fwd_diff(f)(dpf3).d().val.y; // Expect: 40.5
}
}
diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang
index c19a3f6bb..08816c5bc 100644
--- a/tests/autodiff/getter-setter-multi.slang
+++ b/tests/autodiff/getter-setter-multi.slang
@@ -66,8 +66,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
dpA dpa = dpA(a, b);
- outputBuffer[0] = __jvp(f)(dpa).d().z.z; // Expect: 0.5
- outputBuffer[1] = __jvp(f)(dpa).d().k[5]; // Expect: 1
- outputBuffer[2] = __jvp(f)(dpa).d().k[2]; // Expect: 1.5
+ outputBuffer[0] = __fwd_diff(f)(dpa).d().z.z; // Expect: 0.5
+ outputBuffer[1] = __fwd_diff(f)(dpa).d().k[5]; // Expect: 1
+ outputBuffer[2] = __fwd_diff(f)(dpa).d().k[2]; // Expect: 1.5
}
}
diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang
index 0e8cac13b..2f385b87f 100644
--- a/tests/autodiff/getter-setter.slang
+++ b/tests/autodiff/getter-setter.slang
@@ -59,6 +59,6 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
dpA dpa = dpA(a, b);
- outputBuffer[0] = __jvp(f)(dpa).d().z; // Expect: 1
+ outputBuffer[0] = __fwd_diff(f)(dpa).d().z; // Expect: 1
}
}
diff --git a/tests/autodiff/imported-custom-jvp.slang b/tests/autodiff/imported-custom-jvp.slang
index ee8bdf51d..8adcdee25 100644
--- a/tests/autodiff/imported-custom-jvp.slang
+++ b/tests/autodiff/imported-custom-jvp.slang
@@ -20,6 +20,6 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
dpfloat dpa = dpfloat(2.0, 1.0);
dpfloat dpb = dpfloat(1.5, 1.0);
- outputBuffer[0] = __jvp(f)(dpa).d(); // Expect: 2
+ outputBuffer[0] = __fwd_diff(f)(dpa).d(); // Expect: 2
}
}
diff --git a/tests/autodiff/inout-parameters-jvp.slang b/tests/autodiff/inout-parameters-jvp.slang
index ba04c6b65..e53e5db7c 100644
--- a/tests/autodiff/inout-parameters-jvp.slang
+++ b/tests/autodiff/inout-parameters-jvp.slang
@@ -33,12 +33,12 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
dpfloat dpz = dpfloat(z, dz);
- __jvp(h)(dpfloat(x, dx), dpfloat(y, dy), dpz);
+ __fwd_diff(h)(dpfloat(x, dx), dpfloat(y, dy), dpz);
outputBuffer[0] = dpz.d(); // Expect: 12.0
outputBuffer[1] = dpz.p(); // Expect: 6.75
- __jvp(g)(dpfloat(x, dx), dpfloat(y, dy), dpz);
+ __fwd_diff(g)(dpfloat(x, dx), dpfloat(y, dy), dpz);
outputBuffer[2] = dpz.d(); // Expect: 21.5
outputBuffer[3] = dpz.p(); // Expect: 12.5
diff --git a/tests/autodiff/local-redecl-custom-jvp.slang b/tests/autodiff/local-redecl-custom-jvp.slang
index 6241a8bf5..79b90bd16 100644
--- a/tests/autodiff/local-redecl-custom-jvp.slang
+++ b/tests/autodiff/local-redecl-custom-jvp.slang
@@ -25,8 +25,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
dpfloat dpa = dpfloat(5.0, 1.0);
dpfloat dpn = dpfloat(2, 0.0);
- outputBuffer[0] = __jvp(_pow)(dpa, dpn).d(); // Expect: 10.0
- outputBuffer[1] = __jvp(_pow)(
+ outputBuffer[0] = __fwd_diff(_pow)(dpa, dpn).d(); // Expect: 10.0
+ outputBuffer[1] = __fwd_diff(_pow)(
dpfloat(dpa.p(), 0.0),
dpfloat(dpn.p(), 1.0)).d(); // Expect: 40.23595
}
diff --git a/tests/autodiff/nested-jvp.slang b/tests/autodiff/nested-jvp.slang
index baebeee56..96648d861 100644
--- a/tests/autodiff/nested-jvp.slang
+++ b/tests/autodiff/nested-jvp.slang
@@ -58,7 +58,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
float3 d_f90 = float3(0.9, 0.9, 0.9);
float d_cosTheta = 1.0;
- outputBuffer[0] = __jvp(fresnel)(
+ outputBuffer[0] = __fwd_diff(fresnel)(
dpfloat3(f0, d_f0),
dpfloat3(f90, d_f90),
dpfloat(cosTheta, d_cosTheta)).d().y; // Expect: -0.031250
@@ -71,14 +71,14 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
float db = -1.0;
float dc = 0.2;
- outputBuffer[1] = __jvp(g)(
+ outputBuffer[1] = __fwd_diff(g)(
dpfloat(a, da),
dpfloat(b, db),
dpfloat(c, dc)).d(); // Expect: -0.24375
outputBuffer[2] = g(a, b, c); // Expect: 0.95625
- outputBuffer[3] = __jvp(g)(
+ outputBuffer[3] = __fwd_diff(g)(
dpfloat(a, da),
dpfloat(b, db),
dpfloat(3.0, dc)).d(); // Expect: -0.4;
diff --git a/tests/autodiff/out-parameters-jvp.slang b/tests/autodiff/out-parameters-jvp.slang
index b243d4fb5..9a311ed31 100644
--- a/tests/autodiff/out-parameters-jvp.slang
+++ b/tests/autodiff/out-parameters-jvp.slang
@@ -23,7 +23,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
float dy = 0.5;
dpfloat dresult;
- __jvp(h)(dpfloat(x, dx), dpfloat(y, dy), dresult);
+ __fwd_diff(h)(dpfloat(x, dx), dpfloat(y, dy), dresult);
outputBuffer[0] = dresult.d(); // Expect: 9.5
diff --git a/tests/autodiff/overloads-jvp.slang b/tests/autodiff/overloads-jvp.slang
index 26b5c0076..95b9cadd3 100644
--- a/tests/autodiff/overloads-jvp.slang
+++ b/tests/autodiff/overloads-jvp.slang
@@ -33,8 +33,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
outputBuffer[0] = f(dpa.p()); // Expect: 6
outputBuffer[1] = f(dpf3.p()); // Expect: 8
- outputBuffer[2] = __jvp(f)(dpf3).d(); // Expect: 5.5
- outputBuffer[3] = __jvp(f)(dpa).d(); // Expect: 5
- outputBuffer[4] = __jvp(g)(dpa).d(); // Expect: 11.0
+ outputBuffer[2] = __fwd_diff(f)(dpf3).d(); // Expect: 5.5
+ outputBuffer[3] = __fwd_diff(f)(dpa).d(); // Expect: 5
+ outputBuffer[4] = __fwd_diff(g)(dpa).d(); // Expect: 11.0
}
}
diff --git a/tests/autodiff/vector-arithmetic-jvp.slang b/tests/autodiff/vector-arithmetic-jvp.slang
index e05d94733..b79b3e764 100644
--- a/tests/autodiff/vector-arithmetic-jvp.slang
+++ b/tests/autodiff/vector-arithmetic-jvp.slang
@@ -48,17 +48,17 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
float4 a4 = float4(2.0, 1.0, 0.0, 2.0);
float4 b4 = float4(1.5, -2.0, 1.0, 1.5);
- outputBuffer[0] = __jvp(f)(dpfloat3(a, da)).d().z; // Expect: 1
+ outputBuffer[0] = __fwd_diff(f)(dpfloat3(a, da)).d().z; // Expect: 1
- outputBuffer[1] = __jvp(g)(
+ outputBuffer[1] = __fwd_diff(g)(
dpfloat3(a, da),
dpfloat3(b, float3(2.0, 1.0, 0.0))).d().y; // Expect: 8
- outputBuffer[2] = __jvp(h)(
+ outputBuffer[2] = __fwd_diff(h)(
dpfloat2(a2, float2(1.0, 0.0)),
dpfloat2(b2, float2(1.0, 1.0))).d().x; // Expect: 8
- outputBuffer[3] = __jvp(j)(
+ outputBuffer[3] = __fwd_diff(j)(
dpfloat4(a4, float4(1.0)),
dpfloat4(b4, float4(2.0))).d().w; // Expect: 9
}
diff --git a/tests/autodiff/vector-swizzle-jvp.slang b/tests/autodiff/vector-swizzle-jvp.slang
index 775c0140e..fc726d067 100644
--- a/tests/autodiff/vector-swizzle-jvp.slang
+++ b/tests/autodiff/vector-swizzle-jvp.slang
@@ -27,16 +27,16 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
float3 a = float3(2.0, 2.0, 2.0);
float3 da = float3(1.0, 0.5, 1.0);
- outputBuffer[0] = __jvp(f)(dpfloat3(a, da)).d().x; // Expect: 1
- outputBuffer[1] = __jvp(f)(dpfloat3(a, da)).d().y; // Expect: 0.5
+ outputBuffer[0] = __fwd_diff(f)(dpfloat3(a, da)).d().x; // Expect: 1
+ outputBuffer[1] = __fwd_diff(f)(dpfloat3(a, da)).d().y; // Expect: 0.5
float3 x = float3(0.5, 2.0, 0.5);
float4 y = float4(-1.5, 1.0, 4.0, 2.0);
float3 dx = float3(1.0, 0.0, -1.0);
float4 dy = float4(0.0, 0.5, -0.25, 1.0);
- outputBuffer[2] = __jvp(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d().x; // Expect: -2.25
- outputBuffer[3] = __jvp(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d().y; // Expect: 0.5
+ outputBuffer[2] = __fwd_diff(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d().x; // Expect: -2.25
+ outputBuffer[3] = __fwd_diff(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d().y; // Expect: 0.5
}
}