diff options
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 229 |
1 files changed, 115 insertions, 114 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 843428c01..b97556ab1 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -115,7 +115,7 @@ struct DifferentiableTypeConformanceContext IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) { - if (auto conformance = lookUpConformanceForType(builder, origType)) + if (auto conformance = lookUpConformanceForType(builder, origType)) { if (auto witnessTable = as<IRWitnessTable>(conformance)) { @@ -144,6 +144,14 @@ struct DifferentiableTypeConformanceContext // IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) { + switch (origType->getOp()) + { + case kIROp_FloatType: + case kIROp_HalfType: + case kIROp_DoubleType: + case kIROp_VectorType: + return origType; + } return lookUpInterfaceMethod(builder, origType, differentialAssocTypeStructKey); } @@ -1083,8 +1091,7 @@ struct JVPTranscriber // in the current transcription context. // InstPair transcribeCall(IRBuilder* builder, IRCall* origCall) - { - + { if (as<IRFunc>(origCall->getCallee())) { auto origCallee = origCall->getCallee(); @@ -1094,12 +1101,28 @@ struct JVPTranscriber // auto primalCallee = origCallee; - // TODO: If inner is not differentiable, treat as non-differentiable call. - // Build the differential callee - IRInst* diffCall = builder->emitJVPDifferentiateInst( - differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), - primalCallee); - + IRInst* diffCallee = nullptr; + + if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>()) + { + // If the user has already provided an differentiated implementation, use that. + diffCallee = derivativeReferenceDecor->getJVPFunc(); + } + else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>()) + { + // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass + // to generate the implementation. + diffCallee = builder->emitJVPDifferentiateInst( + differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), + primalCallee); + } + else + { + // The callee is non differentiable, just return primal value with null diff value. + IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall); + return InstPair(primalCall, nullptr); + } + List<IRInst*> args; // Go over the parameter list and create pairs for each input (if required) for (UIndex ii = 0; ii < origCall->getArgCount(); ii++) @@ -1109,18 +1132,16 @@ struct JVPTranscriber SLANG_ASSERT(primalArg); auto primalType = primalArg->getDataType(); + auto diffArg = findOrTranscribeDiffInst(builder, origArg); + + if (!diffArg) + diffArg = getDifferentialZeroOfType(builder, primalType); + if (auto pairType = tryGetDiffPairType(builder, primalType)) { - auto diffArg = findOrTranscribeDiffInst(builder, origArg); - - if (!diffArg) - diffArg = getDifferentialZeroOfType(builder, primalType); - // If a pair type can be formed, this must be non-null. SLANG_RELEASE_ASSERT(diffArg); - auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); - args.add(diffPair); } else @@ -1130,17 +1151,19 @@ struct JVPTranscriber } } - auto diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); + IRType* diffReturnType = nullptr; + diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); SLANG_ASSERT(diffReturnType); auto callInst = builder->emitCallInst( diffReturnType, - diffCall, + diffCallee, args); + + IRInst* primalResultValue = pairBuilder->emitPrimalFieldAccess(builder, callInst); + IRInst* diffResultValue = pairBuilder->emitDiffFieldAccess(builder, callInst); - return InstPair( - pairBuilder->emitPrimalFieldAccess(builder, callInst), - pairBuilder->emitDiffFieldAccess(builder, callInst)); + return InstPair(primalResultValue, diffResultValue); } else if(as<IRSpecialize>(origCall->getCallee()) || as<IRLookupWitnessMethod>(origCall->getCallee())) @@ -1396,89 +1419,45 @@ struct JVPTranscriber return InstPair(diffBlock, diffBlock); } - InstPair transcribeFieldExtract(IRBuilder* builder, IRFieldExtract* origExtract) + InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst) { - IRInst* origBase = origExtract->getBase(); + SLANG_ASSERT(as<IRFieldExtract>(originalInst) || as<IRFieldAddress>(originalInst)); + + IRInst* origBase = originalInst->getOperand(0); auto primalBase = findOrTranscribePrimalInst(builder, origBase); - auto diffBase = findOrTranscribeDiffInst(builder, origBase); + auto field = originalInst->getOperand(1); + auto derivativeRefDecor = field->findDecoration<IRJVPDerivativeMemberReferenceDecoration>(); + auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType()); - auto primalExtractType = (IRType*)lookupPrimalInst(origExtract->getDataType(), origExtract->getDataType()); - - IRInst* primalExtract = builder->emitFieldExtract(primalExtractType, primalBase, origExtract->getField()); - IRInst* diffExtract = nullptr; + IRInst* primalOperands[] = { primalBase, field }; + IRInst* primalFieldExtract = builder->emitIntrinsicInst( + primalType, + originalInst->getOp(), + 2, + primalOperands); - if (auto diffExtractType = differentiateType(builder, primalExtractType)) + if (!derivativeRefDecor) { - // Check if we have a getter. - if (auto getterDecoration = origExtract->findDecoration<IRDifferentialGetterDecoration>()) - { - - IRInst* getterFunc = getterDecoration->getGetterFunc(); - - // Must be a method with a single parameter. - SLANG_ASSERT(as<IRFuncType>(getterFunc->getDataType())->getParamCount() == 1); - - // Our getter func accepts a _pointer_ to the target type - // So we have to create a variable and store our type into memory - // here. This will eventually get optimized out in later passes. - // - auto diffTempVar = builder->emitVar( - diffBase->getDataType()); - - builder->emitStore(diffTempVar, diffBase); - - List<IRInst*> args; - args.add(diffTempVar); - - // Emit a call to the getter. The getter will return a reference type. - // We need to load from this to go to a non-ptr 'solid' type. - // - auto diffGetterCall = builder->emitCallInst( - as<IRFuncType>(getterFunc->getDataType())->getResultType(), - getterFunc, - args); - - diffExtract = builder->emitLoad(diffGetterCall); - } + return InstPair(primalFieldExtract, nullptr); } - return InstPair(primalExtract, diffExtract); - } - - InstPair transcribeFieldAddress(IRBuilder* builder, IRFieldAddress* origAddress) - { - IRInst* origBase = origAddress->getBase(); - auto primalBase = findOrTranscribePrimalInst(builder, origBase); - auto diffBase = findOrTranscribeDiffInst(builder, origBase); - - auto primalAddressType = (IRType*)lookupPrimalInst(origAddress->getDataType(), origAddress->getDataType()); + IRInst* diffFieldExtract = nullptr; - IRInst* primalAddress = builder->emitFieldAddress(primalAddressType, primalBase, origAddress->getField()); - IRInst* diffAddress = nullptr; - - if (auto diffAddressType = differentiateType(builder, primalAddressType)) + if (auto diffType = differentiateType(builder, primalType)) { - // If we have a getter associated with this field, we want to use that. - if (auto getterDecoration = origAddress->findDecoration<IRDifferentialGetterDecoration>()) + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) { - auto getterFunc = getterDecoration->getGetterFunc(); - - // Add the base differential inst as the argument. - List<IRInst*> args; - args.add(diffBase); - - diffAddress = builder->emitCallInst( - as<IRFuncType>(getterFunc->getDataType())->getResultType(), - getterFunc, - args); + IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey() }; + diffFieldExtract = builder->emitIntrinsicInst( + diffType, + originalInst->getOp(), + 2, + diffOperands); } - } - - return InstPair(primalAddress, diffAddress); + return InstPair(primalFieldExtract, diffFieldExtract); } - InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr) { SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr)); @@ -1514,7 +1493,6 @@ struct JVPTranscriber return InstPair(primalGetElementPtr, diffGetElementPtr); } - InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop) { // The loop comes with three blocks.. we just need to transcribe each one @@ -1640,9 +1618,13 @@ struct JVPTranscriber as<IRFuncType>(origFunc->getFullType())); diffFunc->setFullType(diffFuncType); - // TODO(sai): Replace naming scheme - // if (auto jvpName = this->getJVPFuncName(builder, primalFn)) - // builder->addNameHintDecoration(diffFunc, jvpName); + if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>()) + { + auto originalName = nameHint->getName(); + StringBuilder newNameSb; + newNameSb << "s_jvp_" << originalName; + builder->addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); + } // Transcribe children from origFunc into diffFunc builder->setInsertInto(diffFunc); @@ -1719,9 +1701,18 @@ struct JVPTranscriber { mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); + if (pair.differential) + { + // Generate name hint for the inst. + if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>()) + { + StringBuilder sb; + sb << "s_diff_" << primalNameHint->getName(); + builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); + } + } return pair.differential; } - instsInProgress.Remove(origInst); getSink()->diagnose(origInst->sourceLoc, @@ -1789,16 +1780,14 @@ struct JVPTranscriber getSink()->diagnose(origInst->sourceLoc, Diagnostics::unexpected, "should not be attempting to differentiate anything specialized here."); + return InstPair(nullptr, nullptr); case kIROp_lookup_interface_method: return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst)); case kIROp_FieldExtract: - return transcribeFieldExtract(builder, as<IRFieldExtract>(origInst)); - case kIROp_FieldAddress: - return transcribeFieldAddress(builder, as<IRFieldAddress>(origInst)); - + return transcribeFieldExtract(builder, origInst); case kIROp_getElement: case kIROp_getElementPtr: return transcribeGetElement(builder, origInst); @@ -1942,11 +1931,6 @@ struct JVPDerivativeContext // Temporary fix: Move generated types, if any, to before their use locations. (&pairBuilderStorage)->relocateNewTypes(builder); - // Remove all kIROp_DifferentiableTypeDictionary instructions and - // kIROp_DifferentialGetterDecoration decorations - // - modified |= stripDiffTypeInformation(builder, module->getModuleInst()); - return modified; } @@ -1954,7 +1938,6 @@ struct JVPDerivativeContext { if(auto jvpDefinition = primalFunction->findDecoration<IRJVPDerivativeReferenceDecoration>()) return jvpDefinition->getJVPFunc(); - return nullptr; } @@ -2166,7 +2149,7 @@ struct JVPDerivativeContext return modified; } - bool stripDiffTypeInformation(IRBuilder* builder, IRInst* parent) + bool stripDiffTypeInformation(IRInst* parent) { bool modified = false; @@ -2175,22 +2158,18 @@ struct JVPDerivativeContext { auto nextChild = child->getNextInst(); - if (child->getOp() == kIROp_DifferentiableTypeDictionary) + switch (child->getOp()) { + case kIROp_DifferentiableTypeDictionary: child->removeAndDeallocate(); child = nextChild; modified = true; continue; } - if (auto getterDecoration = child->findDecoration<IRDifferentialGetterDecoration>()) - { - getterDecoration->removeAndDeallocate(); - } - if (child->getFirstChild() != nullptr) { - modified |= stripDiffTypeInformation(builder, child); + modified |= stripDiffTypeInformation(child); } child = nextChild; @@ -2311,8 +2290,30 @@ bool processJVPDerivativeMarkers( eliminateDeadCode(module, options); JVPDerivativeContext context(module, sink); + bool changed = context.processModule(); + changed |= context.stripDiffTypeInformation(module->getModuleInst()); + return changed; +} - return context.processModule(); +void stripAutoDiffDecorations(IRModule* module) +{ + for (auto inst : module->getGlobalInsts()) + { + for (auto decor = inst->getFirstDecoration(); decor; ) + { + auto next = decor->getNextDecoration(); + switch (decor->getOp()) + { + case kIROp_JVPDerivativeReferenceDecoration: + case kIROp_JVPDerivativeMemberReferenceDecoration: + decor->removeAndDeallocate(); + break; + default: + break; + } + decor = next; + } + } } } |
