diff options
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 1478 |
1 files changed, 1167 insertions, 311 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 5eee13d5e..843428c01 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -7,6 +7,10 @@ #include "slang-ir-dce.h" #include "slang-ir-eliminate-phis.h" +// origX, primalX, diffX +// origX -> primalX (cloneEnv) +// origX -> diffX (instMapD) + namespace Slang { @@ -24,7 +28,7 @@ typedef Pair<IRInst*, IRInst*> InstPair; struct DifferentiableTypeConformanceContext { - Dictionary<IRInst*, IRInst*> witnessTableMap; + Dictionary<IRInst*, IRInst*> witnessTableMap; IRInst* inst = nullptr; @@ -39,6 +43,18 @@ struct DifferentiableTypeConformanceContext // type in the conformance table associated with the concrete type. // IRStructKey* differentialAssocTypeStructKey = nullptr; + + // The struct key for the 'zero()' associated type + // defined inside IDifferential. We use this to lookup the + // implementation of zero() for a given type. + // + IRStructKey* zeroMethodStructKey = nullptr; + + // The struct key for the 'add()' associated type + // defined inside IDifferential. We use this to lookup the + // implementation of add() for a given type. + // + IRStructKey* addMethodStructKey = nullptr; // Modules that don't use differentiable types // won't have the IDifferentiable interface type available. @@ -56,6 +72,9 @@ struct DifferentiableTypeConformanceContext { differentiableInterfaceType = parent->differentiableInterfaceType; differentialAssocTypeStructKey = parent->differentialAssocTypeStructKey; + zeroMethodStructKey = parent->zeroMethodStructKey; + addMethodStructKey = parent->addMethodStructKey; + isInterfaceAvailable = parent->isInterfaceAvailable; } else @@ -64,17 +83,13 @@ struct DifferentiableTypeConformanceContext if (differentiableInterfaceType) { differentialAssocTypeStructKey = findDifferentialTypeStructKey(); + zeroMethodStructKey = findZeroMethodStructKey(); + addMethodStructKey = findAddMethodStructKey(); if (differentialAssocTypeStructKey) isInterfaceAvailable = true; } } - - if (isInterfaceAvailable) - { - // Load all witness tables corresponding to the IDifferentiable interface. - loadWitnessTablesForInterface(differentiableInterfaceType); - } } DifferentiableTypeConformanceContext(IRInst* inst) : @@ -84,35 +99,30 @@ struct DifferentiableTypeConformanceContext // Lookup a witness table for the concreteType. One should exist if concreteType // inherits (successfully) from IDifferentiable. // - IRInst* lookUpConformanceForType(IRInst* type) + IRInst* lookUpConformanceForType(IRBuilder* builder, IRInst* type) { SLANG_ASSERT(isInterfaceAvailable); + // TODO: Cache the returned value to avoid repeatedly scanning through + // blocks looking for the type entries. + // + if (auto irWitness = builder->findDifferentiableTypeEntry(type, type->getParent())) + { + return irWitness; + } - if (witnessTableMap.ContainsKey(type)) - return witnessTableMap[type]; - else if (parent) - return parent->lookUpConformanceForType(type); - else - return nullptr; + return nullptr; } - - // Lookup and return the 'Differential' type declared in the concrete type - // in order to conform to the IDifferentiable interface. - // Note that inside a generic block, this will be a witness table lookup instruction - // that gets resolved during the specialization pass. - // - IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) - { - SLANG_ASSERT(isInterfaceAvailable); - if (auto conformance = lookUpConformanceForType(origType)) + IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) + { + if (auto conformance = lookUpConformanceForType(builder, origType)) { if (auto witnessTable = as<IRWitnessTable>(conformance)) { for (auto entry : witnessTable->getEntries()) { - if (entry->getRequirementKey() == differentialAssocTypeStructKey) - return as<IRType>(entry->getSatisfyingVal()); + if (entry->getRequirementKey() == key) + return entry->getSatisfyingVal(); } } else if (auto witnessTableParam = as<IRParam>(conformance)) @@ -120,12 +130,32 @@ struct DifferentiableTypeConformanceContext return builder->emitLookupInterfaceMethodInst( builder->getTypeKind(), witnessTableParam, - differentialAssocTypeStructKey); + key); } } return nullptr; } + + // Lookup and return the 'Differential' type declared in the concrete type + // in order to conform to the IDifferentiable interface. + // Note that inside a generic block, this will be a witness table lookup instruction + // that gets resolved during the specialization pass. + // + IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, differentialAssocTypeStructKey); + } + + IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, zeroMethodStructKey); + } + + IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, addMethodStructKey); + } private: @@ -150,11 +180,26 @@ struct DifferentiableTypeConformanceContext IRStructKey* findDifferentialTypeStructKey() { + return getIDifferentiableStructKeyAtIndex(0); + } + + IRStructKey* findZeroMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(1); + } + + IRStructKey* findAddMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(2); + } + + IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index) + { if (as<IRModuleInst>(inst) && differentiableInterfaceType) { - // Assume for now that IDifferentiable has exactly one field: the 'Differential' associated type. - SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 1); - if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(0))) + // Assume for now that IDifferentiable has exactly three fields. + SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4); + if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) return as<IRStructKey>(entry->getRequirementKey()); else { @@ -200,12 +245,18 @@ struct DifferentiableTypeConformanceContext genericParam = genericParam->getNextParam(); } - UCount tableIndex = 0; + Count tableIndex = 0; while (genericParam) { SLANG_ASSERT(!as<IRTypeType>(genericParam->getDataType())); + + if (tableIndex >= typeParams.getCount()) + break; + if (auto witnessTableType = as<IRWitnessTableType>(genericParam->getDataType())) { + // TODO(sai): Heavily flawed way to find the right witness table. + // Rewrite this part if (witnessTableType->getConformanceType() == differentiableInterfaceType) witnessTableMap.Add(typeParams[tableIndex], genericParam); } @@ -222,6 +273,40 @@ struct DifferentiableTypeConformanceContext }; + +IRInst* findGlobal(IRInst* inst) +{ + if (inst->getParent() != inst->getModule()->getModuleInst()) + { + return findGlobal(inst->getParent()); + } + + return inst; +} + +void moveGlobalToBeforeUses(IRBuilder*, IRInst* globalInst) +{ + HashSet<IRInst*> globalsOfUses; + for (auto use = globalInst->firstUse; use; use = use->nextUse) + { + globalsOfUses.Add(findGlobal(use->getUser())); + } + + IRInst* earliestUse = nullptr; + for (auto cursor = globalInst; cursor; cursor = cursor->getPrevInst()) + { + if (globalsOfUses.Contains(cursor)) + { + earliestUse = cursor; + } + } + + if (earliestUse) + { + globalInst->insertBefore(earliestUse); + } +} + struct DifferentialPairTypeBuilder { @@ -229,95 +314,246 @@ struct DifferentialPairTypeBuilder diffConformanceContext(diffConformanceContext) {} - IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst) + IRStructField* findField(IRInst* type, IRStructKey* key) { - if (auto basePairStructType = as<IRStructType>(baseInst->getDataType())) + if (auto irStructType = as<IRStructType>(type)) { - auto primalField = as<IRStructField>(basePairStructType->getFirstChild()); - SLANG_ASSERT(primalField); - - return as<IRFieldExtract>(builder->emitFieldExtract( - primalField->getFieldType(), - baseInst, - primalField->getKey() - )); + for (auto field : irStructType->getFields()) + { + if (field->getKey() == key) + { + return field; + } + } } - else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType())) + else if (auto irSpecialize = as<IRSpecialize>(type)) { - if (auto pairStructType = as<IRStructType>(ptrType->getValueType())) + if (auto irGeneric = as<IRGeneric>(irSpecialize->getBase())) { - auto primalField = as<IRStructField>(pairStructType->getFirstChild()); - SLANG_ASSERT(primalField); - - return as<IRFieldAddress>(builder->emitFieldAddress( - builder->getPtrType(primalField->getFieldType()), - baseInst, - primalField->getKey() - )); + if (auto irGenericStructType = as<IRStructType>(findInnerMostGenericReturnVal(irGeneric))) + { + return findField(irGenericStructType, key); + } } } - else + + return nullptr; + } + + IRInst* findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam) + { + // Get base generic that's being specialized. + auto genericType = as<IRGeneric>(as<IRSpecialize>(specializeInst)->getBase()); + SLANG_ASSERT(genericType); + + // Find the index of genericParam in the base generic. + int paramIndex = -1; + int currentIndex = 0; + for (auto param : genericType->getParams()) { - SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>"); + if (param == genericParam) + paramIndex = currentIndex; + currentIndex ++; } - return nullptr; + + SLANG_ASSERT(paramIndex >= 0); + + // Return the corresponding operand in the specialization inst. + return specializeInst->getOperand(1 + paramIndex); } - IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst) + IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key) { if (auto basePairStructType = as<IRStructType>(baseInst->getDataType())) { - auto diffField = as<IRStructField>(basePairStructType->getFirstChild()->getNextInst()); - SLANG_ASSERT(diffField); - return as<IRFieldExtract>(builder->emitFieldExtract( - diffField->getFieldType(), + findField(basePairStructType, key)->getFieldType(), baseInst, - diffField->getKey() + key )); } else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType())) { - if (auto pairStructType = as<IRStructType>(ptrType->getValueType())) + if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType())) { - auto diffField = as<IRStructField>(pairStructType->getFirstChild()->getNextInst()); - SLANG_ASSERT(diffField); - - return as<IRFieldAddress>(builder->emitFieldAddress( - builder->getPtrType(diffField->getFieldType()), + auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(ptrInnerSpecializedType->getBase())); + if (auto genericBasePairStructType = as<IRStructType>(genericType)) + { + return as<IRFieldAddress>(builder->emitFieldAddress( + builder->getPtrType((IRType*) + findSpecializationForParam( + ptrInnerSpecializedType, + findField(ptrInnerSpecializedType, key)->getFieldType())), baseInst, - diffField->getKey() + key )); + } + } + else if (auto ptrBaseStructType = as<IRStructType>(ptrType->getValueType())) + { + return as<IRFieldAddress>(builder->emitFieldAddress( + builder->getPtrType((IRType*) + findField(ptrBaseStructType, key)->getFieldType()), + baseInst, + key)); + } + } + else if (auto specializedType = as<IRSpecialize>(baseInst->getDataType())) + { + // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's + // type, emit the specialization type. + // + auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(specializedType->getBase())); + if (auto genericBasePairStructType = as<IRStructType>(genericType)) + { + return as<IRFieldExtract>(builder->emitFieldExtract( + (IRType*)findSpecializationForParam( + specializedType, + findField(genericBasePairStructType, key)->getFieldType()), + baseInst, + key + )); + } + else if (auto genericPtrType = as<IRPtrTypeBase>(genericType)) + { + if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType())) + { + return as<IRFieldAddress>(builder->emitFieldAddress( + builder->getPtrType((IRType*) + findSpecializationForParam( + specializedType, + findField(genericPairStructType, key)->getFieldType())), + baseInst, + key + )); + } } } else { - SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>"); + SLANG_UNEXPECTED("Unrecognized field. Cannot emit field accessor"); } return nullptr; } + + IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst) + { + return emitFieldAccessor(builder, baseInst, this->globalPrimalKey); + } + + IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst) + { + return emitFieldAccessor(builder, baseInst, this->globalDiffKey); + } + + void relocateNewTypes(IRBuilder* builder) + { + for (auto typeInst : generatedTypeList) + { + moveGlobalToBeforeUses(builder, typeInst); + } + } + + void _createGenericDiffPairType(IRBuilder* builder) + { + // Insert directly at top level (skip any generic scopes etc.) + auto insertLoc = builder->getInsertLoc(); + builder->setInsertInto(builder->getModule()->getModuleInst()); + + // Make a generic version of the pair struct. + auto irGeneric = builder->emitGeneric(); + irGeneric->setFullType(builder->getTypeKind()); + builder->setInsertInto(irGeneric); + + generatedTypeList.add(irGeneric); + + auto irBlock = builder->emitBlock(); + builder->setInsertInto(irBlock); + + auto pTypeParam = builder->emitParam(builder->getTypeType()); + builder->addNameHintDecoration(pTypeParam, UnownedTerminatedStringSlice("pT")); + + auto dTypeParam = builder->emitParam(builder->getTypeType()); + builder->addNameHintDecoration(dTypeParam, UnownedTerminatedStringSlice("dT")); + + auto irStructType = builder->createStructType(); + builder->emitReturn(irStructType); + + auto primalKey = _getOrCreatePrimalStructKey(builder); + builder->addNameHintDecoration(primalKey, UnownedTerminatedStringSlice("primal")); + builder->createStructField(irStructType, primalKey, (IRType*) pTypeParam); + + auto diffKey = _getOrCreateDiffStructKey(builder); + builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential")); + builder->createStructField(irStructType, diffKey, (IRType*) dTypeParam); + + // Reset cursor when done. + builder->setInsertLoc(insertLoc); + + this->genericDiffPairType = irGeneric; + } + + IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder) + { + if (!this->globalDiffKey) + { + // Insert directly at top level (skip any generic scopes etc.) + auto insertLoc = builder->getInsertLoc(); + builder->setInsertInto(builder->getModule()->getModuleInst()); + + this->globalDiffKey = builder->createStructKey(); + builder->addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential")); + + builder->setInsertLoc(insertLoc); + } + + return this->globalDiffKey; + } + + IRStructKey* _getOrCreatePrimalStructKey(IRBuilder* builder) + { + if (!this->globalPrimalKey) + { + // Insert directly at top level (skip any generic scopes etc.) + auto insertLoc = builder->getInsertLoc(); + builder->setInsertInto(builder->getModule()->getModuleInst()); + + this->globalPrimalKey = builder->createStructKey(); + builder->addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal")); + + builder->setInsertLoc(insertLoc); + } + + return this->globalPrimalKey; + } + + IRInst* _getOrCreateGenericDiffPairType(IRBuilder* builder) + { + if (!this->genericDiffPairType) + { + _createGenericDiffPairType(builder); + } + + SLANG_ASSERT(this->genericDiffPairType); + return this->genericDiffPairType; + } - IRStructType* _createDiffPairType(IRBuilder* builder, IRType* origBaseType) + IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType) { if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType)) { - auto diffPairType = builder->createStructType(); - - // Create a keys for the primal and differential fields. - IRStructKey* origKey = builder->createStructKey(); - builder->addNameHintDecoration(origKey, UnownedTerminatedStringSlice("primal")); - builder->createStructField(diffPairType, origKey, origBaseType); + SLANG_ASSERT(!as<IRParam>(origBaseType)); - IRStructKey* diffKey = builder->createStructKey(); - builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential")); - builder->createStructField(diffPairType, diffKey, (IRType*)(diffBaseType)); + auto pairStructType = builder->createStructType(); + builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType); + builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*) diffBaseType); - return diffPairType; + return pairStructType; } return nullptr; } - IRStructType* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType) + IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType) { if (pairTypeCache.ContainsKey(origBaseType)) return pairTypeCache[origBaseType]; @@ -328,10 +564,17 @@ struct DifferentialPairTypeBuilder return pairType; } - Dictionary<IRType*, IRStructType*> pairTypeCache; + Dictionary<IRInst*, IRInst*> pairTypeCache; DifferentiableTypeConformanceContext* diffConformanceContext; + + IRStructKey* globalPrimalKey = nullptr; + + IRStructKey* globalDiffKey = nullptr; + IRInst* genericDiffPairType = nullptr; + + List<IRInst*> generatedTypeList; }; struct JVPTranscriber @@ -341,6 +584,9 @@ struct JVPTranscriber // their differential values. Dictionary<IRInst*, IRInst*> instMapD; + // Set of insts currently being transcribed. Used to avoid infinite loops. + HashSet<IRInst*> instsInProgress; + // Cloning environment to hold mapping from old to new copies for the primal // instructions. IRCloneEnv cloneEnv; @@ -362,7 +608,17 @@ struct JVPTranscriber void mapDifferentialInst(IRInst* origInst, IRInst* diffInst) { - instMapD.Add(origInst, diffInst); + if (hasDifferentialInst(origInst)) + { + if (lookupDiffInst(origInst) != diffInst) + { + SLANG_UNEXPECTED("Inconsistent differential mappings"); + } + } + else + { + instMapD.Add(origInst, diffInst); + } } void mapPrimalInst(IRInst* origInst, IRInst* primalInst) @@ -439,6 +695,7 @@ struct JVPTranscriber for (UIndex i = 0; i < funcType->getParamCount(); i++) { auto origType = funcType->getParamType(i); + origType = (IRType*) lookupPrimalInst(origType, origType); if (auto diffPairType = tryGetDiffPairType(builder, origType)) newParameterTypes.add(diffPairType); else @@ -448,7 +705,8 @@ struct JVPTranscriber // Transcribe return type to a pair. // This will be void if the primal return type is non-differentiable. // - if (auto returnPairType = tryGetDiffPairType(builder, funcType->getResultType())) + auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType()); + if (auto returnPairType = tryGetDiffPairType(builder, origResultType)) diffReturnType = returnPairType; else diffReturnType = builder->getVoidType(); @@ -458,41 +716,101 @@ struct JVPTranscriber IRType* differentiateType(IRBuilder* builder, IRType* origType) { - switch (origType->getOp()) - { - case kIROp_HalfType: - case kIROp_FloatType: - case kIROp_DoubleType: - case kIROp_VectorType: - return (IRType*)(diffConformanceContext->getDifferentialForType(builder, origType)); - case kIROp_OutType: - return builder->getOutType(differentiateType(builder, as<IROutType>(origType)->getValueType())); - case kIROp_InOutType: - return builder->getInOutType(differentiateType(builder, as<IRInOutType>(origType)->getValueType())); - default: + if (auto ptrType = as<IRPtrTypeBase>(origType)) + return builder->getPtrType( + origType->getOp(), + differentiateType(builder, ptrType->getValueType())); + + // If there is an explicit primal version of this type in the local scope, load that + // otherwise use the original type. + // + IRInst* primalType = lookupPrimalInst(origType, origType); + + // Special case certain compound types (PtrType, FuncType, etc..) + // otherwise try to lookup a differential definition for the given type. + // If one does not exist, then we assume it's not differentiable. + // + switch (primalType->getOp()) + { + case kIROp_Param: + if (as<IRTypeType>(primalType->getDataType())) + return (IRType*)(diffConformanceContext->getDifferentialForType( + builder, + (IRType*)primalType)); + else if (as<IRWitnessTableType>(primalType->getDataType())) + return (IRType*)primalType; + + case kIROp_ArrayType: + { + auto primalArrayType = as<IRArrayType>(primalType); + if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) + return builder->getArrayType( + diffElementType, + primalArrayType->getElementCount()); + else + return nullptr; + } + + case kIROp_FuncType: + return differentiateFunctionType(builder, as<IRFuncType>(primalType)); + + case kIROp_OutType: + if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType())) + return builder->getOutType(diffValueType); + else + return nullptr; + + case kIROp_InOutType: + if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType())) + return builder->getInOutType(diffValueType); + else return nullptr; + + case kIROp_TupleType: + { + auto tupleType = as<IRTupleType>(primalType); + List<IRType*> diffTypeList; + // TODO: what if we have type parameters here? + for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) + diffTypeList.add( + differentiateType(builder, (IRType*)tupleType->getOperand(ii))); + + return builder->getTupleType(diffTypeList); + } + + default: + return (IRType*)(diffConformanceContext->getDifferentialForType(builder, (IRType*)primalType)); } } - IRType* tryGetDiffPairType(IRBuilder* builder, IRType* origType) + IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType) { // If this is a PtrType (out, inout, etc..), then create diff pair from // value type and re-apply the appropropriate PtrType wrapper. // - if (auto origPtrType = as<IRPtrTypeBase>(origType)) + if (auto origPtrType = as<IRPtrTypeBase>(primalType)) { if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) - return builder->getPtrType(origType->getOp(), diffPairValueType); + return builder->getPtrType(primalType->getOp(), diffPairValueType); else return nullptr; } - return pairBuilder->getOrCreateDiffPairType(builder, origType); + return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType); } InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) { - if (auto diffPairType = tryGetDiffPairType(builder, origParam->getFullType())) + auto primalDataType = lookupPrimalInst(origParam->getDataType(), origParam->getDataType()); + // Do not differentiate generic type (and witness table) parameters + if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType)) + { + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) { IRParam* diffPairParam = builder->emitParam(diffPairType); @@ -507,6 +825,7 @@ struct JVPTranscriber pairBuilder->emitDiffFieldAccess(builder, diffPairParam)); } + return InstPair( cloneInst(&cloneEnv, builder, origParam), nullptr); @@ -570,15 +889,13 @@ struct JVPTranscriber auto diffLeft = findOrTranscribeDiffInst(builder, origLeft); auto diffRight = findOrTranscribeDiffInst(builder, origRight); - auto leftZero = builder->getFloatValue(origLeft->getDataType(), 0.0); - auto rightZero = builder->getFloatValue(origRight->getDataType(), 0.0); if (diffLeft || diffRight) { - diffLeft = diffLeft ? diffLeft : leftZero; - diffRight = diffRight ? diffRight : rightZero; + diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType()); + diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType()); - auto resultType = origArith->getDataType(); + auto resultType = primalArith->getDataType(); switch(origArith->getOp()) { case kIROp_Add: @@ -608,17 +925,36 @@ struct JVPTranscriber return InstPair(primalArith, nullptr); } + + InstPair transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic) + { + SLANG_ASSERT(origLogic->getOperandCount() == 2); + + // TODO: Check other boolean cases. + if (as<IRBoolType>(origLogic->getDataType())) + { + // Boolean operations are not differentiable. For the linearization + // pass, we do not need to do anything but copy them over to the ne + // function. + auto primalLogic = cloneInst(&cloneEnv, builder, origLogic); + return InstPair(primalLogic, nullptr); + } + + SLANG_UNEXPECTED("Logical operation with non-boolean result"); + } + InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad) { auto origPtr = origLoad->getPtr(); auto primalLoad = cloneInst(&cloneEnv, builder, origLoad); + IRInst* diffLoad = nullptr; + if (auto diffPtr = lookupDiffInst(origPtr, nullptr)) { - IRLoad* diffLoad = as<IRLoad>(builder->emitLoad(diffPtr)); - SLANG_ASSERT(diffLoad); - + // Default case, we're loading from a known differential inst. + diffLoad = as<IRLoad>(builder->emitLoad(diffPtr)); return InstPair(primalLoad, diffLoad); } return InstPair(primalLoad, nullptr); @@ -634,15 +970,17 @@ struct JVPTranscriber auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr); auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr); + IRInst* diffStore = nullptr; + // If the stored value has a differential version, // emit a store instruction for the differential parameter. // Otherwise, emit nothing since there's nothing to load. // if (diffStoreLocation && diffStoreVal) { - IRStore* diffStore = as<IRStore>( - builder->emitStore(diffStoreLocation, diffStoreVal)); - SLANG_ASSERT(diffStore); + // Default case, storing the entire type (and not a member) + diffStore = as<IRStore>( + builder->emitStore(diffStoreLocation, diffStoreVal)); return InstPair(primalStore, diffStore); } @@ -653,14 +991,31 @@ struct JVPTranscriber InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn) { IRInst* origReturnVal = origReturn->getVal(); - - if (auto pairType = tryGetDiffPairType(builder, origReturnVal->getDataType())) + + auto returnDataType = (IRType*) lookupPrimalInst(origReturnVal->getDataType(), origReturnVal->getDataType()); + if (as<IRFunc>(origReturnVal) || as<IRGeneric>(origReturnVal) || as<IRStructType>(origReturnVal) || as<IRFuncType>(origReturnVal)) + { + // If the return value is itself a function, generic or a struct then this + // is likely to be a generic scope. In this case, we lookup the differential + // and return that. + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); + + // Neither of these should be nullptr. + SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal); + IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal)); + + return InstPair(diffReturn, diffReturn); + } + else if (auto pairType = tryGetDiffPairType(builder, returnDataType)) { IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); - IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); if(!diffReturnVal) - diffReturnVal = getZeroOfType(builder, origReturnVal->getDataType()); + diffReturnVal = getDifferentialZeroOfType(builder, returnDataType); + + // If the pair type can be formed, this must be non-null. + SLANG_RELEASE_ASSERT(diffReturnVal); auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal); IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair)); @@ -668,10 +1023,12 @@ struct JVPTranscriber } else { - // If the differential return value is not available, emit a - // void return. - IRInst* voidReturn = builder->emitReturn(); - return InstPair(voidReturn, voidReturn); + // If the return type is not differentiable, emit the primal value only. + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + + IRInst* primalReturn = builder->emitReturn(primalReturnVal); + return InstPair(primalReturn, nullptr); + } } @@ -682,15 +1039,43 @@ struct JVPTranscriber InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct) { IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct); + + // Check if the output type can be differentiated. If it cannot be + // differentiated, don't differentiate the inst + // + auto primalConstructType = (IRType*) lookupPrimalInst(origConstruct->getDataType(), origConstruct->getDataType()); + if (auto diffConstructType = differentiateType(builder, primalConstructType)) + { + UCount operandCount = origConstruct->getOperandCount(); - if (as<IRConstant>(origConstruct->getOperand(0)) && origConstruct->getOperandCount() == 1) - return InstPair(primalConstruct, nullptr); + List<IRInst*> diffOperands; + for (UIndex ii = 0; ii < operandCount; ii++) + { + // If the operand has a differential version, replace the original with + // the differential. Otherwise, use a zero. + // + if (auto diffInst = lookupDiffInst(origConstruct->getOperand(ii), nullptr)) + diffOperands.add(diffInst); + else + { + auto operandDataType = origConstruct->getOperand(ii)->getDataType(); + operandDataType = (IRType*) lookupPrimalInst(operandDataType, operandDataType); + diffOperands.add(getDifferentialZeroOfType(builder, operandDataType)); + } + } + + return InstPair( + primalConstruct, + builder->emitIntrinsicInst( + diffConstructType, + origConstruct->getOp(), + operandCount, + diffOperands.getBuffer())); + } else - getSink()->diagnose(origConstruct->sourceLoc, - Diagnostics::unimplemented, - "this construct instruction cannot be differentiated"); - - return InstPair(primalConstruct, nullptr); + { + return InstPair(primalConstruct, nullptr); + } } // Differentiating a call instruction here is primarily about generating @@ -699,13 +1084,21 @@ struct JVPTranscriber // InstPair transcribeCall(IRBuilder* builder, IRCall* origCall) { - if (auto origCallee = as<IRFunc>(origCall->getCallee())) + + if (as<IRFunc>(origCall->getCallee())) { - + auto origCallee = origCall->getCallee(); + + // Since concrete functions are globals, the primal callee is the same + // as the original callee. + // + auto primalCallee = origCallee; + + // TODO: If inner is not differentiable, treat as non-differentiable call. // Build the differential callee IRInst* diffCall = builder->emitJVPDifferentiateInst( - differentiateFunctionType(builder, as<IRFuncType>(origCallee->getFullType())), - origCallee); + differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), + primalCallee); List<IRInst*> args; // Go over the parameter list and create pairs for each input (if required) @@ -715,17 +1108,17 @@ struct JVPTranscriber auto primalArg = findOrTranscribePrimalInst(builder, origArg); SLANG_ASSERT(primalArg); - auto origType = origArg->getDataType(); - if (auto pairType = tryGetDiffPairType(builder, origType)) + auto primalType = primalArg->getDataType(); + if (auto pairType = tryGetDiffPairType(builder, primalType)) { - auto diffArg = findOrTranscribeDiffInst(builder, origArg); - // TODO(sai): This part is flawed. Replace with a call to the - // 'zero()' interface method. if (!diffArg) - diffArg = getZeroOfType(builder, origType); + diffArg = getDifferentialZeroOfType(builder, primalType); + // If a pair type can be formed, this must be non-null. + SLANG_RELEASE_ASSERT(diffArg); + auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); args.add(diffPair); @@ -737,8 +1130,11 @@ struct JVPTranscriber } } + auto diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); + SLANG_ASSERT(diffReturnType); + auto callInst = builder->emitCallInst( - tryGetDiffPairType(builder, origCall->getFullType()), + diffReturnType, diffCall, args); @@ -746,6 +1142,13 @@ struct JVPTranscriber pairBuilder->emitPrimalFieldAccess(builder, callInst), pairBuilder->emitDiffFieldAccess(builder, callInst)); } + else if(as<IRSpecialize>(origCall->getCallee()) || + as<IRLookupWitnessMethod>(origCall->getCallee())) + { + getSink()->diagnose(origCall->sourceLoc, + Diagnostics::unimplemented, + "attempting to differentiate unspecialized callee or an interface method"); + } else { // Note that this can only happen if the callee is a result @@ -774,7 +1177,7 @@ struct JVPTranscriber return InstPair( primalSwizzle, builder->emitSwizzle( - differentiateType(builder, origSwizzle->getDataType()), + differentiateType(builder, primalSwizzle->getDataType()), diffBase, origSwizzle->getElementCount(), swizzleIndices.getBuffer())); @@ -806,7 +1209,7 @@ struct JVPTranscriber return InstPair( primalInst, builder->emitIntrinsicInst( - differentiateType(builder, origInst->getDataType()), + differentiateType(builder, primalInst->getDataType()), origInst->getOp(), operandCount, diffOperands.getBuffer())); @@ -819,17 +1222,44 @@ struct JVPTranscriber case kIROp_unconditionalBranch: auto origBranch = as<IRUnconditionalBranch>(origInst); - // Branches with extra operands not handled currently. - if (origBranch->getOperandCount() > 1) - break; + // Grab the differentials for any phi nodes. + List<IRInst*> pairArgs; + for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++) + { + auto origArg = origBranch->getArg(ii); - IRInst* diffBranch = nullptr; + IRInst* pairArg = nullptr; + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)origArg->getDataType())) + { + auto diffArg = lookupDiffInst(origArg, nullptr); + if (!diffArg) + { + diffArg = getDifferentialZeroOfType(builder, (IRType*)origArg->getDataType()); + } + + pairArg = builder->emitMakeDifferentialPair( + diffPairType, + lookupPrimalInst(origArg), + diffArg); + } + else + { + pairArg = lookupPrimalInst(origArg); + } + pairArgs.add(pairArg); + } - if (auto diffBlock = lookupDiffInst(origBranch->getTargetBlock(), nullptr)) - diffBranch = builder->emitBranch(as<IRBlock>(diffBlock)); + IRInst* diffBranch = nullptr; + if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock())) + { + diffBranch = builder->emitBranch( + as<IRBlock>(diffBlock), + pairArgs.getCount(), + pairArgs.getBuffer()); + } // For now, every block in the original fn must have a corresponding - // block to compute both primals and derivatives. + // block to compute *both* primals and derivatives (i.e linearized block) SLANG_ASSERT(diffBranch); return InstPair(diffBranch, diffBranch); @@ -843,12 +1273,13 @@ struct JVPTranscriber return InstPair(nullptr, nullptr); } - InstPair transcribeConst(IRBuilder*, IRInst* origInst) { switch(origInst->getOp()) { case kIROp_FloatLit: + case kIROp_VoidLit: + case kIROp_IntLit: return InstPair(origInst, nullptr); } @@ -860,49 +1291,439 @@ struct JVPTranscriber return InstPair(nullptr, nullptr); } + InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) + { + // This is slightly counter-intuitive, but we don't perform any differentiation + // logic here. We simple clone the original specialize which points to the original function, + // or the cloned version in case we're inside a generic scope. + // The differentiation logic is inserted later when this is used in an IRCall. + // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Specialize(Fn)) + // rather than have Specialize(JVPDifferentiate(Fn)) + // + auto diffSpecialize = cloneInst(&cloneEnv, builder, origSpecialize); + return InstPair(diffSpecialize, diffSpecialize); + } + + InstPair transcibeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* origLookup) + { + // This is slightly counter-intuitive, but we don't perform any differentiation + // logic here. We simple clone the original lookup which points to the original function, + // or the cloned version in case we're inside a generic scope. + // The differentiation logic is inserted later when this is used in an IRCall. + // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Lookup(Table)) + // rather than have Lookup(JVPDifferentiate(Table)) + // + auto diffLookup = cloneInst(&cloneEnv, builder, origLookup); + return InstPair(diffLookup, diffLookup); + } + // In differential computation, the 'default' differential value is always zero. // This is a consequence of differential computing being inherently linear. As a // result, it's useful to have a method to generate zero literals of any (arithmetic) type. + // The current implementation requires that types are defined linearly. // - IRInst* getZeroOfType(IRBuilder* builder, IRType* type) + IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) + { + if (auto diffType = differentiateType(builder, primalType)) + { + // Since primalType has a corresponding differential type, we can lookup the + // definition for zero(). + auto zeroMethod = this->diffConformanceContext->getZeroMethodForType(builder, primalType); + SLANG_ASSERT(zeroMethod); + + auto emptyArgList = List<IRInst*>(); + return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); + } + else + { + // We special case a few non-differentiable types that sometimes appear in places + // where we're forced to provide a differential zero value. For instance, + // float3(float, float, int) is accepted by the compiler, but is tricky in the context + // of differentiation since int is non-differentiable, and should be cast to float first. + // In the absence of such casts, this piece of code generates appropriate zero values. + // + switch (primalType->getOp()) + { + case kIROp_IntType: + return builder->getIntValue(primalType, 0); + default: + getSink()->diagnose(primalType->sourceLoc, + Diagnostics::internalCompilerError, + "could not generate zero value for given type"); + return nullptr; + } + } + } + + InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock) + { + auto oldLoc = builder->getInsertLoc(); + + IRInst* diffBlock = builder->emitBlock(); + + // Note: for blocks, we setup the mapping _before_ + // processing the children since we could encounter + // a lookup while processing the children. + // + mapPrimalInst(origBlock, diffBlock); + mapDifferentialInst(origBlock, diffBlock); + + builder->setInsertInto(diffBlock); + + // First transcribe every parameter in the block. + for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) + this->transcribe(builder, param); + + // Look for the differentiable type dictionary and clone it (and anything else we might need). + // TODO: This logic might have issues if there are additional instructions (say lookup_interface_requirement) + // that are operands. + // TODO: This is currently cloning the global dictionary. Should only clone dictionaries in generic blocks. + if (auto origDict = builder->findDifferentiableTypeDictionary(origBlock)) + { + auto clonedDict = cloneInst(&cloneEnv, builder, origDict); + mapPrimalInst(origDict, clonedDict); + mapDifferentialInst(origDict, clonedDict); + } + + // Then, run through every instruction and use the transcriber to generate the appropriate + // derivative code. + // + for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) + this->transcribe(builder, child); + + builder->setInsertLoc(oldLoc); + + return InstPair(diffBlock, diffBlock); + } + + InstPair transcribeFieldExtract(IRBuilder* builder, IRFieldExtract* origExtract) { - switch (type->getOp()) + IRInst* origBase = origExtract->getBase(); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto diffBase = findOrTranscribeDiffInst(builder, origBase); + + auto primalExtractType = (IRType*)lookupPrimalInst(origExtract->getDataType(), origExtract->getDataType()); + + IRInst* primalExtract = builder->emitFieldExtract(primalExtractType, primalBase, origExtract->getField()); + IRInst* diffExtract = nullptr; + + if (auto diffExtractType = differentiateType(builder, primalExtractType)) { - case kIROp_FloatType: - case kIROp_HalfType: - case kIROp_DoubleType: - return builder->getFloatValue(type, 0.0); - case kIROp_IntType: - return builder->getIntValue(type, 0); - case kIROp_VectorType: + // Check if we have a getter. + if (auto getterDecoration = origExtract->findDecoration<IRDifferentialGetterDecoration>()) { - IRInst* args[] = {getZeroOfType(builder, as<IRVectorType>(type)->getElementType())}; - return builder->emitIntrinsicInst( - type, - kIROp_constructVectorFromScalar, - 1, + + IRInst* getterFunc = getterDecoration->getGetterFunc(); + + // Must be a method with a single parameter. + SLANG_ASSERT(as<IRFuncType>(getterFunc->getDataType())->getParamCount() == 1); + + // Our getter func accepts a _pointer_ to the target type + // So we have to create a variable and store our type into memory + // here. This will eventually get optimized out in later passes. + // + auto diffTempVar = builder->emitVar( + diffBase->getDataType()); + + builder->emitStore(diffTempVar, diffBase); + + List<IRInst*> args; + args.add(diffTempVar); + + // Emit a call to the getter. The getter will return a reference type. + // We need to load from this to go to a non-ptr 'solid' type. + // + auto diffGetterCall = builder->emitCallInst( + as<IRFuncType>(getterFunc->getDataType())->getResultType(), + getterFunc, args); + + diffExtract = builder->emitLoad(diffGetterCall); } - default: - getSink()->diagnose(type->sourceLoc, - Diagnostics::internalCompilerError, - "could not generate zero value for given type"); - return nullptr; } + + return InstPair(primalExtract, diffExtract); + } + + InstPair transcribeFieldAddress(IRBuilder* builder, IRFieldAddress* origAddress) + { + IRInst* origBase = origAddress->getBase(); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto diffBase = findOrTranscribeDiffInst(builder, origBase); + + auto primalAddressType = (IRType*)lookupPrimalInst(origAddress->getDataType(), origAddress->getDataType()); + + IRInst* primalAddress = builder->emitFieldAddress(primalAddressType, primalBase, origAddress->getField()); + IRInst* diffAddress = nullptr; + + if (auto diffAddressType = differentiateType(builder, primalAddressType)) + { + // If we have a getter associated with this field, we want to use that. + if (auto getterDecoration = origAddress->findDecoration<IRDifferentialGetterDecoration>()) + { + auto getterFunc = getterDecoration->getGetterFunc(); + + // Add the base differential inst as the argument. + List<IRInst*> args; + args.add(diffBase); + + diffAddress = builder->emitCallInst( + as<IRFuncType>(getterFunc->getDataType())->getResultType(), + getterFunc, + args); + } + + } + + return InstPair(primalAddress, diffAddress); + } + + + InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr) + { + SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr)); + + IRInst* origBase = origGetElementPtr->getOperand(0); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1)); + + auto primalType = (IRType*)lookupPrimalInst(origGetElementPtr->getDataType(), origGetElementPtr->getDataType()); + + IRInst* primalOperands[] = {primalBase, primalIndex}; + IRInst* primalGetElementPtr = builder->emitIntrinsicInst( + primalType, + origGetElementPtr->getOp(), + 2, + primalOperands); + + IRInst* diffGetElementPtr = nullptr; + + if (auto diffType = differentiateType(builder, primalType)) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + { + IRInst* diffOperands[] = {diffBase, primalIndex}; + diffGetElementPtr = builder->emitIntrinsicInst( + diffType, + origGetElementPtr->getOp(), + 2, + diffOperands); + } + } + + return InstPair(primalGetElementPtr, diffGetElementPtr); + } + + + InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop) + { + // The loop comes with three blocks.. we just need to transcribe each one + // and assemble the new loop instruction. + + // Transcribe the target block (this is the 'condition' part of the loop, which + // will branch into the loop body) + auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock()); + + // Transcribe the break block (this is the block after the exiting the loop) + auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); + + // Transcribe the continue block (this is the 'update' part of the loop, which will + // branch into the condition block) + auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); + + + List<IRInst*> diffLoopOperands; + diffLoopOperands.add(diffTargetBlock); + diffLoopOperands.add(diffBreakBlock); + diffLoopOperands.add(diffContinueBlock); + + // If there are any other operands, use their primal versions. + for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++) + { + auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii)); + diffLoopOperands.add(primalOperand); + } + + IRInst* diffLoop = builder->emitIntrinsicInst( + nullptr, + kIROp_loop, + diffLoopOperands.getCount(), + diffLoopOperands.getBuffer()); + + return InstPair(diffLoop, diffLoop); + } + + InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse) + { + // The loop comes with three blocks.. we just need to transcribe each one + // and assemble the new loop instruction. + + // Transcribe the target block (this is the 'condition' part of the loop, which + // will branch into the loop body). + // Note that for the condition we use the primal inst (condition values should not have a + // differential) + auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition()); + SLANG_ASSERT(primalConditionBlock); + + // Transcribe the break block (this is the block after the exiting the loop) + auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock()); + SLANG_ASSERT(diffTrueBlock); + + // Transcribe the continue block (this is the 'update' part of the loop, which will + // branch into the condition block) + auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock()); + SLANG_ASSERT(diffFalseBlock); + + // Transcribe the continue block (this is the 'update' part of the loop, which will + // branch into the condition block) + auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock()); + SLANG_ASSERT(diffAfterBlock); + + + List<IRInst*> diffIfElseArgs; + diffIfElseArgs.add(primalConditionBlock); + diffIfElseArgs.add(diffTrueBlock); + diffIfElseArgs.add(diffFalseBlock); + diffIfElseArgs.add(diffAfterBlock); + + // If there are any other operands, use their primal versions. + for (UIndex ii = diffIfElseArgs.getCount(); ii < origIfElse->getOperandCount(); ii++) + { + auto primalOperand = findOrTranscribePrimalInst(builder, origIfElse->getOperand(ii)); + diffIfElseArgs.add(primalOperand); + } + + IRInst* diffLoop = builder->emitIntrinsicInst( + nullptr, + kIROp_ifElse, + diffIfElseArgs.getCount(), + diffIfElseArgs.getBuffer()); + + return InstPair(diffLoop, diffLoop); + } + + // Transcribe a function definition. + InstPair transcribeFunc(IRBuilder* builder, IRFunc* origFunc) + { + IRFunc* primalFunc = nullptr; + + auto oldLoc = builder->getInsertLoc(); + + // If this is a top-level function, there is no need to clone it + // since it is visible in all the scopes. + // Otherwise, we need to clone it in case of generic scopes. + // + // TODO(sai): Is this the correct thing to do? Can a function cloned inside a + // generic scope but is not the return value of that generic, be used within + // that scope? Or do we have to call out to the original generic specialized with + // the current generic params? + // + bool isTopLevelFunc = (as<IRModuleInst>(origFunc->parent) != nullptr); + if (isTopLevelFunc) + { + builder->setInsertBefore(origFunc); + primalFunc = origFunc; + } + else + { + // TODO(sai): this might never be called, and it might never make sense + // to call it either. Potentially remove this. + primalFunc = as<IRFunc>( + cloneInst(&cloneEnv, builder, origFunc)); + } + + auto diffFunc = builder->createFunc(); + + SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); + IRType* diffFuncType = this->differentiateFunctionType( + builder, + as<IRFuncType>(origFunc->getFullType())); + diffFunc->setFullType(diffFuncType); + + // TODO(sai): Replace naming scheme + // if (auto jvpName = this->getJVPFuncName(builder, primalFn)) + // builder->addNameHintDecoration(diffFunc, jvpName); + + // Transcribe children from origFunc into diffFunc + builder->setInsertInto(diffFunc); + for (auto block = origFunc->getFirstBlock(); block; block = block->getNextBlock()) + this->transcribe(builder, block); + + // Reset builder position + builder->setInsertLoc(oldLoc); + + return InstPair(primalFunc, diffFunc); + } + + // Transcribe a generic definition + InstPair transcribeGeneric(IRBuilder* builder, IRGeneric* origGeneric) + { + // For now, we assume there's only one generic layer. So this inst must be top level + bool isTopLevel = (as<IRModuleInst>(origGeneric->getParent()) != nullptr); + SLANG_RELEASE_ASSERT(isTopLevel); + + IRGeneric* primalGeneric = origGeneric; + + auto oldLoc = builder->getInsertLoc(); + builder->setInsertBefore(origGeneric); + + auto diffGeneric = builder->emitGeneric(); + + // Process type of generic. If the generic is a function, then it's type will also be a + // generic and this logic will transcribe that generic first before continuing with the + // function itself. + // + auto primalType = primalGeneric->getFullType(); + + IRType* diffType = nullptr; + if (primalType) + { + diffType = (IRType*) findOrTranscribeDiffInst(builder, primalType); + } + + diffGeneric->setFullType(diffType); + + // TODO(sai): Replace naming scheme + // if (auto jvpName = this->getJVPFuncName(builder, primalFn)) + // builder->addNameHintDecoration(diffFunc, jvpName); + + // Transcribe children from origFunc into diffFunc. + builder->setInsertInto(diffGeneric); + for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock()) + this->transcribe(builder, block); + + // Reset builder position. + builder->setInsertLoc(oldLoc); + + return InstPair(primalGeneric, diffGeneric); } IRInst* transcribe(IRBuilder* builder, IRInst* origInst) { + // If a differential intstruction is already mapped for + // this original inst, return that. + // + if (auto diffInst = lookupDiffInst(origInst, nullptr)) + { + SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check. + return diffInst; + } + + // Otherwise, dispatch to the appropriate method + // depending on the op-code. + // + instsInProgress.Add(origInst); InstPair pair = transcribeInst(builder, origInst); if (auto primalInst = pair.primal) { mapPrimalInst(origInst, pair.primal); - mapDifferentialInst(origInst, pair.differential); return pair.differential; } + instsInProgress.Remove(origInst); + getSink()->diagnose(origInst->sourceLoc, Diagnostics::internalCompilerError, "failed to transcibe instruction"); @@ -911,7 +1732,7 @@ struct JVPTranscriber InstPair transcribeInst(IRBuilder* builder, IRInst* origInst) { - // Handle common operations + // Handle common SSA-style operations switch (origInst->getOp()) { case kIROp_Param: @@ -934,6 +1755,14 @@ struct JVPTranscriber case kIROp_Sub: case kIROp_Div: return transcribeBinaryArith(builder, origInst); + + case kIROp_Less: + case kIROp_Greater: + case kIROp_And: + case kIROp_Or: + case kIROp_Geq: + case kIROp_Leq: + return transcribeBinaryLogic(builder, origInst); case kIROp_Construct: return transcribeConstruct(builder, origInst); @@ -945,24 +1774,91 @@ struct JVPTranscriber return transcribeSwizzle(builder, as<IRSwizzle>(origInst)); case kIROp_constructVectorFromScalar: + case kIROp_MakeTuple: return transcribeByPassthrough(builder, origInst); case kIROp_unconditionalBranch: - case kIROp_conditionalBranch: return transcribeControlFlow(builder, origInst); case kIROp_FloatLit: + case kIROp_IntLit: + case kIROp_VoidLit: return transcribeConst(builder, origInst); + case kIROp_Specialize: + getSink()->diagnose(origInst->sourceLoc, + Diagnostics::unexpected, + "should not be attempting to differentiate anything specialized here."); + + case kIROp_lookup_interface_method: + return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst)); + + case kIROp_FieldExtract: + return transcribeFieldExtract(builder, as<IRFieldExtract>(origInst)); + + case kIROp_FieldAddress: + return transcribeFieldAddress(builder, as<IRFieldAddress>(origInst)); + + case kIROp_getElement: + case kIROp_getElementPtr: + return transcribeGetElement(builder, origInst); + + case kIROp_loop: + return transcribeLoop(builder, as<IRLoop>(origInst)); + + case kIROp_ifElse: + return transcribeIfElse(builder, as<IRIfElse>(origInst)); + + case kIROp_DifferentiableTypeDictionary: + // Ignore dictionary insts. + return InstPair(nullptr, nullptr); + } // If none of the cases have been hit, check if the instruction is a - // type. - // For now we don't have logic to differentiate types that appear in blocks. - // So, we clone and avoid differentiating them. - // + // type. Only need to explicitly differentiate types if they appear inside a block. + // if (auto origType = as<IRType>(origInst)) - return InstPair(cloneInst(&cloneEnv, builder, origType), nullptr); + { + // If this is a generic type, transcibe the parent + // generic and derive the type from the transcribed generic's + // return value. + // + if (as<IRGeneric>(origType->getParent()->getParent()) && + findInnerMostGenericReturnVal(as<IRGeneric>(origType->getParent()->getParent())) == origType && + !instsInProgress.Contains(origType->getParent()->getParent())) + { + auto origGenericType = origType->getParent()->getParent(); + auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType); + auto innerDiffGenericType = findInnerMostGenericReturnVal(as<IRGeneric>(diffGenericType)); + return InstPair( + origGenericType, + innerDiffGenericType + ); + } + else if (as<IRBlock>(origType->getParent())) + return InstPair( + cloneInst(&cloneEnv, builder, origType), + differentiateType(builder, origType)); + else + return InstPair( + cloneInst(&cloneEnv, builder, origType), + nullptr); + } + + // Handle instructions with children + switch (origInst->getOp()) + { + case kIROp_Func: + return transcribeFunc(builder, as<IRFunc>(origInst)); + + case kIROp_Block: + return transcribeBlock(builder, as<IRBlock>(origInst)); + + case kIROp_Generic: + return transcribeGeneric(builder, as<IRGeneric>(origInst)); + } + // If we reach this statement, the instruction type is likely unhandled. getSink()->diagnose(origInst->sourceLoc, @@ -1042,6 +1938,14 @@ struct JVPDerivativeContext // IRMakeDifferentialPair with an IRMakeStruct. // modified |= processPairTypes(builder, module->getModuleInst(), (&diffConformanceContextStorage)); + + // Temporary fix: Move generated types, if any, to before their use locations. + (&pairBuilderStorage)->relocateNewTypes(builder); + + // Remove all kIROp_DifferentiableTypeDictionary instructions and + // kIROp_DifferentialGetterDecoration decorations + // + modified |= stripDiffTypeInformation(builder, module->getModuleInst()); return modified; } @@ -1079,19 +1983,45 @@ struct JVPDerivativeContext if (auto jvpDiffInst = as<IRJVPDifferentiate>(child)) { - auto baseFunction = jvpDiffInst->getBaseFn(); + auto baseInst = jvpDiffInst->getBaseFn(); + + IRGlobalValueWithCode* baseFunction = nullptr; + + if (auto specializeInst = as<IRSpecialize>(baseInst)) + { + baseFunction = as<IRGlobalValueWithCode>(specializeInst->getBase()); + } + else if (auto globalValWithCode = as<IRGlobalValueWithCode>(baseInst)) + { + baseFunction = globalValWithCode; + } + + SLANG_ASSERT(baseFunction); + // If the JVP Reference already exists, no need to // differentiate again. // - if(lookupJVPReference(baseFunction)) continue; + if (lookupJVPReference(baseFunction)) continue; - if (isFunctionMarkedForJVP(as<IRGlobalValueWithCode>(baseFunction))) + if (isMarkedForJVP(baseFunction)) { - IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(baseFunction)); - builder->addJVPDerivativeReferenceDecoration(baseFunction, jvpFunction); - workQueue->push(jvpFunction); + if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction)) + { + IRInst* diffFunc = (&transcriberStorage)->transcribe(builder, baseFunction); + SLANG_ASSERT(diffFunc); + builder->addJVPDerivativeReferenceDecoration(baseFunction, diffFunc); + workQueue->push(diffFunc); + } + else + { + // TODO(Sai): This would probably be better with a more specific + // error code. + getSink()->diagnose(jvpDiffInst->sourceLoc, + Diagnostics::internalCompilerError, + "Unexpected instruction. Expected func or generic"); + } } - else + else { // TODO(Sai): This would probably be better with a more specific // error code. @@ -1106,55 +2036,33 @@ struct JVPDerivativeContext return true; } - // Run through all the global-level instructions, - // looking for callables. - // Note: We're only processing global callables (IRGlobalValueWithCode) - // for now. - // - bool processMarkedGlobalFunctions(IRBuilder* builder) + IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext*) { - for (auto inst : module->getGlobalInsts()) + + if (auto pairType = as<IRDifferentialPairType>(type)) { - // If the instr is a callable, get all the basic blocks - if (auto callable = as<IRGlobalValueWithCode>(inst)) - { - if (isFunctionMarkedForJVP(callable)) - { - SLANG_ASSERT(as<IRFunc>(callable)); + builder->setInsertBefore(pairType); - IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(callable)); - builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction); + auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType( + builder, + pairType->getValueType()); - unmarkForJVP(callable); - } - } - } - return true; - } + pairType->replaceUsesWith(diffPairStructType); + pairType->removeAndDeallocate(); - IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext* diffContext) - { - if (diffContext->isInterfaceAvailable) + return diffPairStructType; + } + else if (auto loweredStructType = as<IRStructType>(type)) { - if (auto pairType = as<IRDifferentialPairType>(type)) - { - builder->setInsertBefore(pairType); - - auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType( - builder, - pairType->getValueType()); - - pairType->replaceUsesWith(diffPairStructType); - pairType->removeAndDeallocate(); - - return diffPairStructType; - } - else if (auto loweredStructType = as<IRStructType>(type)) - { - // Already lowered to struct. - return loweredStructType; - } + // Already lowered to struct. + return loweredStructType; } + else if (auto specializedStructType = as<IRSpecialize>(type)) + { + // Already lowered to specialized struct. + return specializedStructType; + } + return nullptr; } @@ -1171,7 +2079,7 @@ struct JVPDerivativeContext operands.add(makePairInst->getPrimalValue()); operands.add(makePairInst->getDifferentialValue()); - auto makeStructInst = builder->emitMakeStruct(as<IRStructType>(diffPairStructType), operands); + auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands); makePairInst->replaceUsesWith(makeStructInst); makePairInst->removeAndDeallocate(); @@ -1258,10 +2166,43 @@ struct JVPDerivativeContext return modified; } + bool stripDiffTypeInformation(IRBuilder* builder, IRInst* parent) + { + bool modified = false; + + auto child = parent->getFirstChild(); + while (child) + { + auto nextChild = child->getNextInst(); + + if (child->getOp() == kIROp_DifferentiableTypeDictionary) + { + child->removeAndDeallocate(); + child = nextChild; + modified = true; + continue; + } + + if (auto getterDecoration = child->findDecoration<IRDifferentialGetterDecoration>()) + { + getterDecoration->removeAndDeallocate(); + } + + if (child->getFirstChild() != nullptr) + { + modified |= stripDiffTypeInformation(builder, child); + } + + child = nextChild; + } + + return modified; + } + // Checks decorators to see if the function should // be differentiated (kIROp_JVPDerivativeMarkerDecoration) // - bool isFunctionMarkedForJVP(IRGlobalValueWithCode* callable) + bool isMarkedForJVP(IRGlobalValueWithCode* callable) { for(auto decoration = callable->getFirstDecoration(); decoration; @@ -1292,63 +2233,8 @@ struct JVPDerivativeContext } } - List<IRParam*> emitFuncParameters(IRBuilder* builder, IRFuncType* dataType) - { - List<IRParam*> params; - for(UIndex i = 0; i < dataType->getParamCount(); i++) - { - params.add( - builder->emitParam(dataType->getParamType(i))); - } - return params; - } - - // Perform forward-mode automatic differentiation on - // the intstructions. - // - IRFunc* emitJVPFunction(IRBuilder* builder, - IRFunc* primalFn) - { - eliminatePhisInFunc(LivenessMode::Disabled, module, primalFn); - - builder->setInsertBefore(primalFn->getNextInst()); - - auto jvpFn = builder->createFunc(); - - SLANG_ASSERT(as<IRFuncType>(primalFn->getFullType())); - IRType* jvpFuncType = transcriberStorage.differentiateFunctionType( - builder, - as<IRFuncType>(primalFn->getFullType())); - jvpFn->setFullType(jvpFuncType); - - if (auto jvpName = getJVPFuncName(builder, primalFn)) - builder->addNameHintDecoration(jvpFn, jvpName); - - builder->setInsertInto(jvpFn); - - // Emit a block instruction for every block in the function, and map it as the - // corresponding differential. - // - for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock()) - { - auto jvpBlock = builder->emitBlock(); - transcriberStorage.mapDifferentialInst(block, jvpBlock); - transcriberStorage.mapPrimalInst(block, jvpBlock); - } - - // Go back over the blocks, and process the children of each block. - for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock()) - { - auto jvpBlock = as<IRBlock>(transcriberStorage.lookupDiffInst(block, block)); - SLANG_ASSERT(jvpBlock); - emitJVPBlock(builder, block, jvpBlock); - } - - return jvpFn; - } - IRStringLit* getJVPFuncName(IRBuilder* builder, - IRFunc* func) + IRInst* func) { auto oldLoc = builder->getInsertLoc(); builder->setInsertBefore(func); @@ -1368,36 +2254,6 @@ struct JVPDerivativeContext return name; } - IRBlock* emitJVPBlock(IRBuilder* builder, - IRBlock* origBlock, - IRBlock* jvpBlock = nullptr) - { - JVPTranscriber* transcriber = &(transcriberStorage); - - // Create if not already created, and then insert into new block. - if (!jvpBlock) - jvpBlock = builder->emitBlock(); - else - builder->setInsertInto(jvpBlock); - - - // First transcribe every parameter in the block. - for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) - { - transcriber->transcribe(builder, param); - } - - // Then, run through every instruction and use the transcriber to generate the appropriate - // derivative code. - // - for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) - { - transcriber->transcribe(builder, child); - } - - return jvpBlock; - } - JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) : module(module), sink(sink), diffConformanceContextStorage(module->getModuleInst()), |
