summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-emit-cpp.cpp148
-rw-r--r--source/slang/slang-emit-cpp.h13
-rw-r--r--source/slang/slang-ir-inst-defs.h11
-rw-r--r--source/slang/slang-ir-insts.h23
-rw-r--r--source/slang/slang-ir-lower-generics.cpp31
-rw-r--r--source/slang/slang-ir.cpp14
-rw-r--r--source/slang/slang-lower-to-ir.cpp22
-rw-r--r--tests/compute/dynamic-dispatch-1.slang38
-rw-r--r--tests/compute/dynamic-dispatch-1.slang.expected.txt4
9 files changed, 294 insertions, 10 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index 032f04ff3..accb290fa 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -1446,6 +1446,24 @@ UnownedStringSlice CPPSourceEmitter::_getFuncName(const HLSLIntrinsic* specOp)
return m_slicePool.getSlice(handle);
}
+UnownedStringSlice CPPSourceEmitter::_getWitnessTableWrapperFuncName(IRFunc* func)
+{
+ StringSlicePool::Handle handle = StringSlicePool::kNullHandle;
+ if (m_witnessTableWrapperFuncNameMap.TryGetValue(func, handle))
+ {
+ return m_slicePool.getSlice(handle);
+ }
+
+ StringBuilder builder;
+ builder << getName(func) << "_wtwrapper";
+
+ handle = m_slicePool.add(builder);
+ m_witnessTableWrapperFuncNameMap.Add(func, handle);
+
+ SLANG_ASSERT(handle != StringSlicePool::kNullHandle);
+ return m_slicePool.getSlice(handle);
+}
+
SlangResult CPPSourceEmitter::calcFuncName(const HLSLIntrinsic* specOp, StringBuilder& outBuilder)
{
typedef HLSLIntrinsic::Op Op;
@@ -1591,6 +1609,83 @@ void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable)
pendingWitnessTableDefinitions.add(witnessTable);
}
+void CPPSourceEmitter::_emitWitnessTableWrappers()
+{
+ for (auto witnessTable : pendingWitnessTableDefinitions)
+ {
+ for (auto child : witnessTable->getChildren())
+ {
+ if (auto entry = as<IRWitnessTableEntry>(child))
+ {
+ if (auto funcVal = as<IRFunc>(entry->getSatisfyingVal()))
+ {
+ emitType(funcVal->getResultType());
+ m_writer->emit(" ");
+ m_writer->emit(_getWitnessTableWrapperFuncName(funcVal));
+ m_writer->emit("(");
+ // Emit parameter list.
+ {
+ bool isFirst = true;
+ for (auto param : funcVal->getParams())
+ {
+ if (as<IRTypeType>(param->getFullType()))
+ continue;
+
+ if (isFirst)
+ isFirst = false;
+ else
+ m_writer->emit(",");
+
+ if (param->findDecoration<IRThisPointerDecoration>())
+ {
+ m_writer->emit("void* ");
+ m_writer->emit(getName(param));
+ continue;
+ }
+ emitSimpleFuncParamImpl(param);
+ }
+ }
+ m_writer->emit(")\n{\n");
+ m_writer->indent();
+ m_writer->emit("return ");
+ m_writer->emit(getName(funcVal));
+ m_writer->emit("(");
+ // Emit argument list.
+ {
+ bool isFirst = true;
+ for (auto param : funcVal->getParams())
+ {
+ if (as<IRTypeType>(param->getFullType()))
+ continue;
+
+ if (isFirst)
+ isFirst = false;
+ else
+ m_writer->emit(", ");
+
+ if (param->findDecoration<IRThisPointerDecoration>())
+ {
+ m_writer->emit("*static_cast<");
+ emitType(param->getFullType());
+ m_writer->emit("*>(");
+ m_writer->emit(getName(param));
+ m_writer->emit(")");
+ }
+ else
+ {
+ m_writer->emit(getName(param));
+ }
+ }
+ }
+ m_writer->emit(");\n");
+ m_writer->dedent();
+ m_writer->emit("}\n");
+ }
+ }
+ }
+ }
+}
+
void CPPSourceEmitter::_emitWitnessTableDefinitions()
{
for (auto witnessTable : pendingWitnessTableDefinitions)
@@ -1612,8 +1707,9 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions()
m_writer->emit(",\n");
else
isFirstEntry = false;
+
m_writer->emit("&KernelContext::");
- m_writer->emit(getName(funcVal));
+ m_writer->emit(_getWitnessTableWrapperFuncName(funcVal));
}
else
{
@@ -1671,7 +1767,13 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition(
m_writer->emit(", ");
else
isFirstParam = false;
- emitParamType(param->getFullType(), getName(param));
+ if (param->findDecoration<IRThisPointerDecoration>())
+ {
+ m_writer->emit("void* ");
+ m_writer->emit(getName(param));
+ continue;
+ }
+ emitSimpleFuncParamImpl(param);
}
m_writer->emit(");\n");
}
@@ -1681,7 +1783,7 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition(
}
}
m_writer->dedent();
- m_writer->emit("\n};\n");
+ m_writer->emit("};\n");
}
bool CPPSourceEmitter::tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType)
@@ -1877,6 +1979,31 @@ void CPPSourceEmitter::emitSimpleValueImpl(IRInst* inst)
}
}
+static bool isVoidPtrType(IRType* type)
+{
+ auto ptrType = as<IRPtrType>(type);
+ if (!ptrType) return false;
+ return ptrType->getValueType()->op == kIROp_VoidType;
+}
+
+void CPPSourceEmitter::emitSimpleFuncParamImpl(IRParam* param)
+{
+ // Polymorphic types are already translated to void* type in
+ // lower-generics pass. However, the current emitting logic will
+ // emit "void&" instead of "void*" for pointer types.
+ // In the future, we will handle pointer types more properly,
+ // and this override logic will not be necessary.
+ // For now we special-case this scenario.
+ if (param->findDecoration<IRPolymorphicDecoration>() &&
+ isVoidPtrType(param->getDataType()))
+ {
+ m_writer->emit("void* ");
+ m_writer->emit(getName(param));
+ return;
+ }
+ CLikeSourceEmitter::emitSimpleFuncParamImpl(param);
+}
+
void CPPSourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount)
{
emitSimpleType(m_typeSet.addVectorType(elementType, int(elementCount)));
@@ -2102,6 +2229,16 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
m_writer->emit(")");
return true;
}
+ case kIROp_getAddr:
+ {
+ // Once we clean up the pointer emitting logic, we can
+ // just use GetElementAddress instruction in place of
+ // getAddr instruction, and this case can be removed.
+ m_writer->emit("(&(");
+ emitInstExpr(inst->getOperand(0), EmitOpInfo::get(EmitOp::General));
+ m_writer->emit("))");
+ return true;
+ }
}
}
@@ -2670,6 +2807,11 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
}
}
+ // Emit wrapper functions for each witness table entry.
+ // These wrapper functions takes an abstract type parameter (void*)
+ // in the place of `this` parameter.
+ _emitWitnessTableWrappers();
+
m_writer->dedent();
m_writer->emit("};\n\n");
}
diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h
index b8afc6a76..47ba03d70 100644
--- a/source/slang/slang-emit-cpp.h
+++ b/source/slang/slang-emit-cpp.h
@@ -71,6 +71,7 @@ protected:
virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE;
virtual void emitPreprocessorDirectivesImpl() SLANG_OVERRIDE;
virtual void emitSimpleValueImpl(IRInst* value) SLANG_OVERRIDE;
+ virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE;
virtual void emitModuleImpl(IRModule* module) SLANG_OVERRIDE;
virtual void emitSimpleFuncImpl(IRFunc* func) SLANG_OVERRIDE;
virtual void emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPrec) SLANG_OVERRIDE;
@@ -117,6 +118,10 @@ protected:
UnownedStringSlice _getFuncName(const HLSLIntrinsic* specOp);
+ // Returns a StringSlice representing the mangled name of a witness table
+ // wrapper function.
+ UnownedStringSlice _getWitnessTableWrapperFuncName(IRFunc* func);
+
UnownedStringSlice _getTypeName(IRType* type);
SlangResult _calcCPPTextureTypeName(IRTextureTypeBase* texType, StringBuilder& outName);
@@ -134,12 +139,20 @@ protected:
// of all the witness table objects in `pendingWitnessTableDefinitions`.
void _emitWitnessTableDefinitions();
+ // Emit wrapper functions that are referenced in witness tables.
+ // Wrapper functions wraps the actual member function, and takes a `void*`
+ // as the `this` parameter instead of the actual object type, so that
+ // their signature is agnostic to the object type.
+ void _emitWitnessTableWrappers();
+
HLSLIntrinsic* _addIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* argTypes, Index argTypeCount);
static bool _isVariable(IROp op);
Dictionary<IRType*, StringSlicePool::Handle> m_typeNameMap;
Dictionary<const HLSLIntrinsic*, StringSlicePool::Handle> m_intrinsicNameMap;
+ Dictionary<IRFunc*, StringSlicePool::Handle> m_witnessTableWrapperFuncNameMap;
+
IRTypeSet m_typeSet;
RefPtr<HLSLIntrinsicOpLookup> m_opLookup;
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f5127d0fa..58ff1a79f 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -236,6 +236,7 @@ INST(FieldAddress, get_field_addr, 2, 0)
INST(getElement, getElement, 2, 0)
INST(getElementPtr, getElementPtr, 2, 0)
+INST(getAddr, getAddr, 1, 0)
// "Subscript" an image at a pixel coordinate to get pointer
INST(ImageSubscript, imageSubscript, 2, 0)
@@ -506,6 +507,16 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(BindExistentialSlotsDecoration, bindExistentialSlots, 0, 0)
+ /// A `[polymorphic]` decoration marks a function parameter that should translate to an abstract type
+ /// e.g. (void*) that are casted to actual type before use. For example, a parameter of generic type
+ /// is marked `[polymorphic]`, so that the code gen logic can emit it as a `void*` parameter,
+ /// allowing the function to be used at sites that are agnostic of the actual object type.
+ INST(PolymorphicDecoration, polymorphic, 0, 0)
+
+ /// A `[this_ptr]` decoration marks a function parameter that serves as `this` pointer.
+ INST(ThisPointerDecoration, this_ptr, 0, 0)
+
+
/// A `[format(f)]` decoration specifies that the format of an image should be `f`
INST(FormatDecoration, format, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 3aab4c323..b13d52981 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -166,6 +166,10 @@ IR_SIMPLE_DECORATION(VulkanCallablePayloadDecoration)
/// to it.
IR_SIMPLE_DECORATION(VulkanHitAttributesDecoration)
+IR_SIMPLE_DECORATION(PolymorphicDecoration)
+IR_SIMPLE_DECORATION(ThisPointerDecoration)
+
+
struct IRRequireGLSLVersionDecoration : IRDecoration
{
enum { kOp = kIROp_RequireGLSLVersionDecoration };
@@ -1145,6 +1149,11 @@ struct IRFieldAddress : IRInst
IRInst* getField() { return field.get(); }
};
+struct IRGetAddress : IRInst
+{
+ IR_LEAF_ISA(getAddr);
+};
+
// Terminators
struct IRReturn : IRTerminatorInst
@@ -1894,6 +1903,10 @@ struct IRBuilder
IRInst* basePtr,
IRInst* index);
+ IRInst* emitGetAddress(
+ IRType* type,
+ IRInst* value);
+
IRInst* emitSwizzle(
IRType* type,
IRInst* base,
@@ -2147,6 +2160,16 @@ struct IRBuilder
addDecoration(value, kIROp_LoopControlDecoration, getIntValue(getIntType(), IRIntegerValue(mode)));
}
+ void addPolymorphicDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_PolymorphicDecoration);
+ }
+
+ void addThisPointerDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_ThisPointerDecoration);
+ }
+
void addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0)
{
addDecoration(value, kIROp_SemanticDecoration, getStringValue(text), getIntValue(getIntType(), index));
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp
index 4378d396f..f6340a633 100644
--- a/source/slang/slang-ir-lower-generics.cpp
+++ b/source/slang/slang-ir-lower-generics.cpp
@@ -79,6 +79,14 @@ namespace Slang
block->addParam(as<IRParam>(param));
}
loweredGenericFunctions[genericValue] = loweredFunc;
+ // Turn generic parameters into void pointers.
+ for (auto param : cast<IRFunc>(loweredFunc)->getParams())
+ {
+ if (param->findDecoration<IRPolymorphicDecoration>())
+ {
+ param->setFullType(builder.getPtrType(builder.getVoidType()));
+ }
+ }
addToWorkList(loweredFunc);
return loweredFunc;
}
@@ -103,8 +111,29 @@ namespace Slang
builder->sharedBuilder = &sharedBuilderStorage;
builder->setInsertBefore(inst);
List<IRInst*> args;
+ auto pp = as<IRFunc>(loweredFunc)->getParams().begin();
+ auto voidPtrType = builder->getPtrType(builder->getVoidType());
for (UInt i = 0; i < callInst->getArgCount(); i++)
- args.add(callInst->getArg(i));
+ {
+ auto arg = callInst->getArg(i);
+ if ((*pp)->getDataType() == voidPtrType &&
+ arg->getDataType() != voidPtrType)
+ {
+ // We are calling a generic function that with an argument of
+ // concrete type. We need to convert this argument o void*.
+
+ // Ideally this should just be a GetElementAddress inst.
+ // However the current code emitting logic for this instruction
+ // doesn't truly respect the pointerness and does not produce
+ // what we needed. For now we use another instruction here
+ // to keep changes minimal.
+ arg = builder->emitGetAddress(
+ voidPtrType,
+ arg);
+ }
+ args.add(arg);
+ ++pp;
+ }
for (UInt i = 0; i < specializeInst->getArgCount(); i++)
args.add(specializeInst->getArg(i));
auto newCall = builder->emitCallInst(callInst->getFullType(), loweredFunc, args);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 34ea23b85..77011b569 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3083,6 +3083,20 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitGetAddress(
+ IRType* type,
+ IRInst* value)
+ {
+ auto inst = createInst<IRGetAddress>(
+ this,
+ kIROp_getAddr,
+ type,
+ value);
+
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitSwizzle(
IRType* type,
IRInst* base,
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 01bd0e972..ea04ea85c 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -6157,6 +6157,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
LoweredValInfo paramVal;
+ IRParam* irParam = nullptr;
+
switch( paramInfo.direction )
{
default:
@@ -6166,15 +6168,15 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
// TODO: Is this the best representation we can use?
- IRParam* irParamPtr = subBuilder->emitParam(irParamType);
+ irParam = subBuilder->emitParam(irParamType);
if(auto paramDecl = paramInfo.decl)
{
- addVarDecorations(context, irParamPtr, paramDecl);
- subBuilder->addHighLevelDeclDecoration(irParamPtr, paramDecl);
+ addVarDecorations(context, irParam, paramDecl);
+ subBuilder->addHighLevelDeclDecoration(irParam, paramDecl);
}
- addParamNameHint(irParamPtr, paramInfo);
+ addParamNameHint(irParam, paramInfo);
- paramVal = LoweredValInfo::ptr(irParamPtr);
+ paramVal = LoweredValInfo::ptr(irParam);
// TODO: We might want to copy the pointed-to value into
// a temporary at the start of the function, and then copy
@@ -6194,7 +6196,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// We start by declaring an IR parameter of the same type.
//
auto paramDecl = paramInfo.decl;
- IRParam* irParam = subBuilder->emitParam(irParamType);
+ irParam = subBuilder->emitParam(irParamType);
if( paramDecl )
{
addVarDecorations(context, irParam, paramDecl);
@@ -6249,6 +6251,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
if (paramInfo.isThisParam)
{
subContext->thisVal = paramVal;
+ subBuilder->addThisPointerDecoration(irParam);
+ }
+
+ // Add a [polymorphic] decoration for generic-typed parameters.
+ if (as<IRParam>(irParamType) &&
+ as<IRTypeType>(irParamType->getFullType()))
+ {
+ subBuilder->addPolymorphicDecoration(irParam);
}
}
diff --git a/tests/compute/dynamic-dispatch-1.slang b/tests/compute/dynamic-dispatch-1.slang
new file mode 100644
index 000000000..9e63ee124
--- /dev/null
+++ b/tests/compute/dynamic-dispatch-1.slang
@@ -0,0 +1,38 @@
+//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code
+
+// Test dynamic dispatch code gen for non-static member functions.
+
+interface IInterface
+{
+ int Compute(int inVal);
+};
+
+int GenericCompute<T:IInterface>(T obj, int inVal)
+{
+ return obj.Compute(inVal);
+}
+
+struct Impl : IInterface
+{
+ int base;
+ int Compute(int inVal) { return base + inVal * inVal; }
+};
+
+int test(int inVal)
+{
+ Impl obj;
+ obj.base = 1;
+ return GenericCompute<Impl>(obj, inVal);
+}
+
+//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer : register(u0);
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = outputBuffer[tid];
+ int outVal = test(inVal);
+ outputBuffer[tid] = outVal;
+}
diff --git a/tests/compute/dynamic-dispatch-1.slang.expected.txt b/tests/compute/dynamic-dispatch-1.slang.expected.txt
new file mode 100644
index 000000000..146ab3c8c
--- /dev/null
+++ b/tests/compute/dynamic-dispatch-1.slang.expected.txt
@@ -0,0 +1,4 @@
+1
+2
+5
+A