From 8da47c460df01fad6f1d0614210a770f4781edb1 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Sat, 25 Jun 2022 15:45:34 -0400 Subject: Added basic auto-diff capabilities for local load/store and simple arithmetic. Also added type-checking during the semantic stage. (#2303) * Added JVPTranscriber to handle differentiation of load, store, var, param and return instructions, as well as conversion of data and function types * Changed class names to be more in line with convention. Added correct type checking for __jvp() and verified that simple calls with only loads and stores are processed correctly * Added logic to differentiate basic arithmetic and literals inside IRConstruct and fixed the way parameters are differentiated Co-authored-by: Yong He --- source/slang/slang-ast-expr.h | 6 +- source/slang/slang-check-expr.cpp | 40 +++- source/slang/slang-check-impl.h | 2 +- source/slang/slang-ir-diff-call.cpp | 8 +- source/slang/slang-ir-diff-jvp.cpp | 368 +++++++++++++++++++++++++++++++++--- source/slang/slang-ir-inst-defs.h | 2 +- source/slang/slang-ir-insts.h | 10 +- source/slang/slang-ir.cpp | 30 ++- source/slang/slang-lower-to-ir.cpp | 6 +- source/slang/slang-parser.cpp | 12 +- 10 files changed, 428 insertions(+), 56 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 8f407321e..7226f365e 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -372,10 +372,10 @@ class ExtractExistentialValueExpr: public Expr /// An expression of the form `__jvp(fn)` to access the /// forward-mode derivative version of the function `fn` /// -class JVPDerivativeOfExpr: public Expr +class JVPDifferentiateExpr: public Expr { - SLANG_AST_CLASS(JVPDerivativeOfExpr) - Expr* baseFn; + SLANG_AST_CLASS(JVPDifferentiateExpr) + Expr* baseFunction; }; /// A type expression of the form `__TaggedUnion(A, ...)`. diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index ff469428b..576220c02 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1509,16 +1509,46 @@ namespace Slang return expr; } - Expr* SemanticsExprVisitor::visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr) + Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) { // Check/Resolve inner function declaration. - expr->baseFn = CheckTerm(expr->baseFn); + expr->baseFunction = CheckTerm(expr->baseFunction); - if(auto funcType = as(expr->baseFn->type)) + if(auto primalType = as(expr->baseFunction->type)) { // Resolve JVP type here. - // Temporarily resolving to the same type as the original function. - expr->type = expr->baseFn->type; + // Note that this type checking needs to be in sync with + // the auto-generation logic in slang-ir-jvp-diff.cpp + + auto astBuilder = this->getASTBuilder(); + FuncType* jvpType = astBuilder->create(); + + // Only float types can be differentiated for now. + + // The JVP return type is float if primal return type is float + // void otherwise. + // + if (primalType->resultType->equals(astBuilder->getFloatType())) + jvpType->resultType = astBuilder->getFloatType(); + else + jvpType->resultType = astBuilder->getVoidType(); + + // No support for differentiating function that throw errors, for now. + SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType())); + jvpType->errorType = primalType->errorType; + + for (UInt i = 0; i < primalType->getParamCount(); i++) + { + jvpType->paramTypes.add(primalType->getParamType(i)); + } + + for (UInt i = 0; i < primalType->getParamCount(); i++) + { + if(primalType->getParamType(i)->equals(astBuilder->getFloatType())) + jvpType->paramTypes.add(astBuilder->getFloatType()); + } + + expr->type = jvpType; } else { diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 3be5ba68b..ccf9ccad3 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1732,7 +1732,7 @@ namespace Slang Expr* visitAndTypeExpr(AndTypeExpr* expr); Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr); - Expr* visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr); + Expr* visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr); /// Perform semantic checking on a `modifier` that is being applied to the given `type` Val* checkTypeModifier(Modifier* modifier, Type* type); diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp index 76ffe3c8b..92044be3c 100644 --- a/source/slang/slang-ir-diff-call.cpp +++ b/source/slang/slang-ir-diff-call.cpp @@ -34,10 +34,10 @@ struct DerivativeCallProcessContext do { auto nextChild = child->getNextInst(); - // Look for IRJVPDerivativeOf - if (auto derivOf = as(child)) + // Look for IRJVPDifferentiate + if (auto derivOf = as(child)) { - processDerivativeOf(derivOf); + processDifferentiate(derivOf); } child = nextChild; } @@ -50,7 +50,7 @@ struct DerivativeCallProcessContext // Perform forward-mode automatic differentiation on // the intstructions. - void processDerivativeOf(IRJVPDerivativeOf* derivOfInst) + void processDifferentiate(IRJVPDifferentiate* derivOfInst) { IRFunc* jvpFunc = nullptr; diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 431c8e5b2..2a42a7b6e 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -3,10 +3,303 @@ #include "slang-ir.h" #include "slang-ir-insts.h" +#include "slang-ir-clone.h" namespace Slang { +struct JVPTranscriber +{ + + // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent + // their differential values. + Dictionary instMapD; + + // Cloning environment to hold mapping from old to new copies for the primal + // instructions. + IRCloneEnv cloneEnv; + + void mapDifferentialInst(IRInst* instP, IRInst* instD) + { + instMapD.Add(instP, instD); + } + + IRInst* getDifferentialInst(IRInst* instP) + { + return instMapD[instP]; + } + + IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) + { + List parameterTypesD; + IRType* returnTypeD; + + // Add all primal parameters to the list. + for (UIndex i = 0; i < funcType->getParamCount(); i++) + { + parameterTypesD.add(funcType->getParamType(i)); + } + + // Add differential versions for the types we support. + for (UIndex i = 0; i < funcType->getParamCount(); i++) + { + if (auto typeD = differentiateType(builder, funcType->getParamType(i))) + parameterTypesD.add(typeD); + } + + // Transcribe return type. + // This will be void if the primal return type is non-differentiable. + // + returnTypeD = differentiateType(builder, funcType->getResultType()); + if (!returnTypeD) + returnTypeD = builder->getVoidType(); + + return builder->getFuncType(parameterTypesD, returnTypeD); + } + + IRType* differentiateType(IRBuilder* builder, IRType* typeP) + { + switch (typeP->getOp()) + { + case kIROp_HalfType: + case kIROp_FloatType: + case kIROp_DoubleType: + return builder->getType(typeP->getOp()); + + default: + return nullptr; + } + } + + IRInst* differentiateParam(IRBuilder* builder, IRParam* paramP) + { + if (IRType* typeD = differentiateType(builder, paramP->getFullType())) + { + IRParam* paramD = builder->emitParam(typeD); + SLANG_ASSERT(paramD); + return paramD; + } + return nullptr; + } + + List transcribeParams(IRBuilder* builder, IRInstList paramListP) + { + // Clone (and emit) all the primal parameters. + List newParamListP; + for (auto paramP : paramListP) + { + newParamListP.add(as(cloneInst(&cloneEnv, builder, paramP))); + } + + // Now emit differentials. + List newParamListD; + for (auto paramP : newParamListP) + { + IRParam* paramD = as(differentiateParam(builder, paramP)); + mapDifferentialInst(paramP, paramD); + newParamListD.add(paramD); + } + + return newParamListD; + } + + IRInst* differentiateVar(IRBuilder* builder, IRVar* varP) + { + if (IRType* typeD = differentiateType(builder, varP->getDataType()->getValueType())) + { + IRVar* varD = builder->emitVar(typeD); + SLANG_ASSERT(varD); + return varD; + } + return nullptr; + } + + IRInst* differentiateBinaryArith(IRBuilder* builder, IRInst* arith) + { + SLANG_ASSERT(arith->getOperandCount() == 2); + + auto leftP = arith->getOperand(0); + auto rightP = arith->getOperand(1); + + auto leftD = getDifferentialInst(leftP); + auto rightD = getDifferentialInst(rightP); + + auto leftZero = builder->getFloatValue(leftP->getDataType(), 0.0); + auto rightZero = builder->getFloatValue(rightP->getDataType(), 0.0); + + if (leftD || rightD) + { + leftD = leftD ? leftD : leftZero; + rightD = rightD ? rightD : rightZero; + + // Might have to do special-case handling for non-scalar types, + // like float3 or float3x3 + // + auto resultType = arith->getDataType(); + switch(arith->getOp()) + { + case kIROp_Add: + return builder->emitAdd(resultType, leftD, rightD); + case kIROp_Mul: + return builder->emitAdd(resultType, + builder->emitMul(resultType, leftD, rightP), + builder->emitMul(resultType, leftP, rightD)); + case kIROp_Sub: + return builder->emitSub(resultType, leftD, rightD); + case kIROp_Div: + return builder->emitDiv(resultType, + builder->emitSub( + resultType, + builder->emitMul(resultType, leftD, rightP), + builder->emitMul(resultType, leftP, rightD)), + builder->emitMul( + rightP->getDataType(), rightP, rightP + )); + default: + SLANG_UNEXPECTED("Attempting to differentiate unsupported arithmetic"); + } + } + + return nullptr; + } + + IRInst* differentiateLoad(IRBuilder* builder, IRLoad* loadP) + { + if (auto varP = as(loadP->getPtr())) + { + // If the loaded parameter has a differential version, + // emit a load instruction for the differential parameter. + // Otherwise, emit nothing since there's nothing to load. + // + if (auto varD = as(getDifferentialInst(varP))) + { + IRLoad* loadD = as(builder->emitLoad(varD)); + SLANG_ASSERT(loadD); + return loadD; + } + return nullptr; + } + + SLANG_UNEXPECTED("Attempting to differentiate an unsupported load instruction"); + } + + IRInst* differentiateStore(IRBuilder* builder, IRStore* storeP) + { + IRInst* storeLocation = storeP->getPtr(); + IRInst* storeVal = storeP->getVal(); + if (auto destParam = as(storeLocation)) + { + // If the stored value has a differential version, + // emit a store instruction for the differential parameter. + // Otherwise, emit nothing since there's nothing to load. + // + IRInst* storeValD = getDifferentialInst(storeVal); + IRVar* storeLocationD = as(getDifferentialInst(destParam)); + if (storeValD && storeLocationD) + { + IRStore* storeD = as( + builder->emitStore(storeLocationD, storeValD)); + SLANG_ASSERT(storeD); + return storeD; + } + return nullptr; + } + + SLANG_UNEXPECTED("Attempting to differentiate an unsupported store instruction"); + } + + IRInst* differentiateReturn(IRBuilder* builder, IRReturn* returnP) + { + IRInst* returnVal = findCloneForOperand(&cloneEnv, returnP->getVal()); + if (auto returnValD = getDifferentialInst(returnVal)) + { + IRReturn* returnD = as(builder->emitReturn(returnValD)); + SLANG_ASSERT(returnD); + return returnD; + } + return nullptr; + } + + // Since int/float literals are sometimes nested inside an IRConstructor + // instruction, we check to make sure that the nested instr is a constant + // and then return nullptr. Literals do not need to be differentiated. + // + IRInst* differentiateConstruct(IRBuilder*, IRInst* consP) + { + + if (as(consP->getOperand(0)) && consP->getOperandCount() == 1) + { + return nullptr; + } + SLANG_UNEXPECTED("Attempting to differentiate unsupported constructor"); + } + + // Logic for whether a primal instruction needs to be replicated + // in the differential function. For puerly functional blocks with + // no side-effects, it's safe to replicate everything except the + // return instruction. + // + bool requiresPrimalClone(IRBuilder*, IRInst* instP) + { + if (as(instP)) + { + return false; + } + else + { + return true; + } + } + + IRInst* transcribe(IRBuilder* builder, IRInst* oldInstP) + { + IRInst* instP = oldInstP; + + // Clone the old instruction, but only if it's safe to do so. + // For instance, instructions that handle control flow + // (return statements) shouldn't be replicated. + // + if (requiresPrimalClone(builder, oldInstP)) + instP = cloneInst(&cloneEnv, builder, oldInstP); + SLANG_ASSERT(instP); + + IRInst* instD = differentiateInst(builder, instP); + + mapDifferentialInst(instP, instD); + + return instD; + } + + IRInst* differentiateInst(IRBuilder* builder, IRInst* instP) + { + switch (instP->getOp()) + { + case kIROp_Var: + return differentiateVar(builder, as(instP)); + + case kIROp_Load: + return differentiateLoad(builder, as(instP)); + + case kIROp_Store: + return differentiateStore(builder, as(instP)); + + case kIROp_Return: + return differentiateReturn(builder, as(instP)); + + case kIROp_Add: + case kIROp_Mul: + case kIROp_Sub: + return differentiateBinaryArith(builder, instP); + + case kIROp_Construct: + return differentiateConstruct(builder, instP); + + default: + SLANG_UNEXPECTED("Attempting to differentiate unrecognized instruction"); + } + } +}; + struct JVPDerivativeContext { // This type passes over the module and generates @@ -16,7 +309,12 @@ struct JVPDerivativeContext IRModule* module; // Shared builder state for our derivative passes. - SharedIRBuilder sharedBuilderStorage; + SharedIRBuilder sharedBuilderStorage; + + // A transcriber object that handles the main job of + // processing instructions while maintaining state. + // + JVPTranscriber transcriberStorage; bool processModule() { @@ -31,6 +329,7 @@ struct JVPDerivativeContext // looking for callables. // Note: We're only processing global callables (IRGlobalValueWithCode) // for now. + // IRBuilder builderStorage(sharedBuilderStorage); IRBuilder* builder = &builderStorage; for (auto inst : module->getGlobalInsts()) @@ -43,6 +342,8 @@ struct JVPDerivativeContext SLANG_ASSERT(as(callable)); IRFunc* jvpFunction = emitJVPFunction(&builderStorage, as(callable)); builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction); + + unmarkForJVP(callable); } } } @@ -62,12 +363,27 @@ struct JVPDerivativeContext { return true; } - // TODO: Need to remove this decoration or check for - // JVPDerivativeReferenceDecoration to avoid re-generating code. } return false; } + // Removes the JVPDerivativeMarkerDecoration from the provided callable, + // if it exists. + // + void unmarkForJVP(IRGlobalValueWithCode* callable) + { + for(auto decoration = callable->getFirstDecoration(); + decoration; + decoration = decoration->getNextDecoration()) + { + if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration) + { + decoration->removeAndDeallocate(); + return; + } + } + } + List emitFuncParameters(IRBuilder* builder, IRFuncType* dataType) { List params; @@ -81,15 +397,21 @@ struct JVPDerivativeContext // Perform forward-mode automatic differentiation on // the intstructions. + // IRFunc* emitJVPFunction(IRBuilder* builder, IRFunc* primalFn) { - // Note (sai): Is this safe? Should we use setInsertInto? + builder->setInsertBefore(primalFn->getNextInst()); auto jvpFn = builder->createFunc(); - IRType* jvpFuncType = primalTypeToJVPType(primalFn->getFullType()); + + SLANG_ASSERT(as(primalFn->getFullType())); + IRType* jvpFuncType = transcriberStorage.differentiateFunctionType( + builder, + as(primalFn->getFullType())); jvpFn->setFullType(jvpFuncType); + if (auto jvpName = getJVPFuncName(builder, primalFn)) builder->addNameHintDecoration(jvpFn, jvpName); @@ -100,13 +422,7 @@ struct JVPDerivativeContext for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock()) { - IRBlock* newJVPBlock = nullptr; - if (block == primalFn->getFirstBlock()) - { - newJVPBlock = builder->emitBlock(); - emitFuncParameters(builder, as(jvpFuncType)); - } - newJVPBlock = emitJVPBlock(builder, primalFn->getFirstBlock(), newJVPBlock); + emitJVPBlock(builder, primalFn->getFirstBlock()); } return jvpFn; @@ -138,31 +454,31 @@ struct JVPDerivativeContext IRBlock* primalBlock, IRBlock* jvpBlock = nullptr) { - // Create if not already provided, and insert into new block. + JVPTranscriber* transcriber = &(transcriberStorage); + + // Create if not already created, and then insert into new block. if (!jvpBlock) jvpBlock = builder->emitBlock(); else builder->setInsertInto(jvpBlock); - // Temporarily, we're going to just emit a single return 0 instruction. - for(auto child = primalBlock->getFirstInst(); child; child = child->getNextInst()) + // First transcribe the parameter list. This is done separately because we + // want all the derivative parameters emitted after the primal parameters + // rather than interleaved with one another. + // + transcriber->transcribeParams(builder, primalBlock->getParams()); + + // Run through every instruction and use the transcriber to generate the appropriate + // derivative code. + // + for(auto child = primalBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) { - if (auto returnOp = as(child)) - { - auto zeroVal = builder->getFloatValue(returnOp->getVal()->getDataType(), 0.0); - builder->emitReturn(zeroVal); - } + transcriber->transcribe(builder, child); } return jvpBlock; } - IRType* primalTypeToJVPType(IRType* primalType) - { - // Temporarily, we're going to implement the identity transform. - // The return type is the same as the primal type. - return primalType; - } }; // Set up context and call main process method. diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 793f1f78f..6eae710c0 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -717,7 +717,7 @@ INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0) INST(BitCast, bitCast, 1, 0) INST(Reinterpret, reinterpret, 1, 0) -INST(JVPDerivativeOf, jvpDerivativeOf, 1, 0) +INST(JVPDifferentiate, jvpDifferentiate, 1, 0) // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 82d0d5a0e..521570b22 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -529,17 +529,17 @@ struct IRJVPDerivativeReferenceDecoration : IRDecoration // An instruction that replaces the function symbol // with it's derivative function. -struct IRJVPDerivativeOf : IRInst +struct IRJVPDifferentiate : IRInst { enum { - kOp = kIROp_JVPDerivativeOf + kOp = kIROp_JVPDifferentiate }; // The base function for the call. IRUse base; IRInst* getBaseFn() { return getOperand(0); } - IR_LEAF_ISA(JVPDerivativeOf) + IR_LEAF_ISA(JVPDifferentiate) }; // An instruction that specializes another IR value @@ -2346,7 +2346,7 @@ public: IRInst* emitExtractExistentialWitnessTable( IRInst* existentialValue); - IRInst* emitJVPDerivativeOfInst(IRType* type, IRInst* baseFn); + IRInst* emitJVPDifferentiateInst(IRType* type, IRInst* baseFn); IRInst* emitSpecializeInst( IRType* type, @@ -2820,7 +2820,9 @@ public: IRInst* emitBitNot(IRType* type, IRInst* value); IRInst* emitAdd(IRType* type, IRInst* left, IRInst* right); + IRInst* emitSub(IRType* type, IRInst* left, IRInst* right); IRInst* emitMul(IRType* type, IRInst* left, IRInst* right); + IRInst* emitDiv(IRType* type, IRInst* numerator, IRInst* denominator); IRInst* emitEql(IRInst* left, IRInst* right); IRInst* emitNeq(IRInst* left, IRInst* right); IRInst* emitLess(IRInst* left, IRInst* right); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 950061d4f..dd71ae782 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3027,11 +3027,11 @@ namespace Slang return inst; } - IRInst* IRBuilder::emitJVPDerivativeOfInst(IRType* type, IRInst* baseFn) + IRInst* IRBuilder::emitJVPDifferentiateInst(IRType* type, IRInst* baseFn) { - auto inst = createInst( + auto inst = createInst( this, - kIROp_JVPDerivativeOf, + kIROp_JVPDifferentiate, type, baseFn); addInst(inst); @@ -4370,6 +4370,18 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitSub(IRType* type, IRInst* left, IRInst* right) + { + auto inst = createInst( + this, + kIROp_Sub, + type, + left, + right); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitEql(IRInst* left, IRInst* right) { auto inst = createInst(this, kIROp_Eql, getBoolType(), left, right); @@ -4403,6 +4415,18 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitDiv(IRType* type, IRInst* numerator, IRInst* denominator) + { + auto inst = createInst( + this, + kIROp_Div, + type, + numerator, + denominator); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitShr(IRType* type, IRInst* left, IRInst* right) { auto inst = createInst(this, kIROp_Rsh, type, left, right); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d845342f0..b7c5155b5 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2943,13 +2943,13 @@ struct ExprLoweringVisitorBase : ExprVisitor // of the inner func-expr. This will be resolved // to a concrete function during the derivative // pass. - LoweredValInfo visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr) + LoweredValInfo visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) { - auto baseVal = lowerSubExpr(expr->baseFn); + auto baseVal = lowerSubExpr(expr->baseFunction); SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); return LoweredValInfo::simple( - getBuilder()->emitJVPDerivativeOfInst( + getBuilder()->emitJVPDifferentiateInst( lowerType(context, expr->type), baseVal.val)); } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index ee34eac6f..d168bf55c 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2058,22 +2058,22 @@ namespace Slang } /// Parse an expression of the form __jvp(fn) where fn is an /// identifier pointing to a function. - static Expr* parseJVPDerivativeOf(Parser* parser) + static Expr* parseJVPDifferentiate(Parser* parser) { - JVPDerivativeOfExpr* jvpExpr = parser->astBuilder->create(); + JVPDifferentiateExpr* jvpExpr = parser->astBuilder->create(); parser->ReadToken(TokenType::LParent); - jvpExpr->baseFn = parser->ParseExpression(); + jvpExpr->baseFunction = parser->ParseExpression(); parser->ReadToken(TokenType::RParent); return jvpExpr; } - static NodeBase* parseJVPDerivativeOf(Parser* parser, void* /* unused */) + static NodeBase* parseJVPDifferentiate(Parser* parser, void* /* unused */) { - return parseJVPDerivativeOf(parser); + return parseJVPDifferentiate(parser); } /// Parse a `This` type expression @@ -6492,7 +6492,7 @@ namespace Slang _makeParseExpr("nullptr", parseNullPtrExpr), _makeParseExpr("try", parseTryExpr), _makeParseExpr("__TaggedUnion", parseTaggedUnionType), - _makeParseExpr("__jvp", parseJVPDerivativeOf) + _makeParseExpr("__jvp", parseJVPDifferentiate) }; ConstArrayView getSyntaxParseInfos() -- cgit v1.2.3