diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-27 12:19:30 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-27 12:19:30 -0700 |
| commit | 0cbef6fd6d7924d37ef3ea5ec7c848c80947d13f (patch) | |
| tree | 173fa18c39638e7d41ae092b9012554cb867a31b /source/slang | |
| parent | 351e78f3abc54f114237d4af64f8199476ebf176 (diff) | |
Rename `__jvp`-->`__fwd_diff`. (#2471)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ast-expr.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-call.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 12 |
11 files changed, 37 insertions, 37 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index fca628a49..f2a72703e 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -437,12 +437,12 @@ class HigherOrderInvokeExpr : public Expr Expr* baseFunction; }; - /// An expression of the form `__jvp(fn)` to access the + /// An expression of the form `__fwd_diff(fn)` to access the /// forward-mode derivative version of the function `fn` /// -class JVPDifferentiateExpr: public HigherOrderInvokeExpr +class ForwardDifferentiateExpr: public HigherOrderInvokeExpr { - SLANG_AST_CLASS(JVPDifferentiateExpr) + SLANG_AST_CLASS(ForwardDifferentiateExpr) }; /// A type expression of the form `__TaggedUnion(A, ...)`. diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index d1e737720..0975de985 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1971,7 +1971,7 @@ namespace Slang return jvpType; } - Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) + Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) { this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 31075c3e8..af1173051 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1987,7 +1987,7 @@ namespace Slang Expr* visitPointerTypeExpr(PointerTypeExpr* expr); Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr); - Expr* visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr); + Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr); /// Perform semantic checking on a `modifier` that is being applied to the given `type` Val* checkTypeModifier(Modifier* modifier, Type* type); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index eadf2f63d..ef067c06c 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1548,11 +1548,11 @@ namespace Slang // Lookup the higher order function and process types accordingly. In the future, // if there are enough varieties, we can have dispatch logic instead of an // if-else ladder. - if (auto jvpExpr = as<JVPDifferentiateExpr>(funcExpr)) + if (auto jvpExpr = as<ForwardDifferentiateExpr>(funcExpr)) { if (auto origFuncType = as<FuncType>(jvpExpr->baseFunction->type)) { - // Case: __jvp(name-resolved-to-decl-ref) + // Case: __fwd_diff(name-resolved-to-decl-ref) auto baseFuncDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<CallableDecl>(); SLANG_ASSERT(baseFuncDeclRef); @@ -1567,7 +1567,7 @@ namespace Slang } else if (auto origOverloadedType = as<OverloadGroupType>(jvpExpr->baseFunction->type)) { - // Case: __jvp(name-resolved-to-multiple-decl-ref) + // Case: __fwd_diff(name-resolved-to-multiple-decl-ref) if (auto overloadExpr = as<OverloadedExpr>(jvpExpr->baseFunction)) { @@ -1595,7 +1595,7 @@ namespace Slang } else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<GenericDecl>()) { - // Case: __jvp(name-resolved-to-generic-decl) + // Case: __fwd_diff(name-resolved-to-generic-decl) // Get inner function DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>( diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp index 3f2c6c789..34e7e3de0 100644 --- a/source/slang/slang-ir-diff-call.cpp +++ b/source/slang/slang-ir-diff-call.cpp @@ -34,8 +34,8 @@ struct DerivativeCallProcessContext do { auto nextChild = child->getNextInst(); - // Look for IRJVPDifferentiate - if (auto derivOf = as<IRJVPDifferentiate>(child)) + // Look for IRForwardDifferentiate + if (auto derivOf = as<IRForwardDifferentiate>(child)) { processDifferentiate(derivOf); } @@ -50,14 +50,14 @@ struct DerivativeCallProcessContext // Perform forward-mode automatic differentiation on // the intstructions. - void processDifferentiate(IRJVPDifferentiate* derivOfInst) + void processDifferentiate(IRForwardDifferentiate* derivOfInst) { IRInst* jvpCallable = nullptr; // First get base function auto origCallable = derivOfInst->getBaseFn(); - // Resolve the derivative function for IRJVPDifferentiate(IRSpecialize(IRFunc)) + // Resolve the derivative function for IRForwardDifferentiate(IRSpecialize(IRFunc)) // Check the specialize inst for a reference to the derivative fn. // if (auto origSpecialize = as<IRSpecialize>(origCallable)) @@ -68,7 +68,7 @@ struct DerivativeCallProcessContext } } - // Resolve the derivative function for an IRJVPDifferentiate(IRFunc) + // Resolve the derivative function for an IRForwardDifferentiate(IRFunc) // // Check for the 'JVPDerivativeReference' decorator on the // base function. diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 1a86506b3..7e6fd30dd 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -1125,7 +1125,7 @@ struct JVPTranscriber { // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass // to generate the implementation. - diffCallee = builder->emitJVPDifferentiateInst( + diffCallee = builder->emitForwardDifferentiateInst( differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), primalCallee); } @@ -1357,8 +1357,8 @@ struct JVPTranscriber // logic here. We simple clone the original lookup which points to the original function, // or the cloned version in case we're inside a generic scope. // The differentiation logic is inserted later when this is used in an IRCall. - // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Lookup(Table)) - // rather than have Lookup(JVPDifferentiate(Table)) + // This decision is mostly to maintain a uniform convention of ForwardDifferentiate(Lookup(Table)) + // rather than have Lookup(ForwardDifferentiate(Table)) // auto diffLookup = cloneInst(&cloneEnv, builder, origLookup); return InstPair(diffLookup, diffLookup); @@ -1937,7 +1937,7 @@ struct JVPDerivativeContext IRBuilder builderStorage(sharedBuilderStorage); IRBuilder* builder = &builderStorage; - // Process all JVPDifferentiate instructions (kIROp_JVPDifferentiate), by + // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by // generating derivative code for the referenced function. // bool modified = processReferencedFunctions(builder); @@ -1962,7 +1962,7 @@ struct JVPDerivativeContext return nullptr; } - // Recursively process instructions looking for JVP calls (kIROp_JVPDifferentiate), + // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate), // then check that the referenced function is marked correctly for differentiation. // bool processReferencedFunctions(IRBuilder* builder) @@ -1979,13 +1979,13 @@ struct JVPDerivativeContext { // Either the child instruction has more children (func/block etc..) // and we add it to the work list for further processing, or - // it's an ordinary inst in which case we check if it's a JVPDifferentiate + // it's an ordinary inst in which case we check if it's a ForwardDifferentiate // instruction. // if (child->getFirstChild() != nullptr) workQueue->push(child); - if (auto jvpDiffInst = as<IRJVPDifferentiate>(child)) + if (auto jvpDiffInst = as<IRForwardDifferentiate>(child)) { auto baseInst = jvpDiffInst->getBaseFn(); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index c59286116..ccde80476 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -758,7 +758,7 @@ INST(BitCast, bitCast, 1, 0) INST(Reinterpret, reinterpret, 1, 0) INST(CastPtrToBool, CastPtrToBool, 1, 0) INST(IsType, IsType, 3, 0) -INST(JVPDifferentiate, jvpDifferentiate, 1, 0) +INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0) // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 5a9c14038..95202d9d0 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -579,17 +579,17 @@ struct IRJVPDerivativeMemberReferenceDecoration : IRDecoration // An instruction that replaces the function symbol // with it's derivative function. -struct IRJVPDifferentiate : IRInst +struct IRForwardDifferentiate : IRInst { enum { - kOp = kIROp_JVPDifferentiate + kOp = kIROp_ForwardDifferentiate }; // The base function for the call. IRUse base; IRInst* getBaseFn() { return getOperand(0); } - IR_LEAF_ISA(JVPDifferentiate) + IR_LEAF_ISA(ForwardDifferentiate) }; // Dictionary item mapping a type with a corresponding @@ -2486,7 +2486,7 @@ public: IRInst* emitExtractExistentialWitnessTable( IRInst* existentialValue); - IRInst* emitJVPDifferentiateInst(IRType* type, IRInst* baseFn); + IRInst* emitForwardDifferentiateInst(IRType* type, IRInst* baseFn); IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 2aaeb4ac3..3a59eb6c9 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3065,11 +3065,11 @@ namespace Slang return inst; } - IRInst* IRBuilder::emitJVPDifferentiateInst(IRType* type, IRInst* baseFn) + IRInst* IRBuilder::emitForwardDifferentiateInst(IRType* type, IRInst* baseFn) { - auto inst = createInst<IRJVPDifferentiate>( + auto inst = createInst<IRForwardDifferentiate>( this, - kIROp_JVPDifferentiate, + kIROp_ForwardDifferentiate, type, baseFn); addInst(inst); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ff66caa90..3766a1a5e 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3042,13 +3042,13 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // of the inner func-expr. This will be resolved // to a concrete function during the derivative // pass. - LoweredValInfo visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) + LoweredValInfo visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) { auto baseVal = lowerSubExpr(expr->baseFunction); SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); return LoweredValInfo::simple( - getBuilder()->emitJVPDifferentiateInst( + getBuilder()->emitForwardDifferentiateInst( lowerType(context, expr->type), baseVal.val)); } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index f2284a121..93f2fcdcb 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2089,11 +2089,11 @@ namespace Slang { return parseTaggedUnionType(parser); } - /// Parse an expression of the form __jvp(fn) where fn is an + /// Parse an expression of the form __fwd_diff(fn) where fn is an /// identifier pointing to a function. - static Expr* parseJVPDifferentiate(Parser* parser) + static Expr* parseForwardDifferentiate(Parser* parser) { - JVPDifferentiateExpr* jvpExpr = parser->astBuilder->create<JVPDifferentiateExpr>(); + ForwardDifferentiateExpr* jvpExpr = parser->astBuilder->create<ForwardDifferentiateExpr>(); parser->ReadToken(TokenType::LParent); @@ -2104,9 +2104,9 @@ namespace Slang return jvpExpr; } - static NodeBase* parseJVPDifferentiate(Parser* parser, void* /* unused */) + static NodeBase* parseForwardDifferentiate(Parser* parser, void* /* unused */) { - return parseJVPDifferentiate(parser); + return parseForwardDifferentiate(parser); } /// Parse a `This` type expression @@ -6634,7 +6634,7 @@ namespace Slang _makeParseExpr("none", parseNoneExpr), _makeParseExpr("try", parseTryExpr), _makeParseExpr("__TaggedUnion", parseTaggedUnionType), - _makeParseExpr("__jvp", parseJVPDifferentiate) + _makeParseExpr("__fwd_diff", parseForwardDifferentiate) }; ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos() |
