diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-07-13 15:55:30 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-07-13 15:55:30 -0400 |
| commit | 4af61e2296a49876c2d9e7cf192ae825302a83de (patch) | |
| tree | d067b944b9794fe5061bbf51e8ef6a39d5fcefbf /source/slang/slang-ir-diff-jvp.cpp | |
| parent | 564f0d84a9c5276c05e8288955a7685f96278d1b (diff) | |
Added support for differentiating out and inout parameters. (#2323)
* Added out/inout tests
* Added support for out and inout parameters. Still untested
* Fixed and tested support for out and inout types
* Removed some comments
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 106 |
1 files changed, 89 insertions, 17 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 5b77d483d..f5afccd0c 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -56,7 +56,9 @@ struct JVPTranscriber // Add all primal parameters to the list. for (UIndex i = 0; i < funcType->getParamCount(); i++) { - parameterTypesD.add(funcType->getParamType(i)); + // TODO(sai): Move this check to a separate function. + if (!as<IROutType>(funcType->getParamType(i))) + parameterTypesD.add(funcType->getParamType(i)); } // Add differential versions for the types we support. @@ -85,7 +87,12 @@ struct JVPTranscriber case kIROp_DoubleType: return builder->getType(typeP->getOp()); case kIROp_VectorType: + // TODO(sai): Call differentiateType() on typeP. return as<IRVectorType>(typeP); + case kIROp_OutType: + return builder->getOutType(differentiateType(builder, as<IROutType>(typeP)->getValueType())); + case kIROp_InOutType: + return builder->getInOutType(differentiateType(builder, as<IRInOutType>(typeP)->getValueType())); default: return nullptr; } @@ -102,21 +109,45 @@ struct JVPTranscriber return nullptr; } + IRInst* emitInputParam(IRBuilder* builder, IRParam* paramP) + { + // Convert primal 'inout' types into pure input types, because a + // JVP transformed function must never have primal side-effects. + // + if (auto inoutTypeP = as<IRInOutType>(paramP->getDataType())) + { + auto newParamP = builder->emitParam(inoutTypeP->getValueType()); + cloneEnv.mapOldValToNew.Add(paramP, newParamP); + + return newParamP; + } + else if (as<IROutType>(paramP->getDataType())) + { + getSink()->diagnose(paramP->sourceLoc, + Diagnostics::unexpected, + "encountered unexpected output parameter"); + return nullptr; + } + else + return as<IRParam>(cloneInst(&cloneEnv, builder, paramP)); + } + List<IRParam*> transcribeParams(IRBuilder* builder, IRInstList<IRParam> paramListP) { // Clone (and emit) all the primal parameters. List<IRParam*> newParamListP; for (auto paramP : paramListP) { - newParamListP.add(as<IRParam>(cloneInst(&cloneEnv, builder, paramP))); + if(requiresPrimalClone(builder, paramP)) + newParamListP.add(as<IRParam>(emitInputParam(builder, paramP))); } // Now emit differentials. List<IRParam*> newParamListD; - for (auto paramP : newParamListP) + for (auto paramP : paramListP) { IRParam* paramD = as<IRParam>(differentiateParam(builder, paramP)); - mapDifferentialInst(paramP, paramD); + mapDifferentialInst(findCloneForOperand(&cloneEnv, paramP), paramD); newParamListD.add(paramD); } @@ -187,15 +218,16 @@ struct JVPTranscriber IRInst* differentiateLoad(IRBuilder* builder, IRLoad* loadP) { - if (auto varP = as<IRVar>(loadP->getPtr())) + auto ptrP = loadP->getPtr(); + if (as<IRVar>(ptrP) || as<IRParam>(ptrP)) { // If the loaded parameter has a differential version, // emit a load instruction for the differential parameter. // Otherwise, emit nothing since there's nothing to load. // - if (auto varD = as<IRVar>(getDifferentialInst(varP))) + if (auto ptrD = getDifferentialInst(ptrP, nullptr)) { - IRLoad* loadD = as<IRLoad>(builder->emitLoad(varD)); + IRLoad* loadD = as<IRLoad>(builder->emitLoad(ptrD)); SLANG_ASSERT(loadD); return loadD; } @@ -212,14 +244,14 @@ struct JVPTranscriber { IRInst* storeLocation = storeP->getPtr(); IRInst* storeVal = storeP->getVal(); - if (auto destParam = as<IRVar>(storeLocation)) + if (as<IRVar>(storeLocation) || as<IRParam>(storeLocation)) { // If the stored value has a differential version, // emit a store instruction for the differential parameter. // Otherwise, emit nothing since there's nothing to load. // IRInst* storeValD = getDifferentialInst(storeVal); - IRVar* storeLocationD = as<IRVar>(getDifferentialInst(destParam)); + IRInst* storeLocationD = getDifferentialInst(storeLocation); if (storeValD && storeLocationD) { IRStore* storeD = as<IRStore>( @@ -239,13 +271,18 @@ struct JVPTranscriber IRInst* differentiateReturn(IRBuilder* builder, IRReturn* returnP) { IRInst* returnVal = findCloneForOperand(&cloneEnv, returnP->getVal()); - if (auto returnValD = getDifferentialInst(returnVal)) + if (auto returnValD = getDifferentialInst(returnVal, nullptr)) { IRReturn* returnD = as<IRReturn>(builder->emitReturn(returnValD)); SLANG_ASSERT(returnD); return returnD; } - return nullptr; + else + { + // If the differential return value is not available, emit a + // void return. + return builder->emitReturn(); + } } // Since int/float literals are sometimes nested inside an IRConstructor @@ -352,16 +389,38 @@ struct JVPTranscriber } // Logic for whether a primal instruction needs to be replicated - // in the differential function. For puerly functional blocks with - // no side-effects, it's safe to replicate everything except the - // return instruction. - // + // in the differential function. We detect and avoid replicating + // side-effect instructions. + // bool requiresPrimalClone(IRBuilder*, IRInst* instP) { if (as<IRReturn>(instP)) return false; - else - return true; + else if (auto paramP = as<IRParam>(instP)) + { + // Out-type parameters are discarded from the parameter list, + // since pure JVP functions to not write to primal outputs. + // + if (as<IROutType>(paramP->getDataType())) + return false; + } + else if (auto storeP = as<IRStore>(instP)) + { + IRInst* storeLocation = storeP->getPtr(); + + // Writing to a parameter is a side-effect that should be avoided. + if(as<IRParam>(storeLocation)) + return false; + + // If attempting to store to a location without a clone, + // then this instruction likely has side-effects external to the + // current function. + // + if(!lookUp(&cloneEnv, storeLocation)) + return false; + } + + return true; } IRInst* transcribe(IRBuilder* builder, IRInst* oldInstP) @@ -374,6 +433,19 @@ struct JVPTranscriber // if (requiresPrimalClone(builder, oldInstP)) instP = cloneInst(&cloneEnv, builder, oldInstP); + else + { + // We replace the operands of the old instruction with their clones, + // if available. + // + for(UInt ii = 0; ii < oldInstP->getOperandCount(); ++ii) + { + auto oldOperand = oldInstP->getOperand(ii); + auto newOperand = findCloneForOperand(&cloneEnv, oldOperand); + + instP->getOperands()[ii].init(instP, newOperand); + } + } SLANG_ASSERT(instP); IRInst* instD = differentiateInst(builder, instP); |
