summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-lower-generics.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2020-07-13 15:16:09 -0700
committerGitHub <noreply@github.com>2020-07-13 15:16:09 -0700
commit48f26ef082fa3b0c2a02dc57585f7e43210bbb63 (patch)
treee3e13e8034c0f2efe1454a51b4df0290056dae9f /source/slang/slang-ir-lower-generics.cpp
parent249f48dbb5e240c713661be969a6939ec57561e5 (diff)
Dynamic code gen for functions returning generic types. (#1439)
* Dynamic code gen for functions returning generic types. * Add expected test result.
Diffstat (limited to 'source/slang/slang-ir-lower-generics.cpp')
-rw-r--r--source/slang/slang-ir-lower-generics.cpp287
1 files changed, 189 insertions, 98 deletions
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp
index 8b003f854..260f87bde 100644
--- a/source/slang/slang-ir-lower-generics.cpp
+++ b/source/slang/slang-ir-lower-generics.cpp
@@ -60,6 +60,15 @@ namespace Slang
case kIROp_AssociatedType:
case kIROp_InterfaceType:
return true;
+ case kIROp_Specialize:
+ {
+ for (UInt i = 0; i < typeInst->getOperandCount(); i++)
+ {
+ if (isPolymorphicType(typeInst->getOperand(i)))
+ return true;
+ }
+ return false;
+ }
default:
break;
}
@@ -124,6 +133,41 @@ namespace Slang
}
}
cloneInstDecorationsAndChildren(&cloneEnv, &sharedBuilderStorage, func, loweredFunc);
+
+ // If the function returns a generic typed value, we need to turn it
+ // into an `out` parameter, since only the caller can allocate space
+ // for it.
+ auto oldFuncType = cast<IRFuncType>(func->getDataType());
+ if (isPolymorphicType(oldFuncType->getResultType()))
+ {
+ builder.setInsertInto(loweredFunc->getFirstBlock());
+ // We defer creation of the returnVal parameter until we see the first
+ // `return` instruction, because we can only obtain the cloned return type
+ // of this function by checking the type of the cloned return inst.
+ IRParam* retValParam = nullptr;
+ // Translate all return insts to `store`s.
+ // Those `store`s will be processed and translated into `copy`s when we
+ // get to process them via workList.
+ for (auto bb : loweredFunc->getBlocks())
+ {
+ auto retInst = as<IRReturnVal>(bb->getTerminator());
+ if (!retInst)
+ continue;
+ if (!retValParam)
+ {
+ // Now we have the return type, emit the returnVal parameter.
+ // The type of this parameter is still not translated to RawPointer yet,
+ // and will be processed together with all the other existing parameters.
+ retValParam = builder.emitParamAtHead(
+ builder.getOutType(retInst->getVal()->getDataType()));
+ }
+ builder.setInsertBefore(retInst);
+ builder.emitStore(retValParam, retInst->getVal());
+ builder.emitReturn();
+ retInst->removeAndDeallocate();
+ }
+ }
+
auto block = as<IRBlock>(loweredFunc->getFirstChild());
for (auto param : clonedParams)
{
@@ -139,7 +183,10 @@ namespace Slang
param = param->getNextInst())
{
// Generic typed parameters have a type that is a param itself.
- if (auto rttiParam = as<IRParam>(param->getDataType()))
+ auto paramType = param->getDataType();
+ if (auto ptrType = as<IRPtrTypeBase>(paramType))
+ paramType = ptrType->getValueType();
+ if (auto rttiParam = as<IRParam>(paramType))
{
SLANG_ASSERT(isPointerOfType(rttiParam->getDataType(), kIROp_RTTIType));
// Lower into a function parameter of raw pointer type.
@@ -189,6 +236,14 @@ namespace Slang
{
auto loweredParamType = lowerParameterType(builder, paramType);
translated = translated || (loweredParamType != paramType);
+ if (translated && i == 0)
+ {
+ // We are translating the return value, this means that
+ // the return value must be passed explicitly via an `out` parameter.
+ // In this case, the new return value will be `void`, and the
+ // translated return value type will be the first parameter type;
+ newOperands.add(builder->getVoidType());
+ }
newOperands.add(loweredParamType);
}
}
@@ -382,110 +437,146 @@ namespace Slang
return result;
}
- void processInst(IRInst* inst)
+ void lowerCall(IRCall* callInst)
{
- if (auto callInst = as<IRCall>(inst))
+ // If we see a call(specialize(gFunc, Targs), args),
+ // translate it into call(gFunc, args, Targs).
+ auto funcOperand = callInst->getOperand(0);
+ IRInst* loweredFunc = nullptr;
+ auto specializeInst = as<IRSpecialize>(funcOperand);
+ if (!specializeInst)
+ return;
+
+ auto funcToSpecialize = specializeInst->getOperand(0);
+ List<IRType*> paramTypes;
+ IRFuncType* funcType = nullptr;
+ if (auto interfaceLookup = as<IRLookupWitnessMethod>(funcToSpecialize))
{
- // If we see a call(specialize(gFunc, Targs), args),
- // translate it into call(gFunc, args, Targs).
- auto funcOperand = callInst->getOperand(0);
- IRInst* loweredFunc = nullptr;
- if (auto specializeInst = as<IRSpecialize>(funcOperand))
+ // The callee is a result of witness table lookup, we will only
+ // translate the call.
+ IRInst* callee = nullptr;
+ auto witnessTableType = cast<IRWitnessTableType>(interfaceLookup->getWitnessTable()->getDataType());
+ auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType()));
+ for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
{
- auto funcToSpecialize = specializeInst->getOperand(0);
- List<IRType*> paramTypes;
- if (auto interfaceLookup = as<IRLookupWitnessMethod>(funcToSpecialize))
- {
- // The callee is a result of witness table lookup, we will only
- // translate the call.
- IRInst* callee = nullptr;
- auto witnessTableType = cast<IRWitnessTableType>(interfaceLookup->getWitnessTable()->getDataType());
- auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType()));
- for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
- {
- auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
- if (entry->getRequirementKey() == interfaceLookup->getOperand(1))
- {
- callee = entry->getRequirementVal();
- break;
- }
- }
- auto funcType = cast<IRFuncType>(callee);
- for (UInt i = 0; i < funcType->getParamCount(); i++)
- paramTypes.add(funcType->getParamType(i));
- loweredFunc = funcToSpecialize;
- }
- else
+ auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
+ if (entry->getRequirementKey() == interfaceLookup->getOperand(1))
{
- loweredFunc = lowerGenericFunction(specializeInst->getOperand(0));
- if (loweredFunc == specializeInst->getOperand(0))
- {
- // This is an intrinsic function, don't transform.
- return;
- }
- for (auto param : as<IRFunc>(loweredFunc)->getParams())
- paramTypes.add(param->getDataType());
+ callee = entry->getRequirementVal();
+ break;
}
+ }
+ funcType = cast<IRFuncType>(callee);
+ loweredFunc = funcToSpecialize;
+ }
+ else
+ {
+ loweredFunc = lowerGenericFunction(specializeInst->getOperand(0));
+ if (loweredFunc == specializeInst->getOperand(0))
+ {
+ // This is an intrinsic function, don't transform.
+ return;
+ }
+ funcType = cast<IRFuncType>(loweredFunc->getDataType());
+ }
- IRBuilder builderStorage;
- auto builder = &builderStorage;
- builder->sharedBuilder = &sharedBuilderStorage;
- builder->setInsertBefore(inst);
- List<IRInst*> args;
- auto rawPtrType = builder->getRawPointerType();
- for (UInt i = 0; i < callInst->getArgCount(); i++)
- {
- auto arg = callInst->getArg(i);
- if (as<IRRawPointerType>(paramTypes[i]) &&
- !as<IRRawPointerType>(arg->getDataType()))
- {
- // We are calling a generic function that with an argument of
- // concrete type. We need to convert this argument to void*.
-
- // Ideally this should just be a GetElementAddress inst.
- // However the current code emitting logic for this instruction
- // doesn't truly respect the pointerness and does not produce
- // what we needed. For now we use another instruction here
- // to keep changes minimal.
- arg = builder->emitGetAddress(
- rawPtrType,
- arg);
- }
- args.add(arg);
- }
- for (UInt i = 0; i < specializeInst->getArgCount(); i++)
- {
- auto arg = specializeInst->getArg(i);
- // Translate Type arguments into RTTI object.
- if (as<IRType>(arg))
- {
- // We are using a simple type to specialize a callee.
- // Generate RTTI for this type.
- auto rttiObject = maybeEmitRTTIObject(arg);
- arg = builder->emitGetAddress(
- builder->getPtrType(builder->getRTTIType()),
- rttiObject);
- }
- else if (arg->op == kIROp_Specialize)
- {
- // The type argument used to specialize a callee is itself a
- // specialization of some generic type.
- // TODO: generate RTTI object for specializations of generic types.
- SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types");
- }
- else if (arg->op == kIROp_RTTIObject)
- {
- // We are inside a generic function and using a generic parameter
- // to specialize another callee. The generic parameter of the caller
- // has already been translated into an RTTI object, so we just need
- // to pass this object down.
- }
- args.add(arg);
- }
- auto newCall = builder->emitCallInst(callInst->getFullType(), loweredFunc, args);
- callInst->replaceUsesWith(newCall);
- callInst->removeAndDeallocate();
+ for (UInt i = 0; i < funcType->getParamCount(); i++)
+ paramTypes.add(funcType->getParamType(i));
+
+ IRBuilder builderStorage;
+ auto builder = &builderStorage;
+ builder->sharedBuilder = &sharedBuilderStorage;
+ builder->setInsertBefore(callInst);
+
+ List<IRInst*> args;
+
+ // Indicates whether the caller should allocate space for return value.
+ // If the lowered callee returns void and this call inst has a type that is not void,
+ // then we are calling a transformed function that expects caller allocated return value
+ // as the first argument.
+ bool shouldCallerAllocateReturnValue = (funcType->getResultType()->op == kIROp_VoidType &&
+ callInst->getDataType() != funcType->getResultType());
+
+ IRVar* retVarInst = nullptr;
+ int startParamIndex = 0;
+ if (shouldCallerAllocateReturnValue)
+ {
+ // Declare a var for the return value.
+ retVarInst = builder->emitVar(callInst->getFullType());
+ args.add(retVarInst);
+ startParamIndex = 1;
+ }
+
+ for (UInt i = 0; i < callInst->getArgCount(); i++)
+ {
+ auto arg = callInst->getArg(i);
+ if (as<IRRawPointerType>(paramTypes[i] + startParamIndex) &&
+ !as<IRRawPointerType>(arg->getDataType()))
+ {
+ // We are calling a generic function that with an argument of
+ // concrete type. We need to convert this argument to void*.
+
+ // Ideally this should just be a GetElementAddress inst.
+ // However the current code emitting logic for this instruction
+ // doesn't truly respect the pointerness and does not produce
+ // what we needed. For now we use another instruction here
+ // to keep changes minimal.
+ arg = builder->emitGetAddress(
+ builder->getRawPointerType(),
+ arg);
+ }
+ args.add(arg);
+ }
+ for (UInt i = 0; i < specializeInst->getArgCount(); i++)
+ {
+ auto arg = specializeInst->getArg(i);
+ // Translate Type arguments into RTTI object.
+ if (as<IRType>(arg))
+ {
+ // We are using a simple type to specialize a callee.
+ // Generate RTTI for this type.
+ auto rttiObject = maybeEmitRTTIObject(arg);
+ arg = builder->emitGetAddress(
+ builder->getPtrType(builder->getRTTIType()),
+ rttiObject);
+ }
+ else if (arg->op == kIROp_Specialize)
+ {
+ // The type argument used to specialize a callee is itself a
+ // specialization of some generic type.
+ // TODO: generate RTTI object for specializations of generic types.
+ SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types");
}
+ else if (arg->op == kIROp_RTTIObject)
+ {
+ // We are inside a generic function and using a generic parameter
+ // to specialize another callee. The generic parameter of the caller
+ // has already been translated into an RTTI object, so we just need
+ // to pass this object down.
+ }
+ args.add(arg);
+ }
+ auto callInstType = retVarInst ? builder->getVoidType() : callInst->getFullType();
+ auto newCall = builder->emitCallInst(callInstType, loweredFunc, args);
+ if (retVarInst)
+ {
+ auto loadInst = builder->emitLoad(retVarInst);
+ callInst->replaceUsesWith(loadInst);
+ addToWorkList(loadInst);
+ addToWorkList(retVarInst);
+ }
+ else
+ {
+ callInst->replaceUsesWith(newCall);
+ }
+ callInst->removeAndDeallocate();
+ }
+
+ void processInst(IRInst* inst)
+ {
+ if (auto callInst = as<IRCall>(inst))
+ {
+ lowerCall(callInst);
}
else if (auto witnessTable = as<IRWitnessTable>(inst))
{