From a12480fe49d5ba7c0a9c2ac63363dc76b599ddbd Mon Sep 17 00:00:00 2001 From: Tim Foley Date: Wed, 18 Oct 2017 11:08:47 -0700 Subject: Work on IR-based cross-compilation (#222) There are two big changes here: - Add logic during the initial IR cloning pass for an entry point + target that tries to pick the best possible version of any target-overloaded function. This allows us to pick the intrinsic version of `saturate()` when compiling for HLSL output, but then pick the non-intrinsic version (that is implemented in terms of `clamp()`) when targetting GLSL. - Add an initial specialization pass that tries to deal with generics. This required some fixing work to IR generation, so that we correctly generate explicit operations to specialize a generic for specific types (this is currently implemented as a `specialize` instruction that takes the generic to specialize plus a declaration-reference that represents the specialized form). With that work in place, we can scan for `specialize` instructions inside of non-generic functions, and use them to trigger generation of specialized code. We rely on the name-mangling scheme to help us find pre-existing specializations when possible. There are also a bunch of cleanups encountered along the way: - Don't use the explicit `layout(offset=...)` for uniforms, because it isn't supported by all current drivers. For now we will just assume that our layout rules compute the same values that the driver would for un-marked-up code. We can come back later and try to implement a workaround in the cases where this doesn't apply (e.g., by re-running the layout logic as part of emission, and dropping layout modifiers from variables that don't need explicit layout). - Fix some issues in IR dump printing so that we print function declarations more nicely. - Testing: print out failing pixel when image-diff fails --- source/slang/emit.cpp | 72 +++- source/slang/hlsl.meta.slang | 6 +- source/slang/hlsl.meta.slang.h | 6 +- source/slang/ir-inst-defs.h | 2 + source/slang/ir-insts.h | 35 +- source/slang/ir.cpp | 761 ++++++++++++++++++++++++++++++++++++----- source/slang/ir.h | 13 +- source/slang/lower-to-ir.cpp | 184 +++++----- source/slang/lower.cpp | 1 + source/slang/mangle.cpp | 231 ++++++++++--- source/slang/mangle.h | 5 +- source/slang/syntax.cpp | 46 ++- source/slang/type-defs.h | 1 + tests/ir/loop.slang.expected | 2 +- tools/slang-test/main.cpp | 10 +- 15 files changed, 1134 insertions(+), 241 deletions(-) diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index f9c0b5f09..d8d9e13ed 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -3574,17 +3574,34 @@ struct EmitVisitor switch(info.kind) { case LayoutResourceKind::Uniform: - // Explicit offsets require a GLSL extension. - // - // TODO: We really need to fix this so that we - // only output an explicit offset for things - // that are layed out differently than they - // would normally be... - requireGLSLExtension("GL_ARB_enhanced_layouts"); + { + // Explicit offsets require a GLSL extension (which + // is not universally supported, it seems) or a new + // enough GLSL version (which we don't want to + // universall require), so for right now we + // won't actually output explicit offsets for uniform + // shader parameters. + // + // TODO: We should fix this so that we skip any + // extra work for parameters that are laid out as + // expected by the default rules, but do *something* + // for parameters that need non-default layout. + // + // Using the `GL_ARB_enhanced_layouts` feature is one + // option, but we should also be able to do some + // things by introducing padding into the declaration + // (padding insertion would probably be best done at + // the IR level). + bool useExplicitOffsets = false; + if (useExplicitOffsets) + { + requireGLSLExtension("GL_ARB_enhanced_layouts"); - Emit("layout(offset = "); - Emit(info.index); - Emit(")\n"); + Emit("layout(offset = "); + Emit(info.index); + Emit(")\n"); + } + } break; case LayoutResourceKind::VertexInput: @@ -4073,7 +4090,11 @@ emitDeclImpl(decl, nullptr); { case kIROp_global_var: case kIROp_Func: - return ((IRGlobalValue*)inst)->mangledName; + { + auto& mangledName = ((IRGlobalValue*)inst)->mangledName; + if(mangledName.Length() != 0) + return mangledName; + } break; default: @@ -4396,6 +4417,7 @@ emitDeclImpl(decl, nullptr); case kIROp_boolConst: case kIROp_FieldAddress: case kIROp_getElementPtr: + case kIROp_specialize: return true; } @@ -4937,6 +4959,12 @@ emitDeclImpl(decl, nullptr); } break; + case kIROp_specialize: + { + emitIROperand(context, inst->getArg(0)); + } + break; + default: emit("/* unhandled */"); break; @@ -5579,6 +5607,11 @@ emitDeclImpl(decl, nullptr); if(!value) return nullptr; + if(value->op == kIROp_specialize) + { + value = ((IRSpecialize*) value)->genericVal.usedValue; + } + if(value->op != kIROp_Func) return nullptr; @@ -5608,6 +5641,14 @@ emitDeclImpl(decl, nullptr); } else #endif + if(func->genericDecl) + { + Emit("/* "); + emitIRFuncDecl(context, func); + Emit(" */"); + return; + } + if(!isDefinition(func)) { // This is just a function declaration, @@ -6339,6 +6380,13 @@ String emitEntryPoint( // TODO: we should apply some guaranteed transformations here, // to eliminate constructs that aren't legal downstream (e.g. generics). + + specializeGenerics(lowered); + +// fprintf(stderr, "###\n"); +// dumpIR(lowered); +// fprintf(stderr, "###\n"); + // // TODO: Need to decide whether to do these before or after // target-specific legalization steps. Currently I've folded @@ -6348,6 +6396,8 @@ String emitEntryPoint( // IR back into AST for emission? visitor.emitIRModule(&context, lowered); + + // TODO: need to clean up the IR module here } else if(!(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING ) || translationUnit->compileRequest->loadedModulesList.Count() != 0) diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 81e9931e8..dc1d4d8e8 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -326,9 +326,9 @@ __generic __intrinsic_ __intrinsic_op bool CheckAccessFullyMapped(uint status); // Clamp (HLSL SM 1.0) -__generic __intrinsic_op T clamp(T x, T min, T max); -__generic __intrinsic_op vector clamp(vector x, vector min, vector max); -__generic __intrinsic_op matrix clamp(matrix x, matrix min, matrix max); +__generic T clamp(T x, T min, T max); +__generic vector clamp(vector x, vector min, vector max); +__generic matrix clamp(matrix x, matrix min, matrix max); // Clip (discard) fragment conditionally __generic __intrinsic_op void clip(T x); diff --git a/source/slang/hlsl.meta.slang.h b/source/slang/hlsl.meta.slang.h index eccb12f8d..dfbdbe57b 100644 --- a/source/slang/hlsl.meta.slang.h +++ b/source/slang/hlsl.meta.slang.h @@ -328,9 +328,9 @@ sb << "// Check access status to tiled resource\n"; sb << "__intrinsic_op bool CheckAccessFullyMapped(uint status);\n"; sb << "\n"; sb << "// Clamp (HLSL SM 1.0)\n"; -sb << "__generic __intrinsic_op T clamp(T x, T min, T max);\n"; -sb << "__generic __intrinsic_op vector clamp(vector x, vector min, vector max);\n"; -sb << "__generic __intrinsic_op matrix clamp(matrix x, matrix min, matrix max);\n"; +sb << "__generic T clamp(T x, T min, T max);\n"; +sb << "__generic vector clamp(vector x, vector min, vector max);\n"; +sb << "__generic matrix clamp(matrix x, matrix min, matrix max);\n"; sb << "\n"; sb << "// Clip (discard) fragment conditionally\n"; sb << "__generic __intrinsic_op void clip(T x);\n"; diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index c11d66571..636eeec16 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -96,6 +96,8 @@ INST(IntLit, integer_constant, 0, 0) INST(FloatLit, float_constant, 0, 0) INST(decl_ref, decl_ref, 0, 0) +INST(specialize, specialize, 2, 0) + INST(Construct, construct, 0, 0) INST(Call, call, 1, 0) diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index 9ac79413f..50577b2a3 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -47,13 +47,31 @@ struct IRLoopControlDecoration : IRDecoration IRLoopControl mode; }; +struct IRTargetDecoration : IRDecoration +{ + enum { kDecorationOp = kIRDecorationOp_Target }; + + // TODO: have a more structured representation of target specifiers + String targetName; +}; + // // An IR node to represent a reference to an AST-level // declaration. struct IRDeclRef : IRValue { - DeclRefBase declRef; + DeclRef declRef; +}; + +// An instruction that specializes another IR value +// (representing a generic) to a particular set of +// generic arguments (encoded via an `IRDeclRef`) +// +struct IRSpecialize : IRInst +{ + IRUse genericVal; + IRUse specDeclRefVal; }; // @@ -304,6 +322,16 @@ struct IRBuilder IRValue* getDeclRefVal( DeclRefBase const& declRef); + IRValue* emitSpecializeInst( + IRType* type, + IRValue* genericVal, + IRValue* specDeclRef); + + IRValue* emitSpecializeInst( + IRType* type, + IRValue* genericVal, + DeclRef specDeclRef); + IRInst* emitCallInst( IRType* type, IRValue* func, @@ -452,12 +480,15 @@ struct IRBuilder // Generate a clone of an IR module that is specialized for // a particular entry point, target, etc. - IRModule* specializeIRForEntryPoint( EntryPointRequest* entryPointRequest, ProgramLayout* programLayout, CodeGenTarget target); +// Find suitable uses of the `specialize` instruction that +// can be replaced with references to specialized functions. +void specializeGenerics( + IRModule* module); } diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 31a35cd08..fb6013bc3 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -531,10 +531,41 @@ namespace Slang this, kIROp_decl_ref, nullptr); - irValue->declRef = declRef; + irValue->declRef = DeclRef(declRef.decl, declRef.substitutions); return irValue; } + IRValue* IRBuilder::emitSpecializeInst( + Type* type, + IRValue* genericVal, + IRValue* specDeclRef) + { + auto inst = createInst( + this, + kIROp_specialize, + type, + genericVal, + specDeclRef); + addInst(inst); + return inst; + } + + IRValue* IRBuilder::emitSpecializeInst( + Type* type, + IRValue* genericVal, + DeclRef specDeclRef) + { + auto specDeclRefVal = getDeclRefVal(specDeclRef); + auto inst = createInst( + this, + kIROp_specialize, + type, + genericVal, + specDeclRefVal); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitCallInst( IRType* type, IRValue* func, @@ -1352,15 +1383,26 @@ namespace Slang auto parentDeclRef = declRef.GetParent(); auto genericParentDeclRef = parentDeclRef.As(); - if(genericParentDeclRef) + if (genericParentDeclRef) { - parentDeclRef = genericParentDeclRef.GetParent(); + if (genericParentDeclRef.getDecl()->inner.Ptr() == decl) + { + parentDeclRef = genericParentDeclRef.GetParent(); + } + else + { + genericParentDeclRef = DeclRef(); + } } if(parentDeclRef.As()) { parentDeclRef = DeclRef(); } + else if(parentDeclRef.As()) + { + parentDeclRef = DeclRef(); + } if(parentDeclRef) { @@ -1709,15 +1751,97 @@ namespace Slang dump(context, "\n"); } + void dumpGenericSignature( + IRDumpContext* context, + GenericDecl* genericDecl) + { + for( auto pp = genericDecl->ParentDecl; pp; pp = pp->ParentDecl ) + { + if( auto genericAncestor = dynamic_cast(pp) ) + { + dumpGenericSignature(context, genericAncestor); + break; + } + } + + dump(context, " <"); + bool first = true; + for (auto mm : genericDecl->Members) + { + if (!first) dump(context, ", "); + + if( auto typeParamDecl = mm.As() ) + { + dumpDeclRef(context, makeDeclRef(typeParamDecl.Ptr())); + first = false; + } + else if( auto valueParamDecl = mm.As() ) + { + dumpDeclRef(context, makeDeclRef(valueParamDecl.Ptr())); + first = false; + } + } + first = true; + for (auto mm : genericDecl->Members) + { + if (!first) dump(context, ", "); + else dump(context, " where "); + + if( auto constraintDecl = mm.As() ) + { + dumpType(context, constraintDecl->sub); + dump(context, " : "); + dumpType(context, constraintDecl->sup); + first = false; + } + } + dump(context, ">"); + } + void dumpIRFunc( IRDumpContext* context, IRFunc* func) { + + for( auto dd = func->firstDecoration; dd; dd = dd->next ) + { + switch( dd->op ) + { + case kIRDecorationOp_Target: + { + auto decoration = (IRTargetDecoration*) dd; + + dump(context, "\n"); + dumpIndent(context); + dump(context, "[target("); + dump(context, decoration->targetName.Buffer()); + dump(context, ")]"); + } + break; + + } + } + dump(context, "\n"); dumpIndent(context); dump(context, "ir_func "); dumpID(context, func); + + if (func->genericDecl) + { + dump(context, " "); + dumpGenericSignature(context, func->genericDecl); + } + dumpInstTypeClause(context, func->getType()); + + if (!func->getFirstBlock()) + { + // Just a declaration. + dump(context, ";\n"); + return; + } + dump(context, "\n"); dumpIndent(context); @@ -1941,11 +2065,30 @@ namespace Slang parentBlock = nullptr; } + void IRInst::removeArguments() + { + UInt argCount = this->argCount; + for( UInt aa = 0; aa < argCount; ++aa ) + { + IRUse& use = getArgs()[aa]; + + if(!use.usedValue) + continue; + + // Need to unlink this use from the appropriate linked list. + use.usedValue = nullptr; + *use.prevLink = use.nextUse; + use.prevLink = nullptr; + use.nextUse = nullptr; + } + } + // Remove this instruction from its parent block, // and then destroy it (it had better have no uses!) void IRInst::removeAndDeallocate() { removeFromParent(); + removeArguments(); deallocate(); } @@ -2211,6 +2354,7 @@ namespace Slang // because it is no longer accurate. auto voidFuncType = new FuncType(); + voidFuncType->setSession(session); voidFuncType->resultType = session->getVoidType(); func->type = voidFuncType; @@ -2233,7 +2377,7 @@ namespace Slang RefPtr nextWithSameName; }; - struct IRSpecContext + struct IRSharedSpecContext { // The specialized module we are building IRModule* module; @@ -2241,33 +2385,63 @@ namespace Slang // The original, unspecialized module we are copying IRModule* originalModule; - // The IR builder to use for creating nodes - IRBuilder* builder; - // A map from mangled symbol names to zero or // more global IR values that have that name, // in the *original* module. - Dictionary> symbols; - - // A map from the mangled name of a global variable - // to the layout to use for it. - Dictionary globalVarLayouts; + typedef Dictionary> SymbolDictionary; + SymbolDictionary symbols; // A map from values in the original IR module // to their equivalent in the cloned module. - Dictionary clonedValues; + typedef Dictionary ClonedValueDictionary; + ClonedValueDictionary clonedValues; + + SharedIRBuilder sharedBuilderStorage; + IRBuilder builderStorage; + }; + + struct IRSpecContextBase + { + IRSharedSpecContext* shared; + + IRSharedSpecContext* getShared() { return shared; } + + IRModule* getModule() { return getShared()->module; } + + IRModule* getOriginalModule() { return getShared()->originalModule; } + + IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; } + + IRSharedSpecContext::ClonedValueDictionary& getClonedValues() { return getShared()->clonedValues; } + + // The IR builder to use for creating nodes + IRBuilder* builder; + + // A callback to be used when a value that is not registerd in `clonedValues` + // is needed during cloning. This gives the subtype a chance to intercept + // the operation and clone (or not) as needed. + virtual IRValue* maybeCloneValue(IRValue* originalVal) + { + return originalVal; + } + + // A callback used to clone (or not) types. + virtual RefPtr maybeCloneType(Type* originalType) + { + return originalType; + } }; void registerClonedValue( - IRSpecContext* context, + IRSpecContextBase* context, IRValue* clonedValue, IRValue* originalValue) { - context->clonedValues.Add(originalValue, clonedValue); + context->getClonedValues().Add(originalValue, clonedValue); } void cloneDecorations( - IRSpecContext* context, + IRSpecContextBase* context, IRValue* clonedValue, IRValue* originalValue) { @@ -2292,31 +2466,39 @@ namespace Slang // TODO: implement this } + struct IRSpecContext : IRSpecContextBase + { + // The code-generation target in use + CodeGenTarget target; + + // A map from the mangled name of a global variable + // to the layout to use for it. + Dictionary globalVarLayouts; + + // Override the "maybe clone" logic so that we always clone + virtual IRValue* maybeCloneValue(IRValue* originalVal) override; + }; + + IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar); IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc); - IRValue* cloneValue( - IRSpecContext* context, - IRValue* originalValue) + IRValue* IRSpecContext::maybeCloneValue(IRValue* originalValue) { - IRValue* clonedValue = nullptr; - if (context->clonedValues.TryGetValue(originalValue, clonedValue)) - return clonedValue; - switch (originalValue->op) { case kIROp_global_var: - return cloneGlobalVar(context, (IRGlobalVar*)originalValue); + return cloneGlobalVar(this, (IRGlobalVar*)originalValue); break; case kIROp_Func: - return cloneFunc(context, (IRFunc*)originalValue); + return cloneFunc(this, (IRFunc*)originalValue); break; case kIROp_boolConst: { IRConstant* c = (IRConstant*)originalValue; - return context->builder->getBoolValue(c->u.intVal != 0); + return builder->getBoolValue(c->u.intVal != 0); } break; @@ -2324,21 +2506,21 @@ namespace Slang case kIROp_IntLit: { IRConstant* c = (IRConstant*)originalValue; - return context->builder->getIntValue(c->type, c->u.intVal); + return builder->getIntValue(c->type, c->u.intVal); } break; case kIROp_FloatLit: { IRConstant* c = (IRConstant*)originalValue; - return context->builder->getFloatValue(c->type, c->u.floatVal); + return builder->getFloatValue(c->type, c->u.floatVal); } break; case kIROp_decl_ref: { IRDeclRef* od = (IRDeclRef*)originalValue; - return context->builder->getDeclRefVal(od->declRef); + return builder->getDeclRefVal(od->declRef); } break; @@ -2348,8 +2530,19 @@ namespace Slang } } + IRValue* cloneValue( + IRSpecContextBase* context, + IRValue* originalValue) + { + IRValue* clonedValue = nullptr; + if (context->getClonedValues().TryGetValue(originalValue, clonedValue)) + return clonedValue; + + return context->maybeCloneValue(originalValue); + } + void cloneInst( - IRSpecContext* context, + IRSpecContextBase* context, IRBuilder* builder, IRInst* originalInst) { @@ -2366,7 +2559,8 @@ namespace Slang // it, and then add it to the sequence. UInt argCount = originalInst->getArgCount(); IRInst* clonedInst = createInstWithTrailingArgs( - builder, originalInst->op, originalInst->type, + builder, originalInst->op, + context->maybeCloneType(originalInst->type), 0, nullptr, argCount, nullptr); builder->addInst(clonedInst); @@ -2410,14 +2604,14 @@ namespace Slang } void cloneFunctionCommon( - IRSpecContext* context, + IRSpecContextBase* context, IRFunc* clonedFunc, IRFunc* originalFunc) { // First clone all the simple properties. clonedFunc->mangledName = originalFunc->mangledName; - clonedFunc->genericParams = originalFunc->genericParams; - clonedFunc->type = originalFunc->type; + clonedFunc->genericDecl = originalFunc->genericDecl; + clonedFunc->type = context->maybeCloneType(originalFunc->type); cloneDecorations(context, clonedFunc, originalFunc); @@ -2445,7 +2639,9 @@ namespace Slang originalParam; originalParam = originalParam->getNextParam()) { - IRParam* clonedParam = builder->emitParam(originalParam->getType()); + IRParam* clonedParam = builder->emitParam( + context->maybeCloneType( + originalParam->getType())); registerClonedValue(context, clonedParam, originalParam); } } @@ -2475,7 +2671,7 @@ namespace Slang // // TODO: This isn't really a good requirement to place on the IR... clonedFunc->removeFromParent(); - clonedFunc->insertAtEnd(context->module); + clonedFunc->insertAtEnd(context->getModule()); } IRFunc* specializeIRForEntryPoint( @@ -2486,7 +2682,7 @@ namespace Slang // Look up the IR symbol by name String mangledName = getMangledName(entryPointRequest->decl); RefPtr sym; - if (!context->symbols.TryGetValue(mangledName, sym)) + if (!context->getSymbols().TryGetValue(mangledName, sym)) { SLANG_UNEXPECTED("no matching IR symbol"); return nullptr; @@ -2534,23 +2730,224 @@ namespace Slang return clonedFunc; } - // The case for functions that are not the entry point is - // strictly simpler, so that is nice. - IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc) + IRFunc* cloneSimpleFunc(IRSpecContextBase* context, IRFunc* originalFunc) { - // TODO: We really need to scan through all the various - // global function symbols that have the same mangled name, - // and pick the correct one to lower for the target. - auto clonedFunc = context->builder->createFunc(); registerClonedValue(context, clonedFunc, originalFunc); cloneFunctionCommon(context, clonedFunc, originalFunc); return clonedFunc; } + // Get a string form of the target so that we can + // use it to match against target-specialization modifiers + // + // TODO: We shouldn't be using strings for this. + String getTargetName(IRSpecContext* context) + { + switch( context->target ) + { + case CodeGenTarget::HLSL: + return "hlsl"; + + case CodeGenTarget::GLSL: + return "glsl"; + + default: + SLANG_UNEXPECTED("unhandled case"); + return "unknown"; + } + } + + // How specialized is a given declaration for the chosen target? + enum class TargetSpecializationLevel + { + specializedForOtherTarget = 0, + notSpecialized, + specializedForTarget, + }; + + TargetSpecializationLevel getTargetSpecialiationLevel( + IRGlobalValue* val, + String const& targetName) + { + TargetSpecializationLevel result = TargetSpecializationLevel::notSpecialized; + for( auto dd = val->firstDecoration; dd; dd = dd->next ) + { + if(dd->op != kIRDecorationOp_Target) + continue; + + auto decoration = (IRTargetDecoration*) dd; + if(decoration->targetName == targetName) + return TargetSpecializationLevel::specializedForTarget; + + result = TargetSpecializationLevel::specializedForOtherTarget; + } + + return result; + } + + // Is `newVal` marked as being a better match for our + // chosen code-generation target? + // + // TODO: there is a missing step here where we need + // to check if things are even available in the first place... + bool isBetterForTarget( + IRSpecContext* context, + IRGlobalValue* newVal, + IRGlobalValue* oldVal) + { + String targetName = getTargetName(context); + + // For right now every declaration might have zero or more + // modifiers, representing the targets for which it is specialized. + // Each modifier has a single string "tag" to represent a target. + // We thus decide that a declaration is "more specialized" by: + // + // - Does it have a modifier with a tag with the string for the current target? + // If yes, it is the most specialized it can be. + // + // - Does it have a no tags? Then it is "unspecialized" and that is okay. + // + // - Does it have a modifier with a tag for a *different* target? + // If yes, then it shouldn't even be usable on this target. + // + // Longer term a better approach is to think of this in terms + // of a "disjunction of conjunctions" that is: + // + // (A and B and C) or (A and D) or (E) or (F and G) ... + // + // A code generation target would then consist of a + // conjunction of invidual tags: + // + // (HLSL and SM_4_0 and Vertex and ...) + // + // A declaration is *applicable* on a target if one of + // its conjunctions of tags is a subset of the target's. + // + // One declaration is *better* than another on a target + // if it is applicable and its tags are a superset + // of the other's. + + auto newLevel = getTargetSpecialiationLevel(newVal, targetName); + auto oldLevel = getTargetSpecialiationLevel(oldVal, targetName); + return UInt(newLevel) > UInt(oldLevel); + } + + IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc) + { + // We are being asked to clone a particular function, but in + // the IR that comes out of the front-end there could still + // be multiple, target-specific, declarations of any given + // function, all of which share the same mangled name. + auto mangledName = originalFunc->mangledName; + + if(mangledName.Length() == 0) + { + return cloneSimpleFunc(context, originalFunc); + } + + // + // We will scan through all of the available function declarations + // with the same mangled name as `originalFunc` and try + // to pick the "best" one for our target. + + RefPtr sym; + if( !context->getSymbols().TryGetValue(originalFunc->mangledName, sym) ) + { + // This shouldn't happen! + SLANG_UNEXPECTED("no matching function registered"); + return cloneSimpleFunc(context, originalFunc); + } + + // We will try to track the "best" definition we can find. + IRFunc* bestFunc = (IRFunc*) sym->irGlobalValue; + + for( auto ss = sym->nextWithSameName; ss; ss = ss->nextWithSameName ) + { + IRFunc* newFunc = (IRFunc*) ss->irGlobalValue; + if(isBetterForTarget(context, newFunc, bestFunc)) + bestFunc = newFunc; + } + + // All right, we are now in a position to clone the "best" + // definition that was found. + auto clonedFunc = context->builder->createFunc(); + + // The resulting function will be used as the cloned version + // of every declaration/definition in the original IR. + for( auto ss = sym; ss; ss = ss->nextWithSameName ) + { + registerClonedValue(context, clonedFunc, ss->irGlobalValue); + } + + // Clone the "best" definition into our context + cloneFunctionCommon(context, clonedFunc, bestFunc); + + return clonedFunc; + } + StructTypeLayout* getGlobalStructLayout( ProgramLayout* programLayout); + void insertGlobalValueSymbol( + IRSharedSpecContext* sharedContext, + IRGlobalValue* gv) + { + String mangledName = gv->mangledName; + + // Don't try to register a symbol for global values + // with no mangled name, since these represent symbols + // that shouldn't get "linkage" + if (mangledName == "") + return; + + RefPtr sym = new IRSpecSymbol(); + sym->irGlobalValue = gv; + + RefPtr prev; + if (sharedContext->symbols.TryGetValue(mangledName, prev)) + { + sym->nextWithSameName = prev->nextWithSameName; + prev->nextWithSameName = sym; + } + else + { + sharedContext->symbols.Add(mangledName, sym); + } + } + + void initializeSharedSpecContext( + IRSharedSpecContext* sharedContext, + Session* session, + IRModule* module, + IRModule* originalModule) + { + + SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage; + sharedBuilder->module = nullptr; + sharedBuilder->session = session; + + IRBuilder* builder = &sharedContext->builderStorage; + builder->shared = sharedBuilder; + + if( !module ) + { + module = builder->createModule(); + sharedBuilder->module = module; + } + + sharedContext->module = module; + sharedContext->originalModule = originalModule; + + // First, we will populate a map with all of the IR values + // that use the same mangled name, to make lookup easier + // in other steps. + for (auto gv = originalModule->firstGlobalValue; gv; gv = gv->nextGlobalValue) + { + insertGlobalValueSymbol(sharedContext, gv); + } + } + IRModule* specializeIRForEntryPoint( EntryPointRequest* entryPointRequest, ProgramLayout* programLayout, @@ -2580,52 +2977,24 @@ namespace Slang // we need to pick the "best" one for the chosen code generation target. // - SharedIRBuilder sharedBuilderStorage; - SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; - sharedBuilder->module = nullptr; - sharedBuilder->session = compileRequest->mSession; - - IRBuilder builderStorage; - IRBuilder* builder = &builderStorage; - builder->shared = sharedBuilder; - - IRModule* module = builder->createModule(); - sharedBuilder->module = module; + IRSharedSpecContext sharedContextStorage; - // + initializeSharedSpecContext( + &sharedContextStorage, + compileRequest->mSession, + nullptr, + originalIRModule); IRSpecContext contextStorage; IRSpecContext* context = &contextStorage; + context->shared = &sharedContextStorage; + context->builder = &sharedContextStorage.builderStorage; + context->target = target; - context->builder = builder; - context->module = module; - context->originalModule = originalIRModule; - - // First, we will populate a map with all of the IR values - // that use the same mangled name, to make lookup easier - // in other steps. - for (auto gv = originalIRModule->firstGlobalValue; gv; gv = gv->nextGlobalValue) - { - String mangledName = gv->mangledName; - if (mangledName == "") - continue; - - RefPtr sym = new IRSpecSymbol(); - sym->irGlobalValue = gv; - RefPtr prev; - if (context->symbols.TryGetValue(mangledName, prev)) - { - sym->nextWithSameName = prev->nextWithSameName; - prev->nextWithSameName = sym; - } - else - { - context->symbols.Add(mangledName, sym); - } - } - - // Next, we want to optimize lookup over + // Next, we want to optimize lookup for layout infromation + // associated with global declarations, so that we can + // look things up based on the IR values (using mangled names) auto globalStructLayout = getGlobalStructLayout(programLayout); for (auto globalVarLayout : globalStructLayout->fields) { @@ -2659,8 +3028,230 @@ namespace Slang break; } - return module; + return sharedContextStorage.module; + } + + // + + struct IRSharedGenericSpecContext : IRSharedSpecContext + { + // Non-generic functions to be processed + List workList; + }; + + struct IRGenericSpecContext : IRSpecContextBase + { + IRSharedGenericSpecContext* getShared() { return (IRSharedGenericSpecContext*) shared; } + + // The substutions to apply + RefPtr subst; + + // Override the "maybe clone" logic so that we always clone + virtual IRValue* maybeCloneValue(IRValue* originalVal) override; + + virtual RefPtr maybeCloneType(Type* originalType) override; + }; + + IRValue* IRGenericSpecContext::maybeCloneValue(IRValue* originalVal) + { + switch( originalVal->op ) + { + case kIROp_decl_ref: + { + auto declRefVal = (IRDeclRef*) originalVal; + int diff = 0; + auto substDeclRef = declRefVal->declRef.SubstituteImpl(subst, &diff); + if(!diff) + return originalVal; + + return builder->getDeclRefVal(substDeclRef); + } + break; + + default: + return originalVal; + } } + RefPtr IRGenericSpecContext::maybeCloneType(Type* originalType) + { + return originalType->Substitute(subst).As(); + } + + + IRFunc* getSpecializedFunc( + IRSharedGenericSpecContext* sharedContext, + IRFunc* genericFunc, + DeclRef specDeclRef) + { + // First, we want to see if an existing specialization + // has already been made. To do that we will need to + // compute the mangled name of the specialized function, + // so that we can look for existing declarations. + + String specMangledName = getMangledName(specDeclRef); + + // TODO: This is a terrible linear search, and we should + // avoid it by building a dictionary ahead of time, + // as is being done for the `IRSpecContext` used above. + // We can probalby use the same basic context, actually. + auto module = genericFunc->parentModule; + for(auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) + { + if(gv->mangledName == specMangledName) + return (IRFunc*) gv; + } + + // If we get to this point, then we need to construct a + // new `IRFunc` to represent the result of specialization. + + // The substitutions we are applying might have been created + // using a different overload of a target-specific function, + // so we need to create a dummy substitution here, to make + // sure it used the correct generic. + RefPtr newSubst = new Substitutions(); + newSubst->genericDecl = genericFunc->genericDecl; + newSubst->args = specDeclRef.substitutions->args; + + IRGenericSpecContext context; + context.shared = sharedContext; + context.builder = &sharedContext->builderStorage; + context.subst = newSubst; + + // TODO: other initialization is needed here... + + auto specFunc = cloneSimpleFunc(&context, genericFunc); + + // Set up the clone to recognize that it is no longer generic + specFunc->mangledName = specMangledName; + specFunc->genericDecl = nullptr; + + // Put the function into the global sequence right after + // the function it specializes. + // + // TODO: This shouldn't be needed, if we introduce a sorting + // step before we emit code. + specFunc->removeFromParent(); + specFunc->insertAfter(genericFunc); + + // At this point we've created a new non-generic function, + // which means we should add it to our work list for + // subsequent processing. + sharedContext->workList.Add(specFunc); + + // We also need to make sure that we register this specialized + // function under its mangled name, so that later lookup + // steps will find it. + insertGlobalValueSymbol(sharedContext, specFunc); + + return specFunc; + } + + void specializeGenerics( + IRModule* module) + { + IRSharedGenericSpecContext sharedContextStorage; + auto sharedContext = &sharedContextStorage; + + initializeSharedSpecContext( + sharedContext, + module->session, + module, + module); + + // Our goal here is to find `specialize` instructions that + // can be replaced with references to a suitably sepcialized + // funciton. As a simplification, we will only consider `specialize` + // calls that are inside of non-generic functions, since we assume + // that these will allow us to fully specialize the referenced + // function. + // + // We start by building up a work list of non-generic functions. + for( auto gv = module->getFirstGlobalValue(); + gv; + gv = gv->getNextValue() ) + { + // Is it a function? If not, skip. + if(gv->op != kIROp_Func) + continue; + auto func = (IRFunc*) gv; + + // Is it generic? If so, skip. + if(func->genericDecl) + continue; + + sharedContext->workList.Add(func); + } + + // Now that we have our work list, we are going to + // process it until it goes empty. Along the way + // we may specialize a function and thus create + // a new non-generic function, and in that case + // we will add the new function to the work list. + auto& workList = sharedContext->workList; + while( auto count = workList.Count() ) + { + // We will process the last entry in the + // work list, which amounts to treating + // it like a stack when we have recursive + // specialization to perform. + auto func = workList[count-1]; + workList.RemoveAt(count-1); + + // We are going to go ahead and walk through + // all the instructions in this function, + // and look for `specialize` operations. + for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() ) + { + // We need to be careful when iterating over the instructions, + // because we might end up removing the "current" instruction, + // so that accessing `ii->next` would crash. + IRInst* nextInst = nullptr; + for( auto ii = bb->getFirstInst(); ii; ii = nextInst ) + { + nextInst = ii->nextInst; + + // We only care about `specialize` instructions. + if(ii->op != kIROp_specialize) + continue; + + IRSpecialize* specInst = (IRSpecialize*) ii; + + // We need to check that the value being specialized is + // a generic function. + auto genericVal = specInst->genericVal.usedValue; + if(genericVal->op != kIROp_Func) + continue; + auto genericFunc = (IRFunc*) genericVal; + if(!genericFunc->genericDecl) + continue; + + // Now we extract the specialized decl-ref that will + // tell us how to specialize things. + auto specDeclRefVal = (IRDeclRef*) specInst->specDeclRefVal.usedValue; + auto specDeclRef = specDeclRefVal->declRef; + + // Okay, we have a candidate for specialization here. + // + // We will first find or construct a specialized version + // of the callee funciton/ + auto specFunc = getSpecializedFunc(sharedContext, genericFunc, specDeclRef); + // + // Then we will replace the use sites for the `specialize` + // instruction with uses of the specialized function. + // + specInst->replaceUsesWith(specFunc); + + specInst->removeAndDeallocate(); + } + } + } + + // Once the work list has gone dry, we should have the invariant + // that there are no `specialize` instructions inside of non-generic + // functions that in turn reference a generic function. + } + + // } diff --git a/source/slang/ir.h b/source/slang/ir.h index ecc77dbc4..2477c987f 100644 --- a/source/slang/ir.h +++ b/source/slang/ir.h @@ -12,6 +12,7 @@ namespace Slang { class Decl; +class GenericDecl; class FuncType; class Layout; class Type; @@ -98,6 +99,7 @@ enum IRDecorationOp : uint16_t kIRDecorationOp_HighLevelDecl, kIRDecorationOp_Layout, kIRDecorationOp_LoopControl, + kIRDecorationOp_Target, }; // A "decoration" that gets applied to an instruction. @@ -197,6 +199,11 @@ struct IRInst : IRValue // Remove this instruction from its parent block, // and then destroy it (it had better have no uses!) void removeAndDeallocate(); + + // Clear out the arguments of this instruction, + // so that we don't appear on the list of uses + // for those values. + void removeArguments(); }; typedef int64_t IRIntegerValue; @@ -321,8 +328,10 @@ struct IRFunc : IRGlobalValue // The type of the IR-level function IRFuncType* getType() { return (IRFuncType*) type.Ptr(); } - // Any generic parameters this function has - List> genericParams; + // If this function is generic, then we store a reference + // to the AST-level generic that defines its parameters + // and their constraints. + RefPtr genericDecl; // Convenience accessors for working with the // function's type. diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index d4551421f..a3d67b670 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -270,7 +270,7 @@ struct SharedIRGenContext { CompileRequest* compileRequest; - Dictionary, LoweredValInfo> declValues; + Dictionary declValues; // Arrays we keep around strictly for memory-management purposes: @@ -294,9 +294,16 @@ struct IRGenContext } }; +// Ensure that a version of the given declaration has been emitted to the IR LoweredValInfo ensureDecl( - IRGenContext* context, - DeclRef const& declRef); + IRGenContext* context, + Decl* decl); + +// Emit code as needed to construct a reference to the given declaration with +// any needed specializations in place. +LoweredValInfo emitDeclRef( + IRGenContext* context, + DeclRef declRef); IRValue* getSimpleVal(IRGenContext* context, LoweredValInfo lowered); @@ -564,7 +571,7 @@ LoweredValInfo emitCallToDeclRef( } // Fallback case is to emit an actual call. - LoweredValInfo funcVal = ensureDecl(context, funcDeclRef); + LoweredValInfo funcVal = emitDeclRef(context, funcDeclRef); return emitCallToVal(context, type, funcVal, argCount, args); } @@ -750,6 +757,7 @@ RefPtr getFuncType( IRType* resultType) { RefPtr funcType = new FuncType(); + funcType->setSession(context->getSession()); funcType->resultType = resultType; for (UInt pp = 0; pp < paramCount; ++pp) { @@ -810,43 +818,8 @@ struct ValLoweringVisitor : ValVisitordeclRef.getDecl()->FindModifier() ) - { - auto builder = getBuilder(); - auto intType = getIntType(context); - // - List irArgs; - for( auto val : intrinsicTypeMod->irOperands ) - { - irArgs.Add(builder->getIntValue(intType, val)); - } - - addGenericArgs(&irArgs, type->declRef); - - auto irType = getBuilder()->getIntrinsicType(IROp(intrinsicTypeMod->irOp), irArgs.Count(), irArgs.Buffer()); - return LoweredTypeInfo(irType); - } - - // Catch-all for user-defined type references - LoweredValInfo loweredDeclRef = ensureDecl(context, type->declRef); - - // TODO: make sure that the value is actually a type... - - switch (loweredDeclRef.flavor) - { - case LoweredValInfo::Flavor::Simple: - return LoweredTypeInfo((IRType*)loweredDeclRef.val); - - default: - SLANG_UNIMPLEMENTED_X("type lowering"); - } -#endif } LoweredTypeInfo visitBasicExpressionType(BasicExpressionType* type) @@ -956,7 +929,7 @@ struct ExprLoweringVisitorBase : ExprVisitor LoweredValInfo visitVarExpr(VarExpr* expr) { - LoweredValInfo info = ensureDecl(context, expr->declRef); + LoweredValInfo info = emitDeclRef(context, expr->declRef); return info; } @@ -1431,7 +1404,7 @@ struct ExprLoweringVisitorBase : ExprVisitor LoweredValInfo visitStaticMemberExpr(StaticMemberExpr* expr) { - return ensureDecl(context, expr->declRef); + return emitDeclRef(context, expr->declRef); } LoweredValInfo visitSelectExpr(SelectExpr* expr) @@ -2028,7 +2001,7 @@ struct DeclLoweringVisitor : DeclVisitor if (accessor->HasModifier()) continue; - ensureDecl(context, makeDeclRef(accessor.Ptr())); + ensureDecl(context, accessor); } // The subscript declaration itself won't correspond @@ -2561,8 +2534,7 @@ struct DeclLoweringVisitor : DeclVisitor irFunc->mangledName = mangledName; } - - LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl) + LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl) { // Collect the parameter lists we will use for our new function. ParameterLists parameterLists; @@ -2610,35 +2582,14 @@ struct DeclLoweringVisitor : DeclVisitor // We first need to walk the generic parameters (if any) // because these will influence the declared type of // the function. - UInt genericParamCounter = 0; - for( auto genericParamDecl : parameterLists.genericParams ) - { - irFunc->genericParams.Add(genericParamDecl); - -#if 0 - UInt genericParamIndex = genericParamCounter++; - if( auto genericTypeParamDecl = dynamic_cast(genericParamDecl) ) - { - // In the logical type for the function, a generic - // type parameter will be represented as a parameter of type `Type` - - IRType* irTypeType = context->irBuilder->getTypeType(); - paramTypes.Add(irTypeType); - - // Anywhere else in the parameter type list where this type parameter - // is referenced, we'll need to substitute in a reference - // to the appropriate generic parameter position. - IRType* irParameterType = context->irBuilder->getGenericParameterType(genericParamIndex); - LoweredValInfo LoweredValInfo = LoweredValInfo::type(irParameterType); - subContext->shared->declValues[makeDeclRef(genericTypeParamDecl)] = LoweredValInfo; - } - else + for(auto pp = decl->ParentDecl; pp; pp = pp->ParentDecl) + { + if(auto genericAncestor = dynamic_cast(pp)) { - // TODO: handle the other cases here. - SLANG_UNEXPECTED("generic parameter kind"); + irFunc->genericDecl = genericAncestor; + break; } -#endif } for( auto paramInfo : parameterLists.params ) @@ -2809,6 +2760,18 @@ struct DeclLoweringVisitor : DeclVisitor getBuilder()->addHighLevelDeclDecoration(irFunc, decl); + // If this declaration was marked as being an intrinsic for a particular + // target, then we should reflect that here. + for( auto targetMod : decl->GetModifiersOfType() ) + { + // `targetMod` indicates that this particular declaration represents + // a specialized definition of the particular function for the given + // target, and we need to reflect that at the IR level. + + auto decoration = getBuilder()->addDecoration(irFunc); + decoration->targetName = targetMod->targetToken.Content; + } + // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list @@ -2817,6 +2780,43 @@ struct DeclLoweringVisitor : DeclVisitor return LoweredValInfo::simple(irFunc); } + + + LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl) + { + // A function declaration may have multiple, target-specific + // overloads, and we need to emit an IR version of each of these. + + // The front end will form a linked list of declaratiosn with + // the same signature, whenever there is any kind of redeclaration. + // We will look to see if that linked list has been formed. + auto primaryDecl = decl->primaryDecl; + + if (!primaryDecl) + { + // If there is no linked list then we are in the ordinary + // case with a single declaration, and no special handling + // is needed. + return lowerFuncDecl(decl); + } + + // Otherwise, we need to walk the linked list of declarations + // and make sure to emit IR code for any targets that need it. + + // TODO: Need to be careful about how this is approached, + // to avoid emitting a bunch of extra definitions in the IR. + + auto primaryFuncDecl = dynamic_cast(primaryDecl); + assert(primaryFuncDecl); + LoweredValInfo result = lowerFuncDecl(primaryFuncDecl); + for (auto dd = primaryDecl->nextDecl; dd; dd = dd->nextDecl) + { + auto funcDecl = dynamic_cast(dd); + assert(funcDecl); + lowerFuncDecl(funcDecl); + } + return result; + } }; LoweredValInfo lowerDecl( @@ -2828,20 +2828,17 @@ LoweredValInfo lowerDecl( return visitor.dispatch(decl); } +// Ensure that a version of the given declaration has been emitted to the IR LoweredValInfo ensureDecl( - IRGenContext* context, - DeclRef const& declRef) + IRGenContext* context, + Decl* decl) { auto shared = context->shared; LoweredValInfo result; - if(shared->declValues.TryGetValue(declRef, result)) + if(shared->declValues.TryGetValue(decl, result)) return result; - // TODO: this is where we need to apply any specializations - // from the declaration reference, so that they can be - // applied correctly to the declaration itself... - IRBuilder subIRBuilder; subIRBuilder.shared = context->irBuilder->shared; @@ -2849,13 +2846,42 @@ LoweredValInfo ensureDecl( subContext.irBuilder = &subIRBuilder; - result = lowerDecl(&subContext, declRef.getDecl()); + result = lowerDecl(&subContext, decl); - shared->declValues[declRef] = result; + shared->declValues[decl] = result; return result; } +LoweredValInfo emitDeclRef( + IRGenContext* context, + DeclRef declRef) +{ + // First we need to construct an IR value representing the + // unspecialized declaration. + LoweredValInfo loweredDecl = ensureDecl(context, declRef.getDecl()); + + // If this declaration reference doesn't involve any specializations, + // then we are done at this point. + if(!declRef.substitutions) + return loweredDecl; + + auto val = getSimpleVal(context, loweredDecl); + + RefPtr type; + if(auto declType = val->getType()) + { + type = declType->Substitute(declRef.substitutions).As(); + } + + // Otherwise, we need to construct a specialization of the + // given declaration. + return LoweredValInfo::simple(context->irBuilder->emitSpecializeInst( + type, + val, + declRef)); +} + static void lowerEntryPointToIR( IRGenContext* context, EntryPointRequest* entryPointRequest) diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp index 3e6ee9917..5708bab64 100644 --- a/source/slang/lower.cpp +++ b/source/slang/lower.cpp @@ -739,6 +739,7 @@ struct LoweringVisitor RefPtr visitFuncType(FuncType* type) { RefPtr loweredType = new FuncType(); + loweredType->setSession(getSession()); loweredType->resultType = lowerType(type->resultType); for (auto paramType : type->paramTypes) { diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp index ce36a97b6..71c0605a9 100644 --- a/source/slang/mangle.cpp +++ b/source/slang/mangle.cpp @@ -44,75 +44,193 @@ namespace Slang context->sb.append(str); } + void emitVal( + ManglingContext* context, + Val* val); + + void emitQualifiedName( + ManglingContext* context, + DeclRef declRef); + + void emitSimpleIntVal( + ManglingContext* context, + Val* val) + { + if( auto constVal = dynamic_cast(val) ) + { + auto val = constVal->value; + if( val >= 0 && val <= 9 ) + { + emit(context, (UInt) val); + return; + } + } + + // Fallback: + emitVal(context, val); + } + void emitType( ManglingContext* context, Type* type) { // TODO: actually implement this bit... + + if( auto basicType = dynamic_cast(type) ) + { + switch( basicType->baseType ) + { + case BaseType::Void: emitRaw(context, "V"); break; + case BaseType::Bool: emitRaw(context, "b"); break; + case BaseType::Int: emitRaw(context, "i"); break; + case BaseType::UInt: emitRaw(context, "u"); break; + case BaseType::UInt64: emitRaw(context, "U"); break; + case BaseType::Half: emitRaw(context, "h"); break; + case BaseType::Float: emitRaw(context, "f"); break; + case BaseType::Double: emitRaw(context, "d"); break; + break; + + default: + SLANG_UNEXPECTED("unimplemented case in mangling"); + break; + } + } + else if( auto vecType = dynamic_cast(type) ) + { + emitRaw(context, "v"); + emitSimpleIntVal(context, vecType->elementCount); + emitType(context, vecType->elementType); + } + else if( auto matType = dynamic_cast(type) ) + { + emitRaw(context, "m"); + emitSimpleIntVal(context, matType->getRowCount()); + emitRaw(context, "x"); + emitSimpleIntVal(context, matType->getColumnCount()); + emitType(context, matType->getElementType()); + } + else if( auto namedType = dynamic_cast(type) ) + { + emitType(context, GetType(namedType->declRef)); + } + else if( auto declRefType = dynamic_cast(type) ) + { + emitQualifiedName(context, declRefType->declRef); + } + else + { + SLANG_UNEXPECTED("unimplemented case in mangling"); + } + } + + void emitVal( + ManglingContext* context, + Val* val) + { + if( auto type = dynamic_cast(val) ) + { + emitType(context, type); + } + else + { + SLANG_UNEXPECTED("unimplemented case in mangling"); + } } void emitQualifiedName( ManglingContext* context, - Decl* decl) + DeclRef declRef) { - auto parentDecl = decl->ParentDecl; - if( parentDecl ) + auto parentDeclRef = declRef.GetParent(); + auto parentGenericDeclRef = parentDeclRef.As(); + if( parentDeclRef ) { - emitQualifiedName(context, parentDecl); + // In certain cases we want to skip emitting the parent + if(parentGenericDeclRef && (parentGenericDeclRef.getDecl()->inner.Ptr() != declRef.getDecl())) + { + } + else if(parentDeclRef.As()) + { + } + else + { + emitQualifiedName(context, parentDeclRef); + } } // A generic declaration is kind of a pseudo-declaration // as far as the user is concerned; so we don't want // to emit its name. - if( auto genericDecl = dynamic_cast(decl) ) + if(auto genericDeclRef = declRef.As()) { return; } - emitName(context, decl->nameAndLoc.name); + emitName(context, declRef.GetName()); - if( auto parentGenericDecl = dynamic_cast(parentDecl)) + // Are we the "inner" declaration beneath a generic decl? + if(parentGenericDeclRef && (parentGenericDeclRef.getDecl()->inner.Ptr() == declRef.getDecl())) { - emitRaw(context, "g"); - UInt genericParameterCount = 0; - for( auto mm : parentGenericDecl->Members ) + // There are two cases here: either we have specializations + // in place for the parent generic declaration, or we don't. + + auto subst = declRef.substitutions; + if( subst && subst->genericDecl == parentGenericDeclRef.getDecl() ) { - if(mm.As()) - { - genericParameterCount++; - } - else if(mm.As()) - { - genericParameterCount++; - } - else if(mm.As()) - { - genericParameterCount++; - } - else + // This is the case where we *do* have substitutions. + emitRaw(context, "G"); + UInt genericArgCount = subst->args.Count(); + emit(context, genericArgCount); + for( auto aa : subst->args ) { + emitVal(context, aa); } } - - emit(context, genericParameterCount); - for( auto mm : parentGenericDecl->Members ) + else { - if(auto genericTypeParamDecl = mm.As()) + // We don't have substitutions, so we will emit + // information about the parameters of the generic here. + emitRaw(context, "g"); + UInt genericParameterCount = 0; + for( auto mm : getMembers(parentGenericDeclRef) ) { - emitRaw(context, "T"); + if(mm.As()) + { + genericParameterCount++; + } + else if(mm.As()) + { + genericParameterCount++; + } + else if(mm.As()) + { + genericParameterCount++; + } + else + { + } } - else if(auto genericValueParamDecl = mm.As()) - { - emitRaw(context, "v"); - emitType(context, genericValueParamDecl->getType()); - } - else if(mm.As()) - { - emitRaw(context, "C"); - // TODO: actually emit info about the constraint - } - else + + emit(context, genericParameterCount); + for( auto mm : getMembers(parentGenericDeclRef) ) { + if(auto genericTypeParamDecl = mm.As()) + { + emitRaw(context, "T"); + } + else if(auto genericValueParamDecl = mm.As()) + { + emitRaw(context, "v"); + emitType(context, GetType(genericValueParamDecl)); + } + else if(mm.As()) + { + emitRaw(context, "C"); + // TODO: actually emit info about the constraint + } + else + { + } } } } @@ -124,23 +242,25 @@ namespace Slang // We'll also go ahead and emit the result type as well, // just for completeness. // - if( auto callableDecl = dynamic_cast(decl) ) + if( auto callableDeclRef = declRef.As()) { emitRaw(context, "p"); - UInt parameterCount = callableDecl->GetParameters().Count(); + + auto parameters = GetParameters(callableDeclRef); + UInt parameterCount = parameters.Count(); emit(context, parameterCount); - for(auto pp : callableDecl->GetParameters()) + for(auto paramDeclRef : parameters) { - emitType(context, pp->getType()); + emitType(context, GetType(paramDeclRef)); } - emitType(context, callableDecl->ReturnType); + emitType(context, GetResultType(callableDeclRef)); } } void mangleName( ManglingContext* context, - Decl* decl) + DeclRef declRef) { // TODO: catch cases where the declaration should // forward to something else? E.g., what if we @@ -150,6 +270,8 @@ namespace Slang // clashes with user-defined symbols: emitRaw(context, "_S"); + auto decl = declRef.getDecl(); + // Next we will add a bit of info to register // the *kind* of declaration we are dealing with. // @@ -174,17 +296,24 @@ namespace Slang } // Now we encode the qualified name of the decl. - emitQualifiedName(context, decl); + emitQualifiedName(context, declRef); } - - - String getMangledName(Decl* decl) + String getMangledName(DeclRef const& declRef) { ManglingContext context; + mangleName(&context, declRef); + return context.sb.ProduceString(); + } - mangleName(&context, decl); + String getMangledName(DeclRefBase const & declRef) + { + return getMangledName( + DeclRef(declRef.decl, declRef.substitutions)); + } - return context.sb.ProduceString(); + String getMangledName(Decl* decl) + { + return getMangledName(makeDeclRef(decl)); } } diff --git a/source/slang/mangle.h b/source/slang/mangle.h index 286e2c2c3..11196f496 100644 --- a/source/slang/mangle.h +++ b/source/slang/mangle.h @@ -4,12 +4,13 @@ // This file implements the name mangling scheme for the Slang language. #include "../core/basic.h" +#include "syntax.h" namespace Slang { - class Decl; - String getMangledName(Decl* decl); + String getMangledName(DeclRef const & declRef); + String getMangledName(DeclRefBase const & declRef); } #endif \ No newline at end of file diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index b863cb707..54a4a79b6 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -628,6 +628,7 @@ void Type::accept(IValVisitor* visitor, void* extra) { SLANG_UNEXPECTED("expected a declaration reference type"); } + declRefType->session = session; declRefType->declRef = declRef; return declRefType; } @@ -800,9 +801,52 @@ void Type::accept(IValVisitor* visitor, void* extra) return false; } + RefPtr FuncType::SubstituteImpl(Substitutions* subst, int* ioDiff) + { + int diff = 0; + + // result type + RefPtr substResultType = resultType->SubstituteImpl(subst, &diff).As(); + + // parameter types + List> substParamTypes; + for( auto pp : paramTypes ) + { + substParamTypes.Add(pp->SubstituteImpl(subst, &diff).As()); + } + + // early exit for no change... + if(!diff) + return this; + + (*ioDiff)++; + RefPtr substType = new FuncType(); + substType->session = session; + substType->resultType = substResultType; + substType->paramTypes = substParamTypes; + return substType; + } + Type* FuncType::CreateCanonicalType() { - return this; + // result type + RefPtr canResultType = resultType->GetCanonicalType(); + + // parameter types + List> canParamTypes; + for( auto pp : paramTypes ) + { + canParamTypes.Add(pp->GetCanonicalType()); + } + + RefPtr canType = new FuncType(); + canType->session = session; + canType->resultType = resultType; + canType->paramTypes = canParamTypes; + + session->canonicalTypes.Add(canType); + + return canType; } int FuncType::GetHashCode() diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h index fc3b651b4..e928efb65 100644 --- a/source/slang/type-defs.h +++ b/source/slang/type-defs.h @@ -461,6 +461,7 @@ RAW( virtual String ToString() override; protected: + virtual RefPtr SubstituteImpl(Substitutions* subst, int* ioDiff) override; virtual bool EqualsImpl(Type * type) override; virtual Type* CreateCanonicalType() override; virtual int GetHashCode() override; diff --git a/tests/ir/loop.slang.expected b/tests/ir/loop.slang.expected index a9122c094..390fd80e0 100644 --- a/tests/ir/loop.slang.expected +++ b/tests/ir/loop.slang.expected @@ -7,7 +7,7 @@ ir_global_var %2 : Ptr>>; ir_global_var %3 : Ptr>>; -ir_func @_S04mainp3 : (uint, uint, uint) -> void +ir_func @_S04mainp3uuuV : (uint, uint, uint) -> void { block %4( param %5 : uint, diff --git a/tools/slang-test/main.cpp b/tools/slang-test/main.cpp index 83d512cc1..5c0c81392 100644 --- a/tools/slang-test/main.cpp +++ b/tools/slang-test/main.cpp @@ -1205,9 +1205,10 @@ TestResult doImageComparison(String const& filePath) continue; } + float relativeDiff = 0.0f; if( expectedVal != 0 ) { - float relativeDiff = fabsf(float(actualVal) - float(expectedVal)) / float(expectedVal); + relativeDiff = fabsf(float(actualVal) - float(expectedVal)) / float(expectedVal); if( relativeDiff < kRelativeDiffCutoff ) { @@ -1220,6 +1221,13 @@ TestResult doImageComparison(String const& filePath) // cases where vertex shader results lead to rendering that is off // by one pixel... + fprintf(stderr, "image compare failure at (%d,%d) channel %d. expected %d got %d (absolute error: %d, relative error: %f)\n", + x, y, n, + expectedVal, + actualVal, + absoluteDiff, + relativeDiff); + // There was a difference we couldn't excuse! return kTestResult_Fail; } -- cgit v1.2.3