summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-10-27 12:19:30 -0700
committerGitHub <noreply@github.com>2022-10-27 12:19:30 -0700
commit0cbef6fd6d7924d37ef3ea5ec7c848c80947d13f (patch)
tree173fa18c39638e7d41ae092b9012554cb867a31b /source/slang
parent351e78f3abc54f114237d4af64f8199476ebf176 (diff)
Rename `__jvp`-->`__fwd_diff`. (#2471)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-expr.h6
-rw-r--r--source/slang/slang-check-expr.cpp2
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-check-overload.cpp8
-rw-r--r--source/slang/slang-ir-diff-call.cpp10
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp14
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h8
-rw-r--r--source/slang/slang-ir.cpp6
-rw-r--r--source/slang/slang-lower-to-ir.cpp4
-rw-r--r--source/slang/slang-parser.cpp12
11 files changed, 37 insertions, 37 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index fca628a49..f2a72703e 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -437,12 +437,12 @@ class HigherOrderInvokeExpr : public Expr
Expr* baseFunction;
};
- /// An expression of the form `__jvp(fn)` to access the
+ /// An expression of the form `__fwd_diff(fn)` to access the
/// forward-mode derivative version of the function `fn`
///
-class JVPDifferentiateExpr: public HigherOrderInvokeExpr
+class ForwardDifferentiateExpr: public HigherOrderInvokeExpr
{
- SLANG_AST_CLASS(JVPDifferentiateExpr)
+ SLANG_AST_CLASS(ForwardDifferentiateExpr)
};
/// A type expression of the form `__TaggedUnion(A, ...)`.
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index d1e737720..0975de985 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1971,7 +1971,7 @@ namespace Slang
return jvpType;
}
- Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
+ Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr)
{
this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary();
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 31075c3e8..af1173051 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1987,7 +1987,7 @@ namespace Slang
Expr* visitPointerTypeExpr(PointerTypeExpr* expr);
Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr);
- Expr* visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr);
+ Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr);
/// Perform semantic checking on a `modifier` that is being applied to the given `type`
Val* checkTypeModifier(Modifier* modifier, Type* type);
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index eadf2f63d..ef067c06c 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1548,11 +1548,11 @@ namespace Slang
// Lookup the higher order function and process types accordingly. In the future,
// if there are enough varieties, we can have dispatch logic instead of an
// if-else ladder.
- if (auto jvpExpr = as<JVPDifferentiateExpr>(funcExpr))
+ if (auto jvpExpr = as<ForwardDifferentiateExpr>(funcExpr))
{
if (auto origFuncType = as<FuncType>(jvpExpr->baseFunction->type))
{
- // Case: __jvp(name-resolved-to-decl-ref)
+ // Case: __fwd_diff(name-resolved-to-decl-ref)
auto baseFuncDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<CallableDecl>();
SLANG_ASSERT(baseFuncDeclRef);
@@ -1567,7 +1567,7 @@ namespace Slang
}
else if (auto origOverloadedType = as<OverloadGroupType>(jvpExpr->baseFunction->type))
{
- // Case: __jvp(name-resolved-to-multiple-decl-ref)
+ // Case: __fwd_diff(name-resolved-to-multiple-decl-ref)
if (auto overloadExpr = as<OverloadedExpr>(jvpExpr->baseFunction))
{
@@ -1595,7 +1595,7 @@ namespace Slang
}
else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<GenericDecl>())
{
- // Case: __jvp(name-resolved-to-generic-decl)
+ // Case: __fwd_diff(name-resolved-to-generic-decl)
// Get inner function
DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>(
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp
index 3f2c6c789..34e7e3de0 100644
--- a/source/slang/slang-ir-diff-call.cpp
+++ b/source/slang/slang-ir-diff-call.cpp
@@ -34,8 +34,8 @@ struct DerivativeCallProcessContext
do
{
auto nextChild = child->getNextInst();
- // Look for IRJVPDifferentiate
- if (auto derivOf = as<IRJVPDifferentiate>(child))
+ // Look for IRForwardDifferentiate
+ if (auto derivOf = as<IRForwardDifferentiate>(child))
{
processDifferentiate(derivOf);
}
@@ -50,14 +50,14 @@ struct DerivativeCallProcessContext
// Perform forward-mode automatic differentiation on
// the intstructions.
- void processDifferentiate(IRJVPDifferentiate* derivOfInst)
+ void processDifferentiate(IRForwardDifferentiate* derivOfInst)
{
IRInst* jvpCallable = nullptr;
// First get base function
auto origCallable = derivOfInst->getBaseFn();
- // Resolve the derivative function for IRJVPDifferentiate(IRSpecialize(IRFunc))
+ // Resolve the derivative function for IRForwardDifferentiate(IRSpecialize(IRFunc))
// Check the specialize inst for a reference to the derivative fn.
//
if (auto origSpecialize = as<IRSpecialize>(origCallable))
@@ -68,7 +68,7 @@ struct DerivativeCallProcessContext
}
}
- // Resolve the derivative function for an IRJVPDifferentiate(IRFunc)
+ // Resolve the derivative function for an IRForwardDifferentiate(IRFunc)
//
// Check for the 'JVPDerivativeReference' decorator on the
// base function.
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 1a86506b3..7e6fd30dd 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -1125,7 +1125,7 @@ struct JVPTranscriber
{
// If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
// to generate the implementation.
- diffCallee = builder->emitJVPDifferentiateInst(
+ diffCallee = builder->emitForwardDifferentiateInst(
differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
primalCallee);
}
@@ -1357,8 +1357,8 @@ struct JVPTranscriber
// logic here. We simple clone the original lookup which points to the original function,
// or the cloned version in case we're inside a generic scope.
// The differentiation logic is inserted later when this is used in an IRCall.
- // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Lookup(Table))
- // rather than have Lookup(JVPDifferentiate(Table))
+ // This decision is mostly to maintain a uniform convention of ForwardDifferentiate(Lookup(Table))
+ // rather than have Lookup(ForwardDifferentiate(Table))
//
auto diffLookup = cloneInst(&cloneEnv, builder, origLookup);
return InstPair(diffLookup, diffLookup);
@@ -1937,7 +1937,7 @@ struct JVPDerivativeContext
IRBuilder builderStorage(sharedBuilderStorage);
IRBuilder* builder = &builderStorage;
- // Process all JVPDifferentiate instructions (kIROp_JVPDifferentiate), by
+ // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by
// generating derivative code for the referenced function.
//
bool modified = processReferencedFunctions(builder);
@@ -1962,7 +1962,7 @@ struct JVPDerivativeContext
return nullptr;
}
- // Recursively process instructions looking for JVP calls (kIROp_JVPDifferentiate),
+ // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate),
// then check that the referenced function is marked correctly for differentiation.
//
bool processReferencedFunctions(IRBuilder* builder)
@@ -1979,13 +1979,13 @@ struct JVPDerivativeContext
{
// Either the child instruction has more children (func/block etc..)
// and we add it to the work list for further processing, or
- // it's an ordinary inst in which case we check if it's a JVPDifferentiate
+ // it's an ordinary inst in which case we check if it's a ForwardDifferentiate
// instruction.
//
if (child->getFirstChild() != nullptr)
workQueue->push(child);
- if (auto jvpDiffInst = as<IRJVPDifferentiate>(child))
+ if (auto jvpDiffInst = as<IRForwardDifferentiate>(child))
{
auto baseInst = jvpDiffInst->getBaseFn();
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index c59286116..ccde80476 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -758,7 +758,7 @@ INST(BitCast, bitCast, 1, 0)
INST(Reinterpret, reinterpret, 1, 0)
INST(CastPtrToBool, CastPtrToBool, 1, 0)
INST(IsType, IsType, 3, 0)
-INST(JVPDifferentiate, jvpDifferentiate, 1, 0)
+INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0)
// Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer
INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 5a9c14038..95202d9d0 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -579,17 +579,17 @@ struct IRJVPDerivativeMemberReferenceDecoration : IRDecoration
// An instruction that replaces the function symbol
// with it's derivative function.
-struct IRJVPDifferentiate : IRInst
+struct IRForwardDifferentiate : IRInst
{
enum
{
- kOp = kIROp_JVPDifferentiate
+ kOp = kIROp_ForwardDifferentiate
};
// The base function for the call.
IRUse base;
IRInst* getBaseFn() { return getOperand(0); }
- IR_LEAF_ISA(JVPDifferentiate)
+ IR_LEAF_ISA(ForwardDifferentiate)
};
// Dictionary item mapping a type with a corresponding
@@ -2486,7 +2486,7 @@ public:
IRInst* emitExtractExistentialWitnessTable(
IRInst* existentialValue);
- IRInst* emitJVPDifferentiateInst(IRType* type, IRInst* baseFn);
+ IRInst* emitForwardDifferentiateInst(IRType* type, IRInst* baseFn);
IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 2aaeb4ac3..3a59eb6c9 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3065,11 +3065,11 @@ namespace Slang
return inst;
}
- IRInst* IRBuilder::emitJVPDifferentiateInst(IRType* type, IRInst* baseFn)
+ IRInst* IRBuilder::emitForwardDifferentiateInst(IRType* type, IRInst* baseFn)
{
- auto inst = createInst<IRJVPDifferentiate>(
+ auto inst = createInst<IRForwardDifferentiate>(
this,
- kIROp_JVPDifferentiate,
+ kIROp_ForwardDifferentiate,
type,
baseFn);
addInst(inst);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index ff66caa90..3766a1a5e 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3042,13 +3042,13 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// of the inner func-expr. This will be resolved
// to a concrete function during the derivative
// pass.
- LoweredValInfo visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
+ LoweredValInfo visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr)
{
auto baseVal = lowerSubExpr(expr->baseFunction);
SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
return LoweredValInfo::simple(
- getBuilder()->emitJVPDifferentiateInst(
+ getBuilder()->emitForwardDifferentiateInst(
lowerType(context, expr->type),
baseVal.val));
}
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index f2284a121..93f2fcdcb 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -2089,11 +2089,11 @@ namespace Slang
{
return parseTaggedUnionType(parser);
}
- /// Parse an expression of the form __jvp(fn) where fn is an
+ /// Parse an expression of the form __fwd_diff(fn) where fn is an
/// identifier pointing to a function.
- static Expr* parseJVPDifferentiate(Parser* parser)
+ static Expr* parseForwardDifferentiate(Parser* parser)
{
- JVPDifferentiateExpr* jvpExpr = parser->astBuilder->create<JVPDifferentiateExpr>();
+ ForwardDifferentiateExpr* jvpExpr = parser->astBuilder->create<ForwardDifferentiateExpr>();
parser->ReadToken(TokenType::LParent);
@@ -2104,9 +2104,9 @@ namespace Slang
return jvpExpr;
}
- static NodeBase* parseJVPDifferentiate(Parser* parser, void* /* unused */)
+ static NodeBase* parseForwardDifferentiate(Parser* parser, void* /* unused */)
{
- return parseJVPDifferentiate(parser);
+ return parseForwardDifferentiate(parser);
}
/// Parse a `This` type expression
@@ -6634,7 +6634,7 @@ namespace Slang
_makeParseExpr("none", parseNoneExpr),
_makeParseExpr("try", parseTryExpr),
_makeParseExpr("__TaggedUnion", parseTaggedUnionType),
- _makeParseExpr("__jvp", parseJVPDifferentiate)
+ _makeParseExpr("__fwd_diff", parseForwardDifferentiate)
};
ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos()