summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTim Foley <tfoleyNV@users.noreply.github.com>2017-10-18 11:08:47 -0700
committerGitHub <noreply@github.com>2017-10-18 11:08:47 -0700
commita12480fe49d5ba7c0a9c2ac63363dc76b599ddbd (patch)
tree6344dec2a432f2f5cb1cdd981d0aafe21ea6a443
parentf12c2552b3f494cbc8245edb90b32b93ca8a1539 (diff)
Work on IR-based cross-compilation (#222)
There are two big changes here: - Add logic during the initial IR cloning pass for an entry point + target that tries to pick the best possible version of any target-overloaded function. This allows us to pick the intrinsic version of `saturate()` when compiling for HLSL output, but then pick the non-intrinsic version (that is implemented in terms of `clamp()`) when targetting GLSL. - Add an initial specialization pass that tries to deal with generics. This required some fixing work to IR generation, so that we correctly generate explicit operations to specialize a generic for specific types (this is currently implemented as a `specialize` instruction that takes the generic to specialize plus a declaration-reference that represents the specialized form). With that work in place, we can scan for `specialize` instructions inside of non-generic functions, and use them to trigger generation of specialized code. We rely on the name-mangling scheme to help us find pre-existing specializations when possible. There are also a bunch of cleanups encountered along the way: - Don't use the explicit `layout(offset=...)` for uniforms, because it isn't supported by all current drivers. For now we will just assume that our layout rules compute the same values that the driver would for un-marked-up code. We can come back later and try to implement a workaround in the cases where this doesn't apply (e.g., by re-running the layout logic as part of emission, and dropping layout modifiers from variables that don't need explicit layout). - Fix some issues in IR dump printing so that we print function declarations more nicely. - Testing: print out failing pixel when image-diff fails
-rw-r--r--source/slang/emit.cpp72
-rw-r--r--source/slang/hlsl.meta.slang6
-rw-r--r--source/slang/hlsl.meta.slang.h6
-rw-r--r--source/slang/ir-inst-defs.h2
-rw-r--r--source/slang/ir-insts.h35
-rw-r--r--source/slang/ir.cpp761
-rw-r--r--source/slang/ir.h13
-rw-r--r--source/slang/lower-to-ir.cpp184
-rw-r--r--source/slang/lower.cpp1
-rw-r--r--source/slang/mangle.cpp231
-rw-r--r--source/slang/mangle.h5
-rw-r--r--source/slang/syntax.cpp46
-rw-r--r--source/slang/type-defs.h1
-rw-r--r--tests/ir/loop.slang.expected2
-rw-r--r--tools/slang-test/main.cpp10
15 files changed, 1134 insertions, 241 deletions
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index f9c0b5f09..d8d9e13ed 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -3574,17 +3574,34 @@ struct EmitVisitor
switch(info.kind)
{
case LayoutResourceKind::Uniform:
- // Explicit offsets require a GLSL extension.
- //
- // TODO: We really need to fix this so that we
- // only output an explicit offset for things
- // that are layed out differently than they
- // would normally be...
- requireGLSLExtension("GL_ARB_enhanced_layouts");
+ {
+ // Explicit offsets require a GLSL extension (which
+ // is not universally supported, it seems) or a new
+ // enough GLSL version (which we don't want to
+ // universall require), so for right now we
+ // won't actually output explicit offsets for uniform
+ // shader parameters.
+ //
+ // TODO: We should fix this so that we skip any
+ // extra work for parameters that are laid out as
+ // expected by the default rules, but do *something*
+ // for parameters that need non-default layout.
+ //
+ // Using the `GL_ARB_enhanced_layouts` feature is one
+ // option, but we should also be able to do some
+ // things by introducing padding into the declaration
+ // (padding insertion would probably be best done at
+ // the IR level).
+ bool useExplicitOffsets = false;
+ if (useExplicitOffsets)
+ {
+ requireGLSLExtension("GL_ARB_enhanced_layouts");
- Emit("layout(offset = ");
- Emit(info.index);
- Emit(")\n");
+ Emit("layout(offset = ");
+ Emit(info.index);
+ Emit(")\n");
+ }
+ }
break;
case LayoutResourceKind::VertexInput:
@@ -4073,7 +4090,11 @@ emitDeclImpl(decl, nullptr);
{
case kIROp_global_var:
case kIROp_Func:
- return ((IRGlobalValue*)inst)->mangledName;
+ {
+ auto& mangledName = ((IRGlobalValue*)inst)->mangledName;
+ if(mangledName.Length() != 0)
+ return mangledName;
+ }
break;
default:
@@ -4396,6 +4417,7 @@ emitDeclImpl(decl, nullptr);
case kIROp_boolConst:
case kIROp_FieldAddress:
case kIROp_getElementPtr:
+ case kIROp_specialize:
return true;
}
@@ -4937,6 +4959,12 @@ emitDeclImpl(decl, nullptr);
}
break;
+ case kIROp_specialize:
+ {
+ emitIROperand(context, inst->getArg(0));
+ }
+ break;
+
default:
emit("/* unhandled */");
break;
@@ -5579,6 +5607,11 @@ emitDeclImpl(decl, nullptr);
if(!value)
return nullptr;
+ if(value->op == kIROp_specialize)
+ {
+ value = ((IRSpecialize*) value)->genericVal.usedValue;
+ }
+
if(value->op != kIROp_Func)
return nullptr;
@@ -5608,6 +5641,14 @@ emitDeclImpl(decl, nullptr);
}
else
#endif
+ if(func->genericDecl)
+ {
+ Emit("/* ");
+ emitIRFuncDecl(context, func);
+ Emit(" */");
+ return;
+ }
+
if(!isDefinition(func))
{
// This is just a function declaration,
@@ -6339,6 +6380,13 @@ String emitEntryPoint(
// TODO: we should apply some guaranteed transformations here,
// to eliminate constructs that aren't legal downstream (e.g. generics).
+
+ specializeGenerics(lowered);
+
+// fprintf(stderr, "###\n");
+// dumpIR(lowered);
+// fprintf(stderr, "###\n");
+
//
// TODO: Need to decide whether to do these before or after
// target-specific legalization steps. Currently I've folded
@@ -6348,6 +6396,8 @@ String emitEntryPoint(
// IR back into AST for emission?
visitor.emitIRModule(&context, lowered);
+
+ // TODO: need to clean up the IR module here
}
else if(!(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING ) ||
translationUnit->compileRequest->loadedModulesList.Count() != 0)
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 81e9931e8..dc1d4d8e8 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -326,9 +326,9 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic_
__intrinsic_op bool CheckAccessFullyMapped(uint status);
// Clamp (HLSL SM 1.0)
-__generic<T : __BuiltinArithmeticType> __intrinsic_op T clamp(T x, T min, T max);
-__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic_op vector<T,N> clamp(vector<T,N> x, vector<T,N> min, vector<T,N> max);
-__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic_op matrix<T,N,M> clamp(matrix<T,N,M> x, matrix<T,N,M> min, matrix<T,N,M> max);
+__generic<T : __BuiltinArithmeticType> T clamp(T x, T min, T max);
+__generic<T : __BuiltinArithmeticType, let N : int> vector<T,N> clamp(vector<T,N> x, vector<T,N> min, vector<T,N> max);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> matrix<T,N,M> clamp(matrix<T,N,M> x, matrix<T,N,M> min, matrix<T,N,M> max);
// Clip (discard) fragment conditionally
__generic<T : __BuiltinFloatingPointType> __intrinsic_op void clip(T x);
diff --git a/source/slang/hlsl.meta.slang.h b/source/slang/hlsl.meta.slang.h
index eccb12f8d..dfbdbe57b 100644
--- a/source/slang/hlsl.meta.slang.h
+++ b/source/slang/hlsl.meta.slang.h
@@ -328,9 +328,9 @@ sb << "// Check access status to tiled resource\n";
sb << "__intrinsic_op bool CheckAccessFullyMapped(uint status);\n";
sb << "\n";
sb << "// Clamp (HLSL SM 1.0)\n";
-sb << "__generic<T : __BuiltinArithmeticType> __intrinsic_op T clamp(T x, T min, T max);\n";
-sb << "__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic_op vector<T,N> clamp(vector<T,N> x, vector<T,N> min, vector<T,N> max);\n";
-sb << "__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic_op matrix<T,N,M> clamp(matrix<T,N,M> x, matrix<T,N,M> min, matrix<T,N,M> max);\n";
+sb << "__generic<T : __BuiltinArithmeticType> T clamp(T x, T min, T max);\n";
+sb << "__generic<T : __BuiltinArithmeticType, let N : int> vector<T,N> clamp(vector<T,N> x, vector<T,N> min, vector<T,N> max);\n";
+sb << "__generic<T : __BuiltinArithmeticType, let N : int, let M : int> matrix<T,N,M> clamp(matrix<T,N,M> x, matrix<T,N,M> min, matrix<T,N,M> max);\n";
sb << "\n";
sb << "// Clip (discard) fragment conditionally\n";
sb << "__generic<T : __BuiltinFloatingPointType> __intrinsic_op void clip(T x);\n";
diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h
index c11d66571..636eeec16 100644
--- a/source/slang/ir-inst-defs.h
+++ b/source/slang/ir-inst-defs.h
@@ -96,6 +96,8 @@ INST(IntLit, integer_constant, 0, 0)
INST(FloatLit, float_constant, 0, 0)
INST(decl_ref, decl_ref, 0, 0)
+INST(specialize, specialize, 2, 0)
+
INST(Construct, construct, 0, 0)
INST(Call, call, 1, 0)
diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h
index 9ac79413f..50577b2a3 100644
--- a/source/slang/ir-insts.h
+++ b/source/slang/ir-insts.h
@@ -47,13 +47,31 @@ struct IRLoopControlDecoration : IRDecoration
IRLoopControl mode;
};
+struct IRTargetDecoration : IRDecoration
+{
+ enum { kDecorationOp = kIRDecorationOp_Target };
+
+ // TODO: have a more structured representation of target specifiers
+ String targetName;
+};
+
//
// An IR node to represent a reference to an AST-level
// declaration.
struct IRDeclRef : IRValue
{
- DeclRefBase declRef;
+ DeclRef<Decl> declRef;
+};
+
+// An instruction that specializes another IR value
+// (representing a generic) to a particular set of
+// generic arguments (encoded via an `IRDeclRef`)
+//
+struct IRSpecialize : IRInst
+{
+ IRUse genericVal;
+ IRUse specDeclRefVal;
};
//
@@ -304,6 +322,16 @@ struct IRBuilder
IRValue* getDeclRefVal(
DeclRefBase const& declRef);
+ IRValue* emitSpecializeInst(
+ IRType* type,
+ IRValue* genericVal,
+ IRValue* specDeclRef);
+
+ IRValue* emitSpecializeInst(
+ IRType* type,
+ IRValue* genericVal,
+ DeclRef<Decl> specDeclRef);
+
IRInst* emitCallInst(
IRType* type,
IRValue* func,
@@ -452,12 +480,15 @@ struct IRBuilder
// Generate a clone of an IR module that is specialized for
// a particular entry point, target, etc.
-
IRModule* specializeIRForEntryPoint(
EntryPointRequest* entryPointRequest,
ProgramLayout* programLayout,
CodeGenTarget target);
+// Find suitable uses of the `specialize` instruction that
+// can be replaced with references to specialized functions.
+void specializeGenerics(
+ IRModule* module);
}
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 31a35cd08..fb6013bc3 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -531,10 +531,41 @@ namespace Slang
this,
kIROp_decl_ref,
nullptr);
- irValue->declRef = declRef;
+ irValue->declRef = DeclRef<Decl>(declRef.decl, declRef.substitutions);
return irValue;
}
+ IRValue* IRBuilder::emitSpecializeInst(
+ Type* type,
+ IRValue* genericVal,
+ IRValue* specDeclRef)
+ {
+ auto inst = createInst<IRSpecialize>(
+ this,
+ kIROp_specialize,
+ type,
+ genericVal,
+ specDeclRef);
+ addInst(inst);
+ return inst;
+ }
+
+ IRValue* IRBuilder::emitSpecializeInst(
+ Type* type,
+ IRValue* genericVal,
+ DeclRef<Decl> specDeclRef)
+ {
+ auto specDeclRefVal = getDeclRefVal(specDeclRef);
+ auto inst = createInst<IRSpecialize>(
+ this,
+ kIROp_specialize,
+ type,
+ genericVal,
+ specDeclRefVal);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitCallInst(
IRType* type,
IRValue* func,
@@ -1352,15 +1383,26 @@ namespace Slang
auto parentDeclRef = declRef.GetParent();
auto genericParentDeclRef = parentDeclRef.As<GenericDecl>();
- if(genericParentDeclRef)
+ if (genericParentDeclRef)
{
- parentDeclRef = genericParentDeclRef.GetParent();
+ if (genericParentDeclRef.getDecl()->inner.Ptr() == decl)
+ {
+ parentDeclRef = genericParentDeclRef.GetParent();
+ }
+ else
+ {
+ genericParentDeclRef = DeclRef<GenericDecl>();
+ }
}
if(parentDeclRef.As<ModuleDecl>())
{
parentDeclRef = DeclRef<ContainerDecl>();
}
+ else if(parentDeclRef.As<GenericDecl>())
+ {
+ parentDeclRef = DeclRef<ContainerDecl>();
+ }
if(parentDeclRef)
{
@@ -1709,15 +1751,97 @@ namespace Slang
dump(context, "\n");
}
+ void dumpGenericSignature(
+ IRDumpContext* context,
+ GenericDecl* genericDecl)
+ {
+ for( auto pp = genericDecl->ParentDecl; pp; pp = pp->ParentDecl )
+ {
+ if( auto genericAncestor = dynamic_cast<GenericDecl*>(pp) )
+ {
+ dumpGenericSignature(context, genericAncestor);
+ break;
+ }
+ }
+
+ dump(context, " <");
+ bool first = true;
+ for (auto mm : genericDecl->Members)
+ {
+ if (!first) dump(context, ", ");
+
+ if( auto typeParamDecl = mm.As<GenericTypeParamDecl>() )
+ {
+ dumpDeclRef(context, makeDeclRef(typeParamDecl.Ptr()));
+ first = false;
+ }
+ else if( auto valueParamDecl = mm.As<GenericTypeParamDecl>() )
+ {
+ dumpDeclRef(context, makeDeclRef(valueParamDecl.Ptr()));
+ first = false;
+ }
+ }
+ first = true;
+ for (auto mm : genericDecl->Members)
+ {
+ if (!first) dump(context, ", ");
+ else dump(context, " where ");
+
+ if( auto constraintDecl = mm.As<GenericTypeConstraintDecl>() )
+ {
+ dumpType(context, constraintDecl->sub);
+ dump(context, " : ");
+ dumpType(context, constraintDecl->sup);
+ first = false;
+ }
+ }
+ dump(context, ">");
+ }
+
void dumpIRFunc(
IRDumpContext* context,
IRFunc* func)
{
+
+ for( auto dd = func->firstDecoration; dd; dd = dd->next )
+ {
+ switch( dd->op )
+ {
+ case kIRDecorationOp_Target:
+ {
+ auto decoration = (IRTargetDecoration*) dd;
+
+ dump(context, "\n");
+ dumpIndent(context);
+ dump(context, "[target(");
+ dump(context, decoration->targetName.Buffer());
+ dump(context, ")]");
+ }
+ break;
+
+ }
+ }
+
dump(context, "\n");
dumpIndent(context);
dump(context, "ir_func ");
dumpID(context, func);
+
+ if (func->genericDecl)
+ {
+ dump(context, " ");
+ dumpGenericSignature(context, func->genericDecl);
+ }
+
dumpInstTypeClause(context, func->getType());
+
+ if (!func->getFirstBlock())
+ {
+ // Just a declaration.
+ dump(context, ";\n");
+ return;
+ }
+
dump(context, "\n");
dumpIndent(context);
@@ -1941,11 +2065,30 @@ namespace Slang
parentBlock = nullptr;
}
+ void IRInst::removeArguments()
+ {
+ UInt argCount = this->argCount;
+ for( UInt aa = 0; aa < argCount; ++aa )
+ {
+ IRUse& use = getArgs()[aa];
+
+ if(!use.usedValue)
+ continue;
+
+ // Need to unlink this use from the appropriate linked list.
+ use.usedValue = nullptr;
+ *use.prevLink = use.nextUse;
+ use.prevLink = nullptr;
+ use.nextUse = nullptr;
+ }
+ }
+
// Remove this instruction from its parent block,
// and then destroy it (it had better have no uses!)
void IRInst::removeAndDeallocate()
{
removeFromParent();
+ removeArguments();
deallocate();
}
@@ -2211,6 +2354,7 @@ namespace Slang
// because it is no longer accurate.
auto voidFuncType = new FuncType();
+ voidFuncType->setSession(session);
voidFuncType->resultType = session->getVoidType();
func->type = voidFuncType;
@@ -2233,7 +2377,7 @@ namespace Slang
RefPtr<IRSpecSymbol> nextWithSameName;
};
- struct IRSpecContext
+ struct IRSharedSpecContext
{
// The specialized module we are building
IRModule* module;
@@ -2241,33 +2385,63 @@ namespace Slang
// The original, unspecialized module we are copying
IRModule* originalModule;
- // The IR builder to use for creating nodes
- IRBuilder* builder;
-
// A map from mangled symbol names to zero or
// more global IR values that have that name,
// in the *original* module.
- Dictionary<String, RefPtr<IRSpecSymbol>> symbols;
-
- // A map from the mangled name of a global variable
- // to the layout to use for it.
- Dictionary<String, VarLayout*> globalVarLayouts;
+ typedef Dictionary<String, RefPtr<IRSpecSymbol>> SymbolDictionary;
+ SymbolDictionary symbols;
// A map from values in the original IR module
// to their equivalent in the cloned module.
- Dictionary<IRValue*, IRValue*> clonedValues;
+ typedef Dictionary<IRValue*, IRValue*> ClonedValueDictionary;
+ ClonedValueDictionary clonedValues;
+
+ SharedIRBuilder sharedBuilderStorage;
+ IRBuilder builderStorage;
+ };
+
+ struct IRSpecContextBase
+ {
+ IRSharedSpecContext* shared;
+
+ IRSharedSpecContext* getShared() { return shared; }
+
+ IRModule* getModule() { return getShared()->module; }
+
+ IRModule* getOriginalModule() { return getShared()->originalModule; }
+
+ IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; }
+
+ IRSharedSpecContext::ClonedValueDictionary& getClonedValues() { return getShared()->clonedValues; }
+
+ // The IR builder to use for creating nodes
+ IRBuilder* builder;
+
+ // A callback to be used when a value that is not registerd in `clonedValues`
+ // is needed during cloning. This gives the subtype a chance to intercept
+ // the operation and clone (or not) as needed.
+ virtual IRValue* maybeCloneValue(IRValue* originalVal)
+ {
+ return originalVal;
+ }
+
+ // A callback used to clone (or not) types.
+ virtual RefPtr<Type> maybeCloneType(Type* originalType)
+ {
+ return originalType;
+ }
};
void registerClonedValue(
- IRSpecContext* context,
+ IRSpecContextBase* context,
IRValue* clonedValue,
IRValue* originalValue)
{
- context->clonedValues.Add(originalValue, clonedValue);
+ context->getClonedValues().Add(originalValue, clonedValue);
}
void cloneDecorations(
- IRSpecContext* context,
+ IRSpecContextBase* context,
IRValue* clonedValue,
IRValue* originalValue)
{
@@ -2292,31 +2466,39 @@ namespace Slang
// TODO: implement this
}
+ struct IRSpecContext : IRSpecContextBase
+ {
+ // The code-generation target in use
+ CodeGenTarget target;
+
+ // A map from the mangled name of a global variable
+ // to the layout to use for it.
+ Dictionary<String, VarLayout*> globalVarLayouts;
+
+ // Override the "maybe clone" logic so that we always clone
+ virtual IRValue* maybeCloneValue(IRValue* originalVal) override;
+ };
+
+
IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar);
IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc);
- IRValue* cloneValue(
- IRSpecContext* context,
- IRValue* originalValue)
+ IRValue* IRSpecContext::maybeCloneValue(IRValue* originalValue)
{
- IRValue* clonedValue = nullptr;
- if (context->clonedValues.TryGetValue(originalValue, clonedValue))
- return clonedValue;
-
switch (originalValue->op)
{
case kIROp_global_var:
- return cloneGlobalVar(context, (IRGlobalVar*)originalValue);
+ return cloneGlobalVar(this, (IRGlobalVar*)originalValue);
break;
case kIROp_Func:
- return cloneFunc(context, (IRFunc*)originalValue);
+ return cloneFunc(this, (IRFunc*)originalValue);
break;
case kIROp_boolConst:
{
IRConstant* c = (IRConstant*)originalValue;
- return context->builder->getBoolValue(c->u.intVal != 0);
+ return builder->getBoolValue(c->u.intVal != 0);
}
break;
@@ -2324,21 +2506,21 @@ namespace Slang
case kIROp_IntLit:
{
IRConstant* c = (IRConstant*)originalValue;
- return context->builder->getIntValue(c->type, c->u.intVal);
+ return builder->getIntValue(c->type, c->u.intVal);
}
break;
case kIROp_FloatLit:
{
IRConstant* c = (IRConstant*)originalValue;
- return context->builder->getFloatValue(c->type, c->u.floatVal);
+ return builder->getFloatValue(c->type, c->u.floatVal);
}
break;
case kIROp_decl_ref:
{
IRDeclRef* od = (IRDeclRef*)originalValue;
- return context->builder->getDeclRefVal(od->declRef);
+ return builder->getDeclRefVal(od->declRef);
}
break;
@@ -2348,8 +2530,19 @@ namespace Slang
}
}
+ IRValue* cloneValue(
+ IRSpecContextBase* context,
+ IRValue* originalValue)
+ {
+ IRValue* clonedValue = nullptr;
+ if (context->getClonedValues().TryGetValue(originalValue, clonedValue))
+ return clonedValue;
+
+ return context->maybeCloneValue(originalValue);
+ }
+
void cloneInst(
- IRSpecContext* context,
+ IRSpecContextBase* context,
IRBuilder* builder,
IRInst* originalInst)
{
@@ -2366,7 +2559,8 @@ namespace Slang
// it, and then add it to the sequence.
UInt argCount = originalInst->getArgCount();
IRInst* clonedInst = createInstWithTrailingArgs<IRInst>(
- builder, originalInst->op, originalInst->type,
+ builder, originalInst->op,
+ context->maybeCloneType(originalInst->type),
0, nullptr,
argCount, nullptr);
builder->addInst(clonedInst);
@@ -2410,14 +2604,14 @@ namespace Slang
}
void cloneFunctionCommon(
- IRSpecContext* context,
+ IRSpecContextBase* context,
IRFunc* clonedFunc,
IRFunc* originalFunc)
{
// First clone all the simple properties.
clonedFunc->mangledName = originalFunc->mangledName;
- clonedFunc->genericParams = originalFunc->genericParams;
- clonedFunc->type = originalFunc->type;
+ clonedFunc->genericDecl = originalFunc->genericDecl;
+ clonedFunc->type = context->maybeCloneType(originalFunc->type);
cloneDecorations(context, clonedFunc, originalFunc);
@@ -2445,7 +2639,9 @@ namespace Slang
originalParam;
originalParam = originalParam->getNextParam())
{
- IRParam* clonedParam = builder->emitParam(originalParam->getType());
+ IRParam* clonedParam = builder->emitParam(
+ context->maybeCloneType(
+ originalParam->getType()));
registerClonedValue(context, clonedParam, originalParam);
}
}
@@ -2475,7 +2671,7 @@ namespace Slang
//
// TODO: This isn't really a good requirement to place on the IR...
clonedFunc->removeFromParent();
- clonedFunc->insertAtEnd(context->module);
+ clonedFunc->insertAtEnd(context->getModule());
}
IRFunc* specializeIRForEntryPoint(
@@ -2486,7 +2682,7 @@ namespace Slang
// Look up the IR symbol by name
String mangledName = getMangledName(entryPointRequest->decl);
RefPtr<IRSpecSymbol> sym;
- if (!context->symbols.TryGetValue(mangledName, sym))
+ if (!context->getSymbols().TryGetValue(mangledName, sym))
{
SLANG_UNEXPECTED("no matching IR symbol");
return nullptr;
@@ -2534,23 +2730,224 @@ namespace Slang
return clonedFunc;
}
- // The case for functions that are not the entry point is
- // strictly simpler, so that is nice.
- IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc)
+ IRFunc* cloneSimpleFunc(IRSpecContextBase* context, IRFunc* originalFunc)
{
- // TODO: We really need to scan through all the various
- // global function symbols that have the same mangled name,
- // and pick the correct one to lower for the target.
-
auto clonedFunc = context->builder->createFunc();
registerClonedValue(context, clonedFunc, originalFunc);
cloneFunctionCommon(context, clonedFunc, originalFunc);
return clonedFunc;
}
+ // Get a string form of the target so that we can
+ // use it to match against target-specialization modifiers
+ //
+ // TODO: We shouldn't be using strings for this.
+ String getTargetName(IRSpecContext* context)
+ {
+ switch( context->target )
+ {
+ case CodeGenTarget::HLSL:
+ return "hlsl";
+
+ case CodeGenTarget::GLSL:
+ return "glsl";
+
+ default:
+ SLANG_UNEXPECTED("unhandled case");
+ return "unknown";
+ }
+ }
+
+ // How specialized is a given declaration for the chosen target?
+ enum class TargetSpecializationLevel
+ {
+ specializedForOtherTarget = 0,
+ notSpecialized,
+ specializedForTarget,
+ };
+
+ TargetSpecializationLevel getTargetSpecialiationLevel(
+ IRGlobalValue* val,
+ String const& targetName)
+ {
+ TargetSpecializationLevel result = TargetSpecializationLevel::notSpecialized;
+ for( auto dd = val->firstDecoration; dd; dd = dd->next )
+ {
+ if(dd->op != kIRDecorationOp_Target)
+ continue;
+
+ auto decoration = (IRTargetDecoration*) dd;
+ if(decoration->targetName == targetName)
+ return TargetSpecializationLevel::specializedForTarget;
+
+ result = TargetSpecializationLevel::specializedForOtherTarget;
+ }
+
+ return result;
+ }
+
+ // Is `newVal` marked as being a better match for our
+ // chosen code-generation target?
+ //
+ // TODO: there is a missing step here where we need
+ // to check if things are even available in the first place...
+ bool isBetterForTarget(
+ IRSpecContext* context,
+ IRGlobalValue* newVal,
+ IRGlobalValue* oldVal)
+ {
+ String targetName = getTargetName(context);
+
+ // For right now every declaration might have zero or more
+ // modifiers, representing the targets for which it is specialized.
+ // Each modifier has a single string "tag" to represent a target.
+ // We thus decide that a declaration is "more specialized" by:
+ //
+ // - Does it have a modifier with a tag with the string for the current target?
+ // If yes, it is the most specialized it can be.
+ //
+ // - Does it have a no tags? Then it is "unspecialized" and that is okay.
+ //
+ // - Does it have a modifier with a tag for a *different* target?
+ // If yes, then it shouldn't even be usable on this target.
+ //
+ // Longer term a better approach is to think of this in terms
+ // of a "disjunction of conjunctions" that is:
+ //
+ // (A and B and C) or (A and D) or (E) or (F and G) ...
+ //
+ // A code generation target would then consist of a
+ // conjunction of invidual tags:
+ //
+ // (HLSL and SM_4_0 and Vertex and ...)
+ //
+ // A declaration is *applicable* on a target if one of
+ // its conjunctions of tags is a subset of the target's.
+ //
+ // One declaration is *better* than another on a target
+ // if it is applicable and its tags are a superset
+ // of the other's.
+
+ auto newLevel = getTargetSpecialiationLevel(newVal, targetName);
+ auto oldLevel = getTargetSpecialiationLevel(oldVal, targetName);
+ return UInt(newLevel) > UInt(oldLevel);
+ }
+
+ IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc)
+ {
+ // We are being asked to clone a particular function, but in
+ // the IR that comes out of the front-end there could still
+ // be multiple, target-specific, declarations of any given
+ // function, all of which share the same mangled name.
+ auto mangledName = originalFunc->mangledName;
+
+ if(mangledName.Length() == 0)
+ {
+ return cloneSimpleFunc(context, originalFunc);
+ }
+
+ //
+ // We will scan through all of the available function declarations
+ // with the same mangled name as `originalFunc` and try
+ // to pick the "best" one for our target.
+
+ RefPtr<IRSpecSymbol> sym;
+ if( !context->getSymbols().TryGetValue(originalFunc->mangledName, sym) )
+ {
+ // This shouldn't happen!
+ SLANG_UNEXPECTED("no matching function registered");
+ return cloneSimpleFunc(context, originalFunc);
+ }
+
+ // We will try to track the "best" definition we can find.
+ IRFunc* bestFunc = (IRFunc*) sym->irGlobalValue;
+
+ for( auto ss = sym->nextWithSameName; ss; ss = ss->nextWithSameName )
+ {
+ IRFunc* newFunc = (IRFunc*) ss->irGlobalValue;
+ if(isBetterForTarget(context, newFunc, bestFunc))
+ bestFunc = newFunc;
+ }
+
+ // All right, we are now in a position to clone the "best"
+ // definition that was found.
+ auto clonedFunc = context->builder->createFunc();
+
+ // The resulting function will be used as the cloned version
+ // of every declaration/definition in the original IR.
+ for( auto ss = sym; ss; ss = ss->nextWithSameName )
+ {
+ registerClonedValue(context, clonedFunc, ss->irGlobalValue);
+ }
+
+ // Clone the "best" definition into our context
+ cloneFunctionCommon(context, clonedFunc, bestFunc);
+
+ return clonedFunc;
+ }
+
StructTypeLayout* getGlobalStructLayout(
ProgramLayout* programLayout);
+ void insertGlobalValueSymbol(
+ IRSharedSpecContext* sharedContext,
+ IRGlobalValue* gv)
+ {
+ String mangledName = gv->mangledName;
+
+ // Don't try to register a symbol for global values
+ // with no mangled name, since these represent symbols
+ // that shouldn't get "linkage"
+ if (mangledName == "")
+ return;
+
+ RefPtr<IRSpecSymbol> sym = new IRSpecSymbol();
+ sym->irGlobalValue = gv;
+
+ RefPtr<IRSpecSymbol> prev;
+ if (sharedContext->symbols.TryGetValue(mangledName, prev))
+ {
+ sym->nextWithSameName = prev->nextWithSameName;
+ prev->nextWithSameName = sym;
+ }
+ else
+ {
+ sharedContext->symbols.Add(mangledName, sym);
+ }
+ }
+
+ void initializeSharedSpecContext(
+ IRSharedSpecContext* sharedContext,
+ Session* session,
+ IRModule* module,
+ IRModule* originalModule)
+ {
+
+ SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage;
+ sharedBuilder->module = nullptr;
+ sharedBuilder->session = session;
+
+ IRBuilder* builder = &sharedContext->builderStorage;
+ builder->shared = sharedBuilder;
+
+ if( !module )
+ {
+ module = builder->createModule();
+ sharedBuilder->module = module;
+ }
+
+ sharedContext->module = module;
+ sharedContext->originalModule = originalModule;
+
+ // First, we will populate a map with all of the IR values
+ // that use the same mangled name, to make lookup easier
+ // in other steps.
+ for (auto gv = originalModule->firstGlobalValue; gv; gv = gv->nextGlobalValue)
+ {
+ insertGlobalValueSymbol(sharedContext, gv);
+ }
+ }
+
IRModule* specializeIRForEntryPoint(
EntryPointRequest* entryPointRequest,
ProgramLayout* programLayout,
@@ -2580,52 +2977,24 @@ namespace Slang
// we need to pick the "best" one for the chosen code generation target.
//
- SharedIRBuilder sharedBuilderStorage;
- SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
- sharedBuilder->module = nullptr;
- sharedBuilder->session = compileRequest->mSession;
-
- IRBuilder builderStorage;
- IRBuilder* builder = &builderStorage;
- builder->shared = sharedBuilder;
-
- IRModule* module = builder->createModule();
- sharedBuilder->module = module;
+ IRSharedSpecContext sharedContextStorage;
- //
+ initializeSharedSpecContext(
+ &sharedContextStorage,
+ compileRequest->mSession,
+ nullptr,
+ originalIRModule);
IRSpecContext contextStorage;
IRSpecContext* context = &contextStorage;
+ context->shared = &sharedContextStorage;
+ context->builder = &sharedContextStorage.builderStorage;
+ context->target = target;
- context->builder = builder;
- context->module = module;
- context->originalModule = originalIRModule;
-
- // First, we will populate a map with all of the IR values
- // that use the same mangled name, to make lookup easier
- // in other steps.
- for (auto gv = originalIRModule->firstGlobalValue; gv; gv = gv->nextGlobalValue)
- {
- String mangledName = gv->mangledName;
- if (mangledName == "")
- continue;
-
- RefPtr<IRSpecSymbol> sym = new IRSpecSymbol();
- sym->irGlobalValue = gv;
- RefPtr<IRSpecSymbol> prev;
- if (context->symbols.TryGetValue(mangledName, prev))
- {
- sym->nextWithSameName = prev->nextWithSameName;
- prev->nextWithSameName = sym;
- }
- else
- {
- context->symbols.Add(mangledName, sym);
- }
- }
-
- // Next, we want to optimize lookup over
+ // Next, we want to optimize lookup for layout infromation
+ // associated with global declarations, so that we can
+ // look things up based on the IR values (using mangled names)
auto globalStructLayout = getGlobalStructLayout(programLayout);
for (auto globalVarLayout : globalStructLayout->fields)
{
@@ -2659,8 +3028,230 @@ namespace Slang
break;
}
- return module;
+ return sharedContextStorage.module;
+ }
+
+ //
+
+ struct IRSharedGenericSpecContext : IRSharedSpecContext
+ {
+ // Non-generic functions to be processed
+ List<IRFunc*> workList;
+ };
+
+ struct IRGenericSpecContext : IRSpecContextBase
+ {
+ IRSharedGenericSpecContext* getShared() { return (IRSharedGenericSpecContext*) shared; }
+
+ // The substutions to apply
+ RefPtr<Substitutions> subst;
+
+ // Override the "maybe clone" logic so that we always clone
+ virtual IRValue* maybeCloneValue(IRValue* originalVal) override;
+
+ virtual RefPtr<Type> maybeCloneType(Type* originalType) override;
+ };
+
+ IRValue* IRGenericSpecContext::maybeCloneValue(IRValue* originalVal)
+ {
+ switch( originalVal->op )
+ {
+ case kIROp_decl_ref:
+ {
+ auto declRefVal = (IRDeclRef*) originalVal;
+ int diff = 0;
+ auto substDeclRef = declRefVal->declRef.SubstituteImpl(subst, &diff);
+ if(!diff)
+ return originalVal;
+
+ return builder->getDeclRefVal(substDeclRef);
+ }
+ break;
+
+ default:
+ return originalVal;
+ }
}
+ RefPtr<Type> IRGenericSpecContext::maybeCloneType(Type* originalType)
+ {
+ return originalType->Substitute(subst).As<Type>();
+ }
+
+
+ IRFunc* getSpecializedFunc(
+ IRSharedGenericSpecContext* sharedContext,
+ IRFunc* genericFunc,
+ DeclRef<Decl> specDeclRef)
+ {
+ // First, we want to see if an existing specialization
+ // has already been made. To do that we will need to
+ // compute the mangled name of the specialized function,
+ // so that we can look for existing declarations.
+
+ String specMangledName = getMangledName(specDeclRef);
+
+ // TODO: This is a terrible linear search, and we should
+ // avoid it by building a dictionary ahead of time,
+ // as is being done for the `IRSpecContext` used above.
+ // We can probalby use the same basic context, actually.
+ auto module = genericFunc->parentModule;
+ for(auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue())
+ {
+ if(gv->mangledName == specMangledName)
+ return (IRFunc*) gv;
+ }
+
+ // If we get to this point, then we need to construct a
+ // new `IRFunc` to represent the result of specialization.
+
+ // The substitutions we are applying might have been created
+ // using a different overload of a target-specific function,
+ // so we need to create a dummy substitution here, to make
+ // sure it used the correct generic.
+ RefPtr<Substitutions> newSubst = new Substitutions();
+ newSubst->genericDecl = genericFunc->genericDecl;
+ newSubst->args = specDeclRef.substitutions->args;
+
+ IRGenericSpecContext context;
+ context.shared = sharedContext;
+ context.builder = &sharedContext->builderStorage;
+ context.subst = newSubst;
+
+ // TODO: other initialization is needed here...
+
+ auto specFunc = cloneSimpleFunc(&context, genericFunc);
+
+ // Set up the clone to recognize that it is no longer generic
+ specFunc->mangledName = specMangledName;
+ specFunc->genericDecl = nullptr;
+
+ // Put the function into the global sequence right after
+ // the function it specializes.
+ //
+ // TODO: This shouldn't be needed, if we introduce a sorting
+ // step before we emit code.
+ specFunc->removeFromParent();
+ specFunc->insertAfter(genericFunc);
+
+ // At this point we've created a new non-generic function,
+ // which means we should add it to our work list for
+ // subsequent processing.
+ sharedContext->workList.Add(specFunc);
+
+ // We also need to make sure that we register this specialized
+ // function under its mangled name, so that later lookup
+ // steps will find it.
+ insertGlobalValueSymbol(sharedContext, specFunc);
+
+ return specFunc;
+ }
+
+ void specializeGenerics(
+ IRModule* module)
+ {
+ IRSharedGenericSpecContext sharedContextStorage;
+ auto sharedContext = &sharedContextStorage;
+
+ initializeSharedSpecContext(
+ sharedContext,
+ module->session,
+ module,
+ module);
+
+ // Our goal here is to find `specialize` instructions that
+ // can be replaced with references to a suitably sepcialized
+ // funciton. As a simplification, we will only consider `specialize`
+ // calls that are inside of non-generic functions, since we assume
+ // that these will allow us to fully specialize the referenced
+ // function.
+ //
+ // We start by building up a work list of non-generic functions.
+ for( auto gv = module->getFirstGlobalValue();
+ gv;
+ gv = gv->getNextValue() )
+ {
+ // Is it a function? If not, skip.
+ if(gv->op != kIROp_Func)
+ continue;
+ auto func = (IRFunc*) gv;
+
+ // Is it generic? If so, skip.
+ if(func->genericDecl)
+ continue;
+
+ sharedContext->workList.Add(func);
+ }
+
+ // Now that we have our work list, we are going to
+ // process it until it goes empty. Along the way
+ // we may specialize a function and thus create
+ // a new non-generic function, and in that case
+ // we will add the new function to the work list.
+ auto& workList = sharedContext->workList;
+ while( auto count = workList.Count() )
+ {
+ // We will process the last entry in the
+ // work list, which amounts to treating
+ // it like a stack when we have recursive
+ // specialization to perform.
+ auto func = workList[count-1];
+ workList.RemoveAt(count-1);
+
+ // We are going to go ahead and walk through
+ // all the instructions in this function,
+ // and look for `specialize` operations.
+ for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() )
+ {
+ // We need to be careful when iterating over the instructions,
+ // because we might end up removing the "current" instruction,
+ // so that accessing `ii->next` would crash.
+ IRInst* nextInst = nullptr;
+ for( auto ii = bb->getFirstInst(); ii; ii = nextInst )
+ {
+ nextInst = ii->nextInst;
+
+ // We only care about `specialize` instructions.
+ if(ii->op != kIROp_specialize)
+ continue;
+
+ IRSpecialize* specInst = (IRSpecialize*) ii;
+
+ // We need to check that the value being specialized is
+ // a generic function.
+ auto genericVal = specInst->genericVal.usedValue;
+ if(genericVal->op != kIROp_Func)
+ continue;
+ auto genericFunc = (IRFunc*) genericVal;
+ if(!genericFunc->genericDecl)
+ continue;
+
+ // Now we extract the specialized decl-ref that will
+ // tell us how to specialize things.
+ auto specDeclRefVal = (IRDeclRef*) specInst->specDeclRefVal.usedValue;
+ auto specDeclRef = specDeclRefVal->declRef;
+
+ // Okay, we have a candidate for specialization here.
+ //
+ // We will first find or construct a specialized version
+ // of the callee funciton/
+ auto specFunc = getSpecializedFunc(sharedContext, genericFunc, specDeclRef);
+ //
+ // Then we will replace the use sites for the `specialize`
+ // instruction with uses of the specialized function.
+ //
+ specInst->replaceUsesWith(specFunc);
+
+ specInst->removeAndDeallocate();
+ }
+ }
+ }
+
+ // Once the work list has gone dry, we should have the invariant
+ // that there are no `specialize` instructions inside of non-generic
+ // functions that in turn reference a generic function.
+ }
+
+ //
}
diff --git a/source/slang/ir.h b/source/slang/ir.h
index ecc77dbc4..2477c987f 100644
--- a/source/slang/ir.h
+++ b/source/slang/ir.h
@@ -12,6 +12,7 @@
namespace Slang {
class Decl;
+class GenericDecl;
class FuncType;
class Layout;
class Type;
@@ -98,6 +99,7 @@ enum IRDecorationOp : uint16_t
kIRDecorationOp_HighLevelDecl,
kIRDecorationOp_Layout,
kIRDecorationOp_LoopControl,
+ kIRDecorationOp_Target,
};
// A "decoration" that gets applied to an instruction.
@@ -197,6 +199,11 @@ struct IRInst : IRValue
// Remove this instruction from its parent block,
// and then destroy it (it had better have no uses!)
void removeAndDeallocate();
+
+ // Clear out the arguments of this instruction,
+ // so that we don't appear on the list of uses
+ // for those values.
+ void removeArguments();
};
typedef int64_t IRIntegerValue;
@@ -321,8 +328,10 @@ struct IRFunc : IRGlobalValue
// The type of the IR-level function
IRFuncType* getType() { return (IRFuncType*) type.Ptr(); }
- // Any generic parameters this function has
- List<RefPtr<Decl>> genericParams;
+ // If this function is generic, then we store a reference
+ // to the AST-level generic that defines its parameters
+ // and their constraints.
+ RefPtr<GenericDecl> genericDecl;
// Convenience accessors for working with the
// function's type.
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index d4551421f..a3d67b670 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -270,7 +270,7 @@ struct SharedIRGenContext
{
CompileRequest* compileRequest;
- Dictionary<DeclRef<Decl>, LoweredValInfo> declValues;
+ Dictionary<Decl*, LoweredValInfo> declValues;
// Arrays we keep around strictly for memory-management purposes:
@@ -294,9 +294,16 @@ struct IRGenContext
}
};
+// Ensure that a version of the given declaration has been emitted to the IR
LoweredValInfo ensureDecl(
- IRGenContext* context,
- DeclRef<Decl> const& declRef);
+ IRGenContext* context,
+ Decl* decl);
+
+// Emit code as needed to construct a reference to the given declaration with
+// any needed specializations in place.
+LoweredValInfo emitDeclRef(
+ IRGenContext* context,
+ DeclRef<Decl> declRef);
IRValue* getSimpleVal(IRGenContext* context, LoweredValInfo lowered);
@@ -564,7 +571,7 @@ LoweredValInfo emitCallToDeclRef(
}
// Fallback case is to emit an actual call.
- LoweredValInfo funcVal = ensureDecl(context, funcDeclRef);
+ LoweredValInfo funcVal = emitDeclRef(context, funcDeclRef);
return emitCallToVal(context, type, funcVal, argCount, args);
}
@@ -750,6 +757,7 @@ RefPtr<IRFuncType> getFuncType(
IRType* resultType)
{
RefPtr<FuncType> funcType = new FuncType();
+ funcType->setSession(context->getSession());
funcType->resultType = resultType;
for (UInt pp = 0; pp < paramCount; ++pp)
{
@@ -810,43 +818,8 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredTypeInfo visitDeclRefType(DeclRefType* type)
{
-#if 1
// TODO: is there actually anything to be done at this point?
return LoweredTypeInfo(type);
-#else
- // We need to detect builtin/intrinsic types here, since they should map to custom modifiers
- // We need to catch builtin/intrinsic types here
- if( auto intrinsicTypeMod = type->declRef.getDecl()->FindModifier<IntrinsicTypeModifier>() )
- {
- auto builder = getBuilder();
- auto intType = getIntType(context);
- //
- List<IRValue*> irArgs;
- for( auto val : intrinsicTypeMod->irOperands )
- {
- irArgs.Add(builder->getIntValue(intType, val));
- }
-
- addGenericArgs(&irArgs, type->declRef);
-
- auto irType = getBuilder()->getIntrinsicType(IROp(intrinsicTypeMod->irOp), irArgs.Count(), irArgs.Buffer());
- return LoweredTypeInfo(irType);
- }
-
- // Catch-all for user-defined type references
- LoweredValInfo loweredDeclRef = ensureDecl(context, type->declRef);
-
- // TODO: make sure that the value is actually a type...
-
- switch (loweredDeclRef.flavor)
- {
- case LoweredValInfo::Flavor::Simple:
- return LoweredTypeInfo((IRType*)loweredDeclRef.val);
-
- default:
- SLANG_UNIMPLEMENTED_X("type lowering");
- }
-#endif
}
LoweredTypeInfo visitBasicExpressionType(BasicExpressionType* type)
@@ -956,7 +929,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
LoweredValInfo visitVarExpr(VarExpr* expr)
{
- LoweredValInfo info = ensureDecl(context, expr->declRef);
+ LoweredValInfo info = emitDeclRef(context, expr->declRef);
return info;
}
@@ -1431,7 +1404,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
LoweredValInfo visitStaticMemberExpr(StaticMemberExpr* expr)
{
- return ensureDecl(context, expr->declRef);
+ return emitDeclRef(context, expr->declRef);
}
LoweredValInfo visitSelectExpr(SelectExpr* expr)
@@ -2028,7 +2001,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
if (accessor->HasModifier<IntrinsicOpModifier>())
continue;
- ensureDecl(context, makeDeclRef(accessor.Ptr()));
+ ensureDecl(context, accessor);
}
// The subscript declaration itself won't correspond
@@ -2561,8 +2534,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
irFunc->mangledName = mangledName;
}
-
- LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl)
+ LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl)
{
// Collect the parameter lists we will use for our new function.
ParameterLists parameterLists;
@@ -2610,35 +2582,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// We first need to walk the generic parameters (if any)
// because these will influence the declared type of
// the function.
- UInt genericParamCounter = 0;
- for( auto genericParamDecl : parameterLists.genericParams )
- {
- irFunc->genericParams.Add(genericParamDecl);
-
-#if 0
- UInt genericParamIndex = genericParamCounter++;
- if( auto genericTypeParamDecl = dynamic_cast<GenericTypeParamDecl*>(genericParamDecl) )
- {
- // In the logical type for the function, a generic
- // type parameter will be represented as a parameter of type `Type`
-
- IRType* irTypeType = context->irBuilder->getTypeType();
- paramTypes.Add(irTypeType);
-
- // Anywhere else in the parameter type list where this type parameter
- // is referenced, we'll need to substitute in a reference
- // to the appropriate generic parameter position.
- IRType* irParameterType = context->irBuilder->getGenericParameterType(genericParamIndex);
- LoweredValInfo LoweredValInfo = LoweredValInfo::type(irParameterType);
- subContext->shared->declValues[makeDeclRef(genericTypeParamDecl)] = LoweredValInfo;
- }
- else
+ for(auto pp = decl->ParentDecl; pp; pp = pp->ParentDecl)
+ {
+ if(auto genericAncestor = dynamic_cast<GenericDecl*>(pp))
{
- // TODO: handle the other cases here.
- SLANG_UNEXPECTED("generic parameter kind");
+ irFunc->genericDecl = genericAncestor;
+ break;
}
-#endif
}
for( auto paramInfo : parameterLists.params )
@@ -2809,6 +2760,18 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addHighLevelDeclDecoration(irFunc, decl);
+ // If this declaration was marked as being an intrinsic for a particular
+ // target, then we should reflect that here.
+ for( auto targetMod : decl->GetModifiersOfType<SpecializedForTargetModifier>() )
+ {
+ // `targetMod` indicates that this particular declaration represents
+ // a specialized definition of the particular function for the given
+ // target, and we need to reflect that at the IR level.
+
+ auto decoration = getBuilder()->addDecoration<IRTargetDecoration>(irFunc);
+ decoration->targetName = targetMod->targetToken.Content;
+ }
+
// For convenience, ensure that any additional global
// values that were emitted while outputting the function
// body appear before the function itself in the list
@@ -2817,6 +2780,43 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return LoweredValInfo::simple(irFunc);
}
+
+
+ LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl)
+ {
+ // A function declaration may have multiple, target-specific
+ // overloads, and we need to emit an IR version of each of these.
+
+ // The front end will form a linked list of declaratiosn with
+ // the same signature, whenever there is any kind of redeclaration.
+ // We will look to see if that linked list has been formed.
+ auto primaryDecl = decl->primaryDecl;
+
+ if (!primaryDecl)
+ {
+ // If there is no linked list then we are in the ordinary
+ // case with a single declaration, and no special handling
+ // is needed.
+ return lowerFuncDecl(decl);
+ }
+
+ // Otherwise, we need to walk the linked list of declarations
+ // and make sure to emit IR code for any targets that need it.
+
+ // TODO: Need to be careful about how this is approached,
+ // to avoid emitting a bunch of extra definitions in the IR.
+
+ auto primaryFuncDecl = dynamic_cast<FunctionDeclBase*>(primaryDecl);
+ assert(primaryFuncDecl);
+ LoweredValInfo result = lowerFuncDecl(primaryFuncDecl);
+ for (auto dd = primaryDecl->nextDecl; dd; dd = dd->nextDecl)
+ {
+ auto funcDecl = dynamic_cast<FunctionDeclBase*>(dd);
+ assert(funcDecl);
+ lowerFuncDecl(funcDecl);
+ }
+ return result;
+ }
};
LoweredValInfo lowerDecl(
@@ -2828,20 +2828,17 @@ LoweredValInfo lowerDecl(
return visitor.dispatch(decl);
}
+// Ensure that a version of the given declaration has been emitted to the IR
LoweredValInfo ensureDecl(
- IRGenContext* context,
- DeclRef<Decl> const& declRef)
+ IRGenContext* context,
+ Decl* decl)
{
auto shared = context->shared;
LoweredValInfo result;
- if(shared->declValues.TryGetValue(declRef, result))
+ if(shared->declValues.TryGetValue(decl, result))
return result;
- // TODO: this is where we need to apply any specializations
- // from the declaration reference, so that they can be
- // applied correctly to the declaration itself...
-
IRBuilder subIRBuilder;
subIRBuilder.shared = context->irBuilder->shared;
@@ -2849,13 +2846,42 @@ LoweredValInfo ensureDecl(
subContext.irBuilder = &subIRBuilder;
- result = lowerDecl(&subContext, declRef.getDecl());
+ result = lowerDecl(&subContext, decl);
- shared->declValues[declRef] = result;
+ shared->declValues[decl] = result;
return result;
}
+LoweredValInfo emitDeclRef(
+ IRGenContext* context,
+ DeclRef<Decl> declRef)
+{
+ // First we need to construct an IR value representing the
+ // unspecialized declaration.
+ LoweredValInfo loweredDecl = ensureDecl(context, declRef.getDecl());
+
+ // If this declaration reference doesn't involve any specializations,
+ // then we are done at this point.
+ if(!declRef.substitutions)
+ return loweredDecl;
+
+ auto val = getSimpleVal(context, loweredDecl);
+
+ RefPtr<Type> type;
+ if(auto declType = val->getType())
+ {
+ type = declType->Substitute(declRef.substitutions).As<Type>();
+ }
+
+ // Otherwise, we need to construct a specialization of the
+ // given declaration.
+ return LoweredValInfo::simple(context->irBuilder->emitSpecializeInst(
+ type,
+ val,
+ declRef));
+}
+
static void lowerEntryPointToIR(
IRGenContext* context,
EntryPointRequest* entryPointRequest)
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp
index 3e6ee9917..5708bab64 100644
--- a/source/slang/lower.cpp
+++ b/source/slang/lower.cpp
@@ -739,6 +739,7 @@ struct LoweringVisitor
RefPtr<Type> visitFuncType(FuncType* type)
{
RefPtr<FuncType> loweredType = new FuncType();
+ loweredType->setSession(getSession());
loweredType->resultType = lowerType(type->resultType);
for (auto paramType : type->paramTypes)
{
diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp
index ce36a97b6..71c0605a9 100644
--- a/source/slang/mangle.cpp
+++ b/source/slang/mangle.cpp
@@ -44,75 +44,193 @@ namespace Slang
context->sb.append(str);
}
+ void emitVal(
+ ManglingContext* context,
+ Val* val);
+
+ void emitQualifiedName(
+ ManglingContext* context,
+ DeclRef<Decl> declRef);
+
+ void emitSimpleIntVal(
+ ManglingContext* context,
+ Val* val)
+ {
+ if( auto constVal = dynamic_cast<ConstantIntVal*>(val) )
+ {
+ auto val = constVal->value;
+ if( val >= 0 && val <= 9 )
+ {
+ emit(context, (UInt) val);
+ return;
+ }
+ }
+
+ // Fallback:
+ emitVal(context, val);
+ }
+
void emitType(
ManglingContext* context,
Type* type)
{
// TODO: actually implement this bit...
+
+ if( auto basicType = dynamic_cast<BasicExpressionType*>(type) )
+ {
+ switch( basicType->baseType )
+ {
+ case BaseType::Void: emitRaw(context, "V"); break;
+ case BaseType::Bool: emitRaw(context, "b"); break;
+ case BaseType::Int: emitRaw(context, "i"); break;
+ case BaseType::UInt: emitRaw(context, "u"); break;
+ case BaseType::UInt64: emitRaw(context, "U"); break;
+ case BaseType::Half: emitRaw(context, "h"); break;
+ case BaseType::Float: emitRaw(context, "f"); break;
+ case BaseType::Double: emitRaw(context, "d"); break;
+ break;
+
+ default:
+ SLANG_UNEXPECTED("unimplemented case in mangling");
+ break;
+ }
+ }
+ else if( auto vecType = dynamic_cast<VectorExpressionType*>(type) )
+ {
+ emitRaw(context, "v");
+ emitSimpleIntVal(context, vecType->elementCount);
+ emitType(context, vecType->elementType);
+ }
+ else if( auto matType = dynamic_cast<MatrixExpressionType*>(type) )
+ {
+ emitRaw(context, "m");
+ emitSimpleIntVal(context, matType->getRowCount());
+ emitRaw(context, "x");
+ emitSimpleIntVal(context, matType->getColumnCount());
+ emitType(context, matType->getElementType());
+ }
+ else if( auto namedType = dynamic_cast<NamedExpressionType*>(type) )
+ {
+ emitType(context, GetType(namedType->declRef));
+ }
+ else if( auto declRefType = dynamic_cast<DeclRefType*>(type) )
+ {
+ emitQualifiedName(context, declRefType->declRef);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("unimplemented case in mangling");
+ }
+ }
+
+ void emitVal(
+ ManglingContext* context,
+ Val* val)
+ {
+ if( auto type = dynamic_cast<Type*>(val) )
+ {
+ emitType(context, type);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("unimplemented case in mangling");
+ }
}
void emitQualifiedName(
ManglingContext* context,
- Decl* decl)
+ DeclRef<Decl> declRef)
{
- auto parentDecl = decl->ParentDecl;
- if( parentDecl )
+ auto parentDeclRef = declRef.GetParent();
+ auto parentGenericDeclRef = parentDeclRef.As<GenericDecl>();
+ if( parentDeclRef )
{
- emitQualifiedName(context, parentDecl);
+ // In certain cases we want to skip emitting the parent
+ if(parentGenericDeclRef && (parentGenericDeclRef.getDecl()->inner.Ptr() != declRef.getDecl()))
+ {
+ }
+ else if(parentDeclRef.As<FunctionDeclBase>())
+ {
+ }
+ else
+ {
+ emitQualifiedName(context, parentDeclRef);
+ }
}
// A generic declaration is kind of a pseudo-declaration
// as far as the user is concerned; so we don't want
// to emit its name.
- if( auto genericDecl = dynamic_cast<GenericDecl*>(decl) )
+ if(auto genericDeclRef = declRef.As<GenericDecl>())
{
return;
}
- emitName(context, decl->nameAndLoc.name);
+ emitName(context, declRef.GetName());
- if( auto parentGenericDecl = dynamic_cast<GenericDecl*>(parentDecl))
+ // Are we the "inner" declaration beneath a generic decl?
+ if(parentGenericDeclRef && (parentGenericDeclRef.getDecl()->inner.Ptr() == declRef.getDecl()))
{
- emitRaw(context, "g");
- UInt genericParameterCount = 0;
- for( auto mm : parentGenericDecl->Members )
+ // There are two cases here: either we have specializations
+ // in place for the parent generic declaration, or we don't.
+
+ auto subst = declRef.substitutions;
+ if( subst && subst->genericDecl == parentGenericDeclRef.getDecl() )
{
- if(mm.As<GenericTypeParamDecl>())
- {
- genericParameterCount++;
- }
- else if(mm.As<GenericValueParamDecl>())
- {
- genericParameterCount++;
- }
- else if(mm.As<GenericTypeConstraintDecl>())
- {
- genericParameterCount++;
- }
- else
+ // This is the case where we *do* have substitutions.
+ emitRaw(context, "G");
+ UInt genericArgCount = subst->args.Count();
+ emit(context, genericArgCount);
+ for( auto aa : subst->args )
{
+ emitVal(context, aa);
}
}
-
- emit(context, genericParameterCount);
- for( auto mm : parentGenericDecl->Members )
+ else
{
- if(auto genericTypeParamDecl = mm.As<GenericTypeParamDecl>())
+ // We don't have substitutions, so we will emit
+ // information about the parameters of the generic here.
+ emitRaw(context, "g");
+ UInt genericParameterCount = 0;
+ for( auto mm : getMembers(parentGenericDeclRef) )
{
- emitRaw(context, "T");
+ if(mm.As<GenericTypeParamDecl>())
+ {
+ genericParameterCount++;
+ }
+ else if(mm.As<GenericValueParamDecl>())
+ {
+ genericParameterCount++;
+ }
+ else if(mm.As<GenericTypeConstraintDecl>())
+ {
+ genericParameterCount++;
+ }
+ else
+ {
+ }
}
- else if(auto genericValueParamDecl = mm.As<GenericValueParamDecl>())
- {
- emitRaw(context, "v");
- emitType(context, genericValueParamDecl->getType());
- }
- else if(mm.As<GenericTypeConstraintDecl>())
- {
- emitRaw(context, "C");
- // TODO: actually emit info about the constraint
- }
- else
+
+ emit(context, genericParameterCount);
+ for( auto mm : getMembers(parentGenericDeclRef) )
{
+ if(auto genericTypeParamDecl = mm.As<GenericTypeParamDecl>())
+ {
+ emitRaw(context, "T");
+ }
+ else if(auto genericValueParamDecl = mm.As<GenericValueParamDecl>())
+ {
+ emitRaw(context, "v");
+ emitType(context, GetType(genericValueParamDecl));
+ }
+ else if(mm.As<GenericTypeConstraintDecl>())
+ {
+ emitRaw(context, "C");
+ // TODO: actually emit info about the constraint
+ }
+ else
+ {
+ }
}
}
}
@@ -124,23 +242,25 @@ namespace Slang
// We'll also go ahead and emit the result type as well,
// just for completeness.
//
- if( auto callableDecl = dynamic_cast<CallableDecl*>(decl) )
+ if( auto callableDeclRef = declRef.As<CallableDecl>())
{
emitRaw(context, "p");
- UInt parameterCount = callableDecl->GetParameters().Count();
+
+ auto parameters = GetParameters(callableDeclRef);
+ UInt parameterCount = parameters.Count();
emit(context, parameterCount);
- for(auto pp : callableDecl->GetParameters())
+ for(auto paramDeclRef : parameters)
{
- emitType(context, pp->getType());
+ emitType(context, GetType(paramDeclRef));
}
- emitType(context, callableDecl->ReturnType);
+ emitType(context, GetResultType(callableDeclRef));
}
}
void mangleName(
ManglingContext* context,
- Decl* decl)
+ DeclRef<Decl> declRef)
{
// TODO: catch cases where the declaration should
// forward to something else? E.g., what if we
@@ -150,6 +270,8 @@ namespace Slang
// clashes with user-defined symbols:
emitRaw(context, "_S");
+ auto decl = declRef.getDecl();
+
// Next we will add a bit of info to register
// the *kind* of declaration we are dealing with.
//
@@ -174,17 +296,24 @@ namespace Slang
}
// Now we encode the qualified name of the decl.
- emitQualifiedName(context, decl);
+ emitQualifiedName(context, declRef);
}
-
-
- String getMangledName(Decl* decl)
+ String getMangledName(DeclRef<Decl> const& declRef)
{
ManglingContext context;
+ mangleName(&context, declRef);
+ return context.sb.ProduceString();
+ }
- mangleName(&context, decl);
+ String getMangledName(DeclRefBase const & declRef)
+ {
+ return getMangledName(
+ DeclRef<Decl>(declRef.decl, declRef.substitutions));
+ }
- return context.sb.ProduceString();
+ String getMangledName(Decl* decl)
+ {
+ return getMangledName(makeDeclRef(decl));
}
}
diff --git a/source/slang/mangle.h b/source/slang/mangle.h
index 286e2c2c3..11196f496 100644
--- a/source/slang/mangle.h
+++ b/source/slang/mangle.h
@@ -4,12 +4,13 @@
// This file implements the name mangling scheme for the Slang language.
#include "../core/basic.h"
+#include "syntax.h"
namespace Slang
{
- class Decl;
-
String getMangledName(Decl* decl);
+ String getMangledName(DeclRef<Decl> const & declRef);
+ String getMangledName(DeclRefBase const & declRef);
}
#endif \ No newline at end of file
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index b863cb707..54a4a79b6 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -628,6 +628,7 @@ void Type::accept(IValVisitor* visitor, void* extra)
{
SLANG_UNEXPECTED("expected a declaration reference type");
}
+ declRefType->session = session;
declRefType->declRef = declRef;
return declRefType;
}
@@ -800,9 +801,52 @@ void Type::accept(IValVisitor* visitor, void* extra)
return false;
}
+ RefPtr<Val> FuncType::SubstituteImpl(Substitutions* subst, int* ioDiff)
+ {
+ int diff = 0;
+
+ // result type
+ RefPtr<Type> substResultType = resultType->SubstituteImpl(subst, &diff).As<Type>();
+
+ // parameter types
+ List<RefPtr<Type>> substParamTypes;
+ for( auto pp : paramTypes )
+ {
+ substParamTypes.Add(pp->SubstituteImpl(subst, &diff).As<Type>());
+ }
+
+ // early exit for no change...
+ if(!diff)
+ return this;
+
+ (*ioDiff)++;
+ RefPtr<FuncType> substType = new FuncType();
+ substType->session = session;
+ substType->resultType = substResultType;
+ substType->paramTypes = substParamTypes;
+ return substType;
+ }
+
Type* FuncType::CreateCanonicalType()
{
- return this;
+ // result type
+ RefPtr<Type> canResultType = resultType->GetCanonicalType();
+
+ // parameter types
+ List<RefPtr<Type>> canParamTypes;
+ for( auto pp : paramTypes )
+ {
+ canParamTypes.Add(pp->GetCanonicalType());
+ }
+
+ RefPtr<FuncType> canType = new FuncType();
+ canType->session = session;
+ canType->resultType = resultType;
+ canType->paramTypes = canParamTypes;
+
+ session->canonicalTypes.Add(canType);
+
+ return canType;
}
int FuncType::GetHashCode()
diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h
index fc3b651b4..e928efb65 100644
--- a/source/slang/type-defs.h
+++ b/source/slang/type-defs.h
@@ -461,6 +461,7 @@ RAW(
virtual String ToString() override;
protected:
+ virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override;
virtual bool EqualsImpl(Type * type) override;
virtual Type* CreateCanonicalType() override;
virtual int GetHashCode() override;
diff --git a/tests/ir/loop.slang.expected b/tests/ir/loop.slang.expected
index a9122c094..390fd80e0 100644
--- a/tests/ir/loop.slang.expected
+++ b/tests/ir/loop.slang.expected
@@ -7,7 +7,7 @@ ir_global_var %2 : Ptr<StructuredBuffer<vector<float,4>>>;
ir_global_var %3 : Ptr<RWStructuredBuffer<vector<float,4>>>;
-ir_func @_S04mainp3 : (uint, uint, uint) -> void
+ir_func @_S04mainp3uuuV : (uint, uint, uint) -> void
{
block %4(
param %5 : uint,
diff --git a/tools/slang-test/main.cpp b/tools/slang-test/main.cpp
index 83d512cc1..5c0c81392 100644
--- a/tools/slang-test/main.cpp
+++ b/tools/slang-test/main.cpp
@@ -1205,9 +1205,10 @@ TestResult doImageComparison(String const& filePath)
continue;
}
+ float relativeDiff = 0.0f;
if( expectedVal != 0 )
{
- float relativeDiff = fabsf(float(actualVal) - float(expectedVal)) / float(expectedVal);
+ relativeDiff = fabsf(float(actualVal) - float(expectedVal)) / float(expectedVal);
if( relativeDiff < kRelativeDiffCutoff )
{
@@ -1220,6 +1221,13 @@ TestResult doImageComparison(String const& filePath)
// cases where vertex shader results lead to rendering that is off
// by one pixel...
+ fprintf(stderr, "image compare failure at (%d,%d) channel %d. expected %d got %d (absolute error: %d, relative error: %f)\n",
+ x, y, n,
+ expectedVal,
+ actualVal,
+ absoluteDiff,
+ relativeDiff);
+
// There was a difference we couldn't excuse!
return kTestResult_Fail;
}