summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-06-25 15:45:34 -0400
committerGitHub <noreply@github.com>2022-06-25 12:45:34 -0700
commit8da47c460df01fad6f1d0614210a770f4781edb1 (patch)
tree170a5cc100c69e387e8c6d34217588ea00daed53 /source
parent0229784b93a43e17a088881e6be32b44fc6ce713 (diff)
Added basic auto-diff capabilities for local load/store and simple arithmetic. Also added type-checking during the semantic stage. (#2303)
* Added JVPTranscriber to handle differentiation of load, store, var, param and return instructions, as well as conversion of data and function types * Changed class names to be more in line with convention. Added correct type checking for __jvp() and verified that simple calls with only loads and stores are processed correctly * Added logic to differentiate basic arithmetic and literals inside IRConstruct and fixed the way parameters are differentiated Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-expr.h6
-rw-r--r--source/slang/slang-check-expr.cpp40
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-ir-diff-call.cpp8
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp368
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h10
-rw-r--r--source/slang/slang-ir.cpp30
-rw-r--r--source/slang/slang-lower-to-ir.cpp6
-rw-r--r--source/slang/slang-parser.cpp12
10 files changed, 428 insertions, 56 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 8f407321e..7226f365e 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -372,10 +372,10 @@ class ExtractExistentialValueExpr: public Expr
/// An expression of the form `__jvp(fn)` to access the
/// forward-mode derivative version of the function `fn`
///
-class JVPDerivativeOfExpr: public Expr
+class JVPDifferentiateExpr: public Expr
{
- SLANG_AST_CLASS(JVPDerivativeOfExpr)
- Expr* baseFn;
+ SLANG_AST_CLASS(JVPDifferentiateExpr)
+ Expr* baseFunction;
};
/// A type expression of the form `__TaggedUnion(A, ...)`.
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index ff469428b..576220c02 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1509,16 +1509,46 @@ namespace Slang
return expr;
}
- Expr* SemanticsExprVisitor::visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr)
+ Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
{
// Check/Resolve inner function declaration.
- expr->baseFn = CheckTerm(expr->baseFn);
+ expr->baseFunction = CheckTerm(expr->baseFunction);
- if(auto funcType = as<FuncType>(expr->baseFn->type))
+ if(auto primalType = as<FuncType>(expr->baseFunction->type))
{
// Resolve JVP type here.
- // Temporarily resolving to the same type as the original function.
- expr->type = expr->baseFn->type;
+ // Note that this type checking needs to be in sync with
+ // the auto-generation logic in slang-ir-jvp-diff.cpp
+
+ auto astBuilder = this->getASTBuilder();
+ FuncType* jvpType = astBuilder->create<FuncType>();
+
+ // Only float types can be differentiated for now.
+
+ // The JVP return type is float if primal return type is float
+ // void otherwise.
+ //
+ if (primalType->resultType->equals(astBuilder->getFloatType()))
+ jvpType->resultType = astBuilder->getFloatType();
+ else
+ jvpType->resultType = astBuilder->getVoidType();
+
+ // No support for differentiating function that throw errors, for now.
+ SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType()));
+ jvpType->errorType = primalType->errorType;
+
+ for (UInt i = 0; i < primalType->getParamCount(); i++)
+ {
+ jvpType->paramTypes.add(primalType->getParamType(i));
+ }
+
+ for (UInt i = 0; i < primalType->getParamCount(); i++)
+ {
+ if(primalType->getParamType(i)->equals(astBuilder->getFloatType()))
+ jvpType->paramTypes.add(astBuilder->getFloatType());
+ }
+
+ expr->type = jvpType;
}
else
{
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 3be5ba68b..ccf9ccad3 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1732,7 +1732,7 @@ namespace Slang
Expr* visitAndTypeExpr(AndTypeExpr* expr);
Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr);
- Expr* visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr);
+ Expr* visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr);
/// Perform semantic checking on a `modifier` that is being applied to the given `type`
Val* checkTypeModifier(Modifier* modifier, Type* type);
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp
index 76ffe3c8b..92044be3c 100644
--- a/source/slang/slang-ir-diff-call.cpp
+++ b/source/slang/slang-ir-diff-call.cpp
@@ -34,10 +34,10 @@ struct DerivativeCallProcessContext
do
{
auto nextChild = child->getNextInst();
- // Look for IRJVPDerivativeOf
- if (auto derivOf = as<IRJVPDerivativeOf>(child))
+ // Look for IRJVPDifferentiate
+ if (auto derivOf = as<IRJVPDifferentiate>(child))
{
- processDerivativeOf(derivOf);
+ processDifferentiate(derivOf);
}
child = nextChild;
}
@@ -50,7 +50,7 @@ struct DerivativeCallProcessContext
// Perform forward-mode automatic differentiation on
// the intstructions.
- void processDerivativeOf(IRJVPDerivativeOf* derivOfInst)
+ void processDifferentiate(IRJVPDifferentiate* derivOfInst)
{
IRFunc* jvpFunc = nullptr;
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 431c8e5b2..2a42a7b6e 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -3,10 +3,303 @@
#include "slang-ir.h"
#include "slang-ir-insts.h"
+#include "slang-ir-clone.h"
namespace Slang
{
+struct JVPTranscriber
+{
+
+ // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
+ // their differential values.
+ Dictionary<IRInst*, IRInst*> instMapD;
+
+ // Cloning environment to hold mapping from old to new copies for the primal
+ // instructions.
+ IRCloneEnv cloneEnv;
+
+ void mapDifferentialInst(IRInst* instP, IRInst* instD)
+ {
+ instMapD.Add(instP, instD);
+ }
+
+ IRInst* getDifferentialInst(IRInst* instP)
+ {
+ return instMapD[instP];
+ }
+
+ IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+ {
+ List<IRType*> parameterTypesD;
+ IRType* returnTypeD;
+
+ // Add all primal parameters to the list.
+ for (UIndex i = 0; i < funcType->getParamCount(); i++)
+ {
+ parameterTypesD.add(funcType->getParamType(i));
+ }
+
+ // Add differential versions for the types we support.
+ for (UIndex i = 0; i < funcType->getParamCount(); i++)
+ {
+ if (auto typeD = differentiateType(builder, funcType->getParamType(i)))
+ parameterTypesD.add(typeD);
+ }
+
+ // Transcribe return type.
+ // This will be void if the primal return type is non-differentiable.
+ //
+ returnTypeD = differentiateType(builder, funcType->getResultType());
+ if (!returnTypeD)
+ returnTypeD = builder->getVoidType();
+
+ return builder->getFuncType(parameterTypesD, returnTypeD);
+ }
+
+ IRType* differentiateType(IRBuilder* builder, IRType* typeP)
+ {
+ switch (typeP->getOp())
+ {
+ case kIROp_HalfType:
+ case kIROp_FloatType:
+ case kIROp_DoubleType:
+ return builder->getType(typeP->getOp());
+
+ default:
+ return nullptr;
+ }
+ }
+
+ IRInst* differentiateParam(IRBuilder* builder, IRParam* paramP)
+ {
+ if (IRType* typeD = differentiateType(builder, paramP->getFullType()))
+ {
+ IRParam* paramD = builder->emitParam(typeD);
+ SLANG_ASSERT(paramD);
+ return paramD;
+ }
+ return nullptr;
+ }
+
+ List<IRParam*> transcribeParams(IRBuilder* builder, IRInstList<IRParam> paramListP)
+ {
+ // Clone (and emit) all the primal parameters.
+ List<IRParam*> newParamListP;
+ for (auto paramP : paramListP)
+ {
+ newParamListP.add(as<IRParam>(cloneInst(&cloneEnv, builder, paramP)));
+ }
+
+ // Now emit differentials.
+ List<IRParam*> newParamListD;
+ for (auto paramP : newParamListP)
+ {
+ IRParam* paramD = as<IRParam>(differentiateParam(builder, paramP));
+ mapDifferentialInst(paramP, paramD);
+ newParamListD.add(paramD);
+ }
+
+ return newParamListD;
+ }
+
+ IRInst* differentiateVar(IRBuilder* builder, IRVar* varP)
+ {
+ if (IRType* typeD = differentiateType(builder, varP->getDataType()->getValueType()))
+ {
+ IRVar* varD = builder->emitVar(typeD);
+ SLANG_ASSERT(varD);
+ return varD;
+ }
+ return nullptr;
+ }
+
+ IRInst* differentiateBinaryArith(IRBuilder* builder, IRInst* arith)
+ {
+ SLANG_ASSERT(arith->getOperandCount() == 2);
+
+ auto leftP = arith->getOperand(0);
+ auto rightP = arith->getOperand(1);
+
+ auto leftD = getDifferentialInst(leftP);
+ auto rightD = getDifferentialInst(rightP);
+
+ auto leftZero = builder->getFloatValue(leftP->getDataType(), 0.0);
+ auto rightZero = builder->getFloatValue(rightP->getDataType(), 0.0);
+
+ if (leftD || rightD)
+ {
+ leftD = leftD ? leftD : leftZero;
+ rightD = rightD ? rightD : rightZero;
+
+ // Might have to do special-case handling for non-scalar types,
+ // like float3 or float3x3
+ //
+ auto resultType = arith->getDataType();
+ switch(arith->getOp())
+ {
+ case kIROp_Add:
+ return builder->emitAdd(resultType, leftD, rightD);
+ case kIROp_Mul:
+ return builder->emitAdd(resultType,
+ builder->emitMul(resultType, leftD, rightP),
+ builder->emitMul(resultType, leftP, rightD));
+ case kIROp_Sub:
+ return builder->emitSub(resultType, leftD, rightD);
+ case kIROp_Div:
+ return builder->emitDiv(resultType,
+ builder->emitSub(
+ resultType,
+ builder->emitMul(resultType, leftD, rightP),
+ builder->emitMul(resultType, leftP, rightD)),
+ builder->emitMul(
+ rightP->getDataType(), rightP, rightP
+ ));
+ default:
+ SLANG_UNEXPECTED("Attempting to differentiate unsupported arithmetic");
+ }
+ }
+
+ return nullptr;
+ }
+
+ IRInst* differentiateLoad(IRBuilder* builder, IRLoad* loadP)
+ {
+ if (auto varP = as<IRVar>(loadP->getPtr()))
+ {
+ // If the loaded parameter has a differential version,
+ // emit a load instruction for the differential parameter.
+ // Otherwise, emit nothing since there's nothing to load.
+ //
+ if (auto varD = as<IRVar>(getDifferentialInst(varP)))
+ {
+ IRLoad* loadD = as<IRLoad>(builder->emitLoad(varD));
+ SLANG_ASSERT(loadD);
+ return loadD;
+ }
+ return nullptr;
+ }
+
+ SLANG_UNEXPECTED("Attempting to differentiate an unsupported load instruction");
+ }
+
+ IRInst* differentiateStore(IRBuilder* builder, IRStore* storeP)
+ {
+ IRInst* storeLocation = storeP->getPtr();
+ IRInst* storeVal = storeP->getVal();
+ if (auto destParam = as<IRVar>(storeLocation))
+ {
+ // If the stored value has a differential version,
+ // emit a store instruction for the differential parameter.
+ // Otherwise, emit nothing since there's nothing to load.
+ //
+ IRInst* storeValD = getDifferentialInst(storeVal);
+ IRVar* storeLocationD = as<IRVar>(getDifferentialInst(destParam));
+ if (storeValD && storeLocationD)
+ {
+ IRStore* storeD = as<IRStore>(
+ builder->emitStore(storeLocationD, storeValD));
+ SLANG_ASSERT(storeD);
+ return storeD;
+ }
+ return nullptr;
+ }
+
+ SLANG_UNEXPECTED("Attempting to differentiate an unsupported store instruction");
+ }
+
+ IRInst* differentiateReturn(IRBuilder* builder, IRReturn* returnP)
+ {
+ IRInst* returnVal = findCloneForOperand(&cloneEnv, returnP->getVal());
+ if (auto returnValD = getDifferentialInst(returnVal))
+ {
+ IRReturn* returnD = as<IRReturn>(builder->emitReturn(returnValD));
+ SLANG_ASSERT(returnD);
+ return returnD;
+ }
+ return nullptr;
+ }
+
+ // Since int/float literals are sometimes nested inside an IRConstructor
+ // instruction, we check to make sure that the nested instr is a constant
+ // and then return nullptr. Literals do not need to be differentiated.
+ //
+ IRInst* differentiateConstruct(IRBuilder*, IRInst* consP)
+ {
+
+ if (as<IRConstant>(consP->getOperand(0)) && consP->getOperandCount() == 1)
+ {
+ return nullptr;
+ }
+ SLANG_UNEXPECTED("Attempting to differentiate unsupported constructor");
+ }
+
+ // Logic for whether a primal instruction needs to be replicated
+ // in the differential function. For puerly functional blocks with
+ // no side-effects, it's safe to replicate everything except the
+ // return instruction.
+ //
+ bool requiresPrimalClone(IRBuilder*, IRInst* instP)
+ {
+ if (as<IRReturn>(instP))
+ {
+ return false;
+ }
+ else
+ {
+ return true;
+ }
+ }
+
+ IRInst* transcribe(IRBuilder* builder, IRInst* oldInstP)
+ {
+ IRInst* instP = oldInstP;
+
+ // Clone the old instruction, but only if it's safe to do so.
+ // For instance, instructions that handle control flow
+ // (return statements) shouldn't be replicated.
+ //
+ if (requiresPrimalClone(builder, oldInstP))
+ instP = cloneInst(&cloneEnv, builder, oldInstP);
+ SLANG_ASSERT(instP);
+
+ IRInst* instD = differentiateInst(builder, instP);
+
+ mapDifferentialInst(instP, instD);
+
+ return instD;
+ }
+
+ IRInst* differentiateInst(IRBuilder* builder, IRInst* instP)
+ {
+ switch (instP->getOp())
+ {
+ case kIROp_Var:
+ return differentiateVar(builder, as<IRVar>(instP));
+
+ case kIROp_Load:
+ return differentiateLoad(builder, as<IRLoad>(instP));
+
+ case kIROp_Store:
+ return differentiateStore(builder, as<IRStore>(instP));
+
+ case kIROp_Return:
+ return differentiateReturn(builder, as<IRReturn>(instP));
+
+ case kIROp_Add:
+ case kIROp_Mul:
+ case kIROp_Sub:
+ return differentiateBinaryArith(builder, instP);
+
+ case kIROp_Construct:
+ return differentiateConstruct(builder, instP);
+
+ default:
+ SLANG_UNEXPECTED("Attempting to differentiate unrecognized instruction");
+ }
+ }
+};
+
struct JVPDerivativeContext
{
// This type passes over the module and generates
@@ -16,7 +309,12 @@ struct JVPDerivativeContext
IRModule* module;
// Shared builder state for our derivative passes.
- SharedIRBuilder sharedBuilderStorage;
+ SharedIRBuilder sharedBuilderStorage;
+
+ // A transcriber object that handles the main job of
+ // processing instructions while maintaining state.
+ //
+ JVPTranscriber transcriberStorage;
bool processModule()
{
@@ -31,6 +329,7 @@ struct JVPDerivativeContext
// looking for callables.
// Note: We're only processing global callables (IRGlobalValueWithCode)
// for now.
+ //
IRBuilder builderStorage(sharedBuilderStorage);
IRBuilder* builder = &builderStorage;
for (auto inst : module->getGlobalInsts())
@@ -43,6 +342,8 @@ struct JVPDerivativeContext
SLANG_ASSERT(as<IRFunc>(callable));
IRFunc* jvpFunction = emitJVPFunction(&builderStorage, as<IRFunc>(callable));
builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction);
+
+ unmarkForJVP(callable);
}
}
}
@@ -62,12 +363,27 @@ struct JVPDerivativeContext
{
return true;
}
- // TODO: Need to remove this decoration or check for
- // JVPDerivativeReferenceDecoration to avoid re-generating code.
}
return false;
}
+ // Removes the JVPDerivativeMarkerDecoration from the provided callable,
+ // if it exists.
+ //
+ void unmarkForJVP(IRGlobalValueWithCode* callable)
+ {
+ for(auto decoration = callable->getFirstDecoration();
+ decoration;
+ decoration = decoration->getNextDecoration())
+ {
+ if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration)
+ {
+ decoration->removeAndDeallocate();
+ return;
+ }
+ }
+ }
+
List<IRParam*> emitFuncParameters(IRBuilder* builder, IRFuncType* dataType)
{
List<IRParam*> params;
@@ -81,15 +397,21 @@ struct JVPDerivativeContext
// Perform forward-mode automatic differentiation on
// the intstructions.
+ //
IRFunc* emitJVPFunction(IRBuilder* builder,
IRFunc* primalFn)
{
- // Note (sai): Is this safe? Should we use setInsertInto?
+
builder->setInsertBefore(primalFn->getNextInst());
auto jvpFn = builder->createFunc();
- IRType* jvpFuncType = primalTypeToJVPType(primalFn->getFullType());
+
+ SLANG_ASSERT(as<IRFuncType>(primalFn->getFullType()));
+ IRType* jvpFuncType = transcriberStorage.differentiateFunctionType(
+ builder,
+ as<IRFuncType>(primalFn->getFullType()));
jvpFn->setFullType(jvpFuncType);
+
if (auto jvpName = getJVPFuncName(builder, primalFn))
builder->addNameHintDecoration(jvpFn, jvpName);
@@ -100,13 +422,7 @@ struct JVPDerivativeContext
for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock())
{
- IRBlock* newJVPBlock = nullptr;
- if (block == primalFn->getFirstBlock())
- {
- newJVPBlock = builder->emitBlock();
- emitFuncParameters(builder, as<IRFuncType>(jvpFuncType));
- }
- newJVPBlock = emitJVPBlock(builder, primalFn->getFirstBlock(), newJVPBlock);
+ emitJVPBlock(builder, primalFn->getFirstBlock());
}
return jvpFn;
@@ -138,31 +454,31 @@ struct JVPDerivativeContext
IRBlock* primalBlock,
IRBlock* jvpBlock = nullptr)
{
- // Create if not already provided, and insert into new block.
+ JVPTranscriber* transcriber = &(transcriberStorage);
+
+ // Create if not already created, and then insert into new block.
if (!jvpBlock)
jvpBlock = builder->emitBlock();
else
builder->setInsertInto(jvpBlock);
- // Temporarily, we're going to just emit a single return 0 instruction.
- for(auto child = primalBlock->getFirstInst(); child; child = child->getNextInst())
+ // First transcribe the parameter list. This is done separately because we
+ // want all the derivative parameters emitted after the primal parameters
+ // rather than interleaved with one another.
+ //
+ transcriber->transcribeParams(builder, primalBlock->getParams());
+
+ // Run through every instruction and use the transcriber to generate the appropriate
+ // derivative code.
+ //
+ for(auto child = primalBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
{
- if (auto returnOp = as<IRReturn>(child))
- {
- auto zeroVal = builder->getFloatValue(returnOp->getVal()->getDataType(), 0.0);
- builder->emitReturn(zeroVal);
- }
+ transcriber->transcribe(builder, child);
}
return jvpBlock;
}
- IRType* primalTypeToJVPType(IRType* primalType)
- {
- // Temporarily, we're going to implement the identity transform.
- // The return type is the same as the primal type.
- return primalType;
- }
};
// Set up context and call main process method.
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 793f1f78f..6eae710c0 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -717,7 +717,7 @@ INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0)
INST(BitCast, bitCast, 1, 0)
INST(Reinterpret, reinterpret, 1, 0)
-INST(JVPDerivativeOf, jvpDerivativeOf, 1, 0)
+INST(JVPDifferentiate, jvpDifferentiate, 1, 0)
// Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer
INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 82d0d5a0e..521570b22 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -529,17 +529,17 @@ struct IRJVPDerivativeReferenceDecoration : IRDecoration
// An instruction that replaces the function symbol
// with it's derivative function.
-struct IRJVPDerivativeOf : IRInst
+struct IRJVPDifferentiate : IRInst
{
enum
{
- kOp = kIROp_JVPDerivativeOf
+ kOp = kIROp_JVPDifferentiate
};
// The base function for the call.
IRUse base;
IRInst* getBaseFn() { return getOperand(0); }
- IR_LEAF_ISA(JVPDerivativeOf)
+ IR_LEAF_ISA(JVPDifferentiate)
};
// An instruction that specializes another IR value
@@ -2346,7 +2346,7 @@ public:
IRInst* emitExtractExistentialWitnessTable(
IRInst* existentialValue);
- IRInst* emitJVPDerivativeOfInst(IRType* type, IRInst* baseFn);
+ IRInst* emitJVPDifferentiateInst(IRType* type, IRInst* baseFn);
IRInst* emitSpecializeInst(
IRType* type,
@@ -2820,7 +2820,9 @@ public:
IRInst* emitBitNot(IRType* type, IRInst* value);
IRInst* emitAdd(IRType* type, IRInst* left, IRInst* right);
+ IRInst* emitSub(IRType* type, IRInst* left, IRInst* right);
IRInst* emitMul(IRType* type, IRInst* left, IRInst* right);
+ IRInst* emitDiv(IRType* type, IRInst* numerator, IRInst* denominator);
IRInst* emitEql(IRInst* left, IRInst* right);
IRInst* emitNeq(IRInst* left, IRInst* right);
IRInst* emitLess(IRInst* left, IRInst* right);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 950061d4f..dd71ae782 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3027,11 +3027,11 @@ namespace Slang
return inst;
}
- IRInst* IRBuilder::emitJVPDerivativeOfInst(IRType* type, IRInst* baseFn)
+ IRInst* IRBuilder::emitJVPDifferentiateInst(IRType* type, IRInst* baseFn)
{
- auto inst = createInst<IRJVPDerivativeOf>(
+ auto inst = createInst<IRJVPDifferentiate>(
this,
- kIROp_JVPDerivativeOf,
+ kIROp_JVPDifferentiate,
type,
baseFn);
addInst(inst);
@@ -4370,6 +4370,18 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitSub(IRType* type, IRInst* left, IRInst* right)
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_Sub,
+ type,
+ left,
+ right);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitEql(IRInst* left, IRInst* right)
{
auto inst = createInst<IRInst>(this, kIROp_Eql, getBoolType(), left, right);
@@ -4403,6 +4415,18 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitDiv(IRType* type, IRInst* numerator, IRInst* denominator)
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_Div,
+ type,
+ numerator,
+ denominator);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitShr(IRType* type, IRInst* left, IRInst* right)
{
auto inst = createInst<IRInst>(this, kIROp_Rsh, type, left, right);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index d845342f0..b7c5155b5 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -2943,13 +2943,13 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// of the inner func-expr. This will be resolved
// to a concrete function during the derivative
// pass.
- LoweredValInfo visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr)
+ LoweredValInfo visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
{
- auto baseVal = lowerSubExpr(expr->baseFn);
+ auto baseVal = lowerSubExpr(expr->baseFunction);
SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
return LoweredValInfo::simple(
- getBuilder()->emitJVPDerivativeOfInst(
+ getBuilder()->emitJVPDifferentiateInst(
lowerType(context, expr->type),
baseVal.val));
}
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index ee34eac6f..d168bf55c 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -2058,22 +2058,22 @@ namespace Slang
}
/// Parse an expression of the form __jvp(fn) where fn is an
/// identifier pointing to a function.
- static Expr* parseJVPDerivativeOf(Parser* parser)
+ static Expr* parseJVPDifferentiate(Parser* parser)
{
- JVPDerivativeOfExpr* jvpExpr = parser->astBuilder->create<JVPDerivativeOfExpr>();
+ JVPDifferentiateExpr* jvpExpr = parser->astBuilder->create<JVPDifferentiateExpr>();
parser->ReadToken(TokenType::LParent);
- jvpExpr->baseFn = parser->ParseExpression();
+ jvpExpr->baseFunction = parser->ParseExpression();
parser->ReadToken(TokenType::RParent);
return jvpExpr;
}
- static NodeBase* parseJVPDerivativeOf(Parser* parser, void* /* unused */)
+ static NodeBase* parseJVPDifferentiate(Parser* parser, void* /* unused */)
{
- return parseJVPDerivativeOf(parser);
+ return parseJVPDifferentiate(parser);
}
/// Parse a `This` type expression
@@ -6492,7 +6492,7 @@ namespace Slang
_makeParseExpr("nullptr", parseNullPtrExpr),
_makeParseExpr("try", parseTryExpr),
_makeParseExpr("__TaggedUnion", parseTaggedUnionType),
- _makeParseExpr("__jvp", parseJVPDerivativeOf)
+ _makeParseExpr("__jvp", parseJVPDifferentiate)
};
ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos()