diff options
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 368 |
1 files changed, 342 insertions, 26 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 431c8e5b2..2a42a7b6e 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -3,10 +3,303 @@ #include "slang-ir.h" #include "slang-ir-insts.h" +#include "slang-ir-clone.h" namespace Slang { +struct JVPTranscriber +{ + + // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent + // their differential values. + Dictionary<IRInst*, IRInst*> instMapD; + + // Cloning environment to hold mapping from old to new copies for the primal + // instructions. + IRCloneEnv cloneEnv; + + void mapDifferentialInst(IRInst* instP, IRInst* instD) + { + instMapD.Add(instP, instD); + } + + IRInst* getDifferentialInst(IRInst* instP) + { + return instMapD[instP]; + } + + IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) + { + List<IRType*> parameterTypesD; + IRType* returnTypeD; + + // Add all primal parameters to the list. + for (UIndex i = 0; i < funcType->getParamCount(); i++) + { + parameterTypesD.add(funcType->getParamType(i)); + } + + // Add differential versions for the types we support. + for (UIndex i = 0; i < funcType->getParamCount(); i++) + { + if (auto typeD = differentiateType(builder, funcType->getParamType(i))) + parameterTypesD.add(typeD); + } + + // Transcribe return type. + // This will be void if the primal return type is non-differentiable. + // + returnTypeD = differentiateType(builder, funcType->getResultType()); + if (!returnTypeD) + returnTypeD = builder->getVoidType(); + + return builder->getFuncType(parameterTypesD, returnTypeD); + } + + IRType* differentiateType(IRBuilder* builder, IRType* typeP) + { + switch (typeP->getOp()) + { + case kIROp_HalfType: + case kIROp_FloatType: + case kIROp_DoubleType: + return builder->getType(typeP->getOp()); + + default: + return nullptr; + } + } + + IRInst* differentiateParam(IRBuilder* builder, IRParam* paramP) + { + if (IRType* typeD = differentiateType(builder, paramP->getFullType())) + { + IRParam* paramD = builder->emitParam(typeD); + SLANG_ASSERT(paramD); + return paramD; + } + return nullptr; + } + + List<IRParam*> transcribeParams(IRBuilder* builder, IRInstList<IRParam> paramListP) + { + // Clone (and emit) all the primal parameters. + List<IRParam*> newParamListP; + for (auto paramP : paramListP) + { + newParamListP.add(as<IRParam>(cloneInst(&cloneEnv, builder, paramP))); + } + + // Now emit differentials. + List<IRParam*> newParamListD; + for (auto paramP : newParamListP) + { + IRParam* paramD = as<IRParam>(differentiateParam(builder, paramP)); + mapDifferentialInst(paramP, paramD); + newParamListD.add(paramD); + } + + return newParamListD; + } + + IRInst* differentiateVar(IRBuilder* builder, IRVar* varP) + { + if (IRType* typeD = differentiateType(builder, varP->getDataType()->getValueType())) + { + IRVar* varD = builder->emitVar(typeD); + SLANG_ASSERT(varD); + return varD; + } + return nullptr; + } + + IRInst* differentiateBinaryArith(IRBuilder* builder, IRInst* arith) + { + SLANG_ASSERT(arith->getOperandCount() == 2); + + auto leftP = arith->getOperand(0); + auto rightP = arith->getOperand(1); + + auto leftD = getDifferentialInst(leftP); + auto rightD = getDifferentialInst(rightP); + + auto leftZero = builder->getFloatValue(leftP->getDataType(), 0.0); + auto rightZero = builder->getFloatValue(rightP->getDataType(), 0.0); + + if (leftD || rightD) + { + leftD = leftD ? leftD : leftZero; + rightD = rightD ? rightD : rightZero; + + // Might have to do special-case handling for non-scalar types, + // like float3 or float3x3 + // + auto resultType = arith->getDataType(); + switch(arith->getOp()) + { + case kIROp_Add: + return builder->emitAdd(resultType, leftD, rightD); + case kIROp_Mul: + return builder->emitAdd(resultType, + builder->emitMul(resultType, leftD, rightP), + builder->emitMul(resultType, leftP, rightD)); + case kIROp_Sub: + return builder->emitSub(resultType, leftD, rightD); + case kIROp_Div: + return builder->emitDiv(resultType, + builder->emitSub( + resultType, + builder->emitMul(resultType, leftD, rightP), + builder->emitMul(resultType, leftP, rightD)), + builder->emitMul( + rightP->getDataType(), rightP, rightP + )); + default: + SLANG_UNEXPECTED("Attempting to differentiate unsupported arithmetic"); + } + } + + return nullptr; + } + + IRInst* differentiateLoad(IRBuilder* builder, IRLoad* loadP) + { + if (auto varP = as<IRVar>(loadP->getPtr())) + { + // If the loaded parameter has a differential version, + // emit a load instruction for the differential parameter. + // Otherwise, emit nothing since there's nothing to load. + // + if (auto varD = as<IRVar>(getDifferentialInst(varP))) + { + IRLoad* loadD = as<IRLoad>(builder->emitLoad(varD)); + SLANG_ASSERT(loadD); + return loadD; + } + return nullptr; + } + + SLANG_UNEXPECTED("Attempting to differentiate an unsupported load instruction"); + } + + IRInst* differentiateStore(IRBuilder* builder, IRStore* storeP) + { + IRInst* storeLocation = storeP->getPtr(); + IRInst* storeVal = storeP->getVal(); + if (auto destParam = as<IRVar>(storeLocation)) + { + // If the stored value has a differential version, + // emit a store instruction for the differential parameter. + // Otherwise, emit nothing since there's nothing to load. + // + IRInst* storeValD = getDifferentialInst(storeVal); + IRVar* storeLocationD = as<IRVar>(getDifferentialInst(destParam)); + if (storeValD && storeLocationD) + { + IRStore* storeD = as<IRStore>( + builder->emitStore(storeLocationD, storeValD)); + SLANG_ASSERT(storeD); + return storeD; + } + return nullptr; + } + + SLANG_UNEXPECTED("Attempting to differentiate an unsupported store instruction"); + } + + IRInst* differentiateReturn(IRBuilder* builder, IRReturn* returnP) + { + IRInst* returnVal = findCloneForOperand(&cloneEnv, returnP->getVal()); + if (auto returnValD = getDifferentialInst(returnVal)) + { + IRReturn* returnD = as<IRReturn>(builder->emitReturn(returnValD)); + SLANG_ASSERT(returnD); + return returnD; + } + return nullptr; + } + + // Since int/float literals are sometimes nested inside an IRConstructor + // instruction, we check to make sure that the nested instr is a constant + // and then return nullptr. Literals do not need to be differentiated. + // + IRInst* differentiateConstruct(IRBuilder*, IRInst* consP) + { + + if (as<IRConstant>(consP->getOperand(0)) && consP->getOperandCount() == 1) + { + return nullptr; + } + SLANG_UNEXPECTED("Attempting to differentiate unsupported constructor"); + } + + // Logic for whether a primal instruction needs to be replicated + // in the differential function. For puerly functional blocks with + // no side-effects, it's safe to replicate everything except the + // return instruction. + // + bool requiresPrimalClone(IRBuilder*, IRInst* instP) + { + if (as<IRReturn>(instP)) + { + return false; + } + else + { + return true; + } + } + + IRInst* transcribe(IRBuilder* builder, IRInst* oldInstP) + { + IRInst* instP = oldInstP; + + // Clone the old instruction, but only if it's safe to do so. + // For instance, instructions that handle control flow + // (return statements) shouldn't be replicated. + // + if (requiresPrimalClone(builder, oldInstP)) + instP = cloneInst(&cloneEnv, builder, oldInstP); + SLANG_ASSERT(instP); + + IRInst* instD = differentiateInst(builder, instP); + + mapDifferentialInst(instP, instD); + + return instD; + } + + IRInst* differentiateInst(IRBuilder* builder, IRInst* instP) + { + switch (instP->getOp()) + { + case kIROp_Var: + return differentiateVar(builder, as<IRVar>(instP)); + + case kIROp_Load: + return differentiateLoad(builder, as<IRLoad>(instP)); + + case kIROp_Store: + return differentiateStore(builder, as<IRStore>(instP)); + + case kIROp_Return: + return differentiateReturn(builder, as<IRReturn>(instP)); + + case kIROp_Add: + case kIROp_Mul: + case kIROp_Sub: + return differentiateBinaryArith(builder, instP); + + case kIROp_Construct: + return differentiateConstruct(builder, instP); + + default: + SLANG_UNEXPECTED("Attempting to differentiate unrecognized instruction"); + } + } +}; + struct JVPDerivativeContext { // This type passes over the module and generates @@ -16,7 +309,12 @@ struct JVPDerivativeContext IRModule* module; // Shared builder state for our derivative passes. - SharedIRBuilder sharedBuilderStorage; + SharedIRBuilder sharedBuilderStorage; + + // A transcriber object that handles the main job of + // processing instructions while maintaining state. + // + JVPTranscriber transcriberStorage; bool processModule() { @@ -31,6 +329,7 @@ struct JVPDerivativeContext // looking for callables. // Note: We're only processing global callables (IRGlobalValueWithCode) // for now. + // IRBuilder builderStorage(sharedBuilderStorage); IRBuilder* builder = &builderStorage; for (auto inst : module->getGlobalInsts()) @@ -43,6 +342,8 @@ struct JVPDerivativeContext SLANG_ASSERT(as<IRFunc>(callable)); IRFunc* jvpFunction = emitJVPFunction(&builderStorage, as<IRFunc>(callable)); builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction); + + unmarkForJVP(callable); } } } @@ -62,12 +363,27 @@ struct JVPDerivativeContext { return true; } - // TODO: Need to remove this decoration or check for - // JVPDerivativeReferenceDecoration to avoid re-generating code. } return false; } + // Removes the JVPDerivativeMarkerDecoration from the provided callable, + // if it exists. + // + void unmarkForJVP(IRGlobalValueWithCode* callable) + { + for(auto decoration = callable->getFirstDecoration(); + decoration; + decoration = decoration->getNextDecoration()) + { + if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration) + { + decoration->removeAndDeallocate(); + return; + } + } + } + List<IRParam*> emitFuncParameters(IRBuilder* builder, IRFuncType* dataType) { List<IRParam*> params; @@ -81,15 +397,21 @@ struct JVPDerivativeContext // Perform forward-mode automatic differentiation on // the intstructions. + // IRFunc* emitJVPFunction(IRBuilder* builder, IRFunc* primalFn) { - // Note (sai): Is this safe? Should we use setInsertInto? + builder->setInsertBefore(primalFn->getNextInst()); auto jvpFn = builder->createFunc(); - IRType* jvpFuncType = primalTypeToJVPType(primalFn->getFullType()); + + SLANG_ASSERT(as<IRFuncType>(primalFn->getFullType())); + IRType* jvpFuncType = transcriberStorage.differentiateFunctionType( + builder, + as<IRFuncType>(primalFn->getFullType())); jvpFn->setFullType(jvpFuncType); + if (auto jvpName = getJVPFuncName(builder, primalFn)) builder->addNameHintDecoration(jvpFn, jvpName); @@ -100,13 +422,7 @@ struct JVPDerivativeContext for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock()) { - IRBlock* newJVPBlock = nullptr; - if (block == primalFn->getFirstBlock()) - { - newJVPBlock = builder->emitBlock(); - emitFuncParameters(builder, as<IRFuncType>(jvpFuncType)); - } - newJVPBlock = emitJVPBlock(builder, primalFn->getFirstBlock(), newJVPBlock); + emitJVPBlock(builder, primalFn->getFirstBlock()); } return jvpFn; @@ -138,31 +454,31 @@ struct JVPDerivativeContext IRBlock* primalBlock, IRBlock* jvpBlock = nullptr) { - // Create if not already provided, and insert into new block. + JVPTranscriber* transcriber = &(transcriberStorage); + + // Create if not already created, and then insert into new block. if (!jvpBlock) jvpBlock = builder->emitBlock(); else builder->setInsertInto(jvpBlock); - // Temporarily, we're going to just emit a single return 0 instruction. - for(auto child = primalBlock->getFirstInst(); child; child = child->getNextInst()) + // First transcribe the parameter list. This is done separately because we + // want all the derivative parameters emitted after the primal parameters + // rather than interleaved with one another. + // + transcriber->transcribeParams(builder, primalBlock->getParams()); + + // Run through every instruction and use the transcriber to generate the appropriate + // derivative code. + // + for(auto child = primalBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) { - if (auto returnOp = as<IRReturn>(child)) - { - auto zeroVal = builder->getFloatValue(returnOp->getVal()->getDataType(), 0.0); - builder->emitReturn(zeroVal); - } + transcriber->transcribe(builder, child); } return jvpBlock; } - IRType* primalTypeToJVPType(IRType* primalType) - { - // Temporarily, we're going to implement the identity transform. - // The return type is the same as the primal type. - return primalType; - } }; // Set up context and call main process method. |
