From b2ca2d5a4efeae807d3c3f48f60235e47413b559 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 23 Aug 2024 21:45:59 -0700 Subject: Make variadic generics work with interfaces and forward autodiff. (#4905) --- source/slang/slang-ir-autodiff.cpp | 156 +++++++++++++++++++++++++++++++++++-- 1 file changed, 150 insertions(+), 6 deletions(-) (limited to 'source/slang/slang-ir-autodiff.cpp') diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 6b275179c..b7c2037e5 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -44,6 +44,13 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK return entry->getRequirementVal(); } } + else if (as(witness)) + { + // We are looking up a witness from a type pack. + // This is only allowed if we are looking up a differential type. + // We should turn this into an actual witness table for the type pack/tuple type. + SLANG_UNEXPECTED("looking up from a witness pack is invalid and should have been lowered."); + } else { return builder->emitLookupInterfaceMethodInst( @@ -434,10 +441,33 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) } else { - differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness()); + auto witness = item->getWitness(); // Also register the type's differential type with the same witness. + auto concreteType = item->getConcreteType(); IRBuilder subBuilder(item->getConcreteType()); + if (as(concreteType) || as(concreteType)) + { + // For tuple types, register the differential type for each element, but don't register for the + // tuple/typepack itself. + auto witnessPack = as(witness); + SLANG_ASSERT(witnessPack); + + for (UInt i = 0; i < concreteType->getOperandCount(); i++) + { + auto element = concreteType->getOperand(i); + auto elementWitness = witnessPack->getOperand(i); + differentiableWitnessDictionary.addIfNotExists( + (IRType*)element, + _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey)); + } + return; + } + else + { + differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness()); + } + if (!as(item->getConcreteType())) { differentiableWitnessDictionary.addIfNotExists( @@ -768,16 +798,18 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build SLANG_UNIMPLEMENTED_X("Impl"); } + case kIROp_TypePack: case kIROp_TupleType: { - auto tupleType = as(primalType); List diffTypeList; // TODO: what if we have type parameters here? - for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) + for (UIndex ii = 0; ii < primalType->getOperandCount(); ii++) diffTypeList.add( - differentiateType(builder, (IRType*)tupleType->getOperand(ii))); - - return builder->getTupleType(diffTypeList); + differentiateType(builder, (IRType*)primalType->getOperand(ii))); + if (primalType->getOp() == kIROp_TupleType) + return builder->getTupleType(diffTypeList); + else + return builder->getTypePack((UInt)diffTypeList.getCount(), diffTypeList.getBuffer()); } default: @@ -795,6 +827,12 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil { SLANG_RELEASE_ASSERT(witness || as(primalType)); } + if (as(witness)) + { + // If registered witness is a witness pack for a type pack, + // we should reconstruct the true witness table. + witness = nullptr; + } if (!witness) { @@ -811,6 +849,14 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil { witness = getExtractExistensialTypeWitness(builder, extractExistential); } + else if (auto typePack = as(primalType)) + { + witness = getTupleWitness(builder, typePack); + } + else if (auto tupleType = as(primalType)) + { + witness = getTupleWitness(builder, tupleType); + } } return witness; } @@ -963,6 +1009,104 @@ IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder return table; } +IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder, IRInst* inTupleType) +{ + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffTupleType = (IRType*)differentiateType(builder, (IRType*)inTupleType); + + if (!diffTupleType) + return nullptr; + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + + auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + + // Fill in differential method implementations. + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffTupleType, diffTupleType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType)); + b.emitBlock(); + auto p0 = b.emitParam(diffTupleType); + auto p1 = b.emitParam(diffTupleType); + List results; + for (UInt i = 0; i < inTupleType->getOperandCount(); i++) + { + auto elementType = inTupleType->getOperand(i); + auto diffElementType = (IRType*)diffTupleType->getOperand(i); + auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType); + IRInst* elementResult = nullptr; + if (!innerWitness) + { + elementResult = b.getVoidValue(); + } + else + { + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); + auto iVal = b.getIntValue(b.getIntType(), i); + IRInst* args[2] = { + b.emitGetTupleElement(diffElementType, p0, iVal), + b.emitGetTupleElement(diffElementType, p1, iVal) }; + elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args); + } + results.add(elementResult); + } + IRInst* resultVal = nullptr; + if (diffTupleType->getOp() == kIROp_TupleType) + resultVal = b.emitMakeTuple(diffTupleType, results); + else + resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); + b.emitReturn(resultVal); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); + b.emitBlock(); + List results; + for (UInt i = 0; i < inTupleType->getOperandCount(); i++) + { + auto elementType = inTupleType->getOperand(i); + auto diffElementType = (IRType*)diffTupleType->getOperand(i); + auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType); + IRInst* elementResult = nullptr; + if (!innerWitness) + { + elementResult = b.getVoidValue(); + } + else + { + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); + elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr); + } + results.add(elementResult); + } + IRInst* resultVal = nullptr; + if (diffTupleType->getOp() == kIROp_TupleType) + resultVal = b.emitMakeTuple(diffTupleType, results); + else + resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); + b.emitReturn(resultVal); + } + + // Record this in the context for future lookups + differentiableWitnessDictionary[(IRType*)inTupleType] = table; + + return table; +} + IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness( IRBuilder* builder, IRExtractExistentialType* extractExistentialType) -- cgit v1.2.3